加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 大数据 > 正文

pytorch实现yolov3(4) 非极大值抑制nms

发布时间:2020-12-14 04:42:02 所属栏目:大数据 来源:网络整理
导读:在上一篇里我们实现了forward函数.得到了prediction.此时预测出了特别多的box以及各种class probability,现在我们要从中过滤出我们最终的预测box. 理解了yolov3的输出的格式及每一个位置的含义,并不难理解源码.我在阅读源码的过程中主要的困难在于对pytorch

在上一篇里我们实现了forward函数.得到了prediction.此时预测出了特别多的box以及各种class probability,现在我们要从中过滤出我们最终的预测box.
理解了yolov3的输出的格式及每一个位置的含义,并不难理解源码.我在阅读源码的过程中主要的困难在于对pytorch不熟悉,所以在这篇文章里,关于其中涉及的一些pytorch中的函数的用法我都已经用加粗标示了并且给出了相应的链接,测试代码等.

obj score threshold

我们设置一个obj score thershold,超过这个值的才认为是有效的.

conf_mask = (prediction[:,:,4] > confidence).float().unsqueeze(2)
    prediction = prediction*conf_mask

prediction是1*boxnum*boxattr
prediction[:,4]是1*boxnum 元素值为boxattr的index=4的那个值.

torch中的Tensor index和numpy是类似的,参看下列代码输出

import torch
x = torch.Tensor(1,3,10)    # Create an un-initialized Tensor of size 2x3
print(x)
print(x.shape)                  # Print out the Tensor

y = x[:,4]
print(y)
print(y.shape)

z = x[:,4:6]
print(z)
print(z.shape)

print((y>0.5).float().unsqueeze(2))

#### 输出如下
tensor([[[2.5226e-18,1.6898e-04,1.0413e-11,7.7198e-10,1.0549e-08,4.0516e-11,1.0681e-05,2.9575e-18,6.7333e+22,1.7591e+22],[1.7184e+25,4.3222e+27,6.1972e-04,7.2443e+22,1.7728e+28,7.0367e+22,5.9018e-10,2.6540e-09,1.2972e-11,5.3370e-08],[2.7001e-06,2.6801e-09,4.1292e-05,2.1511e+23,3.2770e-09,2.5125e-18,7.7052e+31,1.9447e+31,5.0207e+28,1.1492e-38]]])
torch.Size([1,10])
tensor([[1.0549e-08,3.2770e-09]])
torch.Size([1,3])
tensor([[[1.0549e-08,4.0516e-11],[1.7728e+28,7.0367e+22],[3.2770e-09,2.5125e-18]]])
torch.Size([1,2])

tensor([[[0.],[0.],[0.]]])

Squeeze and unsqueeze 降低维度,升高维度.

t = torch.ones(2,1,2,1) # Size 2x1x2x1
r = torch.squeeze(t)     # Size 2x2
r = torch.squeeze(t,1)  # Squeeze dimension 1: Size 2x2x1

# Un-squeeze a dimension
x = torch.Tensor([1,3])
r = torch.unsqueeze(x,0)       # Size: 1x3  表示在第0个维度添加1维
r = torch.unsqueeze(x,1)       # Size: 3x1  表示在第1个维度添加1维

这样prediction中objscore<threshold的已经变成了0.

nms

tensor.new() 创建一个和原有tensor的dtype一致的新tensor https://stackoverflow.com/questions/49263588/pytorch-beginner-tensor-new-method

#得到box坐标(top-left corner x,top-left corner y,right-bottom corner x,right-bottom corner y)
    box_corner = prediction.new(prediction.shape)
    box_corner[:,0] = (prediction[:,0] - prediction[:,2]/2)
    box_corner[:,1] = (prediction[:,1] - prediction[:,3]/2)
    box_corner[:,2] = (prediction[:,0] + prediction[:,2]/2) 
    box_corner[:,3] = (prediction[:,1] + prediction[:,3]/2)
    prediction[:,:4] = box_corner[:,:4]

原始的prediction中boxattr存放的是x,y,w,h,...,不方便我们处理,我们将其转换成(top-left corner x,right-bottom corner y)

接下来我们挨个处理每一张图片对应的feature map.

