ToletiSri commited on
Commit
70f3d32
1 Parent(s): 24edbec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -3
app.py CHANGED
@@ -8,10 +8,10 @@ from PIL import Image
8
 
9
 
10
  class _MLPVectorProjector(nn.Module):
11
- def init(
12
  self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
13
  ):
14
- super(_MLPVectorProjector, self).init()
15
  self.mlps = nn.ModuleList()
16
  for _ in range(width):
17
  mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
@@ -92,8 +92,26 @@ def textMode(text, count):
92
 
93
  def imageMode(image, question):
94
  image_embedding = encode_image(image)
 
95
  imgToTextEmb = img_proj_head(image_embedding)
96
- return "In progress"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def audioMode(audio):
99
  if audio is None:
 
8
 
9
 
10
  class _MLPVectorProjector(nn.Module):
11
+ def __init__(
12
  self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
13
  ):
14
+ super(_MLPVectorProjector, self).__init__()
15
  self.mlps = nn.ModuleList()
16
  for _ in range(width):
17
  mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
 
92
 
93
  def imageMode(image, question):
94
  image_embedding = encode_image(image)
95
+ print('-------Image embedding from clip obtained-----------')
96
  imgToTextEmb = img_proj_head(image_embedding)
97
+ print('-------text embedding from projection obtained-----------')
98
+ question = "Question: " + question + "Answer: "
99
+ Qtokens = tokenizer_text.encode(question, add_special_tokens=True)
100
+ Qtoken_embeddings = phi2_finetuned.get_submodule('model.embed_tokens')(Qtokens)
101
+ print('-------question embedding from phi2 obtained-----------')
102
+ inputs = torch.concat((imgToTextEmb, Qtoken_embeddings), axis=-2)
103
+
104
+ prediction = tokenizer.batch_decode(
105
+ phi2.generate(
106
+ inputs_embeds=inputs,
107
+ max_new_tokens=50,
108
+ bos_token_id=tokenizer.bos_token_id,
109
+ eos_token_id=tokenizer.eos_token_id,
110
+ pad_token_id=tokenizer.pad_token_id
111
+ )
112
+ )
113
+ text_pred = prediction[0].rstrip('<|endoftext|>').rstrip("\n")
114
+ return text_pred
115
 
116
  def audioMode(audio):
117
  if audio is None: