|
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList |
|
import torch |
|
|
|
model_name = r"C:\Users\HP-Victus\GVAIDAL\Norah" |
|
|
|
print("🔄 Loading tokenizer and model...") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, local_files_only=True, torch_dtype=torch.float16, device_map="auto") |
|
|
|
def format_prompt(user_input): |
|
return ( |
|
"Tu es un assistant IA utile et intelligent qui répond toujours en français avec des réponses courtes et claires.\n\n" |
|
"Utilisateur: Bonjour, comment vas-tu ?\n" |
|
"Assistant: Bonjour ! Je vais bien, merci. Comment puis-je vous aider ?\n\n" |
|
f"Utilisateur: {user_input}\n" |
|
"Assistant:" |
|
) |
|
|
|
|
|
prompt = format_prompt("Bonjour, comment puis-je vous aider aujourd'hui ?") |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
print("📝 Generating response...") |
|
|
|
|
|
bad_words = ["Jean", "Marie", "Bouchard", "Pierre", "Louis", "Antoine", "Jacques", "Robert", "Roper"] |
|
bad_words_ids = [tokenizer.encode(word, add_special_tokens=False) for word in bad_words] |
|
|
|
|
|
class StopOnSentenceEnd(StoppingCriteria): |
|
def __call__(self, input_ids, scores, **kwargs): |
|
stop_tokens = [tokenizer.encode(".", add_special_tokens=False)[0], |
|
tokenizer.encode("!", add_special_tokens=False)[0], |
|
tokenizer.encode("?", add_special_tokens=False)[0]] |
|
return any(input_ids[0, -1].item() == stop for stop in stop_tokens) |
|
|
|
stopping_criteria = StoppingCriteriaList([StopOnSentenceEnd()]) |
|
|
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_length=100, |
|
min_length=10, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
repetition_penalty=1.5, |
|
eos_token_id=model.config.eos_token_id, |
|
stopping_criteria=stopping_criteria |
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print("💬 Model Response:", response) |
|
|