Yhhxhfh commited on
Commit
f4907db
1 Parent(s): 77c2378

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -128
app.py CHANGED
@@ -1,45 +1,50 @@
1
  import os
2
- import logging
3
- import asyncio
4
- import uvicorn
5
  import torch
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
- from fastapi import FastAPI, Query, HTTPException
8
  from fastapi.responses import HTMLResponse
 
 
 
9
 
10
- # Configuración de logging
11
- logging.basicConfig(level=logging.DEBUG)
12
- logger = logging.getLogger(__name__)
13
 
14
- # Inicializar la aplicación FastAPI
15
- app = FastAPI()
 
16
 
17
- # Diccionario para almacenar los modelos
18
- data_and_models_dict = {}
19
 
20
- # Lista para almacenar el historial de mensajes
21
- message_history = []
22
 
23
- # Función para cargar modelos
24
- async def load_models():
25
- gpt_models = ["gpt2-medium", "gpt2-large", "gpt2"]
26
- for model_name in gpt_models:
27
- try:
28
- model = AutoModelForCausalLM.from_pretrained(model_name)
29
- tokenizer = AutoTokenizer.from_pretrained(model_name)
30
- logger.info(f"Successfully loaded {model_name} model")
31
- return model, tokenizer
32
- except Exception as e:
33
- logger.error(f"Failed to load GPT-2 model: {e}")
34
- raise HTTPException(status_code=500, detail="Failed to load any models")
35
-
36
- # Función para descargar modelos
37
- async def download_models():
38
- model, tokenizer = await load_models()
39
- data_and_models_dict['model'] = (model, tokenizer)
 
 
 
 
 
 
40
 
41
  @app.get('/')
42
- async def main():
43
  html_code = """
44
  <!DOCTYPE html>
45
  <html lang="en">
@@ -48,151 +53,191 @@ async def main():
48
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
49
  <title>ChatGPT Chatbot</title>
50
  <style>
51
- body, html {
52
- height: 100%;
53
  margin: 0;
54
  padding: 0;
55
- font-family: Arial, sans-serif;
56
  }
57
  .container {
58
- height: 100%;
59
- display: flex;
60
- flex-direction: column;
61
- justify-content: center;
62
- align-items: center;
63
  }
64
  .chat-container {
65
- border-radius: 10px;
66
- overflow: hidden;
67
  box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
68
- width: 100%;
69
- height: 100%;
 
70
  }
71
  .chat-box {
72
- height: calc(100% - 60px);
73
  overflow-y: auto;
74
  padding: 10px;
75
  }
76
  .chat-input {
77
- width: calc(100% - 100px);
78
- padding: 10px;
79
  border: none;
80
- border-top: 1px solid #ccc;
 
81
  font-size: 16px;
 
82
  }
83
- .input-container {
84
- display: flex;
85
- align-items: center;
86
- justify-content: space-between;
87
- padding: 10px;
88
- background-color: #f5f5f5;
89
- border-top: 1px solid #ccc;
90
- width: 100%;
91
  }
92
- button {
93
- padding: 10px;
94
- border: none;
95
- cursor: pointer;
96
  background-color: #007bff;
97
  color: #fff;
98
- font-size: 16px;
99
- }
100
- .user-message {
101
- background-color: #cce5ff;
102
- border-radius: 5px;
103
- align-self: flex-end;
104
  max-width: 70%;
105
- margin-left: auto;
106
- margin-right: 10px;
107
- margin-bottom: 10px;
108
  }
109
  .bot-message {
110
- background-color: #d1ecf1;
111
- border-radius: 5px;
112
- align-self: flex-start;
 
 
113
  max-width: 70%;
 
 
 
 
 
 
114
  margin-bottom: 10px;
115
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  </style>
117
  </head>
118
  <body>
119
  <div class="container">
120
- <div class="chat-container">
 
121
  <div class="chat-box" id="chat-box"></div>
122
- <div class="input-container">
123
- <input type="text" class="chat-input" id="user-input" placeholder="Escribe un mensaje...">
124
- <button onclick="sendMessage()">Enviar</button>
125
- </div>
 
 
 
126
  </div>
127
  </div>
128
  <script>
129
- const chatBox = document.getElementById('chat-box');
130
- const userInput = document.getElementById('user-input');
 
 
131
 
132
  function saveMessage(sender, message) {
 
133
  const messageElement = document.createElement('div');
134
- messageElement.textContent = `${sender}: ${message}`;
135
- messageElement.classList.add(`${sender}-message`);
136
- chatBox.appendChild(messageElement);
137
- userInput.value = '';
138
  }
139
 
140
- async function sendMessage() {
141
- const userMessage = userInput.value.trim();
142
- if (!userMessage) return;
143
-
144
- saveMessage('user', userMessage);
145
- await fetch(`/autocomplete?q=${userMessage}`)
146
- .then(response => response.text())
147
- .then(data => {
148
- saveMessage('bot', data);
149
- chatBox.scrollTop = chatBox.scrollHeight;
150
- })
151
- .catch(error => console.error('Error:', error));
152
  }
153
 
154
- userInput.addEventListener("keyup", function(event) {
 
 
155
  if (event.keyCode === 13) {
156
  event.preventDefault();
157
  sendMessage();
158
  }
159
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  </script>
161
  </body>
162
  </html>
163
  """
