BoZhaoHuggingFace commited on
Commit
36a4fc6
1 Parent(s): 9f37abb

Update bunny/serve/model_worker.py

Browse files
Files changed (1) hide show
  1. bunny/serve/model_worker.py +27 -30
bunny/serve/model_worker.py CHANGED
@@ -154,10 +154,8 @@ class ModelWorker:
154
  stop_str = params.get("stop", None)
155
  do_sample = True if temperature > 0.001 else False
156
 
157
-
158
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(
159
  self.device)
160
-
161
  keywords = [stop_str]
162
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
163
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
@@ -173,7 +171,6 @@ class ModelWorker:
173
 
174
  model = model.to('cuda')
175
 
176
-
177
  thread = Thread(target=model.generate, kwargs=dict(
178
  inputs=input_ids,
179
  do_sample=do_sample,
@@ -199,33 +196,33 @@ class ModelWorker:
199
  yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
200
 
201
  def generate_stream_gate(self, params):
202
- # for x in self.generate_stream(params):
203
- # yield x
204
-
205
- try:
206
- for x in self.generate_stream(params):
207
- yield x
208
- except ValueError as e:
209
- logger.info("Caught ValueError:", e)
210
- ret = {
211
- "text": server_error_msg,
212
- "error_code": 1,
213
- }
214
- yield json.dumps(ret).encode() + b"\0"
215
- except torch.cuda.CudaError as e:
216
- logger.info("Caught torch.cuda.CudaError:", e)
217
- ret = {
218
- "text": server_error_msg,
219
- "error_code": 1,
220
- }
221
- yield json.dumps(ret).encode() + b"\0"
222
- except Exception as e:
223
- logger.info("Caught Unknown Error", e)
224
- ret = {
225
- "text": server_error_msg,
226
- "error_code": 1,
227
- }
228
- yield json.dumps(ret).encode() + b"\0"
229
 
230
 
231
  app = FastAPI()
 
154
  stop_str = params.get("stop", None)
155
  do_sample = True if temperature > 0.001 else False
156
 
 
157
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(
158
  self.device)
 
159
  keywords = [stop_str]
160
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
161
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
 
171
 
172
  model = model.to('cuda')
173
 
 
174
  thread = Thread(target=model.generate, kwargs=dict(
175
  inputs=input_ids,
176
  do_sample=do_sample,
 
196
  yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
197
 
198
  def generate_stream_gate(self, params):
199
+ for x in self.generate_stream(params):
200
+ yield x
201
+
202
+ # try:
203
+ # for x in self.generate_stream(params):
204
+ # yield x
205
+ # except ValueError as e:
206
+ # print("Caught ValueError:", e)
207
+ # ret = {
208
+ # "text": server_error_msg,
209
+ # "error_code": 1,
210
+ # }
211
+ # yield json.dumps(ret).encode() + b"\0"
212
+ # except torch.cuda.CudaError as e:
213
+ # print("Caught torch.cuda.CudaError:", e)
214
+ # ret = {
215
+ # "text": server_error_msg,
216
+ # "error_code": 1,
217
+ # }
218
+ # yield json.dumps(ret).encode() + b"\0"
219
+ # except Exception as e:
220
+ # print("Caught Unknown Error", e)
221
+ # ret = {
222
+ # "text": server_error_msg,
223
+ # "error_code": 1,
224
+ # }
225
+ # yield json.dumps(ret).encode() + b"\0"
226
 
227
 
228
  app = FastAPI()