mrcuddle commited on
Commit
b7b912a
·
verified ·
1 Parent(s): c4a29da

Update merge.py

Browse files
Files changed (1) hide show
  1. merge.py +12 -10
merge.py CHANGED
@@ -23,7 +23,7 @@ def merge_tensors(tensor1: torch.Tensor, tensor2: torch.Tensor, p: float) -> tor
23
  torch.Tensor: The merged tensor.
24
  """
25
  delta = tensor2 - tensor1
26
- m = torch.from_numpy(np.random.binomial(1, p, delta.shape)).to(tensor1.dtype)
27
  delta_tilde = m * delta
28
  delta_hat = delta_tilde / (1 - p)
29
  return delta_hat
@@ -42,16 +42,17 @@ def merge_safetensors(file_path1: str, file_path2: str, p: float, lambda_val: fl
42
  dict: A dictionary of merged tensors.
43
  """
44
  merged_tensors = {}
 
45
  with safe_open(file_path1, framework="pt", device="cpu") as f1, safe_open(file_path2, framework="pt", device="cpu") as f2:
46
  keys1 = set(f1.keys())
47
  keys2 = set(f2.keys())
48
  common_keys = keys1.intersection(keys2)
49
 
50
  for key in common_keys:
51
- tensor1 = f1.get_tensor(key)
52
- tensor2 = f2.get_tensor(key)
53
  tensor1, tensor2 = resize_tensors(tensor1, tensor2)
54
- merged_tensors[key] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p)
55
  logging.info(f"Merging {key}")
56
 
57
  return merged_tensors
@@ -131,6 +132,7 @@ def merge_folder(tensor_map: dict, directory_path: str, p: float, lambda_val: fl
131
  """
132
  keys1 = set(tensor_map.keys())
133
  ext = None
 
134
  for filename in os.listdir(directory_path):
135
  if filename.endswith(".safetensors"):
136
  ext = ".safetensors"
@@ -146,14 +148,14 @@ def merge_folder(tensor_map: dict, directory_path: str, p: float, lambda_val: fl
146
  common_keys = keys1.intersection(keys2)
147
  for key in common_keys:
148
  if "block_sparse_moe.gate" in key:
149
- tensor1 = tensor_map[key]['tensor']
150
- tensor2 = f.get_tensor(key)
151
  tensor_map[key]['tensor'] = (tensor1 + tensor2) / 2.0
152
  continue
153
- tensor1 = tensor_map[key]['tensor']
154
- tensor2 = f.get_tensor(key)
155
  tensor1, tensor2 = resize_tensors(tensor1, tensor2)
156
- tensor_map[key]['tensor'] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p)
157
  return tensor_map
158
 
159
  def map_tensors_to_files(directory_path: str) -> dict:
@@ -238,4 +240,4 @@ def main():
238
  save_file(merged, args.output_model)
239
 
240
  if __name__ == '__main__':
241
- main()
 
23
  torch.Tensor: The merged tensor.
24
  """
25
  delta = tensor2 - tensor1
26
+ m = torch.from_numpy(np.random.binomial(1, p, delta.shape)).to(tensor1.device)
27
  delta_tilde = m * delta
28
  delta_hat = delta_tilde / (1 - p)
29
  return delta_hat
 
42
  dict: A dictionary of merged tensors.
43
  """
44
  merged_tensors = {}
45
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46
  with safe_open(file_path1, framework="pt", device="cpu") as f1, safe_open(file_path2, framework="pt", device="cpu") as f2:
47
  keys1 = set(f1.keys())
48
  keys2 = set(f2.keys())
49
  common_keys = keys1.intersection(keys2)
50
 
51
  for key in common_keys:
52
+ tensor1 = f1.get_tensor(key).to(device)
53
+ tensor2 = f2.get_tensor(key).to(device)
54
  tensor1, tensor2 = resize_tensors(tensor1, tensor2)
55
+ merged_tensors[key] = (tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p)).cpu()
56
  logging.info(f"Merging {key}")
57
 
58
  return merged_tensors
 
132
  """
133
  keys1 = set(tensor_map.keys())
134
  ext = None
135
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
136
  for filename in os.listdir(directory_path):
137
  if filename.endswith(".safetensors"):
138
  ext = ".safetensors"
 
148
  common_keys = keys1.intersection(keys2)
149
  for key in common_keys:
150
  if "block_sparse_moe.gate" in key:
151
+ tensor1 = tensor_map[key]['tensor'].to(device)
152
+ tensor2 = f.get_tensor(key).to(device)
153
  tensor_map[key]['tensor'] = (tensor1 + tensor2) / 2.0
154
  continue
155
+ tensor1 = tensor_map[key]['tensor'].to(device)
156
+ tensor2 = f.get_tensor(key).to(device)
157
  tensor1, tensor2 = resize_tensors(tensor1, tensor2)
158
+ tensor_map[key]['tensor'] = (tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p)).cpu()
159
  return tensor_map
160
 
161
  def map_tensors_to_files(directory_path: str) -> dict:
 
240
  save_file(merged, args.output_model)
241
 
242
  if __name__ == '__main__':
243
+ main()