pytorch实现yolov3(4) 非极大值抑制nms
在上一篇里我们实现了forward函数.得到了prediction.此时预测出了特别多的box以及各种class probability,现在我们要从中过滤出我们最终的预测box. obj score threshold我们设置一个obj score thershold,超过这个值的才认为是有效的. conf_mask = (prediction[:,:,4] > confidence).float().unsqueeze(2) prediction = prediction*conf_mask prediction是1*boxnum*boxattr torch中的Tensor index和numpy是类似的,参看下列代码输出 import torch x = torch.Tensor(1,3,10) # Create an un-initialized Tensor of size 2x3 print(x) print(x.shape) # Print out the Tensor y = x[:,4] print(y) print(y.shape) z = x[:,4:6] print(z) print(z.shape) print((y>0.5).float().unsqueeze(2)) #### 输出如下 tensor([[[2.5226e-18,1.6898e-04,1.0413e-11,7.7198e-10,1.0549e-08,4.0516e-11,1.0681e-05,2.9575e-18,6.7333e+22,1.7591e+22],[1.7184e+25,4.3222e+27,6.1972e-04,7.2443e+22,1.7728e+28,7.0367e+22,5.9018e-10,2.6540e-09,1.2972e-11,5.3370e-08],[2.7001e-06,2.6801e-09,4.1292e-05,2.1511e+23,3.2770e-09,2.5125e-18,7.7052e+31,1.9447e+31,5.0207e+28,1.1492e-38]]]) torch.Size([1,10]) tensor([[1.0549e-08,3.2770e-09]]) torch.Size([1,3]) tensor([[[1.0549e-08,4.0516e-11],[1.7728e+28,7.0367e+22],[3.2770e-09,2.5125e-18]]]) torch.Size([1,2]) tensor([[[0.],[0.],[0.]]]) Squeeze and unsqueeze 降低维度,升高维度. t = torch.ones(2,1,2,1) # Size 2x1x2x1 r = torch.squeeze(t) # Size 2x2 r = torch.squeeze(t,1) # Squeeze dimension 1: Size 2x2x1 # Un-squeeze a dimension x = torch.Tensor([1,3]) r = torch.unsqueeze(x,0) # Size: 1x3 表示在第0个维度添加1维 r = torch.unsqueeze(x,1) # Size: 3x1 表示在第1个维度添加1维 这样prediction中objscore<threshold的已经变成了0. nmstensor.new() 创建一个和原有tensor的dtype一致的新tensor https://stackoverflow.com/questions/49263588/pytorch-beginner-tensor-new-method #得到box坐标(top-left corner x,top-left corner y,right-bottom corner x,right-bottom corner y) box_corner = prediction.new(prediction.shape) box_corner[:,0] = (prediction[:,0] - prediction[:,2]/2) box_corner[:,1] = (prediction[:,1] - prediction[:,3]/2) box_corner[:,2] = (prediction[:,0] + prediction[:,2]/2) box_corner[:,3] = (prediction[:,1] + prediction[:,3]/2) prediction[:,:4] = box_corner[:,:4] 原始的prediction中boxattr存放的是x,y,w,h,...,不方便我们处理,我们将其转换成(top-left corner x,right-bottom corner y) 接下来我们挨个处理每一张图片对应的feature map. batch_size = prediction.size(0) write = False for ind in range(batch_size): #image_pred.shape=boxnum*boxattr image_pred = prediction[ind] #image Tensor box_num*box_attr #confidence threshholding #NMS #返回每一行的最大值,及最大值所在的列. max_conf,max_conf_score = torch.max(image_pred[:,5:5+ num_classes],1) #升级成和image_pred同样的维度 max_conf = max_conf.float().unsqueeze(1) max_conf_score = max_conf_score.float().unsqueeze(1) seq = (image_pred[:,:5],max_conf,max_conf_score) #沿着列的方向拼接. 现在image_pred变成boxnum*7 image_pred = torch.cat(seq,1) 这里涉及到torch.max的用法,参见https://blog.csdn.net/Z_lbj/article/details/79766690
c=torch.Tensor([[1,3],[6,5,4]]) print(c) a,b=torch.max(c,1) print(a) print(b) ##输出如下: tensor([[1.,2.,3.],[6.,5.,4.]]) tensor([3.,6.]) tensor([2,0]) torch.cat用法,参见https://pytorch.org/docs/stable/torch.html torch.cat(tensors,dim=0,out=None) → Tensor >>> x = torch.randn(2,3) >>> x tensor([[ 0.6580,-1.0969,-0.4614],[-0.1034,-0.5790,0.1497]]) >>> torch.cat((x,x,x),0) tensor([[ 0.6580,0.1497],[ 0.6580,1) tensor([[ 0.6580,-0.4614,0.6580,0.1497,-0.1034,0.1497]]) 接下来我们只处理obj_score非0的数据(obj_score<obj_threshold转变为0) non_zero_ind = (torch.nonzero(image_pred[:,4])) try: image_pred_ = image_pred[non_zero_ind.squeeze(),:].view(-1,7) except: continue #For PyTorch 0.4 compatibility #Since the above code with not raise exception for no detection #as scalars are supported in PyTorch 0.4 if image_pred_.shape[0] == 0: continue ok,接下来我们对每一种class做nms. #Get the various classes detected in the image img_classes = unique(image_pred_[:,-1]) # -1 index holds the class index 然后依次对每一种类别做处理 for cls in img_classes: #perform NMS #get the detections with one particular class #取出当前class为当前class且class prob!=0的行 cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1) class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze() image_pred_class = image_pred_[class_mask_ind].view(-1,7) #sort the detections such that the entry with the maximum objectness #confidence is at the top #按照obj score从高到低做排序 conf_sort_index = torch.sort(image_pred_class[:,4],descending = True )[1] image_pred_class = image_pred_class[conf_sort_index] idx = image_pred_class.size(0) #Number of detections for i in range(idx): #Get the IOUs of all boxes that come after the one we are looking at #in the loop try: #计算第i个和其后每一行的的iou ious = bbox_iou(image_pred_class[i].unsqueeze(0),image_pred_class[i+1:]) except ValueError: break except IndexError: break #Zero out all the detections that have IoU > treshhold #把与第i行iou>nms_conf的认为是同一个目标的box,将其转成0 iou_mask = (ious < nms_conf).float().unsqueeze(1) image_pred_class[i+1:] *= iou_mask #把iou>nms_conf的移除掉 non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze() image_pred_class = image_pred_class[non_zero_ind].view(-1,7) batch_ind = image_pred_class.new(image_pred_class.size(0),1).fill_(ind) #Repeat the batch_id for as many detections of the class cls in the image seq = batch_ind,image_pred_class 其中计算iou的代码如下,不多解释了.iou=交叠面积/总面积 def bbox_iou(box1,box2): """ Returns the IoU of two bounding boxes """ #Get the coordinates of bounding boxes b1_x1,b1_y1,b1_x2,b1_y2 = box1[:,0],box1[:,1],2],3] b2_x1,b2_y1,b2_x2,b2_y2 = box2[:,box2[:,3] #get the corrdinates of the intersection rectangle inter_rect_x1 = torch.max(b1_x1,b2_x1) inter_rect_y1 = torch.max(b1_y1,b2_y1) inter_rect_x2 = torch.min(b1_x2,b2_x2) inter_rect_y2 = torch.min(b1_y2,b2_y2) #Intersection area inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1,min=0) * torch.clamp(inter_rect_y2 - inter_rect_y1 + 1,min=0) #Union Area b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1) b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1) iou = inter_area / (b1_area + b2_area - inter_area) return iou 关于nms可以看下https://blog.csdn.net/shuzfan/article/details/52711706 tensor index操作用法如下: image_pred_ = torch.Tensor([[1,4,9],[5,6,7,8,9]]) #print(image_pred_[:,-1] == 9) has_9 = (image_pred_[:,-1] == 9) print(has_9) ###执行顺序是(image_pred_[:,-1] == 9).float().unsqueeze(1) 再做tensor乘法 cls_mask = image_pred_*(image_pred_[:,-1] == 9).float().unsqueeze(1) print(cls_mask) class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze() image_pred_class = image_pred_[class_mask_ind] 输出如下: tensor([1,dtype=torch.uint8) tensor([[1.,3.,4.,9.],[5.,6.,7.,8.,9.]]) torch.sort用法如下: d=torch.Tensor([[1,4]]) e=d[:,2] print(e) print(torch.sort(e)) 输出 tensor([3.,4.]) torch.return_types.sort( values=tensor([3.,4.]),indices=tensor([0,1])) 总结一下我们做nms的流程
write_results最终的返回值是一个n*8的tensor,其中8是(batch_index,4个坐标,1个objscore,1个class prob,一个class index) def write_results(prediction,confidence,num_classes,nms_conf = 0.4): print("prediction.shape=",prediction.shape) #将obj_score < confidence的行置为0 conf_mask = (prediction[:,4] > confidence).float().unsqueeze(2) prediction = prediction*conf_mask #得到box坐标(top-left corner x,3]/2) #修改prediction第三个维度的前四列 prediction[:,:4] batch_size = prediction.size(0) write = False for ind in range(batch_size): #image_pred.shape=boxnum*boxattr image_pred = prediction[ind] #image Tensor #confidence threshholding #NMS ##取出每一行的class score最大的一个 max_conf_score,max_conf = torch.max(image_pred[:,1) max_conf = max_conf.float().unsqueeze(1) max_conf_score = max_conf_score.float().unsqueeze(1) seq = (image_pred[:,max_conf_score,max_conf) image_pred = torch.cat(seq,1) #现在变成7列,分别为左上角x,左上角y,右下角x,右下角y,obj score,最大probabilty,相应的class index print(image_pred.shape) non_zero_ind = (torch.nonzero(image_pred[:,7) except: continue #For PyTorch 0.4 compatibility #Since the above code with not raise exception for no detection #as scalars are supported in PyTorch 0.4 if image_pred_.shape[0] == 0: continue #Get the various classes detected in the image img_classes = unique(image_pred_[:,-1]) # -1 index holds the class index for cls in img_classes: #perform NMS #get the detections with one particular class #取出当前class为当前class且class prob!=0的行 cls_mask = image_pred_*(image_pred_[:,image_pred_class if not write: output = torch.cat(seq,1) #沿着列方向,shape 1*8 write = True else: out = torch.cat(seq,1) output = torch.cat((output,out)) #沿着行方向 shape n*8 try: return output except: return 0 (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |