BoundedTaskGroup to control parallelism

Whenever using asyncio.gather or a TaskGroup I find myself wrapping everything within a Semaphore. I’d rather just pass in my desired level of parallelism to the TaskGroup, or use a variant that wraps tasks on my behalf.

If people think this would make sense as a patch on the TaskGroup, or another subclass included in the stdlib I’ll have a go at making a PR. Here’s a simple example of what I mean:

import asyncio

async def test(i):
    print(f'running {i}')
    await asyncio.sleep(1)

class BoundedTaskGroup(asyncio.TaskGroup):
    def __init__(self, *args, max_parallelism = 0, **kwargs):
        super().__init__(*args)
        if max_parallelism:
            self._sem = asyncio.Semaphore(max_parallelism)
        else:
            self._sem = None
    
    def create_task(self, coro, *args, **kwargs):
        if self._sem:
            async def _wrapped_coro(sem, coro):
                async with sem:
                    return await coro
            coro = _wrapped_coro(self._sem, coro)

        return super().create_task(coro, *args, **kwargs)

async def main():
    async with BoundedTaskGroup(max_parallelism=2) as g:
        for i in range(10):
            g.create_task(test(i))

asyncio.run(main())
5 Likes

Thank you for sharing! Your code saved my day :smiley:

1 Like

Hello Alexander,
First of all, thanks you for sharing.
I found something weird when all tasks cancel, it raised a RuntimeWarning: coroutine ‘create..’ was never awaited. It just happen only once at the first run. Second run is working well.

url = ['site1', ...]
async with BoundedTaskGroup(max_parallelism=5) as group:
    for link in list(set(url)):
        group.create_task(data(link), name='jeboo')
async def cancel_all_tasks():
    all_tasks = asyncio.all_tasks()
    all_tasks = [t for t in all_tasks if t.get_name() == 'jeboo']
    for task in all_tasks:
        task.cancel()
    await asyncio.gather(*all_tasks, return_exceptions=True)
async def data(link):
    await asyncio.sleep(5)
    x = 100*1000000
    print(x)
    await cancel_all_tasks()
1 Like

An interesting edge case - you are calling cancel() on a task before it has started, and because it hasn’t yet awaited the coroutine it was passed python warns about that. Here’s an even simpler repro:

import asyncio

async def wraptest(coro):
    await coro

async def main():
    t = asyncio.create_task(wraptest(asyncio.sleep(1)))
    t.cancel()
        
asyncio.run(main())

It’s not immediately obvious to me how I can tell python that this was expected and that I wanted to intentionally not await the coroutine in this case… I’ll do some more digging

1 Like

@alexmac

Thanks you for suggestion.