InternVL3-78B-AWQ / pad_tensors.py
intervitens's picture
Upload folder using huggingface_hub
2995564 verified
import torch
from torch.nn import functional as F
from safetensors.torch import load_file, save_file
pad_size = 128 # Specific to Qwen2-72B architecture
total_shards = 33 # Total number of shards in the model, edit according to the actual files
for shard_idx in range(1, total_shards + 1):
# Generate filename with zero-padded shard numbers
filename = f"model-{shard_idx:05d}-of-{total_shards:05d}.safetensors"
# Load shard
state_dict = load_file(filename)
modified = False
# Process each tensor in the current shard
for key in list(state_dict.keys()):
tensor = state_dict[key]
if 'mlp.up_proj.weight' in key or 'mlp.gate_proj.weight' in key:
# Apply interleaving pattern for up/gate projections
prev_tensor = F.pad(tensor.unsqueeze(1), (0, 0, 0, 1, 0, 0)).reshape(29568*2, -1)[:pad_size*2]
new_tensor = torch.cat([prev_tensor, tensor[pad_size:]], dim=0)
state_dict[key] = new_tensor
modified = True
elif 'mlp.down_proj.weight' in key:
# Apply pattern for down projection
prev_tensor = F.pad(tensor.unsqueeze(2), (0, 1)).reshape(8192, 29568*2)[:, :pad_size*2]
new_tensor = torch.cat([prev_tensor, tensor[:, pad_size:]], dim=1)
state_dict[key] = new_tensor
modified = True
# Save modified shard back to original file if changes were made
if modified:
save_file(state_dict, filename, metadata={"format": "pt"})
print(f"Processed and saved {filename}")
else:
print(f"No modifications needed for {filename}")