加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 大数据 > 正文

PyTorch

发布时间:2020-12-14 04:41:22 所属栏目:大数据 来源:网络整理
导读:1.Torch.nn class torch.nn.Conv2d(in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True) 卷积 //1是输入图像的channel//6是输出图像的channel// 5是卷积核大小 nn.Conv2d( 1,6,5) class torch. nn. Linear (in_feature

1.Torch.nn

class torch.nn.Conv2d(in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True)

卷积

//1是输入图像的channel
//6是输出图像的channel
//5是卷积核大小
nn.Conv2d(1,6,5)

class torch.nn.Linear(in_features,out_features,bias=True)

对输入数据做线性变换:y=Ax+b,全连接

2.torch.nn.functional

torch.nn.functional.max_pool2d(input,stride=None,padding=0,dilation=1,ceil_mode=False,return_indices=False)

import torch.nn.functional as F
//使用2*2的核进行maxpooling
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))

3.torch.optim

import torch.optim as optim # create your optimizer optimizer = optim.SGD(net.parameters(),lr = 0.01) # in your training loop: optimizer.zero_grad() # zero the gradient buffers output = net(input) loss = criterion(output,target) loss.backward() optimizer.step() # Does the update

torch.cat(inputs,dimension=0) → Tensor,通道合并

x = torch.randn(2,3)

0.5983 -0.0341 2.4918

1.5981 -0.5265 -0.8735

torch.cat((x,x,),0)
0.5983 -0.0341 2.4918

1.5981 -0.5265 -0.8735
0.5983 -0.0341 2.4918

1.5981 -0.5265 -0.8735

将多维展开到1维?

x = x.view(x.size(0),-1)

PyTorch数据增强方法

1.对图片进行一定比例缩放

torchvision.transforms.Resize()

  第一个参数是一个 tuple,图片会直接把宽和高缩放到这个大小;第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值

import matplotlib.pyplot as plt
from torchvision import transforms as tfs
from PIL import Image
#pic.shape(640,1024,3)
pic = Image.open(bridge.jpg)
resize = tfs.Resize((300,300))
pic1 = resize(pic)

plt.subplot(1,2,1)
plt.imshow(pic)
plt.axis(off) # 不显示坐标轴

plt.subplot(1,2)
plt.imshow(pic1)
plt.axis(off) # 不显示坐标轴

plt.show()

2.对图片进行随机位置的截取

  随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision 中主要有下面两种方式,一个是 torchvision.transforms.RandomCrop(),传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 torchvision.transforms.CenterCrop()

传入的参数就是截取出的图片的长和宽,会在图片的中心进行截取

import matplotlib.pyplot as plt
from torchvision import transforms as tfs
from PIL import Image
#pic.shape(640,3)
pic = Image.open(bridge.jpg)
#随机位置截取,多截几次结果不一样
random_crop = tfs.RandomCrop((400,400))
pic1 = random_crop(pic)

#中心位置截取,多截几次结果一样
center_crop = tfs.CenterCrop((400,400))
pic2 = center_crop(pic)
plt.subplot(1,1)
plt.imshow(pic1)
plt.xlabel(RandomCrop)
plt.legend()

plt.subplot(1,2)
plt.imshow(pic2)
plt.xlabel(CenterCrop)
plt.legend()

plt.show()

3.对图片进行随机的水平和竖直翻转

随机翻转使用的是?torchvision.transforms.RandomHorizontalFlip()?和?torchvision.transforms.RandomVerticalFlip(),翻转概率是0.5,有可能不翻转

import matplotlib.pyplot as plt
from torchvision import transforms as tfs
from PIL import Image
#pic.shape(640,3)
pic = Image.open(bridge.jpg)
#随机垂直翻转
random_vertical_flip = tfs.RandomVerticalFlip()
pic1 = random_vertical_flip(pic)

#随机水平翻转
random_horizontal_flip = tfs.RandomHorizontalFlip()
pic2 = random_horizontal_flip(pic)
plt.subplot(1,1)
plt.imshow(pic1)
plt.xlabel(vertical)
plt.legend()

plt.subplot(1,2)
plt.imshow(pic2)
plt.xlabel(horizontal)
plt.legend()

plt.show()

4.对图片进行随机角度的旋转

  在 torchvision 中,使用?torchvision.transforms.RandomRotation()?来实现,其中第一个参数就是随机旋转的角度,比如填入 10,那么每次图片就会在 -10 ~ 10 度之间随机旋转

import matplotlib.pyplot as plt
from torchvision import transforms as tfs
from PIL import Image
#pic.shape(640,3)
pic = Image.open(bridge.jpg)
#随机旋转10度
random_rotation = tfs.RandomRotation(10)
pic1 = random_rotation(pic)

plt.subplot(1,1)
plt.imshow(pic)
plt.xlabel(original)
plt.legend()

plt.subplot(1,2)
plt.imshow(pic1)
plt.xlabel(random_rotation)
plt.legend()

plt.show()

5.对图片进行亮度、对比度和颜色的随机变化

  在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色

参数:

brightness(亮度,float类型)——调整亮度的程度,从 [max(0,1-brightness),1+brightness] 中均匀选取。

contrast(对比度,float类型)——调整对比度的程度,从 [max(0,1-contrast),1+contrast] 中均匀选取。
saturation(饱和度,float类型)——调整饱和度的程度, [max(0,1-saturation),1+saturation] 中均匀选取。
hue(色相,float类型) —— 调整色相的程度,从 [-hue,hue] 等均匀选择,其中hue的大小为 [0,0.5]。

?

pic = Image.open(bridge.jpg)
#亮度
brightness = tfs.ColorJitter(brightness=0.5)
pic1 = brightness(pic)
#对比度
contrast = tfs.ColorJitter(contrast=0.4)
pic2 = contrast(pic)
#饱和度
saturation = tfs.ColorJitter(saturation=0.3)
pic3 = saturation(pic)
#饱和度
hue = tfs.ColorJitter(hue=0.4)
pic4 = hue(pic)

亮度

对比度

饱和度

色相

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读