91
91
)
92
92
93
93
94
- def draw_instance_predictions (self , predictions ):
94
+ def draw_instance_predictions (self , predictions , jitter_color = False ):
95
95
"""
96
96
Draw instance-level prediction results on an image.
97
97
Args:
98
98
predictions (Instances): the output of an instance detection/segmentation
99
99
model. Following fields will be used to draw:
100
100
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
101
+ jitter_color (bool): whether to jitter colors
101
102
Returns:
102
103
output (VisImage): image object with visualizations.
103
104
"""
@@ -114,9 +115,14 @@ def draw_instance_predictions(self, predictions):
114
115
else :
115
116
masks = None
116
117
if self ._instance_mode == ColorMode .SEGMENTATION and self .metadata .get ("thing_colors" ):
117
- colors = [
118
- self ._jitter ([x / 255 for x in self .metadata .thing_colors [c ]]) for c in classes
119
- ]
118
+ if jitter_color :
119
+ colors = [
120
+ self ._jitter ([x / 255 for x in self .metadata .thing_colors [c ]]) for c in classes
121
+ ]
122
+ else :
123
+ colors = [
124
+ [x / 255 for x in self .metadata .thing_colors [c ]] for c in classes
125
+ ]
120
126
alpha = 0.8
121
127
else :
122
128
colors = None
@@ -177,6 +183,7 @@ def overlay_instances(
177
183
assigned_colors (list[matplotlib.colors]): a list of colors, where each color
178
184
corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
179
185
for full list of formats that the colors are accepted in.
186
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
180
187
Returns:
181
188
output (VisImage): image object with visualizations.
182
189
"""
@@ -224,8 +231,6 @@ def overlay_instances(
224
231
keypoints = keypoints [sorted_idxs ] if keypoints is not None else None
225
232
226
233
for i in range (num_instances ):
227
- if 'grass' in labels [i ]:
228
- continue
229
234
color = assigned_colors [i ]
230
235
if boxes is not None :
231
236
self .draw_box (boxes [i ], edge_color = color )
@@ -286,8 +291,86 @@ def overlay_instances(
286
291
287
292
return self .output
288
293
294
+
295
+ def draw_panoptic_seg (self ,
296
+ panoptic_seg ,
297
+ segments_info ,
298
+ area_threshold = None ,
299
+ alpha = 0.7 ,
300
+ jitter_color = False ):
301
+ """
302
+ Draw panoptic prediction annotations or results.
303
+
304
+ Args:
305
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
306
+ segment.
307
+ segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
308
+ If it is a ``list[dict]``, each dict contains keys "id", "category_id".
309
+ If None, category id of each pixel is computed by
310
+ ``pixel // metadata.label_divisor``.
311
+ area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
312
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
313
+ jitter_color (bool): whether to jitter colors
314
+
315
+ Returns:
316
+ output (VisImage): image object with visualizations.
317
+ """
318
+ pred = _PanopticPrediction (panoptic_seg , segments_info , self .metadata )
319
+
320
+ if self ._instance_mode == ColorMode .IMAGE_BW :
321
+ self .output .reset_image (self ._create_grayscale_image (pred .non_empty_mask ()))
322
+
323
+ # draw mask for all semantic segments first i.e. "stuff"
324
+ for mask , sinfo in pred .semantic_masks ():
325
+ category_idx = sinfo ["category_id" ]
326
+ try :
327
+ mask_color = [x / 255 for x in self .metadata .stuff_colors [category_idx ]]
328
+ except AttributeError :
329
+ mask_color = None
330
+
331
+ text = self .metadata .stuff_classes [category_idx ]
332
+ self .draw_binary_mask (
333
+ mask ,
334
+ color = mask_color ,
335
+ edge_color = (1.0 , 1.0 , 240.0 / 255 ), # off_white
336
+ text = text ,
337
+ alpha = alpha ,
338
+ area_threshold = area_threshold ,
339
+ )
340
+
341
+ # draw mask for all instances second
342
+ all_instances = list (pred .instance_masks ())
343
+ if len (all_instances ) == 0 :
344
+ return self .output
345
+ masks , sinfo = list (zip (* all_instances ))
346
+ category_ids = [x ["category_id" ] for x in sinfo ]
347
+
348
+ try :
349
+ scores = [x ["score" ] for x in sinfo ]
350
+ except KeyError :
351
+ scores = None
352
+ labels = _create_text_labels (
353
+ category_ids , scores , self .metadata .thing_classes , [x .get ("iscrowd" , 0 ) for x in sinfo ]
354
+ )
355
+
356
+ try :
357
+ if jitter_color :
358
+ colors = [
359
+ self ._jitter ([x / 255 for x in self .metadata .thing_colors [c ]]) for c in category_ids
360
+ ]
361
+ else :
362
+ colors = [
363
+ [x / 255 for x in self .metadata .thing_colors [c ]] for c in category_ids
364
+ ]
365
+ except AttributeError :
366
+ colors = None
367
+ self .overlay_instances (masks = masks , labels = labels , assigned_colors = colors , alpha = alpha )
368
+
369
+ return self .output
370
+
289
371
Visualizer .overlay_instances = overlay_instances
290
372
Visualizer .draw_instance_predictions = draw_instance_predictions
373
+ Visualizer .draw_panoptic_seg = draw_panoptic_seg
291
374
292
375
293
376
def get_nouns (caption , with_preposition ):
@@ -538,4 +621,4 @@ def run_inference(hydra_cfg: ExperimentConfig):
538
621
539
622
540
623
if __name__ == "__main__" :
541
- run_inference ()
624
+ run_inference ()
0 commit comments