164
  return HTMLResponse(content=html_code, status_code=200)
165
 
166
- # Ruta para la generación de respuestas
167
  @app.get('/autocomplete')
168
- async def autocomplete(q: str = Query(...)):
169
- global data_and_models_dict, message_history
170
-
171
- # Verificar si hay modelos cargados
172
- if 'model' not in data_and_models_dict:
173
- await download_models()
174
-
175
- # Obtener el modelo
176
- model, tokenizer = data_and_models_dict['model']
177
-
178
- # Guardar el mensaje del usuario en el historial
179
- message_history.append(q)
180
-
181
- # Generar una respuesta utilizando el modelo
182
- input_ids = tokenizer.encode(q, return_tensors="pt")
183
- output = model.generate(input_ids, max_length=50, num_return_sequences=1)
184
- response_text = tokenizer.decode(output[0], skip_special_tokens=True)
185
-
186
- # Guardar la respuesta en el historial
187
- message_history.append(response_text)
188
-
189
- return response_text
190
-
191
- # Función para ejecutar la aplicación sin reiniciarla
192
- def run_app():
193
- asyncio.run(download_models())
194
- uvicorn.run(app, host='0.0.0.0', port=7860)
195
-
196
- # Ejecutar la aplicación
197
- if __name__ == "__main__":
198
- run_app()
 
1
  import os
2
+ import sys
 
 
3
  import torch
4
+ import uvicorn
5
+ from fastapi import FastAPI, Query
6
  from fastapi.responses import HTMLResponse
7
+ from starlette.middleware.cors import CORSMiddleware
8
+ from transformers import AutoTokenizer, pipeline, GPT2LMHeadModel
9
+ from loguru import logger
10
 
11
+ sys.path.append('..')
 
 
12
 
13
+ # Use finetuned GPT model
14
+ current_dir = os.path.dirname(os.path.realpath(__file__))
15
+ text_file_path = os.path.join(current_dir, 'xfa.txt')
16
 
17
+ with open(text_file_path, 'r') as file:
18
+ model_names = [line.strip() for line in file.readlines()]
19
 
20
+ models_dict = {}
 
21
 
22
+ # Detect and load necessary models
23
+ for name in model_names:
24
+ try:
25
+ model = GPT2LMHeadModel.from_pretrained(name)
26
+ tokenizer = AutoTokenizer.from_pretrained(name)
27
+ models_dict[name] = {
28
+ 'model': model,
29
+ 'tokenizer': tokenizer
30
+ }
31
+ except Exception as e:
32
+ logger.error(f"Error loading model {name}: {e}")
33
+
34
+ app = FastAPI()
35
+ app.add_middleware(
36
+ CORSMiddleware,
37
+ allow_origins=["*"],
38
+ allow_credentials=True,
39
+ allow_methods=["*"],
40
+ allow_headers=["*"]
41
+ )
42
+
43
+ # Global variable to store the messages history
44
+ message_history = []
45
 
46
  @app.get('/')