batch_size = prediction.size(0)
    write = False

    for ind in range(batch_size):
        #image_pred.shape=boxnum*boxattr
        image_pred = prediction[ind]          #image Tensor  box_num*box_attr
        #confidence threshholding 
        #NMS
        #返回每一行的最大值,及最大值所在的列.
        max_conf,max_conf_score = torch.max(image_pred[:,5:5+ num_classes],1)
        #升级成和image_pred同样的维度
        max_conf = max_conf.float().unsqueeze(1)
        max_conf_score = max_conf_score.float().unsqueeze(1)
        seq = (image_pred[:,:5],max_conf,max_conf_score)
        
        #沿着列的方向拼接. 现在image_pred变成boxnum*7
        image_pred = torch.cat(seq,1)

这里涉及到torch.max的用法,参见https://blog.csdn.net/Z_lbj/article/details/79766690
torch.max(input,dim,keepdim=False,out=None) -> (Tensor,LongTensor)
按维度dim 返回最大值.可以这么记忆,沿着第dim维度比较.torch.max(0)即沿着行的方向比较,即得到每列的最大值.
假设input是二维矩阵,即行*列,行是第0维,列是第一维.

  • torch.max(a,0) 返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)
  • torch.max(a,1) 返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
c=torch.Tensor([[1,3],[6,5,4]])
print(c)
a,b=torch.max(c,1)
print(a)
print(b)

##输出如下:
tensor([[1.,2.,3.],[6.,5.,4.]])
tensor([3.,6.])
tensor([2,0])

torch.cat用法,参见https://pytorch.org/docs/stable/torch.html

torch.cat(tensors,dim=0,out=None) → Tensor
>>> x = torch.randn(2,3)
>>> x
tensor([[ 0.6580,-1.0969,-0.4614],[-0.1034,-0.5790,0.1497]])
>>> torch.cat((x,x,x),0)
tensor([[ 0.6580,0.1497],[ 0.6580,1)
tensor([[ 0.6580,-0.4614,0.6580,0.1497,-0.1034,0.1497]])

接下来我们只处理obj_score非0的数据(obj_score<obj_threshold转变为0)

non_zero_ind =  (torch.nonzero(image_pred[:,4]))
        try:
            image_pred_ = image_pred[non_zero_ind.squeeze(),:].view(-1,7)
        except:
            continue

        #For PyTorch 0.4 compatibility
        #Since the above code with not raise exception for no detection 
        #as scalars are supported in PyTorch 0.4
        if image_pred_.shape[0] == 0:
            continue

ok,接下来我们对每一种class做nms.
首先取到我们有哪些类别

#Get the various classes detected in the image
        img_classes = unique(image_pred_[:,-1])  # -1 index holds the class index

然后依次对每一种类别做处理

for cls in img_classes:
            #perform NMS

        
            #get the detections with one particular class
            #取出当前class为当前class且class prob!=0的行
            cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1)
            class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()
            image_pred_class = image_pred_[class_mask_ind].view(-1,7)
            
            #sort the detections such that the entry with the maximum objectness
            #confidence is at the top
            #按照obj score从高到低做排序
            conf_sort_index = torch.sort(image_pred_class[:,4],descending = True )[1]
            image_pred_class = image_pred_class[conf_sort_index]
            idx = image_pred_class.size(0)   #Number of detections
            
            for i in range(idx):
                #Get the IOUs of all boxes that come after the one we are looking at 
                #in the loop
                try:
                    #计算第i个和其后每一行的的iou
                    ious = bbox_iou(image_pred_class[i].unsqueeze(0),image_pred_class[i+1:])
                except ValueError:
                    break
            
                except IndexError:
                    break
            
                #Zero out all the detections that have IoU > treshhold
                #把与第i行iou>nms_conf的认为是同一个目标的box,将其转成0
                iou_mask = (ious < nms_conf).float().unsqueeze(1)
                image_pred_class[i+1:] *= iou_mask       
            
                #把iou>nms_conf的移除掉
                non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze()
                image_pred_class = image_pred_class[non_zero_ind].view(-1,7)
                
            batch_ind = image_pred_class.new(image_pred_class.size(0),1).fill_(ind)      #Repeat the batch_id for as many detections of the class cls in the image
            seq = batch_ind,image_pred_class

其中计算iou的代码如下,不多解释了.iou=交叠面积/总面积

def bbox_iou(box1,box2):
    """
    Returns the IoU of two bounding boxes 
    
    
    """
    #Get the coordinates of bounding boxes
    b1_x1,b1_y1,b1_x2,b1_y2 = box1[:,0],box1[:,1],2],3]
    b2_x1,b2_y1,b2_x2,b2_y2 = box2[:,box2[:,3]
    
    #get the corrdinates of the intersection rectangle
    inter_rect_x1 =  torch.max(b1_x1,b2_x1)
    inter_rect_y1 =  torch.max(b1_y1,b2_y1)
    inter_rect_x2 =  torch.min(b1_x2,b2_x2)
    inter_rect_y2 =  torch.min(b1_y2,b2_y2)
    
    #Intersection area
    inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1,min=0) * torch.clamp(inter_rect_y2 - inter_rect_y1 + 1,min=0)

    #Union Area
    b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1)
    b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1)
    
    iou = inter_area / (b1_area + b2_area - inter_area)
    
    return iou

