diff --git a/layers/box_utils.py b/layers/box_utils.py index 84214947b..d71284c4e 100644 --- a/layers/box_utils.py +++ b/layers/box_utils.py @@ -22,8 +22,8 @@ def center_size(boxes): Return: boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. """ - return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy - boxes[:, 2:] - boxes[:, :2], 1) # w, h + return torch.cat(((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy + boxes[:, 2:] - boxes[:, :2]), 1) # w, h def intersect(box_a, box_b): diff --git a/layers/functions/prior_box.py b/layers/functions/prior_box.py index 7848a390d..973030145 100644 --- a/layers/functions/prior_box.py +++ b/layers/functions/prior_box.py @@ -2,6 +2,7 @@ from math import sqrt as sqrt from itertools import product as product import torch +from layers.box_utils import point_form, center_size class PriorBox(object): @@ -51,5 +52,7 @@ def forward(self): # back to torch land output = torch.Tensor(mean).view(-1, 4) if self.clip: + output = point_form(output) output.clamp_(max=1, min=0) + output = center_size(output) return output