El-Alberto67 commited on
Commit
e1f5b51
·
verified ·
1 Parent(s): 2be00c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -5
app.py CHANGED
@@ -3,29 +3,47 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
 
4
  MODEL = "prithivMLmods/Llama-SmolTalk-3.2-1B-Instruct"
5
 
 
6
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
7
- model = AutoModelForCausalLM.from_pretrained(MODEL, device_map="auto")
8
 
 
 
 
 
 
 
 
 
9
  chatbot = pipeline(
10
  "text-generation",
11
  model=model,
12
  tokenizer=tokenizer,
13
- device_map="auto",
14
  )
15
 
 
16
  system_prompt = "Tu es Aria, une IA bienveillante et polie qui répond de façon concise et claire."
17
 
 
18
  def chat(message, history=[]):
19
  context = "\n".join([f"Utilisateur: {m[0]}\nAria: {m[1]}" for m in history])
20
  prompt = f"{system_prompt}\n{context}\nUtilisateur: {message}\nAria:"
21
 
22
- resp = chatbot(prompt, max_new_tokens=150, do_sample=True, temperature=0.7)[0]["generated_text"]
23
- reply = resp.split("Aria:")[-1].strip()
 
 
 
 
 
 
 
24
 
 
25
  history.append([message, reply])
26
  return history, history
27
 
28
-
29
  with gr.Blocks() as demo:
30
  chat_ui = gr.Chatbot()
31
  msg = gr.Textbox(placeholder="Écris un message...")
 
3
 
4
  MODEL = "prithivMLmods/Llama-SmolTalk-3.2-1B-Instruct"
5
 
6
+ # Charger le tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
 
8
 
9
+ # Charger le modèle en 8 bits pour accélérer
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ MODEL,
12
+ device_map="auto",
13
+ load_in_8bit=True # optimisation vitesse / mémoire
14
+ )
15
+
16
+ # Pipeline avec paramètres par défaut optimisés
17
  chatbot = pipeline(
18
  "text-generation",
19
  model=model,
20
  tokenizer=tokenizer,
21
+ device_map="auto"
22
  )
23
 
24
+ # Prompt système
25
  system_prompt = "Tu es Aria, une IA bienveillante et polie qui répond de façon concise et claire."
26
 
27
+ # Fonction de chat optimisée
28
  def chat(message, history=[]):
29
  context = "\n".join([f"Utilisateur: {m[0]}\nAria: {m[1]}" for m in history])
30
  prompt = f"{system_prompt}\n{context}\nUtilisateur: {message}\nAria:"
31
 
32
+ # Paramètres réduits pour accélérer la génération
33
+ resp = chatbot(
34
+ prompt,
35
+ max_new_tokens=60, # Limite pour réduire temps de calcul
36
+ do_sample=True,
37
+ temperature=0.7,
38
+ top_p=0.9,
39
+ repetition_penalty=1.1
40
+ )[0]["generated_text"]
41
 
42
+ reply = resp.split("Aria:")[-1].strip()
43
  history.append([message, reply])
44
  return history, history
45
 
46
+ # Interface Gradio
47
  with gr.Blocks() as demo:
48
  chat_ui = gr.Chatbot()
49
  msg = gr.Textbox(placeholder="Écris un message...")