Yjhhh commited on
Commit
64e4328
·
verified ·
1 Parent(s): d4cff2f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +372 -0
app.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ import uvicorn
3
+ import requests
4
+ import os
5
+ import io
6
+ import asyncio
7
+ from typing import List, Dict, Any
8
+ from tqdm import tqdm
9
+ from llama_cpp import Llama
10
+ import aiofiles
11
+ import time
12
+ from sklearn.feature_extraction.text import TfidfVectorizer
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+
15
+ app = FastAPI()
16
+
17
+ model_configs = [
18
+ {"repo_id": "Ffftdtd5dtft/gpt2-xl-Q2_K-GGUF", "filename": "gpt2-xl-q2_k.gguf", "name": "GPT-2 XL"},
19
+ {"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-8B-Instruct-Q2_K-GGUF", "filename": "meta-llama-3.1-8b-instruct-q2_k.gguf", "name": "Meta Llama 3.1-8B Instruct"},
20
+ {"repo_id": "Ffftdtd5dtft/gemma-2-9b-it-Q2_K-GGUF", "filename": "gemma-2-9b-it-q2_k.gguf", "name": "Gemma 2-9B IT"},
21
+ {"repo_id": "Ffftdtd5dtft/gemma-2-27b-Q2_K-GGUF", "filename": "gemma-2-27b-q2_k.gguf", "name": "Gemma 2-27B"},
22
+ {"repo_id": "Ffftdtd5dtft/Phi-3-mini-128k-instruct-Q2_K-GGUF", "filename": "phi-3-mini-128k-instruct-q2_k.gguf", "name": "Phi-3 Mini 128K Instruct"},
23
+ {"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-8B-Q2_K-GGUF", "filename": "meta-llama-3.1-8b-q2_k.gguf", "name": "Meta Llama 3.1-8B"},
24
+ {"repo_id": "Ffftdtd5dtft/Qwen2-7B-Instruct-Q2_K-GGUF", "filename": "qwen2-7b-instruct-q2_k.gguf", "name": "Qwen2 7B Instruct"},
25
+ {"repo_id": "Ffftdtd5dtft/starcoder2-3b-Q2_K-GGUF", "filename": "starcoder2-3b-q2_k.gguf", "name": "Starcoder2 3B"},
26
+ {"repo_id": "Ffftdtd5dtft/Qwen2-1.5B-Instruct-Q2_K-GGUF", "filename": "qwen2-1.5b-instruct-q2_k.gguf", "name": "Qwen2 1.5B Instruct"},
27
+ {"repo_id": "Ffftdtd5dtft/starcoder2-15b-Q2_K-GGUF", "filename": "starcoder2-15b-q2_k.gguf", "name": "Starcoder2 15B"},
28
+ {"repo_id": "Ffftdtd5dtft/gemma-2-2b-it-Q2_K-GGUF", "filename": "gemma-2-2b-it-q2_k.gguf", "name": "Gemma 2-2B IT"},
29
+ {"repo_id": "Ffftdtd5dtft/sarvam-2b-v0.5-Q2_K-GGUF", "filename": "sarvam-2b-v0.5-q2_k.gguf", "name": "Sarvam 2B v0.5"},
30
+ {"repo_id": "Ffftdtd5dtft/WizardLM-13B-Uncensored-Q2_K-GGUF", "filename": "wizardlm-13b-uncensored-q2_k.gguf", "name": "WizardLM 13B Uncensored"},
31
+ {"repo_id": "Ffftdtd5dtft/WizardLM-7B-Uncensored-Q2_K-GGUF", "filename": "wizardlm-7b-uncensored-q2_k.gguf", "name": "WizardLM 7B Uncensored"},
32
+ {"repo_id": "Ffftdtd5dtft/Qwen2-Math-7B-Instruct-Q2_K-GGUF", "filename": "qwen2-math-7b-instruct-q2_k.gguf", "name": "Qwen2 Math 7B Instruct"}
33
+ ]
34
+
35
+ models_dir = "modelos"
36
+ models = {}
37
+
38
+ class ModelManager:
39
+ def __init__(self):
40
+ self.model_parts = {}
41
+ self.load_lock = asyncio.Lock()
42
+ self.index_lock = asyncio.Lock()
43
+ self.part_size = 1024 * 1024
44
+
45
+ async def download_model(self, model_config):
46
+ model_path = os.path.join(models_dir, model_config['filename'])
47
+ if not os.path.exists(model_path):
48
+ url = f"https://huggingface.co/{model_config['repo_id']}/resolve/main/{model_config['filename']}"
49
+ print(f"Descargando modelo desde {url}")
50
+ try:
51
+ start_time = time.time()
52
+ response = requests.get(url, stream=True)
53
+ response.raise_for_status()
54
+
55
+ total_size = int(response.headers.get('content-length', 0))
56
+ with open(model_path, 'wb') as f:
57
+ with tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Descargando {model_config['filename']}") as pbar:
58
+ for chunk in response.iter_content(chunk_size=8192):
59
+ f.write(chunk)
60
+ pbar.update(len(chunk))
61
+ end_time = time.time()
62
+ download_duration = end_time - start_time
63
+ print(f"Descarga completa para {model_config['name']} en {download_duration:.2f} segundos")
64
+ except requests.RequestException as e:
65
+ raise HTTPException(status_code=500, detail=f"Error al descargar el modelo: {e}")
66
+ else:
67
+ print(f"Modelo {model_config['filename']} ya descargado.")
68
+ return model_path
69
+
70
+ async def download_all_models(self):
71
+ async with self.load_lock:
72
+ download_tasks = [self.download_model(config) for config in model_configs]
73
+ await asyncio.gather(*download_tasks)
74
+
75
+ async def load_all_models(self):
76
+ async with self.load_lock:
77
+ load_tasks = [self.load_model(config) for config in model_configs]
78
+ await asyncio.gather(*load_tasks)
79
+
80
+ async def load_model(self, model_config):
81
+ model_name = model_config['name']
82
+ if model_name not in models:
83
+ try:
84
+ model_path = os.path.join(models_dir, model_config['filename'])
85
+ start_time = time.time()
86
+ print(f"Cargando modelo desde {model_path}")
87
+
88
+ llama = Llama(model_path=model_path)
89
+
90
+ end_time = time.time()
91
+ load_duration = end_time - start_time
92
+ if load_duration > 0:
93
+ print(f"Modelo {model_name} tardó {load_duration:.2f} segundos en cargar")
94
+ else:
95
+ print(f"Modelo {model_name} cargado correctamente en {load_duration:.2f} segundos")
96
+
97
+ tokenizer = llama.tokenizer
98
+ models[model_name] = {
99
+ 'model': llama,
100
+ 'tokenizer': tokenizer,
101
+ }
102
+ except Exception as e:
103
+ print(f"Error al cargar el modelo: {e}")
104
+
105
+ async def generate_response(self, user_input, model_name=None, top_k=50, top_p=0.95, temperature=0.8):
106
+ results = []
107
+ if model_name:
108
+ model_data = models.get(model_name)
109
+ if not model_data:
110
+ return {"model_name": model_name, "error": "Modelo no encontrado"}
111
+ try:
112
+ tokenizer = model_data['tokenizer']
113
+ input_ids = tokenizer(user_input).input_ids
114
+ outputs = model_data['model'].generate(
115
+ [input_ids],
116
+ top_k=top_k,
117
+ top_p=top_p,
118
+ temperature=temperature
119
+ )
120
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
121
+ parts = []
122
+ while len(generated_text) > 1000:
123
+ part = generated_text[:1000]
124
+ parts.append(part)
125
+ generated_text = generated_text[1000:]
126
+ parts.append(generated_text)
127
+ results.append({
128
+ 'model_name': model_name,
129
+ 'generated_text': generated_text,
130
+ 'generated_text_parts': parts
131
+ })
132
+ except Exception as e:
133
+ return {'model_name': model_name, 'error': str(e)}
134
+ else:
135
+ for model_name, model_data in models.items():
136
+ try:
137
+ tokenizer = model_data['tokenizer']
138
+ input_ids = tokenizer(user_input).input_ids
139
+ outputs = model_data['model'].generate(
140
+ [input_ids],
141
+ top_k=top_k,
142
+ top_p=top_p,
143
+ temperature=temperature
144
+ )
145
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
146
+ parts = []
147
+ while len(generated_text) > 1000:
148
+ part = generated_text[:1000]
149
+ parts.append(part)
150
+ generated_text = generated_text[1000:]
151
+ parts.append(generated_text)
152
+ results.append({
153
+ 'model_name': model_name,
154
+ 'generated_text': generated_text,
155
+ 'generated_text_parts': parts
156
+ })
157
+ except Exception as e:
158
+ results.append({'model_name': model_name, 'error': str(e)})
159
+
160
+ if len(results) > 1:
161
+ best_response = self.choose_best_response(user_input, results)
162
+ elif len(results) == 1:
163
+ best_response = results[0]
164
+ else:
165
+ return {"model_name": "Error", "error": "No se pudo generar una respuesta con ningún modelo."}
166
+
167
+ return best_response
168
+
169
+ def choose_best_response(self, user_input, responses):
170
+ valid_responses = [r for r in responses if 'error' not in r]
171
+
172
+ tfidf = TfidfVectorizer()
173
+ response_texts = [r['generated_text'] for r in valid_responses]
174
+ tfidf_matrix = tfidf.fit_transform([user_input] + response_texts)
175
+ similarities = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:])
176
+
177
+ best_index = similarities.argmax()
178
+ best_response = valid_responses[best_index]
179
+
180
+ return best_response
181
+
182
+ @app.post("/generate/")
183
+ async def generate(request: Request):
184
+ data = await request.json()
185
+ user_input = data.get('input', '')
186
+ model_name = data.get('model')
187
+ top_k = data.get('top_k', 50)
188
+ top_p = data.get('top_p', 0.95)
189
+ temperature = data.get('temperature', 0.8)
190
+ if not user_input:
191
+ raise HTTPException(status_code=400, detail="Se requiere una entrada de usuario.")
192
+
193
+ try:
194
+ response = await model_manager.generate_response(user_input, model_name, top_k, top_p, temperature)
195
+ return {"response": response}
196
+ except Exception as e:
197
+ raise HTTPException(status_code=500, detail=str(e))
198
+
199
+ @app.get("/models")
200
+ async def get_available_models():
201
+ return {"models": [config['name'] for config in model_configs]}
202
+
203
+ async def load_models_on_startup():
204
+ global model_manager
205
+ model_manager = ModelManager()
206
+ await model_manager.download_all_models()
207
+ await model_manager.load_all_models()
208
+
209
+ @app.on_event("startup")
210
+ async def startup_event():
211
+ await load_models_on_startup()
212
+ print("Modelos cargados. API lista.")
213
+
214
+ if __name__ == "__main__":
215
+ if not os.path.exists(models_dir):
216
+ os.makedirs(models_dir)
217
+
218
+ uvicorn.run(app, host="0.0.0.0", port=7860)
219
+
220
+ html_code = """
221
+ <!DOCTYPE html>
222
+ <html>
223
+ <head>
224
+ <title>Chatbot</title>
225
+ <style>
226
+ body {
227
+ display: flex;
228
+ justify-content: center;
229
+ align-items: center;
230
+ height: 100vh;
231
+ margin: 0;
232
+ font-family: sans-serif;
233
+ }
234
+
235
+ .container {
236
+ border: 1px solid #ccc;
237
+ border-radius: 5px;
238
+ width: 400px;
239
+ height: 500px;
240
+ overflow: hidden;
241
+ }
242
+
243
+ .chat-log {
244
+ padding: 10px;
245
+ height: 400px;
246
+ overflow-y: scroll;
247
+ }
248
+
249
+ .chat-message {
250
+ margin-bottom: 10px;
251
+ padding: 8px;
252
+ border-radius: 5px;
253
+ }
254
+
255
+ .user-message {
256
+ background-color: #eee;
257
+ }
258
+
259
+ .bot-message {
260
+ background-color: #ccf;
261
+ }
262
+
263
+ .input-area {
264
+ display: flex;
265
+ padding: 10px;
266
+ }
267
+
268
+ #user-input {
269
+ flex: 1;
270
+ padding: 8px;
271
+ border: 1px solid #ccc;
272
+ border-radius: 5px;
273
+ }
274
+
275
+ #send-button {
276
+ padding: 8px 15px;
277
+ background-color: #4CAF50;
278
+ color: white;
279
+ border: none;
280
+ border-radius: 5px;
281
+ cursor: pointer;
282
+ margin-left: 10px;
283
+ }
284
+
285
+ #model-select {
286
+ width: 100%;
287
+ padding: 8px;
288
+ border: 1px solid #ccc;
289
+ border-radius: 5px;
290
+ margin-bottom: 10px;
291
+ }
292
+ </style>
293
+ </head>
294
+ <body>
295
+ <div class="container">
296
+ <div class="chat-log" id="chat-log">
297
+ </div>
298
+ <div class="input-area">
299
+ <input type="text" id="user-input" placeholder="Escribe tu mensaje...">
300
+ <button id="send-button">Enviar</button>
301
+ </div>
302
+ <select id="model-select">
303
+ <option value="">Todos los modelos</option>
304
+ </select>
305
+ </div>
306
+
307
+ <script>
308
+ const chatLog = document.getElementById('chat-log');
309
+ const userInput = document.getElementById('user-input');
310
+ const sendButton = document.getElementById('send-button');
311
+ const modelSelect = document.getElementById('model-select');
312
+ let currentConversationId = null;
313
+
314
+ async function startNewConversation() {
315
+ }
316
+
317
+ startNewConversation();
318
+
319
+ async function getAvailableModels() {
320
+ const response = await fetch('/models');
321
+ const data = await response.json();
322
+ return data.models;
323
+ }
324
+
325
+ async function displayAvailableModels() {
326
+ const models = await getAvailableModels();
327
+ models.forEach(model => {
328
+ const option = document.createElement('option');
329
+ option.value = model;
330
+ option.text = model;
331
+ modelSelect.add(option);
332
+ });
333
+ }
334
+
335
+ displayAvailableModels();
336
+
337
+ sendButton.addEventListener('click', async () => {
338
+ const userMessage = userInput.value;
339
+ userInput.value = '';
340
+ const selectedModel = modelSelect.value;
341
+
342
+ appendMessage('user', userMessage);
343
+
344
+ const response = await fetch('/generate/', {
345
+ method: 'POST',
346
+ headers: {
347
+ 'Content-Type': 'application/json'
348
+ },
349
+ body: JSON.stringify({ input: userMessage, model: selectedModel })
350
+ });
351
+
352
+ const data = await response.json();
353
+ if (data.response.error) {
354
+ appendMessage('bot', `Error del modelo ${data.response.model_name}: ${data.response.error}`);
355
+ } else {
356
+ data.response.generated_text_parts.forEach(part => {
357
+ appendMessage('bot', part);
358
+ });
359
+ }
360
+ });
361
+
362
+ function appendMessage(role, message) {
363
+ const messageElement = document.createElement('div');
364
+ messageElement.classList.add('chat-message', `${role}-message`);
365
+ messageElement.textContent = message;
366
+ chatLog.appendChild(messageElement);
367
+ chatLog.scrollTop = chatLog.scrollHeight;
368
+ }
369
+ </script>
370
+ </body>
371
+ </html>
372
+ """