@@ -162,8 +162,11 @@ class StackTraceResult(NamedTuple):
162
162
163
163
164
164
class InvalidThreadIdError (Exception ):
165
- def __init__ (self , thread_id : Any ) -> None :
166
- super ().__init__ (f"Invalid thread id { thread_id } " )
165
+ def __init__ (self , current_thread_id : Any , expected_thread_id : Any = None ) -> None :
166
+ super ().__init__ (
167
+ f"Invalid thread id { current_thread_id } "
168
+ + (f", expected { expected_thread_id } " if expected_thread_id is not None else "" )
169
+ )
167
170
168
171
169
172
class MarkerObject :
@@ -429,44 +432,45 @@ def stop(self) -> None:
429
432
with self .condition :
430
433
self .state = State .Stopped
431
434
432
- if self .main_thread is not None and self . main_thread . ident :
435
+ if self .main_thread_is_alive :
433
436
self .send_event (
434
437
self ,
435
438
ContinuedEvent (
436
439
body = ContinuedEventBody (
437
- thread_id = self .main_thread . ident ,
440
+ thread_id = self .main_thread_id ,
438
441
all_threads_continued = True ,
439
442
)
440
443
),
441
444
)
442
445
443
446
self .condition .notify_all ()
444
447
448
+ def check_thread_id (self , thread_id : int ) -> None :
449
+ if not self .main_thread_is_alive and thread_id != self .main_thread_id :
450
+ raise InvalidThreadIdError (thread_id , self .main_thread_id )
451
+
445
452
def continue_all (self ) -> None :
446
- if self .main_thread is not None and self . main_thread . ident is not None :
447
- self .continue_thread (self .main_thread . ident )
453
+ if self .main_thread_is_alive :
454
+ self .continue_thread (self .main_thread_id )
448
455
449
456
def continue_thread (self , thread_id : int ) -> None :
450
- if self .main_thread is None or thread_id != self .main_thread .ident :
451
- raise InvalidThreadIdError (thread_id )
457
+ self .check_thread_id (thread_id )
452
458
453
459
with self .condition :
454
460
self .requested_state = RequestedState .Running
455
461
self .condition .notify_all ()
456
462
457
463
def pause_thread (self , thread_id : int ) -> None :
458
- # thread_id 0 means all threads
459
- if self .main_thread is None or (thread_id != 0 and thread_id != self .main_thread .ident ):
460
- raise InvalidThreadIdError (thread_id )
464
+ if thread_id != 0 :
465
+ self .check_thread_id (thread_id )
461
466
462
467
with self .condition :
463
468
self .requested_state = RequestedState .Pause
464
469
465
470
self .condition .notify_all ()
466
471
467
472
def next (self , thread_id : int , granularity : Optional [SteppingGranularity ] = None ) -> None :
468
- if self .main_thread is None or thread_id != self .main_thread .ident :
469
- raise InvalidThreadIdError (thread_id )
473
+ self .check_thread_id (thread_id )
470
474
471
475
with self .condition :
472
476
if self .full_stack_frames and self .full_stack_frames [0 ].type in [
@@ -500,17 +504,15 @@ def step_in(
500
504
target_id : Optional [int ] = None ,
501
505
granularity : Optional [SteppingGranularity ] = None ,
502
506
) -> None :
503
- if self .main_thread is None or thread_id != self .main_thread .ident :
504
- raise InvalidThreadIdError (thread_id )
507
+ self .check_thread_id (thread_id )
505
508
506
509
with self .condition :
507
510
self .requested_state = RequestedState .StepIn
508
511
509
512
self .condition .notify_all ()
510
513
511
514
def step_out (self , thread_id : int , granularity : Optional [SteppingGranularity ] = None ) -> None :
512
- if self .main_thread is None or thread_id != self .main_thread .ident :
513
- raise InvalidThreadIdError (thread_id )
515
+ self .check_thread_id (thread_id )
514
516
515
517
with self .condition :
516
518
self .requested_state = RequestedState .StepOut
@@ -589,7 +591,7 @@ def process_start_state(self, source: str, line_no: int, type: str, status: str)
589
591
body = StoppedEventBody (
590
592
description = "Paused" ,
591
593
reason = StoppedReason .PAUSE ,
592
- thread_id = threading . current_thread (). ident ,
594
+ thread_id = self . main_thread_id ,
593
595
)
594
596
),
595
597
)
@@ -605,7 +607,7 @@ def process_start_state(self, source: str, line_no: int, type: str, status: str)
605
607
body = StoppedEventBody (
606
608
description = "Next step" ,
607
609
reason = StoppedReason .STEP ,
608
- thread_id = threading . current_thread (). ident ,
610
+ thread_id = self . main_thread_id ,
609
611
)
610
612
),
611
613
)
@@ -620,7 +622,7 @@ def process_start_state(self, source: str, line_no: int, type: str, status: str)
620
622
body = StoppedEventBody (
621
623
description = "Step in" ,
622
624
reason = StoppedReason .STEP ,
623
- thread_id = threading . current_thread (). ident ,
625
+ thread_id = self . main_thread_id ,
624
626
)
625
627
),
626
628
)
@@ -635,7 +637,7 @@ def process_start_state(self, source: str, line_no: int, type: str, status: str)
635
637
body = StoppedEventBody (
636
638
description = "Step out" ,
637
639
reason = StoppedReason .STEP ,
638
- thread_id = threading . current_thread (). ident ,
640
+ thread_id = self . main_thread_id ,
639
641
)
640
642
),
641
643
)
@@ -708,7 +710,7 @@ def process_start_state(self, source: str, line_no: int, type: str, status: str)
708
710
body = StoppedEventBody (
709
711
description = "Breakpoint hit" ,
710
712
reason = StoppedReason .BREAKPOINT ,
711
- thread_id = threading . current_thread (). ident ,
713
+ thread_id = self . main_thread_id ,
712
714
hit_breakpoint_ids = [breakpoint_id_manager .get_id (v ) for v in breakpoints ],
713
715
)
714
716
),
@@ -747,7 +749,7 @@ def process_end_state(
747
749
StoppedEvent (
748
750
body = StoppedEventBody (
749
751
reason = reason ,
750
- thread_id = threading . current_thread (). ident ,
752
+ thread_id = self . main_thread_id ,
751
753
description = description ,
752
754
text = text ,
753
755
)
@@ -782,12 +784,12 @@ def wait_for_running(self) -> None:
782
784
if self .requested_state == RequestedState .Running :
783
785
self .requested_state = RequestedState .Nothing
784
786
self .state = State .Running
785
- if self .main_thread is not None and self . main_thread . ident is not None :
787
+ if self .main_thread_is_alive :
786
788
self .send_event (
787
789
self ,
788
790
ContinuedEvent (
789
791
body = ContinuedEventBody (
790
- thread_id = self .main_thread . ident ,
792
+ thread_id = self .main_thread_id ,
791
793
all_threads_continued = True ,
792
794
)
793
795
),
@@ -947,7 +949,7 @@ def start_suite(self, name: str, attributes: Dict[str, Any]) -> None:
947
949
StoppedEvent (
948
950
body = StoppedEventBody (
949
951
reason = StoppedReason .ENTRY ,
950
- thread_id = threading . current_thread (). ident ,
952
+ thread_id = self . main_thread_id ,
951
953
)
952
954
),
953
955
)
@@ -1221,15 +1223,25 @@ def end_keyword(self, name: str, attributes: Dict[str, Any]) -> None:
1221
1223
def set_main_thread (self , thread : threading .Thread ) -> None :
1222
1224
self .main_thread = thread
1223
1225
1224
- def get_threads (self ) -> List [Thread ]:
1225
- main_thread = self .main_thread or threading .main_thread ()
1226
+ @property
1227
+ def main_thread_id (self ) -> int :
1228
+ return 1 if self .main_thread_is_alive else 0
1226
1229
1227
- return [
1228
- Thread (
1229
- id = main_thread .ident if main_thread .ident else 0 ,
1230
- name = main_thread .name or "" ,
1231
- )
1232
- ]
1230
+ @property
1231
+ def main_thread_is_alive (self ) -> bool :
1232
+ return self .main_thread is not None and self .main_thread .is_alive ()
1233
+
1234
+ def get_threads (self ) -> List [Thread ]:
1235
+ return (
1236
+ [
1237
+ Thread (
1238
+ id = self .main_thread_id ,
1239
+ name = "RobotMain" ,
1240
+ )
1241
+ ]
1242
+ if self .main_thread_is_alive
1243
+ else []
1244
+ )
1233
1245
1234
1246
WINDOW_PATH_REGEX : ClassVar = re .compile (r"^(([a-z]:[\\/])|(\\\\)).*$" , re .RegexFlag .IGNORECASE )
1235
1247
@@ -1281,8 +1293,7 @@ def get_stack_trace(
1281
1293
levels : Optional [int ] = None ,
1282
1294
format : Optional [StackFrameFormat ] = None ,
1283
1295
) -> StackTraceResult :
1284
- if self .main_thread is None or thread_id != self .main_thread .ident :
1285
- raise InvalidThreadIdError (thread_id )
1296
+ self .check_thread_id (thread_id )
1286
1297
1287
1298
start_frame = start_frame or 0
1288
1299
levels = start_frame + (levels or len (self .stack_frames ))
0 commit comments