with torch.no_grad(): mf = 0.3 lr = 0.7 for param_old, param, m in zip(model_old.parameters(), model.parameters(), momentum): grad = param - param_old m.copy_(mf * m + grad) param.copy_(param_old + lr * (m + mf * grad)) param_old.copy_(param)