from safetensors import safe_open from safetensors.torch import save_file import torch def rename_key(key): parts = key.split('.') if 'roberta' in parts: parts.remove('roberta') if 'parametrizations' in parts: parts.remove('parametrizations') if 'weight' in parts and 'original' in parts: parts.remove('original') if 'encoder.layers' in key: parts[parts.index('layers')] = 'layer' if 'mixer' in parts: parts[parts.index('mixer')] = 'attention' if 'Wqkv' in parts: parts[parts.index('Wqkv')] = 'qkv_proj' if 'out_proj' in parts: parts[parts.index('out_proj')] = 'o_proj' if 'mlp.fc1' in key: parts[parts.index('fc1')] = 'up_proj' if 'mlp.fc2' in key: parts[parts.index('fc2')] = 'down_proj' if 'emb_ln' in parts: parts[parts.index('emb_ln')] = 'LayerNorm' parts.insert(0, 'embeddings') if 'norm1' in parts: parts[parts.index('norm1')] = 'attn_ln' if 'norm2' in parts: parts[parts.index('norm2')] = 'mlp_ln' if 'weight' in parts: if parts[-2] in ['attn_ln', 'mlp_ln', 'LayerNorm']: parts[-1] = 'gamma' if 'bias' in parts: if parts[-2] in ['attn_ln', 'mlp_ln', 'LayerNorm']: parts[-1] = 'beta' return '.'.join(parts) input_file = "original_model.safetensors" output_file = "model.safetensors" new_tensors = {} with safe_open(input_file, framework="pt", device="cpu") as f: for key in f.keys(): if 'lora' not in key: new_key = rename_key(key) tensor = f.get_tensor(key) if 'mlp.up_proj' in new_key: # Create up_proj and up_gate_proj new_tensors[new_key] = tensor gate_key = new_key.replace('up_proj', 'up_gate_proj') # Expand the tensor to match the expected shape expanded_tensor = torch.cat([tensor] * 2, dim=0) new_tensors[gate_key] = expanded_tensor else: new_tensors[new_key] = tensor save_file(new_tensors, output_file) print(f"Renamed tensors saved to {output_file}") # Inspect the renamed tensors with safe_open(output_file, framework="pt", device="cpu") as f: print("\nRenamed tensors:") for key in f.keys(): print(f"{key}: {f.get_tensor(key).shape}")