47
+ async def index():
48
  html_code = """
49
  <!DOCTYPE html>
50
  <html lang="en">
 
53
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
54
  <title>ChatGPT Chatbot</title>
55
  <style>
56
+ body {
57
+ font-family: Arial, sans-serif;
58
  margin: 0;
59
  padding: 0;
60
+ background-color: #f4f4f4;
61
  }
62
  .container {
63
+ max-width: 800px;
64
+ margin: auto;
65
+ padding: 20px;
 
 
66
  }
67
  .chat-container {
68
+ background-color: #fff;
69
+ border-radius: 8px;
70
  box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
71
+ overflow: hidden;
72
+ margin-bottom: 20px;
73
+ animation: fadeInUp 0.5s ease forwards;
74
  }
75
  .chat-box {
76
+ height: 300px;
77
  overflow-y: auto;
78
  padding: 10px;
79
  }
80
  .chat-input {
81
+ width: calc(100% - 20px);
 
82
  border: none;
83
+ border-top: 1px solid #ddd;
84
+ padding: 10px;
85
  font-size: 16px;
86
+ outline: none;
87
  }
88
+ .chat-input:focus {
89
+ border-top: 1px solid #007bff;
 
 
 
 
 
 
90
  }
91
+ .user-message {
92
+ margin-bottom: 10px;
93
+ padding: 8px 12px;
94
+ border-radius: 8px;
95
  background-color: #007bff;
96
  color: #fff;
 
 
 
 
 
 
97
  max-width: 70%;
98
+ word-wrap: break-word;
99
+ align-self: flex-end;
 
100
  }
101
  .bot-message {
102
+ margin-bottom: 10px;
103
+ padding: 8px 12px;
104
+ border-radius: 8px;
105
+ background-color: #4CAF50;
106
+ color: #fff;
107
  max-width: 70%;
108
+ word-wrap: break-word;
109
+ }
110
+ .toggle-history {
111
+ text-align: center;
112
+ cursor: pointer;
113
+ color: #007bff;
114
  margin-bottom: 10px;
115
  }
116
+ .history-container {
117
+ display: none;
118
+ }
119
+ .history-container.show {
120
+ display: block;
121
+ }
122
+ .history-container .history-content {
123
+ max-height: 200px;
124
+ overflow-y: auto;
125
+ }
126
+ @keyframes fadeInUp {
127
+ from {
128
+ opacity: 0;
129
+ transform: translateY(20px);
130
+ }
131
+ to {
132
+ opacity: 1;
133
+ transform: translateY(0);
134
+ }
135
+ }
136
  </style>
137
  </head>
138
  <body>
139
  <div class="container">
140
+ <h1 style="text-align: center;">ChatGPT Chatbot</h1>
141
+ <div class="chat-container" id="chat-container">
142
  <div class="chat-box" id="chat-box"></div>
143
+ <input type="text" class="chat-input" id="user-input" placeholder="Type your message...">
144
+ <button onclick="retryLastMessage()">Retry Last Message</button>
145
+ </div>
146
+ <div class="toggle-history" onclick="toggleHistory()">Toggle History</div>
147
+ <div class="history-container" id="history-container">
148
+ <h2>Chat History</h2>
149
+ <div class="history-content" id="history-content"></div>
150
  </div>
151
  </div>
152
  <script>
153
+ function toggleHistory() {
154
+ const historyContainer = document.getElementById('history-container');
155
+ historyContainer.classList.toggle('show');
156
+ }
157
 
158
  function saveMessage(sender, message) {
159
+ const historyContent = document.getElementById('history-content');
160
  const messageElement = document.createElement('div');
161
+ messageElement.className = `${sender}-message`;
162
+ messageElement.innerText = message;
163
+ historyContent.appendChild(messageElement);
 
164
  }
165
 
166
+ function appendMessage(sender, message) {
167
+ const chatBox = document.getElementById('chat-box');
168
+ const messageElement = document.createElement('div');
169
+ messageElement.className = `${sender}-message`;
170
+ messageElement.innerText = message;
171
+ chatBox.appendChild(messageElement);
 
 
 
 
 
 
172
  }
173
 
174
+ const userInput = document.getElementById('user-input');
175
+
176
+ userInput.addEventListener('keyup', function(event) {
177
  if (event.keyCode === 13) {
178
  event.preventDefault();
179
  sendMessage();
180
  }
181
  });
182
+
183
+ function sendMessage() {
184
+ const userMessage = userInput.value.trim();
185
+ if (userMessage === '') return;
186
+
187
+ saveMessage('user', userMessage);
188
+ appendMessage('user', userMessage);
189
+ userInput.value = '';
190
+
191
+ fetch(`/autocomplete?q=${encodeURIComponent(userMessage)}`)
192
+ .then(response => response.json())
193
+ .then(data => {
194
+ const botMessages = data.result;
195
+ botMessages.forEach(message => {
196
+ saveMessage('bot', message);
197
+ appendMessage('bot', message);
198
+ });
199
+ })
200
+ .catch(error => {
201
+ console.error('Error:', error);
202
+ });
203
+ }
204
+
205
+ function retryLastMessage() {
206
+ const lastUserMessage = document.querySelector('.user-message:last-of-type');
207
+ if (lastUserMessage) {
208
+ userInput.value = lastUserMessage.innerText;
209
+ }
210
+ }
211
  </script>
212
  </body>
213
  </html>
214
  """
215
  return HTMLResponse(content=html_code, status_code=200)
216
 
 
217
  @app.get('/autocomplete')
218
+ async def autocomplete(q: str = Query(..., title='query')):
219
+ global message_history
220
+ message_history.append(('user', q))
221
+
222
+ try:
223
+ # Use combined models for responses
224
+ generated_responses = []
225
+ for model_name, model_info in models_dict.items():
226
+ model = model_info['model']
227
+ tokenizer = model_info['tokenizer']
228
+ text_generation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
229
+ generated_response = text_generation_pipeline(q, do_sample=True, num_return_sequences=5)
230
+ generated_responses.extend([response['generated_text'] for response in generated_response])
231
+ message_history.extend([('bot', response['generated_text']) for response in generated_response])
232
+
233
+ logger.debug(f"Successfully autocomplete, q:{q}, res:{generated_responses}")
234
+
235
+ # Find the response closest to the question
236
+ closest_response = min(generated_responses, key=lambda x: abs(len(x) - len(q)))
237
+
238
+ return {"result": [closest_response]}
239
+ except Exception as e:
240
+ logger.error(f"Ignored error in autocomplete: {e}")
241
+
242
+ if __name__ == '__main__':
243
+ uvicorn.run(app=app, host='0.0.0.0', port=int(os.getenv("PORT", 8001)))