成都市建设工程交易中心网站,wordpress登陆后跳转到首页,泉州安溪县住房和城乡建设网站,合肥在线网站文章目录 1.保存、加载模型2.torch.nn.Module.state_dict()2.1基本使用2.2保存和加载状态字典 3.创建Checkpoint3.1基本使用3.2完整案例 1.保存、加载模型 torch.save()用于保存一个序列化对象到磁盘上#xff0c;该序列化对象可以是任何类型的对象#xff0c;包括模型、张量… 文章目录 1.保存、加载模型2.torch.nn.Module.state_dict()2.1基本使用2.2保存和加载状态字典 3.创建Checkpoint3.1基本使用3.2完整案例 1.保存、加载模型 torch.save()用于保存一个序列化对象到磁盘上该序列化对象可以是任何类型的对象包括模型、张量和字典等内部使用pickle模块实现对象的序列化。数据会被保存为.pt、.pth格式可通过torch.load()从磁盘加载被保存的序列化对象加载时会重新构造出原来的对象。 torch.save()有两种保存模型的方式
1.保存整个模型继承了torch.nn.Module的类不推荐使用。 torch.load()利用pickle将保存的序列化对象反序列化得到原始数据。可用于加载完整模型或状态字典。
#保存整个模型
torch.save(model, PATH)
#加载模型
model torch.load(PATH)2.仅保存模型的参数状态字典state_dict推荐使用。 torch.nn.Module.load_state_dict()通过反序列化得到模型的state_dict()状态字典来加载模型传入的参数是状态字典而非.pt、.pth文件。
#只保存模型参数
torch.save(model.state_dict(), PATH)
#加载模型
modelModel()
model.load_state_dict(torch.load(PATH))在实际使用中推荐第二种方式第一种方式往往容易产生各种错误
设备错误。若在cuda:0上训练好一个模型并保存则读取出来的模型也是默认在cuda:0上如果训练过程的其他数据被放到了cuda:1上那么就会发生错误
RuntimeError: arguments are located on different GPUs at /opt/conda/conda-bld/pytorch_1503966894950/work/torch/lib/THC/generated/../generic/THCTensorMathPointwise.cu:215此时需要将其他其他数据都保存在cuda:0上或加载模型时指定使用cuda:1
device torch.device(cuda:1)
model torch.load(PATH, map_locationdevice)版本错误比如使用pytorch1.0训练并保存CNN模型再用pytorch1.1读取模型则会出现错误
AttributeError: Conv2d object has no attribute padding_mode此时只能通过获取该模型的参数来加载新的模型
#加载模型参数
model_state torch.load(model_path).state_dict()
#初始化新模型并加载参数
model Model()
model.load_state_dict(model_state)2.torch.nn.Module.state_dict()
2.1基本使用 torch.nn.Module.state_dict()用于返回模型的状态字典其中保存了模型的可学习参数。其中只有可学习参数的层卷积层、全连接层等和注册缓冲区batchnorm’s running_mean才会作为模型参数保存优化器也有状态字典也可进行保存。 【例子】
import torch
import torch.nn as nn
import torch.optim as optim# 定义模型
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 nn.Conv2d(3, 6, 5)self.pool nn.MaxPool2d(2, 2)self.conv2 nn.Conv2d(6, 16, 5)self.fc1 nn.Linear(16 * 5 * 5, 120)self.fc2 nn.Linear(120, 84)self.fc3 nn.Linear(84, 10)def forward(self, x):x self.pool(F.relu(self.conv1(x)))x self.pool(F.relu(self.conv2(x)))x x.view(-1, 16 * 5 * 5)x F.relu(self.fc1(x))x F.relu(self.fc2(x))x self.fc3(x)return x# 初始化模型
model TheModelClass()# 初始化优化器
optimizer optim.SGD(model.parameters(), lr0.001, momentum0.9)# 打印模型的状态字典
print(Models state_dict:)
for param_tensor in model.state_dict():print(param_tensor, \t, model.state_dict()[param_tensor].size())# 打印优化器的状态字典
print(Optimizers state_dict:)
for var_name in optimizer.state_dict():print(var_name, \t, optimizer.state_dict()[var_name])查看模型与优化器的状态字典
2.2保存和加载状态字典 通过torch.save()来保存模型的状态字典state_dict即只保存学习到的模型参数并通过torch.nn.Module.load_state_dict()来加载并恢复模型参数。PyTorch中最常见的模型保存扩展名为.pt或.pth。
#保存模型状态字典
PATH ./test_state_dict.pth
torch.save(model.state_dict(), PATH)
#根据状态字典加载模型
model TheModelClass()
model.load_state_dict(torch.load(PATH))
model.eval()
#打印新模型的状态字典
print(Models state_dict:)
for param_tensor in model.state_dict():print(param_tensor, \t, model.state_dict()[param_tensor].size())注意模型推理之前需要调用model.eval()函数将dropout和batch normalization层设置为评估模式否则会导致模型推理结果不一致。
3.创建Checkpoint
3.1基本使用 模型检查点checkpoint是指模型训练过程中保存的模型状态包括模型参数权重与偏置、优化器状态等其他相关的训练信息。通过保存检查点可以实现在训练过程中定期保存模型的当前状态以便在需要时恢复训练或用于模型评估和推理。模型检查点常见的保存信息如下
1.模型权重模型的状态字典。2.优化器状态优化器的状态字典。3.训练状态当前的训练轮数epoch、批次batch等。4.其他数据如学习率调度器的状态、自定义指标等。
例如 【保存检查点】
#将模型参数和优化器状态的状态字典保存到检查点中
checkpoint {model_state_dict: model.state_dict(),optimizer_state_dict: optimizer.state_dict(),loss: loss.item(),epoch:epoch
}#保存检查点
torch.save(checkpoint, checkpoint.pth)【加载检查点】
# 加载检查点
checkpoint torch.load(checkpoint.pth)# 恢复模型和优化器状态
model.load_state_dict(checkpoint[model_state_dict])
optimizer.load_state_dict(checkpoint[optimizer_state_dict])# 恢复训练状态
epoch checkpoint[epoch]
loss checkpoint[loss]# 如果是恢复训练可以从保存的epoch继续
for epoch in range(epoch, num_epochs):# 继续训练3.2完整案例
import torch
import torch.nn as nn
import torch.optim as optim# 假设有一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc nn.Linear(10, 1)def forward(self, x):return self.fc(x)model SimpleModel()
optimizer optim.SGD(model.parameters(), lr0.01)
loss_fn nn.MSELoss()# 训练循环
num_epochs 100
for epoch in range(num_epochs):# 假设有输入x和目标yx torch.randn(64, 10)y torch.randn(64, 1)optimizer.zero_grad()output model(x)loss loss_fn(output, y)loss.backward()optimizer.step()# 每10个epoch保存一次检查点if epoch % 10 0:checkpoint {model_state_dict: model.state_dict(),optimizer_state_dict: optimizer.state_dict(),epoch: epoch,loss: loss.item()}torch.save(checkpoint, fcheckpoint_epoch_{epoch}.pth)# 加载检查点并继续训练
checkpoint torch.load(checkpoint_epoch_10.pth)
model.load_state_dict(checkpoint[model_state_dict])
optimizer.load_state_dict(checkpoint[optimizer_state_dict])
start_epoch checkpoint[epoch]
loss checkpoint[loss]# 从第11个epoch开始继续训练
for epoch in range(start_epoch 1, num_epochs):# 继续训练pass