python常用代码
发布时间:2020-12-20 10:42:50 所属栏目:Python 来源:网络整理
导读:目录 常用代码片段及技巧 自动选择GPU和CPU 切换当前目录 打印模型参数 将tensor的列表转换为tensor 内存不够 debug tensor memory 常用代码片段及技巧 自动选择GPU和CPU device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# model and
目录
常用代码片段及技巧自动选择GPU和CPUdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # model and tensor to device vgg = models.vgg16().to(device) 切换当前目录import os try: os.chdir(os.path.join(os.getcwd(),'..')) print(os.getcwd()) except: pass 打印模型参数from torchsummary import summary # 1 means in_channels summary(model,(1,28,28)) 将tensor的列表转换为tensorx = torch.stack(tensor_list) 内存不够
debug tensor memory
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def debug_memory(): import collections,gc,resource,torch print('maxrss = {}'.format( resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)) tensors = collections.Counter((str(o.device),o.dtype,tuple(o.shape)) for o in gc.get_objects() if torch.is_tensor(o)) for line in sorted(tensors.items()): print('{}t{}'.format(*line)) # example import tensor x = torch.tensor(3,3) debug_memory() y = torch.tensor(3,3) debug_memory() z = [torch.randn(i).long() for i in range(10)] debug_memory() (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |