PyTorch
1.Torch.nnclass torch.nn.Conv2d(in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True)卷积 //1是输入图像的channel //6是输出图像的channel //5是卷积核大小 nn.Conv2d(1,6,5) class torch.nn.Linear(in_features,out_features,bias=True)对输入数据做线性变换:y=Ax+b,全连接 2.torch.nn.functionaltorch.nn.functional.max_pool2d(input,stride=None,padding=0,dilation=1,ceil_mode=False,return_indices=False)import torch.nn.functional as F //使用2*2的核进行maxpooling x = F.max_pool2d(F.relu(self.conv1(x)),(2,2)) 3.torch.optimimport torch.optim as optim # create your optimizer optimizer = optim.SGD(net.parameters(),lr = 0.01) # in your training loop: optimizer.zero_grad() # zero the gradient buffers output = net(input) loss = criterion(output,target) loss.backward() optimizer.step() # Does the update
torch.cat(inputs,dimension=0) → Tensor,通道合并x = torch.randn(2,3) 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 torch.cat((x,x,),0) 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 将多维展开到1维? x = x.view(x.size(0),-1) PyTorch数据增强方法1.对图片进行一定比例缩放 torchvision.transforms.Resize()第一个参数是一个 tuple,图片会直接把宽和高缩放到这个大小;第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值 import matplotlib.pyplot as plt from torchvision import transforms as tfs from PIL import Image #pic.shape(640,1024,3) pic = Image.open(‘bridge.jpg‘) resize = tfs.Resize((300,300)) pic1 = resize(pic) plt.subplot(1,2,1) plt.imshow(pic) plt.axis(‘off‘) # 不显示坐标轴 plt.subplot(1,2) plt.imshow(pic1) plt.axis(‘off‘) # 不显示坐标轴 plt.show() 2.对图片进行随机位置的截取 随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision 中主要有下面两种方式,一个是 torchvision.transforms.RandomCrop(),传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 torchvision.transforms.CenterCrop() 传入的参数就是截取出的图片的长和宽,会在图片的中心进行截取 import matplotlib.pyplot as plt from torchvision import transforms as tfs from PIL import Image #pic.shape(640,3) pic = Image.open(‘bridge.jpg‘) #随机位置截取,多截几次结果不一样 random_crop = tfs.RandomCrop((400,400)) pic1 = random_crop(pic) #中心位置截取,多截几次结果一样 center_crop = tfs.CenterCrop((400,400)) pic2 = center_crop(pic) plt.subplot(1,1) plt.imshow(pic1) plt.xlabel(‘RandomCrop‘) plt.legend() plt.subplot(1,2) plt.imshow(pic2) plt.xlabel(‘CenterCrop‘) plt.legend() plt.show() 3.对图片进行随机的水平和竖直翻转 随机翻转使用的是? import matplotlib.pyplot as plt from torchvision import transforms as tfs from PIL import Image #pic.shape(640,3) pic = Image.open(‘bridge.jpg‘) #随机垂直翻转 random_vertical_flip = tfs.RandomVerticalFlip() pic1 = random_vertical_flip(pic) #随机水平翻转 random_horizontal_flip = tfs.RandomHorizontalFlip() pic2 = random_horizontal_flip(pic) plt.subplot(1,1) plt.imshow(pic1) plt.xlabel(‘vertical‘) plt.legend() plt.subplot(1,2) plt.imshow(pic2) plt.xlabel(‘horizontal‘) plt.legend() plt.show() 4.对图片进行随机角度的旋转 在 torchvision 中,使用? import matplotlib.pyplot as plt from torchvision import transforms as tfs from PIL import Image #pic.shape(640,3) pic = Image.open(‘bridge.jpg‘) #随机旋转10度 random_rotation = tfs.RandomRotation(10) pic1 = random_rotation(pic) plt.subplot(1,1) plt.imshow(pic) plt.xlabel(‘original‘) plt.legend() plt.subplot(1,2) plt.imshow(pic1) plt.xlabel(‘random_rotation‘) plt.legend() plt.show() 5.对图片进行亮度、对比度和颜色的随机变化 在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色 参数: brightness(亮度,float类型)——调整亮度的程度,从 [max(0,1-brightness),1+brightness] 中均匀选取。 contrast(对比度,float类型)——调整对比度的程度,从 [max(0,1-contrast),1+contrast] 中均匀选取。 ? pic = Image.open(‘bridge.jpg‘) #亮度 brightness = tfs.ColorJitter(brightness=0.5) pic1 = brightness(pic) #对比度 contrast = tfs.ColorJitter(contrast=0.4) pic2 = contrast(pic) #饱和度 saturation = tfs.ColorJitter(saturation=0.3) pic3 = saturation(pic) #饱和度 hue = tfs.ColorJitter(hue=0.4) pic4 = hue(pic) 亮度 对比度 饱和度
色相 (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |