Pytroch Tutorials

torch.nn

torch.nn是专门为神经网络设计的模块化接口。nn构建于 Autograd之上,可用来定义和运行神经网络。 这里我们主要介绍几个一些常用的类
torch.nn只支持小批量输入,不支持单个样本,如nn.Conv2d接受一个四维的张量,每一维分别是Samples * Channels * height * width即样本数*通道数*高*宽。如果单个样本,需要用input.unsqueeze(0)添加其他维数

保存模型参数

torch.save(model.state_dict(), ‘\parameter.pkl’)
model = TheModelClass(…)
model.load_state_dict(torch.load(‘\parameter.pkl’))

保存整个模型

torch.save(model, ‘\model.pkl’)
model = torch.load(‘\model.pkl’)

其他常用函数

  1. torch.squeeze()和 torch.unsqueeze()
    squeeze()取掉维数为1的维度,unsqueeze(N)在指定位置增加一个维数为1的维度
  2. torch.cat((input1, input2), n)在第n维将input1和input2级联起来

设置gpu

  1. CUDA_VISIBLE_DEVICES=1 python test.py
  2. import os
    os.environ[“CUDA_VISIBLE_DEVICES”] = “2”
  3. import torch
    torch.cuda.set_device(id)