diff --git a/torchao/_models/sam2/modeling/sam2_base.py b/torchao/_models/sam2/modeling/sam2_base.py index 4c2a24a0ef..5c4eda1d6c 100644 --- a/torchao/_models/sam2/modeling/sam2_base.py +++ b/torchao/_models/sam2/modeling/sam2_base.py @@ -788,9 +788,10 @@ def _track_step( if prev_sam_mask_logits is not None: assert point_inputs is not None and mask_inputs is None mask_inputs = prev_sam_mask_logits + else: + assert mask_inputs is None multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) - - assert mask_inputs is None + assert multimask_output if point_inputs is not None: point_inputs = {k: point_inputs[k].contiguous() for k in point_inputs}