Hjgugugjhuhjggg commited on
Commit
f05e47d
·
verified ·
1 Parent(s): 1f3523a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -13
app.py CHANGED
@@ -14,7 +14,7 @@ from dotenv import load_dotenv
14
  import huggingface_hub
15
  from threading import Thread
16
  from typing import AsyncIterator, List, Dict
17
- from transformers import StoppingCriteria, StoppingCriteriaList
18
  import torch
19
 
20
  load_dotenv()
@@ -135,7 +135,7 @@ model_loader = GCSModelLoader(bucket)
135
  @app.post("/generate")
136
  async def generate(request: GenerateRequest):
137
  model_name = request.model_name
138
- input_text = request.input_text
139
  task_type = request.task_type
140
  requested_max_new_tokens = request.max_new_tokens
141
  generation_params = request.model_dump(
@@ -153,12 +153,10 @@ async def generate(request: GenerateRequest):
153
  config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
154
  stopping_criteria_list = StoppingCriteriaList()
155
 
156
- # Add user-defined stopping strings if provided
157
  if user_defined_stopping_strings:
158
  stop_words_ids = [tokenizer.encode(stop_string, add_special_tokens=False) for stop_string in user_defined_stopping_strings]
159
  stopping_criteria_list.append(StopOnKeywords(stop_words_ids))
160
 
161
- # Automatically add EOS token as a stopping criterion
162
  if config.eos_token_id is not None:
163
  eos_token_ids = [config.eos_token_id]
164
  if isinstance(config.eos_token_id, int):
@@ -172,10 +170,11 @@ async def generate(request: GenerateRequest):
172
  stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos))
173
 
174
  async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
 
175
  all_generated_text = ""
176
- stop_reason = None # To track why the generation stopped
177
 
178
- while True: # Loop indefinitely, relying on stopping criteria
179
  text_pipeline = pipeline(
180
  task_type,
181
  model=model_name,
@@ -183,11 +182,11 @@ async def generate(request: GenerateRequest):
183
  token=HUGGINGFACE_HUB_TOKEN,
184
  stopping_criteria=stopping_criteria_list,
185
  **generation_params,
186
- max_new_tokens=requested_max_new_tokens # Generate in chunks
187
  )
188
 
189
- def generate_on_thread(pipeline, input_text, output_queue):
190
- result = pipeline(input_text)
191
  output_queue.put_nowait(result)
192
 
193
  output_queue = asyncio.Queue()
@@ -199,12 +198,11 @@ async def generate(request: GenerateRequest):
199
  newly_generated_text = result[0]['generated_text'][len(all_generated_text):]
200
 
201
  if not newly_generated_text:
202
- break # Should ideally not happen with proper stopping criteria
203
 
204
  all_generated_text += newly_generated_text
205
  yield {"response": [{'generated_text': newly_generated_text}]}
206
 
207
- # Check if any stopping criteria was met
208
  if stopping_criteria_list:
209
  for criteria in stopping_criteria_list:
210
  if isinstance(criteria, StopOnKeywords) and criteria.current_encounters > 0:
@@ -213,7 +211,6 @@ async def generate(request: GenerateRequest):
213
  if stop_reason:
214
  break
215
 
216
- # If the generated text seems to match the EOS token, stop
217
  if config.eos_token_id is not None:
218
  eos_tokens = [config.eos_token_id]
219
  if isinstance(config.eos_token_id, int):
@@ -230,7 +227,6 @@ async def generate(request: GenerateRequest):
230
  stop_reason = "eos_token"
231
  break
232
 
233
- # Update input text for the next iteration
234
  input_text = all_generated_text
235
 
236
  async def text_stream():
 
14
  import huggingface_hub
15
  from threading import Thread
16
  from typing import AsyncIterator, List, Dict
17
+ from transformers.stopping_criteria import StoppingCriteria, StoppingCriteriaList
18
  import torch
19
 
20
  load_dotenv()
 
135
  @app.post("/generate")
136
  async def generate(request: GenerateRequest):
137
  model_name = request.model_name
138
+ input_text = request.input_text # Initialize input_text here
139
  task_type = request.task_type
140
  requested_max_new_tokens = request.max_new_tokens
141
  generation_params = request.model_dump(
 
153
  config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
154
  stopping_criteria_list = StoppingCriteriaList()
155
 
 
156
  if user_defined_stopping_strings:
157
  stop_words_ids = [tokenizer.encode(stop_string, add_special_tokens=False) for stop_string in user_defined_stopping_strings]
158
  stopping_criteria_list.append(StopOnKeywords(stop_words_ids))
159
 
 
160
  if config.eos_token_id is not None:
161
  eos_token_ids = [config.eos_token_id]
162
  if isinstance(config.eos_token_id, int):
 
170
  stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos))
171
 
172
  async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
173
+ nonlocal input_text # Allow modification of the outer scope variable
174
  all_generated_text = ""
175
+ stop_reason = None
176
 
177
+ while True:
178
  text_pipeline = pipeline(
179
  task_type,
180
  model=model_name,
 
182
  token=HUGGINGFACE_HUB_TOKEN,
183
  stopping_criteria=stopping_criteria_list,
184
  **generation_params,
185
+ max_new_tokens=requested_max_new_tokens
186
  )
187
 
188
+ def generate_on_thread(pipeline, current_input_text, output_queue):
189
+ result = pipeline(current_input_text)
190
  output_queue.put_nowait(result)
191
 
192
  output_queue = asyncio.Queue()
 
198
  newly_generated_text = result[0]['generated_text'][len(all_generated_text):]
199
 
200
  if not newly_generated_text:
201
+ break
202
 
203
  all_generated_text += newly_generated_text
204
  yield {"response": [{'generated_text': newly_generated_text}]}
205
 
 
206
  if stopping_criteria_list:
207
  for criteria in stopping_criteria_list:
208
  if isinstance(criteria, StopOnKeywords) and criteria.current_encounters > 0:
 
211
  if stop_reason:
212
  break
213
 
 
214
  if config.eos_token_id is not None:
215
  eos_tokens = [config.eos_token_id]
216
  if isinstance(config.eos_token_id, int):
 
227
  stop_reason = "eos_token"
228
  break
229
 
 
230
  input_text = all_generated_text
231
 
232
  async def text_stream():