dofbi commited on
Commit
af1cf71
·
1 Parent(s): 9a04ccb
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -1,23 +1,35 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import gradio as gr
 
 
 
 
 
3
 
4
  # Charger le modèle
5
  model_name = "soynade-research/Oolel-v0.1"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
8
 
9
  # Fonction pour générer une réponse
10
  def generate_response(user_input, max_new_tokens=150, temperature=0.7):
11
- inputs = tokenizer(user_input, return_tensors="pt").to("cuda")
 
 
12
  outputs = model.generate(inputs.input_ids, max_new_tokens=max_new_tokens, temperature=temperature)
 
13
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
14
 
15
  # Interface Gradio
16
  iface = gr.Interface(
17
  fn=generate_response,
18
- inputs=[gr.Textbox(label="Message utilisateur"), gr.Slider(50, 500, value=150, label="Nombre max de tokens")],
 
 
 
19
  outputs="text",
20
  title="Oolel Chatbot"
21
  )
22
 
 
23
  iface.launch()
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import gradio as gr
3
+ import torch
4
+
5
+ # Vérifier si CUDA est disponible et configurer le périphérique
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ print(f"Utilisation du périphérique : {device}")
8
 
9
  # Charger le modèle
10
  model_name = "soynade-research/Oolel-v0.1"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
13
 
14
  # Fonction pour générer une réponse
15
  def generate_response(user_input, max_new_tokens=150, temperature=0.7):
16
+ # Préparer l'entrée pour le modèle
17
+ inputs = tokenizer(user_input, return_tensors="pt").to(device)
18
+ # Générer une réponse
19
  outputs = model.generate(inputs.input_ids, max_new_tokens=max_new_tokens, temperature=temperature)
20
+ # Décoder la réponse en texte
21
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
22
 
23
  # Interface Gradio
24
  iface = gr.Interface(
25
  fn=generate_response,
26
+ inputs=[
27
+ gr.Textbox(label="Message utilisateur"),
28
+ gr.Slider(50, 500, value=150, label="Nombre max de tokens")
29
+ ],
30
  outputs="text",
31
  title="Oolel Chatbot"
32
  )
33
 
34
+ # Lancer l'interface
35
  iface.launch()