Norah / test_norah.py
Visdom9's picture
Pushing fine-tuned Norah model
3254881
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
import torch
model_name = r"C:\Users\HP-Victus\GVAIDAL\Norah" # Use full path
print("🔄 Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(model_name, local_files_only=True, torch_dtype=torch.float16, device_map="auto")
def format_prompt(user_input):
return (
"Tu es un assistant IA utile et intelligent qui répond toujours en français avec des réponses courtes et claires.\n\n"
"Utilisateur: Bonjour, comment vas-tu ?\n"
"Assistant: Bonjour ! Je vais bien, merci. Comment puis-je vous aider ?\n\n"
f"Utilisateur: {user_input}\n"
"Assistant:"
)
# Test conversation
prompt = format_prompt("Bonjour, comment puis-je vous aider aujourd'hui ?")
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cuda" if torch.cuda.is_available() else "cpu")
print("📝 Generating response...")
# Encode names to block
bad_words = ["Jean", "Marie", "Bouchard", "Pierre", "Louis", "Antoine", "Jacques", "Robert", "Roper"]
bad_words_ids = [tokenizer.encode(word, add_special_tokens=False) for word in bad_words]
# Stopping criteria: Stop at sentence completion
class StopOnSentenceEnd(StoppingCriteria):
def __call__(self, input_ids, scores, **kwargs):
stop_tokens = [tokenizer.encode(".", add_special_tokens=False)[0],
tokenizer.encode("!", add_special_tokens=False)[0],
tokenizer.encode("?", add_special_tokens=False)[0]]
return any(input_ids[0, -1].item() == stop for stop in stop_tokens)
stopping_criteria = StoppingCriteriaList([StopOnSentenceEnd()])
# Generate response
outputs = model.generate(
**inputs,
max_length=100, # Allows complete sentences
min_length=10, # Ensures at least some response
do_sample=True, # Allows varied responses
temperature=0.7, # More natural responses
top_p=0.9, # Higher probability for relevant words
repetition_penalty=1.5, # Prevents repetition but keeps coherence
eos_token_id=model.config.eos_token_id,
stopping_criteria=stopping_criteria # Ensures sentence completion
)
# Decode and display response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("💬 Model Response:", response)