2
2
3
3
import asyncio
4
4
import concurrent .futures
5
+ import contextlib
5
6
import contextvars
6
7
import functools
7
8
import inspect
8
9
import threading
9
- import warnings
10
10
import weakref
11
11
from collections import deque
12
12
from concurrent .futures .thread import ThreadPoolExecutor
@@ -523,6 +523,22 @@ def done(task: asyncio.Future[_T]) -> None:
523
523
return result
524
524
525
525
526
+ _pool = concurrent .futures .ThreadPoolExecutor (max_workers = 1 )
527
+
528
+
529
+ @contextlib .asynccontextmanager
530
+ async def async_lock (lock : threading .RLock ) -> AsyncGenerator [None , None ]:
531
+ locked = lock .acquire (False )
532
+ while not locked :
533
+ await asyncio .sleep (0 )
534
+ locked = lock .acquire (False )
535
+ try :
536
+ yield
537
+ finally :
538
+ if locked :
539
+ lock .release ()
540
+
541
+
526
542
class Event :
527
543
"""Thread safe version of an async Event"""
528
544
@@ -547,39 +563,47 @@ def set(self) -> None:
547
563
if not self ._value :
548
564
self ._value = True
549
565
550
- for fut in self ._waiters :
566
+ while self ._waiters :
567
+ fut = self ._waiters .popleft ()
568
+
551
569
if not fut .done ():
552
570
if fut ._loop == asyncio .get_running_loop ():
553
571
if not fut .done ():
554
572
fut .set_result (True )
555
573
else :
556
574
557
- def s (w : asyncio .Future [Any ]) -> None :
558
- if not w .done ():
559
- w .set_result (True )
575
+ def s (w : asyncio .Future [Any ], ev : threading .Event ) -> None :
576
+ try :
577
+ if not w .done ():
578
+ w .set_result (True )
579
+ finally :
580
+ ev .set ()
560
581
561
582
if not fut .done ():
562
- fut ._loop .call_soon_threadsafe (s , fut )
583
+ done = threading .Event ()
584
+
585
+ fut ._loop .call_soon_threadsafe (s , fut , done )
586
+
587
+ if not done .wait (120 ):
588
+ raise RuntimeError ("Callback timeout" )
563
589
564
590
def clear (self ) -> None :
565
591
with self ._lock :
566
592
self ._value = False
567
593
568
594
async def wait (self , timeout : Optional [float ] = None ) -> bool :
569
- if self ._value :
570
- return True
571
-
572
- fut = create_sub_future ()
573
595
with self ._lock :
596
+ if self ._value :
597
+ return True
598
+
599
+ fut = create_sub_future ()
574
600
self ._waiters .append (fut )
601
+
575
602
try :
576
603
await asyncio .wait_for (fut , timeout )
577
604
return True
578
605
except asyncio .TimeoutError :
579
606
return False
580
- finally :
581
- with self ._lock :
582
- self ._waiters .remove (fut )
583
607
584
608
585
609
class Semaphore :
@@ -600,8 +624,8 @@ def __repr__(self) -> str:
600
624
extra = f"{ extra } , waiters:{ len (self ._waiters )} "
601
625
return f"<{ res [1 :- 1 ]} [{ extra } ]>"
602
626
603
- def _wake_up_next (self ) -> None :
604
- with self ._lock :
627
+ async def _wake_up_next (self ) -> None :
628
+ async with async_lock ( self ._lock ) :
605
629
while self ._waiters :
606
630
waiter = self ._waiters .popleft ()
607
631
@@ -612,14 +636,23 @@ def _wake_up_next(self) -> None:
612
636
else :
613
637
if waiter ._loop .is_running ():
614
638
615
- def s (w : asyncio .Future [Any ]) -> None :
616
- if w ._loop .is_running () and not w .done ():
617
- w .set_result (True )
639
+ def s (w : asyncio .Future [Any ], ev : threading .Event ) -> None :
640
+ try :
641
+ if w ._loop .is_running () and not w .done ():
642
+ w .set_result (True )
643
+ finally :
644
+ ev .set ()
618
645
619
646
if not waiter .done ():
620
- waiter ._loop .call_soon_threadsafe (s , waiter )
647
+ done = threading .Event ()
648
+
649
+ waiter ._loop .call_soon_threadsafe (s , waiter , done )
650
+
651
+ if not done .wait (120 ):
652
+ raise RuntimeError ("Callback timeout" )
653
+
621
654
else :
622
- warnings . warn ("Loop is not running." )
655
+ raise RuntimeError ("Loop is not running." )
623
656
624
657
def locked (self ) -> bool :
625
658
with self ._lock :
@@ -628,7 +661,7 @@ def locked(self) -> bool:
628
661
async def acquire (self , timeout : Optional [float ] = None ) -> bool :
629
662
while self ._value <= 0 :
630
663
fut = create_sub_future ()
631
- with self ._lock :
664
+ async with async_lock ( self ._lock ) :
632
665
self ._waiters .append (fut )
633
666
try :
634
667
await asyncio .wait_for (fut , timeout )
@@ -638,26 +671,26 @@ async def acquire(self, timeout: Optional[float] = None) -> bool:
638
671
if not fut .done ():
639
672
fut .cancel ()
640
673
if self ._value > 0 and not fut .cancelled ():
641
- self ._wake_up_next ()
674
+ await self ._wake_up_next ()
642
675
643
676
raise
644
677
645
- with self ._lock :
678
+ async with async_lock ( self ._lock ) :
646
679
self ._value -= 1
647
680
648
681
return True
649
682
650
- def release (self ) -> None :
683
+ async def release (self ) -> None :
651
684
self ._value += 1
652
- self ._wake_up_next ()
685
+ await self ._wake_up_next ()
653
686
654
687
async def __aenter__ (self ) -> None :
655
688
await self .acquire ()
656
689
657
690
async def __aexit__ (
658
691
self , exc_type : Optional [Type [BaseException ]], exc_val : Optional [BaseException ], exc_tb : Optional [TracebackType ]
659
692
) -> None :
660
- self .release ()
693
+ await self .release ()
661
694
662
695
663
696
class BoundedSemaphore (Semaphore ):
@@ -667,10 +700,10 @@ def __init__(self, value: int = 1) -> None:
667
700
self ._bound_value = value
668
701
super ().__init__ (value )
669
702
670
- def release (self ) -> None :
703
+ async def release (self ) -> None :
671
704
if self ._value >= self ._bound_value :
672
705
raise ValueError ("BoundedSemaphore released too many times" )
673
- super ().release ()
706
+ await super ().release ()
674
707
675
708
676
709
class Lock :
@@ -683,8 +716,8 @@ def __repr__(self) -> str:
683
716
async def acquire (self , timeout : Optional [float ] = None ) -> bool :
684
717
return await self ._block .acquire (timeout )
685
718
686
- def release (self ) -> None :
687
- self ._block .release ()
719
+ async def release (self ) -> None :
720
+ await self ._block .release ()
688
721
689
722
@property
690
723
def locked (self ) -> bool :
@@ -696,7 +729,7 @@ async def __aenter__(self) -> None:
696
729
async def __aexit__ (
697
730
self , exc_type : Optional [Type [BaseException ]], exc_val : Optional [BaseException ], exc_tb : Optional [TracebackType ]
698
731
) -> None :
699
- self .release ()
732
+ await self .release ()
700
733
701
734
702
735
class OldLock :
@@ -772,7 +805,7 @@ async def release(self) -> None:
772
805
if self ._locked :
773
806
self ._locked = False
774
807
else :
775
- warnings . warn (f"Lock is not acquired ({ len (self ._waiters ) if self ._waiters else 0 } waiters)." )
808
+ raise RuntimeError (f"Lock is not acquired ({ len (self ._waiters ) if self ._waiters else 0 } waiters)." )
776
809
777
810
await self ._wake_up_next ()
778
811
0 commit comments