|
import torch
|
|
from transformers import MllamaForConditionalGeneration, AutoProcessor
|
|
import os
|
|
import json
|
|
from safetensors import safe_open
|
|
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cross_attention_layers = [3, 8, 13, 18, 23, 28, 33, 38]
|
|
|
|
|
|
b8 = './models/v000000_L3-8B-Stheno-v3.2-abliterated'
|
|
|
|
print(b8)
|
|
|
|
model_id = "./models/meta-llama_Llama-3.2-11B-Vision-Instruct"
|
|
|
|
def create_layer_mapping(total_layers=32, cross_attn_layers=cross_attention_layers):
|
|
"""
|
|
Creates a mapping from llama-3.1-8b layer indices to llama-3.2-11b layer indices.
|
|
"""
|
|
mapping = {}
|
|
shift = 0
|
|
next_cross_attn_idx = 0
|
|
for X in range(total_layers):
|
|
|
|
if next_cross_attn_idx < len(cross_attn_layers) and (X + shift) == cross_attn_layers[next_cross_attn_idx]:
|
|
shift += 1
|
|
next_cross_attn_idx += 1
|
|
Y = X + shift
|
|
mapping[X] = Y
|
|
return mapping
|
|
|
|
def load_sharded_state_dict(model_dir):
|
|
index_file = os.path.join(model_dir, 'model.safetensors.index.json')
|
|
with open(index_file, 'r') as f:
|
|
index_data = json.load(f)
|
|
weight_map = index_data['weight_map']
|
|
state_dict = {}
|
|
shard_to_params = {}
|
|
for param_name, shard_file in weight_map.items():
|
|
if shard_file not in shard_to_params:
|
|
shard_to_params[shard_file] = []
|
|
shard_to_params[shard_file].append(param_name)
|
|
for shard_file, params_in_shard in shard_to_params.items():
|
|
shard_path = os.path.join(model_dir, shard_file)
|
|
with safe_open(shard_path, framework="pt", device="cpu") as f:
|
|
for name in params_in_shard:
|
|
state_dict[name] = f.get_tensor(name)
|
|
return state_dict
|
|
|
|
def compare_model_states(model, new_state_dict):
|
|
current_state = model.state_dict()
|
|
unchanged_params = []
|
|
changed_params = []
|
|
missing_params = []
|
|
|
|
for name, param in current_state.items():
|
|
if name not in new_state_dict:
|
|
missing_params.append(name)
|
|
elif torch.equal(param.cpu(), new_state_dict[name].cpu()):
|
|
unchanged_params.append(name)
|
|
else:
|
|
changed_params.append(name)
|
|
|
|
return {
|
|
'unchanged': unchanged_params,
|
|
'changed': changed_params,
|
|
'missing': missing_params
|
|
}
|
|
|
|
|
|
layer_mapping = create_layer_mapping()
|
|
|
|
|
|
llama_3_2_state_dict = load_sharded_state_dict(model_id)
|
|
|
|
|
|
llama_3_2_embeddings = llama_3_2_state_dict['language_model.model.embed_tokens.weight']
|
|
|
|
llama_3_2_state_dict.clear()
|
|
|
|
b8dict = load_sharded_state_dict(b8)
|
|
|
|
embed_tokens_weight = b8dict['model.embed_tokens.weight']
|
|
new_vocab_size = 128264
|
|
new_embed_tokens_weight = torch.zeros((new_vocab_size, 4096), dtype=embed_tokens_weight.dtype)
|
|
|
|
|
|
new_embed_tokens_weight[:128256, :] = embed_tokens_weight
|
|
|
|
new_embed_tokens_weight[128256:, :] = llama_3_2_embeddings[128256:, :]
|
|
|
|
b8dict['model.embed_tokens.weight'] = new_embed_tokens_weight
|
|
|
|
|
|
llama_3_2_embeddings = None
|
|
|
|
|
|
st8dict = {}
|
|
for name, param in b8dict.items():
|
|
|
|
if not re.match(r'model\.layers\.\d+\.', name):
|
|
new_name = 'language_model.' + name
|
|
else:
|
|
|
|
match = re.match(r'model\.layers\.(\d+)\.(.+)', name)
|
|
if match:
|
|
X = int(match.group(1))
|
|
suffix = match.group(2)
|
|
|
|
Y = layer_mapping.get(X, X + len(cross_attention_layers))
|
|
new_name = f'language_model.model.layers.{Y}.{suffix}'
|
|
else:
|
|
|
|
new_name = 'language_model.' + name
|
|
st8dict[new_name] = param
|
|
|
|
|
|
with open('st8dict.txt', 'w') as f:
|
|
f.write('\n'.join(st8dict.keys()))
|
|
|
|
|
|
model = MllamaForConditionalGeneration.from_pretrained(
|
|
model_id,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="cpu",
|
|
)
|
|
|
|
|
|
|
|
model.load_state_dict(st8dict, strict=False)
|
|
|
|
b8dict.clear()
|
|
st8dict.clear()
|
|
|
|
|
|
'''
|
|
result = compare_model_states(model, original_state)
|
|
|
|
print("Unchanged parameters:", len(result['unchanged']))
|
|
print("Changed parameters:", len(result['changed']))
|
|
print("Missing parameters:", len(result['missing']))
|
|
|
|
#write result to file
|
|
with open('result.txt', 'w') as f:
|
|
f.write(json.dumps(result, indent=2))
|
|
'''
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id)
|
|
|
|
|
|
model.save_pretrained("llama-3.2-11b-vision-stheno-abliterated")
|
|
|