chatbot / test_model.py
Deva1211's picture
Switched to resolving issues
c422049
#!/usr/bin/env python3
"""
Test script to validate Mistral-7B-Instruct AWQ model response generation
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def test_model():
print("Loading Mistral-7B-Instruct AWQ for testing...")
# Try AWQ model first, fallback to regular model if needed
try:
tokenizer = AutoTokenizer.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.2-AWQ")
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/Mistral-7B-Instruct-v0.2-AWQ",
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
print("✅ AWQ model loaded successfully!")
except Exception as e:
print(f"⚠️ AWQ model failed to load: {e}")
print("📦 Falling back to regular Mistral-7B-Instruct-v0.2...")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.2",
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
print("✅ Regular model loaded successfully!")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Model loaded successfully!")
# Test conversation
test_messages = [
"I feel sad today",
"What should I do?",
"Hello"
]
for i, message in enumerate(test_messages):
print(f"\n--- Test {i+1}: '{message}' ---")
# Use Mistral chat template format
messages = [
{"role": "user", "content": message}
]
# Apply chat template
conversation = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
input_ids = tokenizer.encode(conversation, return_tensors="pt")
# Generate response with settings optimized for Mistral AWQ
with torch.no_grad():
chat_history_ids = model.generate(
input_ids.to(model.device),
max_new_tokens=100,
no_repeat_ngram_size=2,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
temperature=0.9,
top_k=50,
top_p=0.9,
use_cache=True
)
# Decode response
response = tokenizer.decode(
chat_history_ids[:, input_ids.shape[-1]:][0],
skip_special_tokens=True
).strip()
print(f"Raw response: '{response}'")
print(f"Response length: {len(response)} characters")
if len(response) > 1:
print("✅ Good response generated")
else:
print("⚠️ Short/empty response")
print("\n✅ Model testing complete!")
if __name__ == "__main__":
test_model()