David Day commited on
Commit
3b36384
1 Parent(s): df6eb2a
Files changed (2) hide show
  1. model_worker.py +12 -33
  2. requirements.txt +1 -0
model_worker.py CHANGED
@@ -52,12 +52,12 @@ class ModelWorker:
52
  torch_device='cpu',
53
  device_map="cpu",
54
  )
55
- self.model.to("cuda:0")
56
 
57
  @spaces.GPU
58
  def generate_stream(self, params):
59
  tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
60
- logger.info(f'Model devices: {self.model.device}')
61
 
62
  prompt = params["prompt"]
63
  ori_prompt = prompt
@@ -70,17 +70,18 @@ class ModelWorker:
70
 
71
  images = [load_image_from_base64(image) for image in images]
72
  images = process_images(images, image_processor, model.config)
 
73
 
74
  if type(images) is list:
75
- images = [image.to(self.model.device, dtype=torch.float16) for image in images]
76
  else:
77
- images = images.to(self.model.device, dtype=torch.float16)
78
 
79
  if self.load_bf16:
80
  images = images.to(dtype=torch.bfloat16)
81
 
82
  replace_token = DEFAULT_IMAGE_TOKEN
83
- if getattr(self.model.config, 'mm_use_im_start_end', False):
84
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
85
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
86
 
@@ -99,15 +100,15 @@ class ModelWorker:
99
  stop_str = params.get("stop", None)
100
  do_sample = True if temperature > 0.001 else False
101
 
102
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
103
  keywords = [stop_str]
104
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
105
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
106
 
107
  max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
108
 
109
  if max_new_tokens < 1:
110
- yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
111
  return
112
 
113
  thread = Thread(target=model.generate, kwargs=dict(
@@ -128,33 +129,11 @@ class ModelWorker:
128
  generated_text += new_text
129
  if generated_text.endswith(stop_str):
130
  generated_text = generated_text[:-len(stop_str)]
131
- yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
132
 
133
  def generate_stream_gate(self, params):
134
- try:
135
- for x in self.generate_stream(params):
136
- yield x
137
- except ValueError as e:
138
- print("Caught ValueError:", e)
139
- ret = {
140
- "text": server_error_msg,
141
- "error_code": 1,
142
- }
143
- yield json.dumps(ret).encode() + b"\0"
144
- except torch.cuda.CudaError as e:
145
- print("Caught torch.cuda.CudaError:", e)
146
- ret = {
147
- "text": server_error_msg,
148
- "error_code": 1,
149
- }
150
- yield json.dumps(ret).encode() + b"\0"
151
- except Exception as e:
152
- print("Caught Unknown Error", e)
153
- ret = {
154
- "text": server_error_msg,
155
- "error_code": 1,
156
- }
157
- yield json.dumps(ret).encode() + b"\0"
158
 
159
  def release_model_semaphore(fn=None):
160
  model_semaphore.release()
 
52
  torch_device='cpu',
53
  device_map="cpu",
54
  )
55
+ self.model.to('cuda')
56
 
57
  @spaces.GPU
58
  def generate_stream(self, params):
59
  tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
60
+ logger.info(f'Model devices: {model.device}')
61
 
62
  prompt = params["prompt"]
63
  ori_prompt = prompt
 
70
 
71
  images = [load_image_from_base64(image) for image in images]
72
  images = process_images(images, image_processor, model.config)
73
+ logger.info(f'Images: {images.shape}')
74
 
75
  if type(images) is list:
76
+ images = [image.to(model.device, dtype=torch.float16) for image in images]
77
  else:
78
+ images = images.to(model.device, dtype=torch.float16)
79
 
80
  if self.load_bf16:
81
  images = images.to(dtype=torch.bfloat16)
82
 
83
  replace_token = DEFAULT_IMAGE_TOKEN
84
+ if getattr(model.config, 'mm_use_im_start_end', False):
85
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
86
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
87
 
 
100
  stop_str = params.get("stop", None)
101
  do_sample = True if temperature > 0.001 else False
102
 
103
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
104
  keywords = [stop_str]
105
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
106
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=None)
107
 
108
  max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
109
 
110
  if max_new_tokens < 1:
111
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode()
112
  return
113
 
114
  thread = Thread(target=model.generate, kwargs=dict(
 
129
  generated_text += new_text
130
  if generated_text.endswith(stop_str):
131
  generated_text = generated_text[:-len(stop_str)]
132
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode()
133
 
134
  def generate_stream_gate(self, params):
135
+ for x in self.generate_stream(params):
136
+ yield x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  def release_model_semaphore(fn=None):
139
  model_semaphore.release()
requirements.txt CHANGED
@@ -11,4 +11,5 @@ einops==0.6.1
11
  einops-exts==0.0.4
12
  timm==0.6.13
13
  httpx==0.24.0
 
14
  scipy
 
11
  einops-exts==0.0.4
12
  timm==0.6.13
13
  httpx==0.24.0
14
+ numpy==1.26.4
15
  scipy