Skip to content

Commit a786589

Browse files
author
Anton
committed
test: more tests
1 parent a460da4 commit a786589

File tree

3 files changed

+198
-134
lines changed

3 files changed

+198
-134
lines changed

taskiq/receiver/receiver.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ def __init__( # noqa: WPS211
7474
self.queue: PriorityQueue[bytes] = PriorityQueue()
7575

7676
self.sem_sleeping: Optional[asyncio.Semaphore] = None
77-
if max_sleeping_tasks is not None and max_sleeping_tasks <= 0:
78-
raise ValueError("`max_idle_tasks` should be greater then zero or None.")
77+
if max_sleeping_tasks is not None and max_sleeping_tasks < 0:
78+
raise ValueError(
79+
"`max_sleeping_tasks` should be greater than zero or None.",
80+
)
7981
if max_sleeping_tasks is not None and max_sleeping_tasks > 0:
8082
self.sem_sleeping = asyncio.Semaphore(max_sleeping_tasks)
8183

tests/cli/worker/test_receiver.py

+171-132
Original file line numberDiff line numberDiff line change
@@ -293,138 +293,177 @@ async def task_sem() -> int:
293293
assert sem_num == max_async_tasks + 2
294294

295295

296-
@pytest.mark.anyio
297-
async def test_tasks_chain_without_idler() -> None:
298-
""""""
299-
broker = InMemoryQueueBroker()
300-
301-
@broker.task
302-
async def task_add_one(val: int) -> int:
303-
return val + 1
304-
305-
@broker.task
306-
async def task_map(vals: List[int]) -> List[int]:
307-
tasks = [await task_add_one.kiq(val) for val in vals]
308-
resps_tasks = [asyncio.create_task(t.wait_result(timeout=1)) for t in tasks]
309-
resps = await asyncio.gather(*resps_tasks)
310-
311-
return [r.return_value for r in resps]
312-
313-
receiver = get_receiver(broker, max_async_tasks=1)
314-
listen_task = asyncio.create_task(receiver.listen())
315-
316-
task = await task_map.kiq(list(range(0, 10)))
317-
with pytest.raises(TaskiqResultTimeoutError):
318-
await task.wait_result(timeout=1)
319-
320-
await broker.shutdown()
321-
await listen_task
322-
323-
324-
@pytest.mark.anyio
325-
async def test_tasks_chain_with_idler() -> None:
326-
""""""
327-
broker = InMemoryQueueBroker()
328-
329-
@broker.task
330-
async def task_add_one(val: int) -> int:
331-
return val + 1
332-
333-
@broker.task
334-
async def task_map(vals: List[int], ctx: Context = Depends()) -> List[int]:
335-
tasks = [await task_add_one.kiq(val) for val in vals]
336-
await ctx.sleep(0.5)
337-
resps_tasks = [asyncio.create_task(t.wait_result(timeout=1)) for t in tasks]
338-
resps = await asyncio.gather(*resps_tasks)
339-
res = [r.return_value for r in resps]
340-
return res
341-
342-
receiver = get_receiver(broker, max_async_tasks=1, max_sleeping_tasks=1)
343-
listen_task = asyncio.create_task(receiver.listen())
344-
345-
task = await task_map.kiq(list(range(0, 10)))
346-
resp = await task.wait_result(timeout=1)
347-
assert resp.return_value == list(range(1, 11))
348-
349-
await broker.shutdown()
350-
await listen_task
351-
352-
assert receiver.sem_sleeping._value == 1 # type: ignore
353-
assert receiver.sem._value == 1 # type: ignore
354-
355-
356-
@pytest.mark.anyio
357-
async def test_tasks_chain_deep() -> None:
358-
""""""
359-
broker = InMemoryQueueBroker()
360-
361-
@broker.task
362-
async def task_run(depth: int, val: Any, ctx: Context = Depends()) -> Any:
363-
if depth == 0:
364-
return val
365-
366-
t = await task_run.kiq(depth - 1, val)
367-
resp = await wait_for_task(t, interval=0.05, ctx=ctx)
368-
return resp.return_value
369-
370-
async def wait_for_task(
371-
task: AsyncTaskiqTask[Any],
372-
interval: float,
373-
ctx: Context,
374-
) -> TaskiqResult[Any]:
375-
while True:
376-
resp_task = asyncio.create_task(
377-
task.wait_result(interval * 0.4, timeout=interval),
378-
)
379-
await ctx.sleep(interval)
380-
381-
try:
382-
return await resp_task
383-
except TaskiqResultTimeoutError:
384-
continue
385-
386-
receiver = get_receiver(broker, max_async_tasks=1, max_sleeping_tasks=10)
387-
listen_task = asyncio.create_task(receiver.listen())
388-
389-
task = await task_run.kiq(10, "hello world!")
390-
resp = await task.wait_result(timeout=1)
391-
assert resp.return_value == "hello world!"
392-
393-
await broker.shutdown()
394-
await listen_task
395-
396-
assert receiver.sem_sleeping._value == 10 # type: ignore
397-
assert receiver.sem._value == 1 # type: ignore
398-
399-
400-
@pytest.mark.anyio
401-
async def test_tasks_sleep() -> None:
402-
""""""
403-
broker = InMemoryQueueBroker()
404-
405-
@broker.task
406-
async def task_run(ind: int, ctx: Context = Depends()) -> int:
407-
await ctx.sleep(0.1)
408-
return ind
409-
410-
receiver = get_receiver(broker, max_async_tasks=1, max_sleeping_tasks=20)
411-
listen_task = asyncio.create_task(receiver.listen())
412-
413-
with anyio.fail_after(1):
414-
tasks_tasks = [asyncio.create_task(task_run.kiq(ind)) for ind in range(100)]
415-
tasks = await asyncio.gather(*tasks_tasks)
416-
resps_tasks = [
417-
asyncio.create_task(task.wait_result(timeout=1)) for task in tasks
418-
]
419-
resps = await asyncio.gather(*resps_tasks)
420-
value = [resp.return_value for resp in resps]
421-
assert value == list(range(100))
422-
423-
await broker.shutdown()
424-
await listen_task
425-
426-
assert receiver.sem_sleeping._value == 20 # type: ignore
427-
assert receiver.sem._value == 1 # type: ignore
296+
class Test_sleeping_tasks:
297+
@pytest.mark.anyio
298+
async def test_max_sleeping_task_arg_error(self) -> None:
299+
with pytest.raises(ValueError):
300+
get_receiver(max_sleeping_tasks=-1)
301+
302+
@pytest.mark.anyio
303+
async def test_tasks_chain_without_nonblocking_sleep(self) -> None:
304+
""""""
305+
broker = InMemoryQueueBroker()
306+
307+
@broker.task
308+
async def task_add_one(val: int) -> int:
309+
return val + 1
310+
311+
@broker.task
312+
async def task_map(vals: List[int]) -> List[int]:
313+
tasks = [await task_add_one.kiq(val) for val in vals]
314+
resps_tasks = [asyncio.create_task(t.wait_result(timeout=1)) for t in tasks]
315+
resps = await asyncio.gather(*resps_tasks)
316+
317+
return [r.return_value for r in resps]
318+
319+
receiver = get_receiver(broker, max_async_tasks=1)
320+
listen_task = asyncio.create_task(receiver.listen())
321+
322+
task = await task_map.kiq(list(range(0, 10)))
323+
with pytest.raises(TaskiqResultTimeoutError):
324+
await task.wait_result(timeout=1)
325+
326+
await broker.shutdown()
327+
await listen_task
328+
329+
@pytest.mark.anyio
330+
async def test_tasks_chain_with_nonblocking_sleep(self) -> None:
331+
""""""
332+
broker = InMemoryQueueBroker()
333+
334+
@broker.task
335+
async def task_add_one(val: int) -> int:
336+
return val + 1
337+
338+
@broker.task
339+
async def task_map(vals: List[int], ctx: Context = Depends()) -> List[int]:
340+
tasks = [await task_add_one.kiq(val) for val in vals]
341+
await ctx.sleep(0.5)
342+
resps_tasks = [asyncio.create_task(t.wait_result(timeout=1)) for t in tasks]
343+
resps = await asyncio.gather(*resps_tasks)
344+
res = [r.return_value for r in resps]
345+
return res
346+
347+
receiver = get_receiver(broker, max_async_tasks=1, max_sleeping_tasks=1)
348+
listen_task = asyncio.create_task(receiver.listen())
349+
350+
task = await task_map.kiq(list(range(0, 10)))
351+
resp = await task.wait_result(timeout=1)
352+
assert resp.return_value == list(range(1, 11))
353+
354+
await broker.shutdown()
355+
await listen_task
356+
357+
assert receiver.sem_sleeping._value == 1 # type: ignore
358+
assert receiver.sem._value == 1 # type: ignore
359+
360+
@pytest.mark.anyio
361+
async def test_tasks_long_chain(self) -> None:
362+
""""""
363+
broker = InMemoryQueueBroker()
364+
365+
@broker.task
366+
async def task_run(depth: int, val: Any, ctx: Context = Depends()) -> Any:
367+
if depth == 0:
368+
return val
369+
370+
t = await task_run.kiq(depth - 1, val)
371+
resp = await wait_for_task(t, interval=0.05, ctx=ctx)
372+
return resp.return_value
373+
374+
async def wait_for_task(
375+
task: AsyncTaskiqTask[Any],
376+
interval: float,
377+
ctx: Context,
378+
) -> TaskiqResult[Any]:
379+
while True:
380+
resp_task = asyncio.create_task(
381+
task.wait_result(interval * 0.4, timeout=interval),
382+
)
383+
await ctx.sleep(interval)
384+
385+
try:
386+
return await resp_task
387+
except TaskiqResultTimeoutError:
388+
continue
389+
390+
receiver = get_receiver(broker, max_async_tasks=1, max_sleeping_tasks=10)
391+
listen_task = asyncio.create_task(receiver.listen())
392+
393+
task = await task_run.kiq(10, "hello world!")
394+
resp = await task.wait_result(timeout=1)
395+
assert resp.return_value == "hello world!"
396+
397+
await broker.shutdown()
398+
await listen_task
399+
400+
assert receiver.sem_sleeping._value == 10 # type: ignore
401+
assert receiver.sem._value == 1 # type: ignore
402+
403+
@pytest.mark.parametrize(
404+
("max_async_tasks", "max_sleeping_tasks"),
405+
[(1, 20), (None, None), (None, 20), (0, None), (0, 20), (0, 0)],
406+
)
407+
@pytest.mark.anyio
408+
async def test_tasks_sleep(
409+
self,
410+
max_async_tasks: Any,
411+
max_sleeping_tasks: Any,
412+
) -> None:
413+
""""""
414+
broker = InMemoryQueueBroker()
415+
416+
@broker.task
417+
async def task_run(ind: int, ctx: Context = Depends()) -> int:
418+
await ctx.sleep(0.1)
419+
return ind
420+
421+
receiver = get_receiver(
422+
broker,
423+
max_async_tasks=max_async_tasks,
424+
max_sleeping_tasks=max_sleeping_tasks,
425+
)
426+
listen_task = asyncio.create_task(receiver.listen())
427+
428+
with anyio.fail_after(1):
429+
tasks_tasks = [asyncio.create_task(task_run.kiq(ind)) for ind in range(100)]
430+
tasks = await asyncio.gather(*tasks_tasks)
431+
resps_tasks = [
432+
asyncio.create_task(task.wait_result(timeout=1)) for task in tasks
433+
]
434+
resps = await asyncio.gather(*resps_tasks)
435+
value = [resp.return_value for resp in resps]
436+
assert value == list(range(100))
437+
438+
await broker.shutdown()
439+
await listen_task
440+
441+
if max_sleeping_tasks is not None and max_sleeping_tasks > 0:
442+
assert receiver.sem_sleeping._value == max_sleeping_tasks # type: ignore
443+
444+
if max_async_tasks is not None and max_async_tasks > 0:
445+
assert receiver.sem._value == 1 # type: ignore
446+
447+
@pytest.mark.anyio
448+
async def test_max_sleeping_task_arg_none(self) -> None:
449+
""""""
450+
broker = InMemoryQueueBroker()
451+
452+
@broker.task
453+
async def task_run(ind: int, ctx: Context = Depends()) -> int:
454+
await ctx.sleep(0.1)
455+
return ind
456+
457+
receiver = get_receiver(broker, max_async_tasks=1, max_sleeping_tasks=None)
458+
listen_task = asyncio.create_task(receiver.listen()) # type: ignore
459+
460+
with pytest.raises(TaskiqResultTimeoutError):
461+
tasks_tasks = [asyncio.create_task(task_run.kiq(ind)) for ind in range(100)]
462+
tasks = await asyncio.gather(*tasks_tasks)
463+
resps_tasks = [
464+
asyncio.create_task(task.wait_result(timeout=1)) for task in tasks
465+
]
466+
await asyncio.gather(*resps_tasks)
428467

429468

430469
@pytest.mark.anyio

tests/test_semaphore.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import time
23

34
import anyio
45
import pytest
@@ -36,3 +37,25 @@ async def c3() -> None:
3637

3738
with anyio.fail_after(1):
3839
await asyncio.gather(t1, t2, t3, return_exceptions=True)
40+
41+
42+
@pytest.mark.anyio
43+
async def test_semaphore_with() -> None:
44+
sem = PrioritySemaphore(1)
45+
46+
async def task() -> float:
47+
t = time.time()
48+
async with sem:
49+
await asyncio.sleep(0.1)
50+
return time.time() - t
51+
52+
tasks = [task() for _ in range(10)]
53+
with anyio.fail_after(2):
54+
times = await asyncio.gather(*tasks)
55+
56+
times = list(sorted(times))
57+
assert len(times) == 10
58+
assert 0.1 <= min(times) < 0.2
59+
assert 1 < max(times)
60+
for prev, next in zip(times, times[1:]):
61+
assert next - prev >= 0.1

0 commit comments

Comments
 (0)