torch.nn
torch.nn
是专门为神经网络设计的模块化接口。nn构建于 Autograd之上,可用来定义和运行神经网络。 这里我们主要介绍几个一些常用的类torch.nn
只支持小批量输入,不支持单个样本,如nn.Conv2d
接受一个四维的张量,每一维分别是Samples * Channels * height * width
,即样本数*通道数*高*宽
。如果单个样本,需要用input.unsqueeze(0)添加其他维数
保存模型参数
torch.save(model.state_dict(), ‘\parameter.pkl’)
model = TheModelClass(…)
model.load_state_dict(torch.load(‘\parameter.pkl’))
保存整个模型
torch.save(model, ‘\model.pkl’)
model = torch.load(‘\model.pkl’)
其他常用函数
- torch.squeeze()和 torch.unsqueeze()
squeeze()取掉维数为1的维度,unsqueeze(N)在指定位置增加一个维数为1的维度 - torch.cat((input1, input2), n)在第n维将input1和input2级联起来
设置gpu
- CUDA_VISIBLE_DEVICES=1 python test.py
- import os
os.environ[“CUDA_VISIBLE_DEVICES”] = “2” - import torch
torch.cuda.set_device(id)