Skip to content

Commit 6a5d00f

Browse files
committed
implement threaded rpc_methods
1 parent ad3b6a7 commit 6a5d00f

File tree

3 files changed

+95
-54
lines changed

3 files changed

+95
-54
lines changed

robotcode/jsonrpc2/protocol.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from typing import (
1313
Any,
1414
Callable,
15-
Coroutine,
1615
Dict,
1716
Generic,
1817
Iterator,
@@ -477,23 +476,27 @@ def _handle_body(self, body: bytes, charset: str) -> None:
477476

478477
def _handle_messages(self, iterator: Iterator[JsonRPCMessage]) -> None:
479478
def done(f: asyncio.Future[Any]) -> None:
480-
ex = f.exception()
481-
if ex is not None and not isinstance(ex, asyncio.CancelledError):
482-
self._logger.exception(ex, exc_info=ex)
479+
if f.done() and not f.cancelled():
480+
ex = f.exception()
481+
482+
if ex is None or isinstance(ex, asyncio.CancelledError):
483+
return
484+
485+
# self._logger.exception(ex, exc_info=ex)
483486

484487
for m in iterator:
485488
create_sub_task(self.handle_message(m)).add_done_callback(done)
486489

487490
@_logger.call
488491
async def handle_message(self, message: JsonRPCMessage) -> None:
489492
if isinstance(message, JsonRPCRequest):
490-
self.handle_request(message)
493+
await self.handle_request(message)
491494
elif isinstance(message, JsonRPCNotification):
492-
self.handle_notification(message)
495+
await self.handle_notification(message)
493496
elif isinstance(message, JsonRPCError):
494-
self.handle_error(message)
497+
await self.handle_error(message)
495498
elif isinstance(message, JsonRPCResponse):
496-
self.handle_response(message)
499+
await self.handle_response(message)
497500

498501
@_logger.call
499502
def send_response(self, id: Optional[Union[str, int, None]], result: Optional[Any] = None) -> None:
@@ -570,7 +573,7 @@ def send_notification(self, method: str, params: Any) -> None:
570573
self.send_message(JsonRPCNotification(method=method, params=params))
571574

572575
@_logger.call(exception=True)
573-
def handle_response(self, message: JsonRPCResponse) -> None:
576+
async def handle_response(self, message: JsonRPCResponse) -> None:
574577
if message.id is None:
575578
error = "Invalid response. Response id is null."
576579
self._logger.warning(error)
@@ -614,7 +617,7 @@ def s(f: asyncio.Future[Any], r: Any) -> None:
614617
entry.future._loop.call_soon_threadsafe(entry.future.set_exception, e)
615618

616619
@_logger.call
617-
def handle_error(self, message: JsonRPCError) -> None:
620+
async def handle_error(self, message: JsonRPCError) -> None:
618621
raise JsonRPCErrorException(message.error.code, message.error.message, message.error.data)
619622

620623
@staticmethod
@@ -674,7 +677,7 @@ def _convert_params(
674677
return args, kw_args
675678

676679
@_logger.call
677-
def handle_request(self, message: JsonRPCRequest) -> Optional[asyncio.Task[_T]]:
680+
async def handle_request(self, message: JsonRPCRequest) -> None:
678681
e = self.registry.get_entry(message.method)
679682

680683
if e is None or not callable(e.method):
@@ -716,7 +719,7 @@ def done(t: asyncio.Task[Any]) -> None:
716719

717720
task.add_done_callback(done)
718721

719-
return task
722+
await task
720723

721724
@_logger.call
722725
def cancel_request(self, id: Union[int, str, None]) -> None:
@@ -734,18 +737,24 @@ async def cancel_all_received_request(self) -> None:
734737
entry.future.cancel()
735738

736739
@_logger.call
737-
def handle_notification(self, message: JsonRPCNotification) -> None:
740+
async def handle_notification(self, message: JsonRPCNotification) -> None:
738741
e = self.registry.get_entry(message.method)
739742

740743
if e is None or not callable(e.method):
741744
self._logger.warning(f"Unknown method: {message.method}")
742745
return
743746
try:
744747
params = self._convert_params(e.method, e.param_type, message.params)
745-
result = e.method(*params[0], **params[1])
746-
if inspect.isawaitable(result):
747-
create_sub_task(cast(Coroutine[Any, Any, Any], result))
748748

749+
if isinstance(e.method, HasThreaded) and cast(HasThreaded, e.method).__threaded__:
750+
task = run_coroutine_in_thread(ensure_coroutine(e.method), *params[0], **params[1])
751+
else:
752+
task = create_sub_task(
753+
ensure_coroutine(e.method)(*params[0], **params[1]),
754+
name=message.method,
755+
)
756+
757+
await task
749758
except asyncio.CancelledError:
750759
pass
751760
except (SystemExit, KeyboardInterrupt):

robotcode/language_server/common/parts/diagnostics.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
async_tasking_event,
1515
async_tasking_event_iterator,
1616
create_sub_task,
17+
threaded,
1718
)
1819
from ....utils.logging import LoggingDescriptor
1920
from ....utils.uri import Uri
@@ -101,7 +102,7 @@ def extend_capabilities(self, capabilities: ServerCapabilities) -> None:
101102
):
102103
capabilities.diagnostic_provider = DiagnosticOptions(
103104
inter_file_dependencies=True,
104-
workspace_diagnostics=True,
105+
workspace_diagnostics=False,
105106
identifier=f"robotcodelsp_{uuid.uuid4()}",
106107
work_done_progress=True,
107108
)
@@ -190,11 +191,7 @@ async def __get_full_document_diagnostics(self, document: TextDocument) -> Relat
190191
diagnostics: List[Diagnostic] = []
191192

