Vasudevakrishna commited on
Commit
b9ac069
1 Parent(s): 7b897fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -16
app.py CHANGED
@@ -73,22 +73,27 @@ def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
73
  inputs_embeddings.append(end_iq_embeds)
74
  # Combine embeddings
75
  combined_embeds = torch.cat(inputs_embeddings, dim=1)
76
- print("----------",combined_embeds.shape)
77
-
78
- for pos in range(max_tokens - 1):
79
- model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
80
- print("-=-=-=-", model_output_logits.shape)
81
- predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
82
- predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
83
- predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
84
- print(predicted_caption)
85
- next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
86
- combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
87
- del next_token_embeds
88
- del predicted_word_token
89
- del predicted_word_token_logits
90
- del combined_embeds
91
- predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
 
 
 
 
 
92
  predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
93
  return predicted_captions_decoded
94
 
 
73
  inputs_embeddings.append(end_iq_embeds)
74
  # Combine embeddings
75
  combined_embeds = torch.cat(inputs_embeddings, dim=1)
76
+
77
+ predicted_caption = phi2_model.generate(inputs_embeds=combined_embeds,
78
+ max_new_tokens=max_tokens,
79
+ return_dict_in_generate = True)
80
+ # print("----------",combined_embeds.shape)
81
+
82
+ # for pos in range(max_tokens - 1):
83
+ # model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
84
+ # print("-=-=-=-", model_output_logits.shape)
85
+ # predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
86
+ # predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
87
+ # predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
88
+ # print(predicted_caption)
89
+ # next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
90
+ # combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
91
+ # del next_token_embeds
92
+ # del predicted_word_token
93
+ # del predicted_word_token_logits
94
+ # del combined_embeds
95
+ # predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
96
+ predicted_captions_decoded =tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
97
  predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
98
  return predicted_captions_decoded
99