Skip to content

Commit 6f802c9

Browse files
TAO 5.2 Release - ODISE Visualization Updates
1 parent 99e0a38 commit 6f802c9

File tree

1 file changed

+90
-7
lines changed

1 file changed

+90
-7
lines changed

nvidia_tao_pytorch/cv/odise/scripts/inference.py

+90-7
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,14 @@
9191
)
9292

9393

94-
def draw_instance_predictions(self, predictions):
94+
def draw_instance_predictions(self, predictions, jitter_color=False):
9595
"""
9696
Draw instance-level prediction results on an image.
9797
Args:
9898
predictions (Instances): the output of an instance detection/segmentation
9999
model. Following fields will be used to draw:
100100
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
101+
jitter_color (bool): whether to jitter colors
101102
Returns:
102103
output (VisImage): image object with visualizations.
103104
"""
@@ -114,9 +115,14 @@ def draw_instance_predictions(self, predictions):
114115
else:
115116
masks = None
116117
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
117-
colors = [
118-
self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
119-
]
118+
if jitter_color:
119+
colors = [
120+
self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
121+
]
122+
else:
123+
colors = [
124+
[x / 255 for x in self.metadata.thing_colors[c]] for c in classes
125+
]
120126
alpha = 0.8
121127
else:
122128
colors = None
@@ -177,6 +183,7 @@ def overlay_instances(
177183
assigned_colors (list[matplotlib.colors]): a list of colors, where each color
178184
corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
179185
for full list of formats that the colors are accepted in.
186+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
180187
Returns:
181188
output (VisImage): image object with visualizations.
182189
"""
@@ -224,8 +231,6 @@ def overlay_instances(
224231
keypoints = keypoints[sorted_idxs] if keypoints is not None else None
225232

226233
for i in range(num_instances):
227-
if 'grass' in labels[i]:
228-
continue
229234
color = assigned_colors[i]
230235
if boxes is not None:
231236
self.draw_box(boxes[i], edge_color=color)
@@ -286,8 +291,86 @@ def overlay_instances(
286291

287292
return self.output
288293

294+
295+
def draw_panoptic_seg(self,
296+
panoptic_seg,
297+
segments_info,
298+
area_threshold=None,
299+
alpha=0.7,
300+
jitter_color=False):
301+
"""
302+
Draw panoptic prediction annotations or results.
303+
304+
Args:
305+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
306+
segment.
307+
segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
308+
If it is a ``list[dict]``, each dict contains keys "id", "category_id".
309+
If None, category id of each pixel is computed by
310+
``pixel // metadata.label_divisor``.
311+
area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
312+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
313+
jitter_color (bool): whether to jitter colors
314+
315+
Returns:
316+
output (VisImage): image object with visualizations.
317+
"""
318+
pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
319+
320+
if self._instance_mode == ColorMode.IMAGE_BW:
321+
self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
322+
323+
# draw mask for all semantic segments first i.e. "stuff"
324+
for mask, sinfo in pred.semantic_masks():
325+
category_idx = sinfo["category_id"]
326+
try:
327+
mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
328+
except AttributeError:
329+
mask_color = None
330+
331+
text = self.metadata.stuff_classes[category_idx]
332+
self.draw_binary_mask(
333+
mask,
334+
color=mask_color,
335+
edge_color=(1.0, 1.0, 240.0 / 255), # off_white
336+
text=text,
337+
alpha=alpha,
338+
area_threshold=area_threshold,
339+
)
340+
341+
# draw mask for all instances second
342+
all_instances = list(pred.instance_masks())
343+
if len(all_instances) == 0:
344+
return self.output
345+
masks, sinfo = list(zip(*all_instances))
346+
category_ids = [x["category_id"] for x in sinfo]
347+
348+
try:
349+
scores = [x["score"] for x in sinfo]
350+
except KeyError:
351+
scores = None
352+
labels = _create_text_labels(
353+
category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo]
354+
)
355+
356+
try:
357+
if jitter_color:
358+
colors = [
359+
self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids
360+
]
361+
else:
362+
colors = [
363+
[x / 255 for x in self.metadata.thing_colors[c]] for c in category_ids
364+
]
365+
except AttributeError:
366+
colors = None
367+
self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha)
368+
369+
return self.output
370+
289371
Visualizer.overlay_instances = overlay_instances
290372
Visualizer.draw_instance_predictions = draw_instance_predictions
373+
Visualizer.draw_panoptic_seg = draw_panoptic_seg
291374

292375

293376
def get_nouns(caption, with_preposition):
@@ -538,4 +621,4 @@ def run_inference(hydra_cfg: ExperimentConfig):
538621

539622

540623
if __name__ == "__main__":
541-
run_inference()
624+
run_inference()

0 commit comments

Comments
 (0)