关于nms可以看下https://blog.csdn.net/shuzfan/article/details/52711706

tensor index操作用法如下:

image_pred_ = torch.Tensor([[1,4,9],[5,6,7,8,9]])
#print(image_pred_[:,-1] == 9)
has_9 = (image_pred_[:,-1] == 9)
print(has_9)

###执行顺序是(image_pred_[:,-1] == 9).float().unsqueeze(1) 再做tensor乘法
cls_mask = image_pred_*(image_pred_[:,-1] == 9).float().unsqueeze(1)
print(cls_mask)
class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()
image_pred_class = image_pred_[class_mask_ind]

输出如下:
tensor([1,dtype=torch.uint8)
tensor([[1.,3.,4.,9.],[5.,6.,7.,8.,9.]])

torch.sort用法如下:

d=torch.Tensor([[1,4]])
e=d[:,2]
print(e)
print(torch.sort(e))

输出
tensor([3.,4.])

torch.return_types.sort(
values=tensor([3.,4.]),indices=tensor([0,1]))

总结一下我们做nms的流程
每一个image,会预测出N个detetction信息,包括4+1+C(4个坐标信息,1个obj score以及C个class probability)

  • 首先过滤掉obj_score < confidence的行
  • 每一行只取class probability最高的作为预测出来的类别
  • 将所有的预测按照obj_score从大到小排序
  • 循环每一种类别,开始做nms
    • 比较第一个box与其后所有box的iou,删除iou>threshold的box,即剔除所有相似box
    • 比较下一个box与其后所有box的iou,删除所有与该box相似的box
    • 不断重复上述过程,直至不再有相似box
    • 至此,实现了当前处理的类别的多个box均是独一无二的box.

write_results最终的返回值是一个n*8的tensor,其中8是(batch_index,4个坐标,1个objscore,1个class prob,一个class index)

def write_results(prediction,confidence,num_classes,nms_conf = 0.4):
    print("prediction.shape=",prediction.shape)

    #将obj_score < confidence的行置为0
    conf_mask = (prediction[:,4] > confidence).float().unsqueeze(2)
    prediction = prediction*conf_mask

    #得到box坐标(top-left corner x,3]/2)
    #修改prediction第三个维度的前四列
    prediction[:,:4]

    batch_size = prediction.size(0)
    write = False

    for ind in range(batch_size):
        #image_pred.shape=boxnum*boxattr
        image_pred = prediction[ind]          #image Tensor
        #confidence threshholding 
        #NMS

        ##取出每一行的class score最大的一个
        max_conf_score,max_conf = torch.max(image_pred[:,1)
        max_conf = max_conf.float().unsqueeze(1)
        max_conf_score = max_conf_score.float().unsqueeze(1)
        seq = (image_pred[:,max_conf_score,max_conf)
        image_pred = torch.cat(seq,1) #现在变成7列,分别为左上角x,左上角y,右下角x,右下角y,obj score,最大probabilty,相应的class index
        print(image_pred.shape)

        non_zero_ind =  (torch.nonzero(image_pred[:,7)
        except:
            continue

        #For PyTorch 0.4 compatibility
        #Since the above code with not raise exception for no detection 
        #as scalars are supported in PyTorch 0.4
        if image_pred_.shape[0] == 0:
            continue 

        #Get the various classes detected in the image
        img_classes = unique(image_pred_[:,-1])  # -1 index holds the class index
        
        
        for cls in img_classes:
            #perform NMS

            #get the detections with one particular class
            #取出当前class为当前class且class prob!=0的行
            cls_mask = image_pred_*(image_pred_[:,image_pred_class
            
            if not write:
                output = torch.cat(seq,1)  #沿着列方向,shape 1*8
                write = True
            else:
                out = torch.cat(seq,1)
                output = torch.cat((output,out)) #沿着行方向 shape n*8

    try:
        return output
    except:
        return 0

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读