Spaces:
Sleeping
Sleeping
import os | |
import json | |
import torch | |
import argparse | |
import logging | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') | |
logger = logging.getLogger(__name__) | |
def debug_weights_structure(weights_path): | |
"""Examine the structure of the weights file to help debug loading issues""" | |
weights = torch.load(weights_path, map_location="cpu") | |
logger.info(f"Type of loaded weights: {type(weights)}") | |
if isinstance(weights, dict): | |
logger.info(f"Top-level keys: {list(weights.keys())}") | |
# Print a few sample keys to understand the structure | |
sample_keys = list(weights.keys())[:5] | |
for key in sample_keys: | |
logger.info(f"Sample key structure: {key} -> {type(weights[key])}") | |
return weights | |
def main(): | |
parser = argparse.ArgumentParser(description="Run inference with a TEQ-quantized model") | |
parser.add_argument("--model_dir", type=str, default=".", | |
help="Directory containing quantized model files") | |
parser.add_argument("--weights_file", type=str, default="quantized_weight.pt", | |
help="Name of the quantized weights file") | |
parser.add_argument("--config_file", type=str, default="qconfig.json", | |
help="Name of the quantization config file") | |
parser.add_argument("--base_model", type=str, required=True, | |
help="Original model name or path (for tokenizer and model architecture)") | |
parser.add_argument("--prompt", type=str, default="Once upon a time, a little girl", | |
help="Text prompt for inference") | |
parser.add_argument("--max_new_tokens", type=int, default=100, | |
help="Maximum number of new tokens to generate") | |
parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda", "xpu"], | |
help="Device to run inference on") | |
parser.add_argument("--output_file", type=str, default=None, | |
help="File to save the generated text to (optional)") | |
parser.add_argument("--debug", action="store_true", | |
help="Print additional debug information") | |
args = parser.parse_args() | |
# Set up paths | |
weights_path = os.path.join(args.model_dir, args.weights_file) | |
config_path = os.path.join(args.model_dir, args.config_file) | |
# Check if files exist | |
if not os.path.exists(weights_path): | |
raise FileNotFoundError(f"Quantized weights file not found: {weights_path}") | |
if not os.path.exists(config_path): | |
raise FileNotFoundError(f"Quantization config file not found: {config_path}") | |
# Load tokenizer | |
logger.info(f"Loading tokenizer from {args.base_model}...") | |
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True) | |
# Examine the structure of the weights file | |
logger.info(f"Analyzing weights structure from {weights_path}...") | |
weights = debug_weights_structure(weights_path) | |
# Load the base model directly (bypassing TEQ quantization) | |
logger.info(f"Loading base model from {args.base_model}...") | |
model = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True) | |
# Print model's state_dict keys for debugging | |
if args.debug: | |
model_keys = list(model.state_dict().keys()) | |
logger.info(f"Model has {len(model_keys)} keys in state_dict") | |
logger.info(f"Sample model keys: {model_keys[:5]}") | |
# Check if weights contains 'state_dict' key and adjust accordingly | |
if 'state_dict' in weights: | |
logger.info("Found 'state_dict' key in weights file, extracting it...") | |
weights = weights['state_dict'] | |
# Try to match the weights to the model structure | |
try: | |
# First attempt: Direct loading | |
logger.info("Attempting to load weights directly...") | |
missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False) | |
if missing_keys: | |
logger.warning(f"Missing {len(missing_keys)} keys in state_dict") | |
if args.debug: | |
logger.warning(f"Sample missing keys: {missing_keys[:5]}") | |
if unexpected_keys: | |
logger.warning(f"Found {len(unexpected_keys)} unexpected keys in state_dict") | |
if args.debug: | |
logger.warning(f"Sample unexpected keys: {unexpected_keys[:5]}") | |
# Validate if we have critical missing keys | |
if len(missing_keys) > len(model.state_dict()) * 0.5: | |
logger.error("Too many missing keys! Weight loading may have failed") | |
except Exception as e: | |
logger.error(f"Error loading weights: {str(e)}") | |
logger.info("Attempting to transform keys to match model structure...") | |
# Create a transformed state_dict | |
transformed_weights = {} | |
# Try removing 'module.' prefix | |
for key in weights: | |
if key.startswith('module.'): | |
transformed_weights[key[7:]] = weights[key] | |
else: | |
transformed_weights[key] = weights[key] | |
# Try loading the transformed weights | |
missing_keys, unexpected_keys = model.load_state_dict(transformed_weights, strict=False) | |
logger.info(f"After transformation: {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys") | |
# Put model in evaluation mode | |
model.eval() | |
# Move model to specified device | |
device = args.device | |
logger.info(f"Moving model to {device}...") | |
model = model.to(device) | |
# Optimize with IPEX if using Intel hardware | |
if device == "xpu": | |
try: | |
import intel_extension_for_pytorch as ipex | |
logger.info("Optimizing model with IPEX...") | |
model = ipex.optimize(model, dtype=torch.float16) | |
except ImportError: | |
logger.warning("IPEX not available, skipping optimization") | |
# Run inference | |
logger.info(f"Generating text for prompt: '{args.prompt}'") | |
inputs = tokenizer(args.prompt, return_tensors="pt").to(device) | |
# Generate text | |
with torch.no_grad(): | |
output_ids = model.generate( | |
inputs["input_ids"], | |
max_new_tokens=args.max_new_tokens, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
) | |
# Decode the generated tokens | |
generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] | |
logger.info("\nGenerated text:") | |
logger.info("-" * 50) | |
logger.info(generated_text) | |
logger.info("-" * 50) | |
# Save to file if specified | |
if args.output_file: | |
with open(args.output_file, 'w') as f: | |
f.write(generated_text) | |
logger.info(f"Generated text saved to {args.output_file}") | |
if __name__ == "__main__": | |
main() | |