ynhe commited on
Commit
ed5d21f
·
1 Parent(s): 4fbb18f
Files changed (1) hide show
  1. conversation.py +31 -29
conversation.py CHANGED
@@ -61,28 +61,29 @@ class Chat:
61
  def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
62
  repetition_penalty=1.0, length_penalty=1, temperature=1.0):
63
  conv.messages.append([conv.roles[1], None])
64
- embs = self.get_context_emb(conv, img_list)
65
- outputs = self.model.llama_model.generate(
66
- inputs_embeds=embs,
67
- max_new_tokens=max_new_tokens,
68
- stopping_criteria=self.stopping_criteria,
69
- num_beams=num_beams,
70
- do_sample=True,
71
- min_length=min_length,
72
- top_p=top_p,
73
- repetition_penalty=repetition_penalty,
74
- length_penalty=length_penalty,
75
- temperature=temperature,
76
- )
77
- output_token = outputs[0]
78
- if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
79
- output_token = output_token[1:]
80
- if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
81
- output_token = output_token[1:]
82
- output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
83
- output_text = output_text.split('###')[0] # remove the stop sign '###'
84
- output_text = output_text.split('Assistant:')[-1].strip()
85
- conv.messages[-1][1] = output_text
 
86
  return output_text, output_token.cpu().numpy(), conv
87
 
88
  def get_index(self, num_frames, num_segments):
@@ -139,9 +140,10 @@ class Chat:
139
 
140
  else:
141
  raise NotImplementedError
142
- print("Input video shape:", vid_chat.shape)
143
- image_emb, _ = self.model.encode_img(image)
144
- img_list.append(image_emb)
 
145
  conv.messages.append([
146
  conv.roles[0],
147
  f"<Video><VideoHere></Video> {msg}\n"
@@ -161,10 +163,10 @@ class Chat:
161
  T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
162
  ]
163
  )
164
-
165
- img = transform(img).unsqueeze(0).unsqueeze(0).cuda()
166
- image_emb, _ = self.model.encode_img(img)
167
- img_list.append(image_emb)
168
  conv.messages.append([
169
  conv.roles[0],
170
  f"<Image><ImageHere></Image>\n"
 
61
  def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
62
  repetition_penalty=1.0, length_penalty=1, temperature=1.0):
63
  conv.messages.append([conv.roles[1], None])
64
+ with torch.no_grad():
65
+ embs = self.get_context_emb(conv, img_list)
66
+ outputs = self.model.llama_model.generate(
67
+ inputs_embeds=embs,
68
+ max_new_tokens=max_new_tokens,
69
+ stopping_criteria=self.stopping_criteria,
70
+ num_beams=num_beams,
71
+ do_sample=True,
72
+ min_length=min_length,
73
+ top_p=top_p,
74
+ repetition_penalty=repetition_penalty,
75
+ length_penalty=length_penalty,
76
+ temperature=temperature,
77
+ )
78
+ output_token = outputs[0]
79
+ if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
80
+ output_token = output_token[1:]
81
+ if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
82
+ output_token = output_token[1:]
83
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
84
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
85
+ output_text = output_text.split('Assistant:')[-1].strip()
86
+ conv.messages[-1][1] = output_text
87
  return output_text, output_token.cpu().numpy(), conv
88
 
89
  def get_index(self, num_frames, num_segments):
 
140
 
141
  else:
142
  raise NotImplementedError
143
+ with torch.no_grad():
144
+ print("Input video shape:", vid_chat.shape)
145
+ image_emb, _ = self.model.encode_img(image)
146
+ img_list.append(image_emb)
147
  conv.messages.append([
148
  conv.roles[0],
149
  f"<Video><VideoHere></Video> {msg}\n"
 
163
  T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
164
  ]
165
  )
166
+ with torch.no_grad():
167
+ img = transform(img).unsqueeze(0).unsqueeze(0).cuda()
168
+ image_emb, _ = self.model.encode_img(img)
169
+ img_list.append(image_emb)
170
  conv.messages.append([
171
  conv.roles[0],
172
  f"<Image><ImageHere></Image>\n"