DDT / src /utils /copy.py
wangshuai6
init space
9e426da
raw
history blame contribute delete
362 Bytes
import torch
@torch.no_grad()
def copy_params(src_model, dst_model):
for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()):
dst_param.data.copy_(src_param.data)
@torch.no_grad()
def swap_tensors(tensor1, tensor2):
tmp = torch.empty_like(tensor1)
tmp.copy_(tensor1)
tensor1.copy_(tensor2)
tensor2.copy_(tmp)