import torch from transformers import MllamaForConditionalGeneration, AutoProcessor import os import json from safetensors import safe_open import re # apologies in advance for shitty gpt-assisted code # this script should also work with 70b/90b if you change `cross_attention_layers` and `total_layers` accordingly # but i dont have enough deditated wam to test it and i dont feel like spinning up runpod so cross_attention_layers = [3, 8, 13, 18, 23, 28, 33, 38] #b8 = './models/mlabonne_Meta-Llama-3.1-8B-Instruct-abliterated' b8 = './models/v000000_L3-8B-Stheno-v3.2-abliterated' #b8 = './models/arcee-ai_Llama-3.1-SuperNova-Lite' print(b8) model_id = "./models/meta-llama_Llama-3.2-11B-Vision-Instruct" def create_layer_mapping(total_layers=32, cross_attn_layers=cross_attention_layers): """ Creates a mapping from llama-3.1-8b layer indices to llama-3.2-11b layer indices. """ mapping = {} shift = 0 next_cross_attn_idx = 0 for X in range(total_layers): # Check if a cross-attention layer is inserted before this layer if next_cross_attn_idx < len(cross_attn_layers) and (X + shift) == cross_attn_layers[next_cross_attn_idx]: shift += 1 next_cross_attn_idx += 1 Y = X + shift mapping[X] = Y return mapping def load_sharded_state_dict(model_dir): index_file = os.path.join(model_dir, 'model.safetensors.index.json') with open(index_file, 'r') as f: index_data = json.load(f) weight_map = index_data['weight_map'] state_dict = {} shard_to_params = {} for param_name, shard_file in weight_map.items(): if shard_file not in shard_to_params: shard_to_params[shard_file] = [] shard_to_params[shard_file].append(param_name) for shard_file, params_in_shard in shard_to_params.items(): shard_path = os.path.join(model_dir, shard_file) with safe_open(shard_path, framework="pt", device="cpu") as f: for name in params_in_shard: state_dict[name] = f.get_tensor(name) return state_dict def compare_model_states(model, new_state_dict): current_state = model.state_dict() unchanged_params = [] changed_params = [] missing_params = [] for name, param in current_state.items(): if name not in new_state_dict: missing_params.append(name) elif torch.equal(param.cpu(), new_state_dict[name].cpu()): unchanged_params.append(name) else: changed_params.append(name) return { 'unchanged': unchanged_params, 'changed': changed_params, 'missing': missing_params } layer_mapping = create_layer_mapping() # Load Llama 3.2 state dict llama_3_2_state_dict = load_sharded_state_dict(model_id) # Extract the embedding matrix from Llama 3.2 llama_3_2_embeddings = llama_3_2_state_dict['language_model.model.embed_tokens.weight'] # Shape: [128264, 4096] llama_3_2_state_dict.clear() b8dict = load_sharded_state_dict(b8) embed_tokens_weight = b8dict['model.embed_tokens.weight'] # Shape: [128256, 4096] new_vocab_size = 128264 # From Llama 3.2 new_embed_tokens_weight = torch.zeros((new_vocab_size, 4096), dtype=embed_tokens_weight.dtype) # Copy the existing embeddings new_embed_tokens_weight[:128256, :] = embed_tokens_weight # Copy the additional embeddings from Llama 3.2 new_embed_tokens_weight[128256:, :] = llama_3_2_embeddings[128256:, :] b8dict['model.embed_tokens.weight'] = new_embed_tokens_weight llama_3_2_embeddings = None # Adjust Llama 3.1 parameter names to match Llama 3.2 language model st8dict = {} for name, param in b8dict.items(): # Prefix non-layer parameters with 'language_model.' if not re.match(r'model\.layers\.\d+\.', name): new_name = 'language_model.' + name else: # Extract the layer index X from 'model.layers.X.' match = re.match(r'model\.layers\.(\d+)\.(.+)', name) if match: X = int(match.group(1)) suffix = match.group(2) # Get the corresponding Y in llama-3.2-11b Y = layer_mapping.get(X, X + len(cross_attention_layers)) new_name = f'language_model.model.layers.{Y}.{suffix}' else: # If the pattern doesn't match, just prefix with 'language_model.' new_name = 'language_model.' + name st8dict[new_name] = param #write st8dict keys to file for verification with open('st8dict.txt', 'w') as f: f.write('\n'.join(st8dict.keys())) model = MllamaForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="cpu", ) #original_state = {k: v.clone() for k, v in model.state_dict().items()} model.load_state_dict(st8dict, strict=False) b8dict.clear() st8dict.clear() ''' result = compare_model_states(model, original_state) print("Unchanged parameters:", len(result['unchanged'])) print("Changed parameters:", len(result['changed'])) print("Missing parameters:", len(result['missing'])) #write result to file with open('result.txt', 'w') as f: f.write(json.dumps(result, indent=2)) ''' processor = AutoProcessor.from_pretrained(model_id) model.save_pretrained("llama-3.2-11b-vision-stheno-abliterated")