Vasudevakrishna commited on
Commit
c5fb1de
1 Parent(s): 7eb436c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -73,12 +73,15 @@ def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
73
  inputs_embeddings.append(end_iq_embeds)
74
  # Combine embeddings
75
  combined_embeds = torch.cat(inputs_embeddings, dim=1)
 
76
 
77
  for pos in range(max_tokens - 1):
78
  model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
 
79
  predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
80
  predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
81
  predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
 
82
  next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
83
  combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
84
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
@@ -110,4 +113,4 @@ with gr.Blocks() as demo:
110
 
111
  if __name__ == "__main__":
112
 
113
- demo.launch(share=True)
 
73
  inputs_embeddings.append(end_iq_embeds)
74
  # Combine embeddings
75
  combined_embeds = torch.cat(inputs_embeddings, dim=1)
76
+ print("----------",combined_embeds.shape)
77
 
78
  for pos in range(max_tokens - 1):
79
  model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
80
+ print("-=-=-=-", model_output_logits.shape)
81
  predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
82
  predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
83
  predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
84
+ print(predicted_caption)
85
  next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
86
  combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
87
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
 
113
 
114
  if __name__ == "__main__":
115
 
116
+ demo.launch(share=True, debug=True)