Spaces:
Build error
Build error
Hjgugugjhuhjggg
commited on
Commit
•
1cb967f
1
Parent(s):
84d1dae
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
import gc
|
2 |
import psutil
|
3 |
import os
|
4 |
-
import time
|
5 |
import torch
|
6 |
from fastapi import FastAPI
|
7 |
-
from
|
8 |
from chatgptcache import cache
|
9 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
10 |
from sklearn.metrics.pairwise import cosine_similarity
|
@@ -15,13 +14,14 @@ from collections import Counter
|
|
15 |
import asyncio
|
16 |
import torch.nn.utils.prune as prune
|
17 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
18 |
|
19 |
nltk.download('punkt')
|
20 |
nltk.download('stopwords')
|
21 |
|
22 |
app = FastAPI()
|
23 |
|
24 |
-
# Definir los modelos (serán cargados más tarde)
|
25 |
model_1 = None
|
26 |
model_2 = None
|
27 |
model_3 = None
|
@@ -37,12 +37,10 @@ previous_responses_2 = []
|
|
37 |
previous_responses_3 = []
|
38 |
previous_responses_4 = []
|
39 |
|
40 |
-
MAX_TOKENS = 2048
|
41 |
|
42 |
-
# Usar ThreadPoolExecutor para ejecución en paralelo
|
43 |
executor = ThreadPoolExecutor(max_workers=4)
|
44 |
|
45 |
-
# Configuración del dispositivo (CPU)
|
46 |
device = torch.device("cpu")
|
47 |
|
48 |
def get_best_response(new_response, previous_responses):
|
@@ -90,17 +88,16 @@ def apply_pruning(model):
|
|
90 |
for name, module in model.named_modules():
|
91 |
if isinstance(module, torch.nn.Linear):
|
92 |
prune.random_unstructured(module, name="weight", amount=0.2)
|
93 |
-
prune.remove(module, name="weight")
|
94 |
return model
|
95 |
|
96 |
def split_input(input_text, max_tokens):
|
97 |
-
tokens = input_text.split()
|
98 |
chunks = []
|
99 |
chunk = []
|
100 |
total_tokens = 0
|
101 |
-
|
102 |
for word in tokens:
|
103 |
-
word_length = len(word.split())
|
104 |
if total_tokens + word_length > max_tokens:
|
105 |
chunks.append(" ".join(chunk))
|
106 |
chunk = [word]
|
@@ -108,20 +105,17 @@ def split_input(input_text, max_tokens):
|
|
108 |
else:
|
109 |
chunk.append(word)
|
110 |
total_tokens += word_length
|
111 |
-
|
112 |
if chunk:
|
113 |
-
chunks.append(" ".join(chunk))
|
114 |
-
|
115 |
return chunks
|
116 |
|
117 |
def split_output(output_text, max_tokens):
|
118 |
-
tokens = output_text.split()
|
119 |
chunks = []
|
120 |
chunk = []
|
121 |
total_tokens = 0
|
122 |
-
|
123 |
for word in tokens:
|
124 |
-
word_length = len(word.split())
|
125 |
if total_tokens + word_length > max_tokens:
|
126 |
chunks.append(" ".join(chunk))
|
127 |
chunk = [word]
|
@@ -129,44 +123,48 @@ def split_output(output_text, max_tokens):
|
|
129 |
else:
|
130 |
chunk.append(word)
|
131 |
total_tokens += word_length
|
132 |
-
|
133 |
if chunk:
|
134 |
-
chunks.append(" ".join(chunk))
|
135 |
-
|
136 |
return chunks
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
async def load_models():
|
151 |
global model_1, model_2, model_3, model_4
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
]
|
158 |
-
results = await asyncio.gather(*tasks)
|
159 |
-
model_1, model_2, model_3, model_4 = results
|
160 |
-
model_1 = apply_pruning(model_1)
|
161 |
-
model_2 = apply_pruning(model_2)
|
162 |
-
model_3 = apply_pruning(model_3)
|
163 |
-
model_4 = apply_pruning(model_4)
|
164 |
-
print("Modelos cargados y podados exitosamente.")
|
165 |
|
166 |
async def optimize_models_periodically():
|
167 |
while True:
|
168 |
-
await load_models()
|
169 |
-
await asyncio.sleep(3600)
|
170 |
|
171 |
@app.on_event("startup")
|
172 |
async def startup():
|
@@ -181,34 +179,16 @@ async def monitor_memory():
|
|
181 |
|
182 |
@app.get("/generate")
|
183 |
async def generate_response(model_name: str, input_text: str):
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
prompt = prev_output + chunk
|
195 |
-
output_text += model.generate(prompt)
|
196 |
-
prev_output = output_text.split()[-50:]
|
197 |
-
|
198 |
-
output_chunks = split_output(output_text, MAX_TOKENS)
|
199 |
-
best_response = get_best_response(output_chunks[0], previous_responses)
|
200 |
-
cache.put(input_text, best_response)
|
201 |
-
previous_responses.append(best_response)
|
202 |
-
return best_response
|
203 |
-
|
204 |
-
result = await asyncio.get_event_loop().run_in_executor(
|
205 |
-
executor,
|
206 |
-
generate_for_model,
|
207 |
-
model_1 if model_name == "model1" else model_2 if model_name == "model2" else model_3 if model_name == "model3" else model_4,
|
208 |
-
input_text,
|
209 |
-
cache_1 if model_name == "model1" else cache_2 if model_name == "model2" else cache_3 if model_name == "model3" else cache_4,
|
210 |
-
previous_responses_1 if model_name == "model1" else previous_responses_2 if model_name == "model2" else previous_responses_3 if model_name == "model3" else previous_responses_4
|
211 |
-
)
|
212 |
return {f"{model_name}_output": result}
|
213 |
|
214 |
@app.get("/unified_summary")
|
|
|
1 |
import gc
|
2 |
import psutil
|
3 |
import os
|
|
|
4 |
import torch
|
5 |
from fastapi import FastAPI
|
6 |
+
from langchain.llms import VLLM
|
7 |
from chatgptcache import cache
|
8 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
9 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
14 |
import asyncio
|
15 |
import torch.nn.utils.prune as prune
|
16 |
from concurrent.futures import ThreadPoolExecutor
|
17 |
+
from langchain.prompts import PromptTemplate
|
18 |
+
from langchain.chains import LLMChain
|
19 |
|
20 |
nltk.download('punkt')
|
21 |
nltk.download('stopwords')
|
22 |
|
23 |
app = FastAPI()
|
24 |
|
|
|
25 |
model_1 = None
|
26 |
model_2 = None
|
27 |
model_3 = None
|
|
|
37 |
previous_responses_3 = []
|
38 |
previous_responses_4 = []
|
39 |
|
40 |
+
MAX_TOKENS = 2048
|
41 |
|
|
|
42 |
executor = ThreadPoolExecutor(max_workers=4)
|
43 |
|
|
|
44 |
device = torch.device("cpu")
|
45 |
|
46 |
def get_best_response(new_response, previous_responses):
|
|
|
88 |
for name, module in model.named_modules():
|
89 |
if isinstance(module, torch.nn.Linear):
|
90 |
prune.random_unstructured(module, name="weight", amount=0.2)
|
91 |
+
prune.remove(module, name="weight")
|
92 |
return model
|
93 |
|
94 |
def split_input(input_text, max_tokens):
|
95 |
+
tokens = input_text.split()
|
96 |
chunks = []
|
97 |
chunk = []
|
98 |
total_tokens = 0
|
|
|
99 |
for word in tokens:
|
100 |
+
word_length = len(word.split())
|
101 |
if total_tokens + word_length > max_tokens:
|
102 |
chunks.append(" ".join(chunk))
|
103 |
chunk = [word]
|
|
|
105 |
else:
|
106 |
chunk.append(word)
|
107 |
total_tokens += word_length
|
|
|
108 |
if chunk:
|
109 |
+
chunks.append(" ".join(chunk))
|
|
|
110 |
return chunks
|
111 |
|
112 |
def split_output(output_text, max_tokens):
|
113 |
+
tokens = output_text.split()
|
114 |
chunks = []
|
115 |
chunk = []
|
116 |
total_tokens = 0
|
|
|
117 |
for word in tokens:
|
118 |
+
word_length = len(word.split())
|
119 |
if total_tokens + word_length > max_tokens:
|
120 |
chunks.append(" ".join(chunk))
|
121 |
chunk = [word]
|
|
|
123 |
else:
|
124 |
chunk.append(word)
|
125 |
total_tokens += word_length
|
|
|
126 |
if chunk:
|
127 |
+
chunks.append(" ".join(chunk))
|
|
|
128 |
return chunks
|
129 |
|
130 |
+
def create_langchain_model(model_name: str, device: torch.device, cache, previous_responses):
|
131 |
+
vllm_llm = VLLM(model_name=model_name, device=device)
|
132 |
+
template = """
|
133 |
+
You are a helpful assistant. Given the following text, generate a meaningful response:
|
134 |
+
{input_text}
|
135 |
+
"""
|
136 |
+
prompt = PromptTemplate(input_variables=["input_text"], template=template)
|
137 |
+
chain = LLMChain(llm=vllm_llm, prompt=prompt)
|
138 |
+
def generate_for_model(input_text):
|
139 |
+
cached_output = cache.get(input_text)
|
140 |
+
if cached_output:
|
141 |
+
return cached_output
|
142 |
+
input_chunks = split_input(input_text, MAX_TOKENS)
|
143 |
+
output_text = ""
|
144 |
+
prev_output = ""
|
145 |
+
for chunk in input_chunks:
|
146 |
+
prompt = prev_output + chunk
|
147 |
+
output_text += chain.run(input_text=prompt)
|
148 |
+
prev_output = output_text.split()[-50:]
|
149 |
+
output_chunks = split_output(output_text, MAX_TOKENS)
|
150 |
+
best_response = get_best_response(output_chunks[0], previous_responses)
|
151 |
+
cache.put(input_text, best_response)
|
152 |
+
previous_responses.append(best_response)
|
153 |
+
return best_response
|
154 |
+
return generate_for_model
|
155 |
|
156 |
async def load_models():
|
157 |
global model_1, model_2, model_3, model_4
|
158 |
+
model_1 = create_langchain_model("Hjgugugjhuhjggg/llama-3.2-1B-spinquant-hf", device, cache_1, previous_responses_1)
|
159 |
+
model_2 = create_langchain_model("Qwen/Qwen2.5-Coder-1.5B", device, cache_2, previous_responses_2)
|
160 |
+
model_3 = create_langchain_model("Qwen/Qwen2.5-3B-Instruct", device, cache_3, previous_responses_3)
|
161 |
+
model_4 = create_langchain_model("gpt2", device, cache_4, previous_responses_4)
|
162 |
+
print("Modelos cargados exitosamente.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
async def optimize_models_periodically():
|
165 |
while True:
|
166 |
+
await load_models()
|
167 |
+
await asyncio.sleep(3600)
|
168 |
|
169 |
@app.on_event("startup")
|
170 |
async def startup():
|
|
|
179 |
|
180 |
@app.get("/generate")
|
181 |
async def generate_response(model_name: str, input_text: str):
|
182 |
+
if model_name == "model1":
|
183 |
+
result = await asyncio.get_event_loop().run_in_executor(executor, model_1, input_text)
|
184 |
+
elif model_name == "model2":
|
185 |
+
result = await asyncio.get_event_loop().run_in_executor(executor, model_2, input_text)
|
186 |
+
elif model_name == "model3":
|
187 |
+
result = await asyncio.get_event_loop().run_in_executor(executor, model_3, input_text)
|
188 |
+
elif model_name == "model4":
|
189 |
+
result = await asyncio.get_event_loop().run_in_executor(executor, model_4, input_text)
|
190 |
+
else:
|
191 |
+
return {"error": "Model not found"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
return {f"{model_name}_output": result}
|
193 |
|
194 |
@app.get("/unified_summary")
|