Mizukiluke commited on
Commit
f860a77
1 Parent(s): 512a4c1

Fix the function of chat to support the gradio demo

Browse files
Files changed (1) hide show
  1. modeling_mplugowl3.py +20 -31
modeling_mplugowl3.py CHANGED
@@ -142,7 +142,6 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
142
  media_offset=None,
143
  attention_mask=None,
144
  tokenizer=None,
145
- return_vision_hidden_states=False,
146
  stream=False,
147
  decode_text=False,
148
  **kwargs
@@ -156,9 +155,6 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
156
  result = self._decode_stream(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, **kwargs)
157
  else:
158
  result = self._decode(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, attention_mask=attention_mask, decode_text=decode_text, **kwargs)
159
-
160
- if return_vision_hidden_states:
161
- return result, image_embeds
162
 
163
  return result
164
 
@@ -166,10 +162,9 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
166
  self,
167
  images,
168
  videos,
169
- msgs,
170
  tokenizer,
171
  processor=None,
172
- vision_hidden_states=None,
173
  max_new_tokens=2048,
174
  min_new_tokens=0,
175
  sampling=True,
@@ -180,21 +175,23 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
180
  use_image_id=None,
181
  **kwargs
182
  ):
183
- print(msgs)
 
 
 
 
184
  if processor is None:
185
  if self.processor is None:
186
- self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
187
- processor = self.processor
188
-
189
-
190
- inputs = processor(
191
- prompts_lists,
192
- input_images_lists,
193
- max_slice_nums=max_slice_nums,
194
- use_image_id=use_image_id,
195
- return_tensors="pt",
196
- max_length=max_inp_length
197
- ).to(self.device)
198
 
199
  if sampling:
200
  generation_config = {
@@ -202,12 +199,12 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
202
  "top_k": 100,
203
  "temperature": 0.7,
204
  "do_sample": True,
205
- "repetition_penalty": 1.05
206
  }
207
  else:
208
  generation_config = {
209
  "num_beams": 3,
210
- "repetition_penalty": 1.2,
211
  }
212
 
213
  if min_new_tokens > 0:
@@ -216,14 +213,10 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
216
  generation_config.update(
217
  (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
218
  )
219
-
220
- inputs.pop("image_sizes")
221
  with torch.inference_mode():
222
  res = self.generate(
223
  **inputs,
224
- tokenizer=tokenizer,
225
- max_new_tokens=max_new_tokens,
226
- vision_hidden_states=vision_hidden_states,
227
  stream=stream,
228
  decode_text=True,
229
  **generation_config
@@ -238,9 +231,5 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
238
  return stream_gen()
239
 
240
  else:
241
- if batched:
242
- answer = res
243
- else:
244
- answer = res[0]
245
  return answer
246
-
 
142
  media_offset=None,
143
  attention_mask=None,
144
  tokenizer=None,
 
145
  stream=False,
146
  decode_text=False,
147
  **kwargs
 
155
  result = self._decode_stream(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, **kwargs)
156
  else:
157
  result = self._decode(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, attention_mask=attention_mask, decode_text=decode_text, **kwargs)
 
 
 
158
 
159
  return result
160
 
 
162
  self,
163
  images,
164
  videos,
165
+ messages,
166
  tokenizer,
167
  processor=None,
 
168
  max_new_tokens=2048,
169
  min_new_tokens=0,
170
  sampling=True,
 
175
  use_image_id=None,
176
  **kwargs
177
  ):
178
+ print(messages)
179
+ if len(images)>1:
180
+ cut_flag=False
181
+ else:
182
+ cut_flag=True
183
  if processor is None:
184
  if self.processor is None:
185
+ processor = self.init_processor(tokenizer)
186
+ else:
187
+ processor = self.processor
188
+ inputs = processor(messages, images=images, videos=videos, cut_enable=cut_flag)
189
+ inputs.to('cuda')
190
+ inputs.update({
191
+ 'tokenizer': tokenizer,
192
+ 'max_new_tokens': max_new_tokens,
193
+ # 'stream':True,
194
+ })
 
 
195
 
196
  if sampling:
197
  generation_config = {
 
199
  "top_k": 100,
200
  "temperature": 0.7,
201
  "do_sample": True,
202
+ # "repetition_penalty": 1.05
203
  }
204
  else:
205
  generation_config = {
206
  "num_beams": 3,
207
+ # "repetition_penalty": 1.2,
208
  }
209
 
210
  if min_new_tokens > 0:
 
213
  generation_config.update(
214
  (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
215
  )
216
+ print(inputs)
 
217
  with torch.inference_mode():
218
  res = self.generate(
219
  **inputs,
 
 
 
220
  stream=stream,
221
  decode_text=True,
222
  **generation_config
 
231
  return stream_gen()
232
 
233
  else:
234
+ answer = res[0]
 
 
 
235
  return answer