Update app.py
Browse files
app.py
CHANGED
@@ -14,7 +14,7 @@ from dotenv import load_dotenv
|
|
14 |
import huggingface_hub
|
15 |
from threading import Thread
|
16 |
from typing import AsyncIterator, List, Dict
|
17 |
-
from transformers import StoppingCriteria, StoppingCriteriaList
|
18 |
import torch
|
19 |
|
20 |
load_dotenv()
|
@@ -135,7 +135,7 @@ model_loader = GCSModelLoader(bucket)
|
|
135 |
@app.post("/generate")
|
136 |
async def generate(request: GenerateRequest):
|
137 |
model_name = request.model_name
|
138 |
-
input_text = request.input_text
|
139 |
task_type = request.task_type
|
140 |
requested_max_new_tokens = request.max_new_tokens
|
141 |
generation_params = request.model_dump(
|
@@ -153,12 +153,10 @@ async def generate(request: GenerateRequest):
|
|
153 |
config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
|
154 |
stopping_criteria_list = StoppingCriteriaList()
|
155 |
|
156 |
-
# Add user-defined stopping strings if provided
|
157 |
if user_defined_stopping_strings:
|
158 |
stop_words_ids = [tokenizer.encode(stop_string, add_special_tokens=False) for stop_string in user_defined_stopping_strings]
|
159 |
stopping_criteria_list.append(StopOnKeywords(stop_words_ids))
|
160 |
|
161 |
-
# Automatically add EOS token as a stopping criterion
|
162 |
if config.eos_token_id is not None:
|
163 |
eos_token_ids = [config.eos_token_id]
|
164 |
if isinstance(config.eos_token_id, int):
|
@@ -172,10 +170,11 @@ async def generate(request: GenerateRequest):
|
|
172 |
stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos))
|
173 |
|
174 |
async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
|
|
|
175 |
all_generated_text = ""
|
176 |
-
stop_reason = None
|
177 |
|
178 |
-
while True:
|
179 |
text_pipeline = pipeline(
|
180 |
task_type,
|
181 |
model=model_name,
|
@@ -183,11 +182,11 @@ async def generate(request: GenerateRequest):
|
|
183 |
token=HUGGINGFACE_HUB_TOKEN,
|
184 |
stopping_criteria=stopping_criteria_list,
|
185 |
**generation_params,
|
186 |
-
max_new_tokens=requested_max_new_tokens
|
187 |
)
|
188 |
|
189 |
-
def generate_on_thread(pipeline,
|
190 |
-
result = pipeline(
|
191 |
output_queue.put_nowait(result)
|
192 |
|
193 |
output_queue = asyncio.Queue()
|
@@ -199,12 +198,11 @@ async def generate(request: GenerateRequest):
|
|
199 |
newly_generated_text = result[0]['generated_text'][len(all_generated_text):]
|
200 |
|
201 |
if not newly_generated_text:
|
202 |
-
break
|
203 |
|
204 |
all_generated_text += newly_generated_text
|
205 |
yield {"response": [{'generated_text': newly_generated_text}]}
|
206 |
|
207 |
-
# Check if any stopping criteria was met
|
208 |
if stopping_criteria_list:
|
209 |
for criteria in stopping_criteria_list:
|
210 |
if isinstance(criteria, StopOnKeywords) and criteria.current_encounters > 0:
|
@@ -213,7 +211,6 @@ async def generate(request: GenerateRequest):
|
|
213 |
if stop_reason:
|
214 |
break
|
215 |
|
216 |
-
# If the generated text seems to match the EOS token, stop
|
217 |
if config.eos_token_id is not None:
|
218 |
eos_tokens = [config.eos_token_id]
|
219 |
if isinstance(config.eos_token_id, int):
|
@@ -230,7 +227,6 @@ async def generate(request: GenerateRequest):
|
|
230 |
stop_reason = "eos_token"
|
231 |
break
|
232 |
|
233 |
-
# Update input text for the next iteration
|
234 |
input_text = all_generated_text
|
235 |
|
236 |
async def text_stream():
|
|
|
14 |
import huggingface_hub
|
15 |
from threading import Thread
|
16 |
from typing import AsyncIterator, List, Dict
|
17 |
+
from transformers.stopping_criteria import StoppingCriteria, StoppingCriteriaList
|
18 |
import torch
|
19 |
|
20 |
load_dotenv()
|
|
|
135 |
@app.post("/generate")
|
136 |
async def generate(request: GenerateRequest):
|
137 |
model_name = request.model_name
|
138 |
+
input_text = request.input_text # Initialize input_text here
|
139 |
task_type = request.task_type
|
140 |
requested_max_new_tokens = request.max_new_tokens
|
141 |
generation_params = request.model_dump(
|
|
|
153 |
config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
|
154 |
stopping_criteria_list = StoppingCriteriaList()
|
155 |
|
|
|
156 |
if user_defined_stopping_strings:
|
157 |
stop_words_ids = [tokenizer.encode(stop_string, add_special_tokens=False) for stop_string in user_defined_stopping_strings]
|
158 |
stopping_criteria_list.append(StopOnKeywords(stop_words_ids))
|
159 |
|
|
|
160 |
if config.eos_token_id is not None:
|
161 |
eos_token_ids = [config.eos_token_id]
|
162 |
if isinstance(config.eos_token_id, int):
|
|
|
170 |
stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos))
|
171 |
|
172 |
async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
|
173 |
+
nonlocal input_text # Allow modification of the outer scope variable
|
174 |
all_generated_text = ""
|
175 |
+
stop_reason = None
|
176 |
|
177 |
+
while True:
|
178 |
text_pipeline = pipeline(
|
179 |
task_type,
|
180 |
model=model_name,
|
|
|
182 |
token=HUGGINGFACE_HUB_TOKEN,
|
183 |
stopping_criteria=stopping_criteria_list,
|
184 |
**generation_params,
|
185 |
+
max_new_tokens=requested_max_new_tokens
|
186 |
)
|
187 |
|
188 |
+
def generate_on_thread(pipeline, current_input_text, output_queue):
|
189 |
+
result = pipeline(current_input_text)
|
190 |
output_queue.put_nowait(result)
|
191 |
|
192 |
output_queue = asyncio.Queue()
|
|
|
198 |
newly_generated_text = result[0]['generated_text'][len(all_generated_text):]
|
199 |
|
200 |
if not newly_generated_text:
|
201 |
+
break
|
202 |
|
203 |
all_generated_text += newly_generated_text
|
204 |
yield {"response": [{'generated_text': newly_generated_text}]}
|
205 |
|
|
|
206 |
if stopping_criteria_list:
|
207 |
for criteria in stopping_criteria_list:
|
208 |
if isinstance(criteria, StopOnKeywords) and criteria.current_encounters > 0:
|
|
|
211 |
if stop_reason:
|
212 |
break
|
213 |
|
|
|
214 |
if config.eos_token_id is not None:
|
215 |
eos_tokens = [config.eos_token_id]
|
216 |
if isinstance(config.eos_token_id, int):
|
|
|
227 |
stop_reason = "eos_token"
|
228 |
break
|
229 |
|
|
|
230 |
input_text = all_generated_text
|
231 |
|
232 |
async def text_stream():
|