有高并发 高访问量网站开发,阿里云网站空间主机,网站开发开题报告计划进度安排,汉中微信网站建设服务由于项目需要训练一个主干网络接多个分支的模型#xff0c;所以先训练一个主干网络加第一个分支#xff0c;再用另外的数据训练第二个分支#xff0c;训练的过程中需要冻结主干网络部分#xff0c;后面的分支训练过程也一样需要冻结主干网络部分。 冻结模型的方式
for nam… 由于项目需要训练一个主干网络接多个分支的模型所以先训练一个主干网络加第一个分支再用另外的数据训练第二个分支训练的过程中需要冻结主干网络部分后面的分支训练过程也一样需要冻结主干网络部分。 冻结模型的方式
for name, para in model.named_parameters():# 冻结backbone的权重if name.split(.)[0] backbone:para.requires_grad False # 或者用para.requires_grad_(False)一个是通过属性直接赋值一个是通过函数赋值else:para.requires_grad True# 可以打印需要更新梯度的参数
for name, value in model.named_parameters():print(name, \t更新梯度,value.requires_grad)坑1这样做并不能冻结batchnorm层的参数所以需要在训练中手动冻结。如
def fix_bn(m):classname m.__class__.__name__if classname.find(SyncBatchNorm) ! -1 or classname.find(InstanceNorm2d) ! -1 or classname.find(BatchNorm2d) ! -1: #SyncBatchNorm, InstanceNorm2dif m.num_features in [32, 64, 96, 128, 256, 384, 768, 192, 1152, 224]: # 需要冻结的BN层的通道数m.eval()def train():for epoch in range(max_epoch):model.train()if args.freeze:model.apply(fix_bn)model.backbone[5][0].block[0][1].eval() # 假如需要冻结的BN层通道数和不需要冻结的BN层通道数一样则需要单独写for batch_idx, (data, target) in enumerate(train_loader):...
坑2用了冻结训练freeze就不要用EMA方式更新模型了不然收敛缓慢不说还会造成前面冻结的参数产生变化可以从EMA的代码看出端倪
class EMA:def __init__(self, model, decay0.9999):super().__init__()import copyself.decay decayself.model copy.deepcopy(model)self.model.eval()def update_fn(self, model, fn):with torch.no_grad():e_std self.model.state_dict().values()#m_std model.module.state_dict().values() # multi-gpum_std model.state_dict().values() # single-gpufor e, m in zip(e_std, m_std):e.copy_(fn(e, m))def update(self, model):self.update_fn(model, fnlambda e, m: self.decay * e (1. - self.decay) * m)可以看出EMA的方式更新模型方式大部分是结合上一个模型的参数的即
model_update decay*model(t-1) (1-decay)*model(t) # model(t-1) 代表上一次迭代模型的参数model(t)代表当前迭代得到的模型参数虽然冻结了backbone的参数阻止了梯度在backbone中反向传播但参数由于经过如上乘法及加法运算由于精度原因还是会发生微小变化虽然训练次数增加这个变化会扩大从而达不到冻结训练的效果。而且从计算公式可以看出来采用EMA的方式更新模型参数参数会更新得很慢会造成网络难以学习的“错觉”。我在这里困住了3天有怀疑过是否是网络设计问题是否是多GPU同步的问题是否是参数设置如学习率过小权重衰减过大或者dropout设置过大等等最终一步一步排除定位到EMA的问题。 以这次的经验来看EMA只适合在上一次训练得到模型的基础上这一次加了额外的数据需要在上一次的基础上做微调的情况。