|
|
|
""" |
|
Stage 1 v2 Sharted Edition π©: Fast Multi-GPU Interpolation from Qwen3-32B to Qwen3-72B |
|
Optimized for 8x MI300X GPUs with parallel processing and sharted weight loading |
|
FIXED: Correct o_proj dimensions |
|
""" |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
import os |
|
import json |
|
from tqdm import tqdm |
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
|
from accelerate import init_empty_weights |
|
import numpy as np |
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor |
|
import gc |
|
from safetensors.torch import load_file, save_file |
|
import shutil |
|
|
|
|
|
|
|
SRC_HIDDEN_SIZE = 5120 |
|
SRC_INTERMEDIATE_SIZE = 25600 |
|
SRC_NUM_HEADS = 40 |
|
SRC_NUM_LAYERS = 64 |
|
|
|
|
|
|
|
|
|
SRC_Q_HEADS = 64 |
|
SRC_KV_HEADS = 8 |
|
|
|
|
|
TGT_HIDDEN_SIZE = 8192 |
|
TGT_INTERMEDIATE_SIZE = 29568 |
|
TGT_NUM_HEADS = 64 |
|
|
|
|
|
TGT_Q_HEADS = 64 |
|
TGT_KV_HEADS = 8 |
|
HEAD_DIM = 128 |
|
|
|
|
|
DELTA_HIDDEN = TGT_HIDDEN_SIZE - SRC_HIDDEN_SIZE |
|
DELTA_INTERMEDIATE = TGT_INTERMEDIATE_SIZE - SRC_INTERMEDIATE_SIZE |
|
|
|
OUTPUT_DIR = "./Qwen3-58B-Embiggened" |
|
|
|
|
|
NUM_GPUS = 8 |
|
BATCH_SIZE = 16 |
|
|
|
def get_layer_info(name): |
|
"""Extract layer number and component type from parameter name.""" |
|
if "model.layers." in name: |
|
parts = name.split(".") |
|
try: |
|
layer_idx = int(parts[2]) |
|
return layer_idx, ".".join(parts[3:]) |
|
except: |
|
return None, name |
|
return None, name |
|
|
|
def get_interpolation_weight(layer_idx, num_layers=SRC_NUM_LAYERS): |
|
"""Get interpolation weight based on layer depth.""" |
|
if layer_idx is None: |
|
return 0.5 |
|
|
|
relative_pos = layer_idx / (num_layers - 1) |
|
|
|
if relative_pos < 0.25: |
|
return 0.3 |
|
elif relative_pos < 0.75: |
|
return 0.5 |
|
else: |
|
return 0.7 |
|
|
|
@torch.jit.script |
|
def add_structured_noise_jit(tensor: torch.Tensor, noise_scale: float = 0.01) -> torch.Tensor: |
|
"""JIT-compiled structured noise addition.""" |
|
noise = torch.randn_like(tensor) * noise_scale * tensor.std() |
|
|
|
if tensor.ndim == 2 and tensor.shape[0] > 100 and tensor.shape[1] > 100: |
|
h, w = noise.shape |
|
center_mask = torch.ones_like(noise) |
|
center_mask[h//4:3*h//4, w//4:3*w//4] *= 0.5 |
|
noise *= center_mask |
|
|
|
return noise |
|
|
|
@torch.jit.script |
|
def preserve_norm_jit(original: torch.Tensor, interpolated: torch.Tensor) -> torch.Tensor: |
|
"""JIT-compiled norm preservation.""" |
|
original_norm = original.norm() |
|
interpolated_norm = interpolated.norm() |
|
|
|
if interpolated_norm > 0: |
|
scale_factor = original_norm / interpolated_norm |
|
return interpolated * scale_factor |
|
return interpolated |
|
|
|
def structure_aware_interpolation_gpu(block1, block2, weight=0.5, add_noise=True, device='cuda'): |
|
"""GPU-accelerated interpolation.""" |
|
|
|
if block1.device.type != 'cuda': |
|
block1 = block1.to(device) |
|
if block2.device.type != 'cuda': |
|
block2 = block2.to(device) |
|
|
|
|
|
interpolated = (1 - weight) * block1 + weight * block2 |
|
|
|
|
|
if add_noise: |
|
noise = add_structured_noise_jit(interpolated, 0.005) |
|
interpolated = interpolated + noise |
|
|
|
return interpolated |
|
|
|
def upscale_tensor_gpu(tensor: torch.Tensor, name: str, device='cuda') -> torch.Tensor: |
|
"""GPU-accelerated tensor upscaling with FIXED o_proj dimensions.""" |
|
|
|
tensor = tensor.to(device) |
|
|
|
layer_idx, component = get_layer_info(name) |
|
interp_weight = get_interpolation_weight(layer_idx) |
|
|
|
|
|
if "o_proj.weight" in name: |
|
print(f"\n[DEBUG] Processing {name}: input shape = {tensor.shape}") |
|
|
|
|
|
if tensor.ndim == 1: |
|
if tensor.shape[0] == SRC_HIDDEN_SIZE: |
|
block1, block2 = tensor[:DELTA_HIDDEN], tensor[-DELTA_HIDDEN:] |
|
interpolated = structure_aware_interpolation_gpu(block1, block2, weight=interp_weight, device=device) |
|
result = torch.cat([tensor, interpolated], dim=0) |
|
if "layernorm" in name: |
|
result = preserve_norm_jit(tensor, result) |
|
return result |
|
elif "k_norm" in name or "q_norm" in name: |
|
return tensor |
|
|
|
|
|
elif tensor.ndim == 2: |
|
|
|
if "embed_tokens" in name or "lm_head" in name: |
|
if tensor.shape[1] == SRC_HIDDEN_SIZE: |
|
block1, block2 = tensor[:, :DELTA_HIDDEN], tensor[:, -DELTA_HIDDEN:] |
|
interpolated = structure_aware_interpolation_gpu(block1, block2, weight=0.3, device=device) |
|
return torch.cat([tensor, interpolated], dim=1) |
|
|
|
|
|
elif "self_attn" in name: |
|
if "q_proj.weight" in name: |
|
|
|
|
|
|
|
block1, block2 = tensor[:, :DELTA_HIDDEN], tensor[:, -DELTA_HIDDEN:] |
|
interpolated = structure_aware_interpolation_gpu(block1, block2, weight=interp_weight, device=device) |
|
result = torch.cat([tensor, interpolated], dim=1) |
|
|
|
return preserve_norm_jit(tensor, result) |
|
|
|
elif "k_proj.weight" in name or "v_proj.weight" in name: |
|
|
|
|
|
block1, block2 = tensor[:, :DELTA_HIDDEN], tensor[:, -DELTA_HIDDEN:] |
|
interpolated = structure_aware_interpolation_gpu(block1, block2, weight=interp_weight, device=device) |
|
result = torch.cat([tensor, interpolated], dim=1) |
|
return preserve_norm_jit(tensor, result) |
|
|
|
elif "o_proj.weight" in name: |
|
|
|
|
|
|
|
|
|
print(f"\n[DEBUG] Processing {name}: input shape = {tensor.shape}") |
|
print(f"[DEBUG] Expected input: [5120, 8192], Expected output: [8192, 8192]") |
|
|
|
|
|
row_block1 = tensor[:DELTA_HIDDEN, :] |
|
row_block2 = tensor[-DELTA_HIDDEN:, :] |
|
row_interp = structure_aware_interpolation_gpu(row_block1, row_block2, weight=interp_weight, device=device) |
|
|
|
print(f"[DEBUG] row interpolation: block1={row_block1.shape}, block2={row_block2.shape}, interp={row_interp.shape}") |
|
|
|
result = torch.cat([tensor, row_interp], dim=0) |
|
|
|
print(f"[DEBUG] Final result: {result.shape}") |
|
|
|
assert result.shape == (TGT_HIDDEN_SIZE, TGT_HIDDEN_SIZE), f"o_proj shape error: got {result.shape}" |
|
|
|
return preserve_norm_jit(tensor, result) |
|
|
|
|
|
elif "mlp" in name: |
|
if "gate_proj.weight" in name or "up_proj.weight" in name: |
|
|
|
mlp_weight = min(interp_weight + 0.1, 0.8) |
|
|
|
|
|
row_block1, row_block2 = tensor[:DELTA_INTERMEDIATE, :], tensor[-DELTA_INTERMEDIATE:, :] |
|
upscaled_rows = torch.cat([tensor, structure_aware_interpolation_gpu(row_block1, row_block2, weight=mlp_weight, device=device)], dim=0) |
|
|
|
|
|
col_block1, col_block2 = upscaled_rows[:, :DELTA_HIDDEN], upscaled_rows[:, -DELTA_HIDDEN:] |
|
result = torch.cat([upscaled_rows, structure_aware_interpolation_gpu(col_block1, col_block2, weight=mlp_weight, device=device)], dim=1) |
|
|
|
result = preserve_norm_jit(tensor, result) |
|
return result |
|
|
|
elif "down_proj.weight" in name: |
|
|
|
mlp_weight = interp_weight |
|
|
|
|
|
row_block1, row_block2 = tensor[:DELTA_HIDDEN, :], tensor[-DELTA_HIDDEN:, :] |
|
upscaled_rows = torch.cat([tensor, structure_aware_interpolation_gpu(row_block1, row_block2, weight=mlp_weight, device=device)], dim=0) |
|
|
|
|
|
col_block1, col_block2 = upscaled_rows[:, :DELTA_INTERMEDIATE], upscaled_rows[:, -DELTA_INTERMEDIATE:] |
|
result = torch.cat([upscaled_rows, structure_aware_interpolation_gpu(col_block1, col_block2, weight=mlp_weight, device=device)], dim=1) |
|
|
|
return result |
|
|
|
return tensor |
|
|
|
def process_layer_batch(layer_tensors, device): |
|
"""Process a batch of tensors from the same layer on a specific GPU.""" |
|
processed = {} |
|
|
|
with torch.cuda.device(device): |
|
for name, tensor in layer_tensors: |
|
processed_tensor = upscale_tensor_gpu(tensor, name, device=device) |
|
|
|
processed[name] = processed_tensor.cpu() |
|
|
|
return processed |
|
|
|
def load_model_sharted(model_id): |
|
"""Load model weights from sharted safetensors files. π©""" |
|
print("\nπ© Loading sharted weights...") |
|
|
|
model_path = os.path.join(model_id, "model.safetensors.index.json") |
|
|
|
if os.path.exists(model_path): |
|
|
|
with open(model_path, 'r') as f: |
|
index = json.load(f) |
|
|
|
weight_map = index['weight_map'] |
|
unique_files = set(weight_map.values()) |
|
|
|
all_weights = {} |
|
for file in tqdm(unique_files, desc="Loading sharts"): |
|
file_path = os.path.join(model_id, file) |
|
weights = load_file(file_path) |
|
all_weights.update(weights) |
|
|
|
return all_weights |
|
else: |
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
print(f"Downloading model from HuggingFace: {model_id}") |
|
local_dir = snapshot_download(model_id) |
|
return load_model_sharted(local_dir) |
|
|
|
def save_model_sharted(state_dict, output_dir, max_shart_size="5GB"): |
|
"""Save model in sharted safetensors format. π©""" |
|
print("\nπ© Sharting model weights...") |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
size_map = {'GB': 1e9, 'MB': 1e6} |
|
for unit, multiplier in size_map.items(): |
|
if unit in max_shart_size: |
|
max_bytes = int(float(max_shart_size.replace(unit, '')) * multiplier) |
|
break |
|
|
|
|
|
sharts = [] |
|
current_shart = {} |
|
current_size = 0 |
|
|
|
for name, tensor in state_dict.items(): |
|
tensor_size = tensor.numel() * tensor.element_size() |
|
|
|
if current_size + tensor_size > max_bytes and current_shart: |
|
sharts.append(current_shart) |
|
current_shart = {} |
|
current_size = 0 |
|
|
|
current_shart[name] = tensor |
|
current_size += tensor_size |
|
|
|
if current_shart: |
|
sharts.append(current_shart) |
|
|
|
|
|
weight_map = {} |
|
for i, shart in enumerate(tqdm(sharts, desc="Saving sharts")): |
|
shart_name = f"model-{i+1:05d}-of-{len(sharts):05d}.safetensors" |
|
save_file(shart, os.path.join(output_dir, shart_name)) |
|
|
|
for name in shart: |
|
weight_map[name] = shart_name |
|
|
|
|
|
index = { |
|
"metadata": {"total_size": sum(t.numel() * t.element_size() for t in state_dict.values())}, |
|
"weight_map": weight_map |
|
} |
|
|
|
with open(os.path.join(output_dir, "model.safetensors.index.json"), 'w') as f: |
|
json.dump(index, f, indent=2) |
|
|
|
print(f"π© Successfully sharted into {len(sharts)} files!") |
|
|
|
def verify_architecture(model_path): |
|
"""Verify the model architecture matches expected dimensions.""" |
|
print("\n" + "="*60) |
|
print("ARCHITECTURE VERIFICATION") |
|
print("="*60) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16, |
|
device_map="cpu", |
|
trust_remote_code=True |
|
) |
|
|
|
expected = { |
|
"lm_head.weight": (151936, 8192), |
|
"model.embed_tokens.weight": (151936, 8192), |
|
"model.layers.0.input_layernorm.weight": (8192,), |
|
"model.layers.0.mlp.down_proj.weight": (8192, 29568), |
|
"model.layers.0.mlp.gate_proj.weight": (29568, 8192), |
|
"model.layers.0.mlp.up_proj.weight": (29568, 8192), |
|
"model.layers.0.post_attention_layernorm.weight": (8192,), |
|
"model.layers.0.self_attn.k_norm.weight": (128,), |
|
"model.layers.0.self_attn.k_proj.weight": (1024, 8192), |
|
"model.layers.0.self_attn.o_proj.weight": (8192, 8192), |
|
"model.layers.0.self_attn.q_norm.weight": (128,), |
|
"model.layers.0.self_attn.q_proj.weight": (8192, 8192), |
|
"model.layers.0.self_attn.v_proj.weight": (1024, 8192), |
|
"model.norm.weight": (8192,), |
|
} |
|
|
|
all_correct = True |
|
|
|
for name, expected_shape in expected.items(): |
|
param_dict = dict(model.named_parameters()) |
|
if name in param_dict: |
|
actual_shape = tuple(param_dict[name].shape) |
|
if actual_shape == expected_shape: |
|
print(f"β {name}: {actual_shape}") |
|
else: |
|
print(f"β {name}: {actual_shape} (expected {expected_shape})") |
|
all_correct = False |
|
else: |
|
print(f"β {name}: NOT FOUND") |
|
all_correct = False |
|
|
|
num_layers = model.config.num_hidden_layers |
|
print(f"\nNumber of layers: {num_layers} (Stage 1 should have 64)") |
|
|
|
if all_correct and num_layers == 64: |
|
print("\nβ
Architecture verification PASSED!") |
|
else: |
|
print("\nβ Architecture verification FAILED!") |
|
|
|
del model |
|
return all_correct |
|
|
|
def run_diagnostics(model_path): |
|
"""Run comprehensive diagnostics on the upscaled model.""" |
|
print("\n" + "="*60) |
|
print("COMPREHENSIVE DIAGNOSTICS") |
|
print("="*60) |
|
|
|
|
|
print("\nLoading model for diagnostics...") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
|
|
print("\nπ§ͺ Generation Quality Tests:") |
|
test_cases = [ |
|
("The capital of France is", ["Paris"]), |
|
("2 + 2 =", ["4", "four"]), |
|
("The quick brown fox", ["jumps", "jumped", "lazy", "dog"]), |
|
("Hello, my name is", None), |
|
("Water boils at", ["100", "212", "degrees"]), |
|
("The Earth orbits the", ["Sun", "solar"]), |
|
("Machine learning is a type of", ["artificial intelligence", "AI"]), |
|
("Python is a", ["programming", "language", "snake"]), |
|
("The largest planet is", ["Jupiter"]), |
|
("DNA stands for", ["deoxyribonucleic", "acid"]), |
|
] |
|
|
|
device = model.device |
|
coherent_count = 0 |
|
total_tests = len(test_cases) |
|
|
|
for prompt, expected in test_cases: |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=20, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_k=50, |
|
top_p=0.95, |
|
pad_token_id=tokenizer.pad_token_id, |
|
) |
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
generated_only = generated_text[len(prompt):].strip() |
|
|
|
print(f"\n Prompt: '{prompt}'") |
|
print(f" Generated: '{generated_only}'") |
|
|
|
|
|
is_coherent = True |
|
|
|
|
|
words = generated_only.split() |
|
if len(words) > 3: |
|
if len(set(words)) < len(words) / 2: |
|
print(" β οΈ High repetition detected") |
|
is_coherent = False |
|
|
|
|
|
if expected and len(generated_only) > 0: |
|
found = any(kw.lower() in generated_only.lower() for kw in expected) |
|
if found: |
|
print(" β Contains expected content") |
|
else: |
|
print(" β οΈ Missing expected keywords") |
|
is_coherent = False |
|
|
|
if is_coherent and len(generated_only.split()) >= 2: |
|
coherent_count += 1 |
|
|
|
coherence_rate = (coherent_count / total_tests) * 100 |
|
print(f"\nπ Overall coherence rate: {coherence_rate:.1f}%") |
|
|
|
|
|
print("\nπ Perplexity Test:") |
|
test_text = "The quick brown fox jumps over the lazy dog." |
|
inputs = tokenizer(test_text, return_tensors="pt").to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs, labels=inputs["input_ids"]) |
|
perplexity = torch.exp(outputs.loss).item() |
|
|
|
print(f" Perplexity: {perplexity:.2f}") |
|
|
|
if perplexity > 100: |
|
print(" β οΈ Very high perplexity") |
|
elif perplexity > 50: |
|
print(" β οΈ Moderately high perplexity") |
|
else: |
|
print(" β Reasonable perplexity") |
|
|
|
|
|
print("\nπ Weight Statistics (checking for anomalies):") |
|
anomalies = 0 |
|
|
|
for name, param in model.named_parameters(): |
|
if torch.isnan(param).any(): |
|
print(f" β οΈ {name}: Contains NaN!") |
|
anomalies += 1 |
|
elif torch.isinf(param).any(): |
|
print(f" β οΈ {name}: Contains Inf!") |
|
anomalies += 1 |
|
elif param.std() < 1e-8: |
|
print(f" β οΈ {name}: Zero variance!") |
|
anomalies += 1 |
|
|
|
if anomalies == 0: |
|
print(" β No anomalies detected in weights") |
|
|
|
|
|
success = coherence_rate >= 70 and perplexity < 100 and anomalies == 0 |
|
|
|
print("\n" + "="*60) |
|
print("DIAGNOSTIC SUMMARY") |
|
print("="*60) |
|
|
|
if success: |
|
print("β
Model passed all basic diagnostics!") |
|
print(" - Good coherence rate") |
|
print(" - Reasonable perplexity") |
|
print(" - No weight anomalies") |
|
else: |
|
print("β οΈ Some issues detected:") |
|
if coherence_rate < 70: |
|
print(f" - Low coherence rate: {coherence_rate:.1f}%") |
|
if perplexity >= 100: |
|
print(f" - High perplexity: {perplexity:.2f}") |
|
if anomalies > 0: |
|
print(f" - Weight anomalies: {anomalies}") |
|
|
|
return success |
|
|
|
def main(): |
|
print("="*60) |
|
print("Stage 1 v2 SHARTED π©: Multi-GPU Accelerated Interpolation") |
|
print("Qwen3-32B β 72B Dimensions") |
|
print(f"Using {NUM_GPUS} GPUs for parallel processing") |
|
print("FIXED: Correct o_proj dimensions") |
|
print("="*60) |
|
|
|
source_model_id = "Qwen/Qwen3-32B" |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.set_device(0) |
|
print(f"\nπ CUDA available: {torch.cuda.device_count()} devices") |
|
for i in range(min(NUM_GPUS, torch.cuda.device_count())): |
|
print(f" GPU {i}: {torch.cuda.get_device_name(i)}") |
|
|
|
|
|
print(f"\nπ Loading tokenizer from: {source_model_id}") |
|
tokenizer = AutoTokenizer.from_pretrained(source_model_id, trust_remote_code=True) |
|
|
|
|
|
print(f"\nβ‘ Loading model weights using fast sharted loading...") |
|
source_weights = load_model_sharted(source_model_id) |
|
|
|
print(f"\nπ Loaded {len(source_weights)} tensors from sharts") |
|
|
|
|
|
layer_groups = {} |
|
other_tensors = [] |
|
|
|
for name, tensor in source_weights.items(): |
|
layer_idx, _ = get_layer_info(name) |
|
if layer_idx is not None: |
|
if layer_idx not in layer_groups: |
|
layer_groups[layer_idx] = [] |
|
layer_groups[layer_idx].append((name, tensor)) |
|
else: |
|
other_tensors.append((name, tensor)) |
|
|
|
print(f"\nπ§ Processing tensors across {NUM_GPUS} GPUs...") |
|
print(" - Parallel layer processing") |
|
print(" - JIT-compiled operations") |
|
print(" - Efficient memory management") |
|
print(" - Sharted weight I/O π©") |
|
|
|
new_state_dict = {} |
|
|
|
|
|
with tqdm(total=len(source_weights), desc="Upscaling tensors") as pbar: |
|
|
|
layer_indices = sorted(layer_groups.keys()) |
|
|
|
for i in range(0, len(layer_indices), NUM_GPUS): |
|
batch_futures = [] |
|
|
|
|
|
for j, layer_idx in enumerate(layer_indices[i:i+NUM_GPUS]): |
|
gpu_id = j % NUM_GPUS |
|
device = f'cuda:{gpu_id}' |
|
|
|
|
|
layer_tensors = layer_groups[layer_idx] |
|
processed = process_layer_batch(layer_tensors, device) |
|
new_state_dict.update(processed) |
|
pbar.update(len(layer_tensors)) |
|
|
|
|
|
if j % 4 == 0: |
|
torch.cuda.empty_cache() |
|
|
|
|
|
for name, tensor in other_tensors: |
|
device = 'cuda:0' |
|
new_tensor = upscale_tensor_gpu(tensor, name, device=device).cpu() |
|
new_state_dict[name] = new_tensor |
|
pbar.update(1) |
|
|
|
|
|
del source_weights |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
print("\nπ Creating target model configuration...") |
|
config = AutoConfig.from_pretrained(source_model_id, trust_remote_code=True) |
|
config.hidden_size = TGT_HIDDEN_SIZE |
|
config.intermediate_size = TGT_INTERMEDIATE_SIZE |
|
config.num_attention_heads = TGT_NUM_HEADS |
|
config.torch_dtype = torch.bfloat16 |
|
|
|
|
|
print("\nπ Quick verification of tensor dimensions BEFORE saving:") |
|
|
|
|
|
critical_checks = [ |
|
"model.layers.0.self_attn.q_proj.weight", |
|
"model.layers.0.self_attn.k_proj.weight", |
|
"model.layers.0.self_attn.v_proj.weight", |
|
"model.layers.0.self_attn.o_proj.weight", |
|
"model.layers.0.mlp.gate_proj.weight" |
|
] |
|
|
|
for check_name in critical_checks: |
|
for name, tensor in new_state_dict.items(): |
|
if check_name in name: |
|
print(f" {name}: {tensor.shape}") |
|
break |
|
|
|
|
|
print("\nπ― Verifying ALL o_proj dimensions:") |
|
o_proj_issue = False |
|
for name, tensor in new_state_dict.items(): |
|
if "o_proj.weight" in name: |
|
if tensor.shape != (TGT_HIDDEN_SIZE, TGT_HIDDEN_SIZE): |
|
print(f" β {name}: {tensor.shape} - INCORRECT!") |
|
o_proj_issue = True |
|
else: |
|
if "layer.0" in name or "layer.63" in name: |
|
print(f" β {name}: {tensor.shape}") |
|
|
|
if o_proj_issue: |
|
print("\nβ ERROR: o_proj dimensions are incorrect! Not saving model.") |
|
return False |
|
|
|
|
|
print(f"\nπΎ Saving model to: {OUTPUT_DIR}") |
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
config.save_pretrained(OUTPUT_DIR) |
|
tokenizer.save_pretrained(OUTPUT_DIR) |
|
|
|
|
|
save_model_sharted(new_state_dict, OUTPUT_DIR) |
|
|
|
|
|
for file in ['generation_config.json', 'tokenizer_config.json', 'special_tokens_map.json']: |
|
src = os.path.join(source_model_id, file) |
|
dst = os.path.join(OUTPUT_DIR, file) |
|
if os.path.exists(src): |
|
shutil.copy(src, dst) |
|
|
|
|
|
metadata = { |
|
"stage": "1-v2-sharted", |
|
"source_model": source_model_id, |
|
"method": "gpu_accelerated_structure_aware_interpolation_sharted", |
|
"num_gpus_used": NUM_GPUS, |
|
"fixes": [ |
|
"Corrected o_proj dimensions to 8192x8192", |
|
"Proper handling of GQA architecture" |
|
], |
|
"optimizations": [ |
|
"Multi-GPU parallel processing", |
|
"JIT-compiled operations", |
|
"Sharted weight loading/saving π©", |
|
"Efficient memory management" |
|
], |
|
"sharting_info": { |
|
"format": "safetensors", |
|
"max_shart_size": "5GB", |
|
"poop_emoji": "π©" |
|
} |
|
} |
|
|
|
with open(os.path.join(OUTPUT_DIR, "stage1_v2_metadata.json"), "w") as f: |
|
json.dump(metadata, f, indent=2) |
|
|
|
print("\nβ
Stage 1 v2 SHARTED interpolation complete! π©") |
|
print(f"π Model saved to: {OUTPUT_DIR}") |
|
|
|
|
|
arch_ok = verify_architecture(OUTPUT_DIR) |
|
diag_ok = run_diagnostics(OUTPUT_DIR) |
|
|
|
if arch_ok and diag_ok: |
|
print("\nπ SUCCESS! Enhanced sharted interpolation completed successfully. π©") |
|
print(f"π Model saved to: {OUTPUT_DIR}") |
|
print("\nπ Ready for Stage 2: Layer duplication (64β80 layers)") |
|
else: |
|
print("\nβ οΈ Some issues detected. Review the diagnostics above.") |
|
|
|
return arch_ok and diag_ok |
|
|
|
if __name__ == "__main__": |
|
success = main() |
|
exit(0 if success else 1) |