192193
async for result_any in self.collect(
193-
self,
194-
document,
195-
full=True,
196-
callback_filter=language_id_filter(document),
197-
return_exceptions=True,
194+
self, document, full=True, callback_filter=language_id_filter(document), return_exceptions=True
198195
):
199196
result = cast(DiagnosticsResult, result_any)
200197

@@ -207,6 +204,7 @@ async def __get_full_document_diagnostics(self, document: TextDocument) -> Relat
207204
return RelatedFullDocumentDiagnosticReport(items=diagnostics, result_id=str(uuid.uuid4()))
208205

209206
@rpc_method(name="textDocument/diagnostic", param_type=DocumentDiagnosticParams)
207+
@threaded()
210208
async def _text_document_diagnostic(
211209
self,
212210
text_document: TextDocumentIdentifier,
@@ -265,6 +263,7 @@ async def _get_diagnostics() -> Optional[RelatedFullDocumentDiagnosticReport]:
265263
self._logger.debug(lambda: f"textDocument/diagnostic ready {text_document}")
266264

267265
@rpc_method(name="workspace/diagnostic", param_type=WorkspaceDiagnosticParams)
266+
@threaded()
268267
async def _workspace_diagnostic(
269268
self,
270269
identifier: Optional[str],

robotcode/utils/async_tools.py

+65-32
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import asyncio
44
import concurrent.futures
5+
import contextlib
56
import contextvars
67
import functools
78
import inspect
89
import threading
9-
import warnings
1010
import weakref
1111
from collections import deque
1212
from concurrent.futures.thread import ThreadPoolExecutor
@@ -523,6 +523,22 @@ def done(task: asyncio.Future[_T]) -> None:
523523
return result
524524

525525

526+
_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
527+
528+
529+
@contextlib.asynccontextmanager
530+
async def async_lock(lock: threading.RLock) -> AsyncGenerator[None, None]:
531+
locked = lock.acquire(False)
532+
while not locked:
533+
await asyncio.sleep(0)
534+
locked = lock.acquire(False)
535+
try:
536+
yield
537+
finally:
538+
if locked:
539+
lock.release()
540+
541+
526542
class Event:
527543
"""Thread safe version of an async Event"""
528544

@@ -547,39 +563,47 @@ def set(self) -> None:
547563
if not self._value:
548564
self._value = True
549565

550-
for fut in self._waiters:
566+
while self._waiters:
567+
fut = self._waiters.popleft()
568+
551569
if not fut.done():
552570
if fut._loop == asyncio.get_running_loop():
553571
if not fut.done():
554572
fut.set_result(True)
555573
else:
556574

557-
def s(w: asyncio.Future[Any]) -> None:
558-
if not w.done():
559-
w.set_result(True)
575+
def s(w: asyncio.Future[Any], ev: threading.Event) -> None:
576+
try:
577+
if not w.done():
578+
w.set_result(True)
579+
finally:
580+
ev.set()
560581

561582
if not fut.done():
562-
fut._loop.call_soon_threadsafe(s, fut)
583+
done = threading.Event()
584+
585+
fut._loop.call_soon_threadsafe(s, fut, done)
586+
587+
if not done.wait(120):
588+
raise RuntimeError("Callback timeout")
563589

564590
def clear(self) -> None:
565591
with self._lock:
566592
self._value = False
567593

568594
async def wait(self, timeout: Optional[float] = None) -> bool:
569-
if self._value:
570-
return True
571-
572-
fut = create_sub_future()
573595
with self._lock:
596+
if self._value:
597+
return True
598+
599+
fut = create_sub_future()
574600
self._waiters.append(fut)
601+
575602
try:
576603
await asyncio.wait_for(fut, timeout)
577604
return True
578605
except asyncio.TimeoutError:
579606
return False
580-
finally:
581-
with self._lock:
582-
self._waiters.remove(fut)
583607

584608

585609
class Semaphore:
@@ -600,8 +624,8 @@ def __repr__(self) -> str:
600624
extra = f"{extra}, waiters:{len(self._waiters)}"
601625
return f"<{res[1:-1]} [{extra}]>"
602626

603-
def _wake_up_next(self) -> None:
604-
with self._lock:
627+
async def _wake_up_next(self) -> None:
628+
async with async_lock(self._lock):
605629
while self._waiters:
606630
waiter = self._waiters.popleft()
607631

@@ -612,14 +636,23 @@ def _wake_up_next(self) -> None:
612636
else:
613637
if waiter._loop.is_running():
614638

615-
def s(w: asyncio.Future[Any]) -> None:
616-
if w._loop.is_running() and not w.done():
617-
w.set_result(True)
639+
def s(w: asyncio.Future[Any], ev: threading.Event) -> None:
640+
try:
641+
if w._loop.is_running() and not w.done():
642+
w.set_result(True)
643+
finally:
644+
ev.set()
618645

619646
if not waiter.done():
620-
waiter._loop.call_soon_threadsafe(s, waiter)
647+
done = threading.Event()
648+
649+
waiter._loop.call_soon_threadsafe(s, waiter, done)
650+
651+
if not done.wait(120):
652+
raise RuntimeError("Callback timeout")
653+
621654
else:
622-
warnings.warn("Loop is not running.")
655+
raise RuntimeError("Loop is not running.")
623656

624657
def locked(self) -> bool:
625658
with self._lock:
@@ -628,7 +661,7 @@ def locked(self) -> bool:
628661
async def acquire(self, timeout: Optional[float] = None) -> bool:
629662
while self._value <= 0:
630663
fut = create_sub_future()
631-
with self._lock:
664+
async with async_lock(self._lock):
632665
self._waiters.append(fut)
633666
try:
634667
await asyncio.wait_for(fut, timeout)
@@ -638,26 +671,26 @@ async def acquire(self, timeout: Optional[float] = None) -> bool:
638671
if not fut.done():
639672
fut.cancel()
640673
if self._value > 0 and not fut.cancelled():
641-
self._wake_up_next()
674+
await self._wake_up_next()
642675

643676
raise
644677

645-
with self._lock:
678+
async with async_lock(self._lock):
646679
self._value -= 1
647680

648681
return True
649682

650-
def release(self) -> None:
683+
async def release(self) -> None:
651684
self._value += 1
652-
self._wake_up_next()
685+
await self._wake_up_next()
653686

654687
async def __aenter__(self) -> None:
655688
await self.acquire()
656689

657690
async def __aexit__(
658691
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
659692
) -> None:
660-
self.release()
693+
await self.release()
661694

662695

663696
class BoundedSemaphore(Semaphore):
@@ -667,10 +700,10 @@ def __init__(self, value: int = 1) -> None:
667700
self._bound_value = value
668701
super().__init__(value)
669702

670-
def release(self) -> None:
703+
async def release(self) -> None:
671704
if self._value >= self._bound_value:
672705
raise ValueError("BoundedSemaphore released too many times")
673-
super().release()
706+
await super().release()
674707

675708

676709
class Lock:
@@ -683,8 +716,8 @@ def __repr__(self) -> str:
683716
async def acquire(self, timeout: Optional[float] = None) -> bool:
684717
return await self._block.acquire(timeout)
685718

686-
def release(self) -> None:
687-
self._block.release()
719+
async def release(self) -> None:
720+
await self._block.release()
688721

689722
@property
690723
def locked(self) -> bool:
@@ -696,7 +729,7 @@ async def __aenter__(self) -> None:
696729
async def __aexit__(
697730
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
698731
) -> None:
699-
self.release()
732+
await self.release()
700733

701734

702735
class OldLock:
@@ -772,7 +805,7 @@ async def release(self) -> None:
772805
if self._locked:
773806
self._locked = False
774807
else:
775-
warnings.warn(f"Lock is not acquired ({len(self._waiters) if self._waiters else 0} waiters).")
808+
raise RuntimeError(f"Lock is not acquired ({len(self._waiters) if self._waiters else 0} waiters).")
776809

777810
await self._wake_up_next()
778811

0 commit comments

Comments
 (0)