AIdeaText commited on
Commit
6a9fc93
1 Parent(s): eea761e

Update modules/chatbot.py

Browse files
Files changed (1) hide show
  1. modules/chatbot.py +56 -22
modules/chatbot.py CHANGED
@@ -1,25 +1,59 @@
1
- import requests
2
- import os
3
- from dotenv import load_dotenv
4
-
5
- # Cargar variables de entorno
6
- load_dotenv()
7
-
8
- class Llama2Chatbot:
9
- def __init__(self):
10
- self.API_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-2-7b-hf"
11
- api_key = os.getenv("HF_API_KEY")
12
- if not api_key:
13
- raise ValueError("No se encontró la clave de API de Hugging Face. Asegúrate de configurar la variable de entorno HF_API_KEY.")
14
- self.headers = {"Authorization": f"Bearer {api_key}"}
15
-
16
- def generate_response(self, prompt):
17
- payload = {"inputs": prompt}
18
- response = requests.post(self.API_URL, headers=self.headers, json=payload)
19
- return response.json()[0]['generated_text']
20
 
21
  def initialize_chatbot():
22
- return Llama2Chatbot()
 
 
23
 
24
- def get_chatbot_response(chatbot, prompt):
25
- return chatbot.generate_response(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
2
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def initialize_chatbot():
5
+ model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
6
+ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
7
+ return model, tokenizer
8
 
9
+ def get_chatbot_response(model, tokenizer, prompt, src_lang):
10
+ tokenizer.src_lang = src_lang
11
+ encoded_input = tokenizer(prompt, return_tensors="pt")
12
+ generated_tokens = model.generate(**encoded_input, max_length=100)
13
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
14
+
15
+ def display_chatbot_interface(lang_code):
16
+ translations = {
17
+ 'es': {
18
+ 'title': "AIdeaText - Chatbot Multilingüe",
19
+ 'input_placeholder': "Escribe tu mensaje aquí...",
20
+ 'send_button': "Enviar",
21
+ },
22
+ 'en': {
23
+ 'title': "AIdeaText - Multilingual Chatbot",
24
+ 'input_placeholder': "Type your message here...",
25
+ 'send_button': "Send",
26
+ },
27
+ 'fr': {
28
+ 'title': "AIdeaText - Chatbot Multilingue",
29
+ 'input_placeholder': "Écrivez votre message ici...",
30
+ 'send_button': "Envoyer",
31
+ }
32
+ }
33
+
34
+ t = translations[lang_code]
35
+
36
+ st.header(t['title'])
37
+
38
+ if 'chatbot' not in st.session_state:
39
+ st.session_state.chatbot, st.session_state.tokenizer = initialize_chatbot()
40
+
41
+ if 'messages' not in st.session_state:
42
+ st.session_state.messages = []
43
+
44
+ for message in st.session_state.messages:
45
+ with st.chat_message(message["role"]):
46
+ st.markdown(message["content"])
47
+
48
+ if prompt := st.chat_input(t['input_placeholder']):
49
+ st.session_state.messages.append({"role": "user", "content": prompt})
50
+ with st.chat_message("user"):
51
+ st.markdown(prompt)
52
+
53
+ with st.chat_message("assistant"):
54
+ response = get_chatbot_response(st.session_state.chatbot, st.session_state.tokenizer, prompt, lang_code)
55
+ st.markdown(response)
56
+ st.session_state.messages.append({"role": "assistant", "content": response})
57
+
58
+ # Guardar la conversación en la base de datos
59
+ store_chat_history(st.session_state.username, st.session_state.messages)