sanjanatule commited on
Commit
e40af41
1 Parent(s): 5606a28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -15,30 +15,47 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
15
  clip_embed = 768
16
  phi_embed = 2560
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # models
19
  clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
20
  projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
 
 
21
  phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
22
 
23
  # load weights
24
  model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
25
  merged_model = model_to_merge.merge_and_unload()
26
  projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device)))
 
27
 
28
  def model_generate_ans(img,val_q):
29
 
30
  max_generate_length = 100
31
 
32
  # image
33
- image_processed = processor(images=img, return_tensors="pt").to(device)
34
  clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
35
- val_image_embeds = projection(clip_val_outputs).to(torch.float16)
 
36
 
37
  img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
38
- img_token_embeds = merged_model.model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
39
 
40
  val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
41
- val_q_embeds = merged_model.model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
42
 
43
  val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
44
 
@@ -49,6 +66,8 @@ def model_generate_ans(img,val_q):
49
  predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
50
  predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
51
  predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
 
 
52
 
53
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
54
 
 
15
  clip_embed = 768
16
  phi_embed = 2560
17
 
18
+ class SimpleResBlock(nn.Module):
19
+ def __init__(self, phi_embed):
20
+ super().__init__()
21
+ self.pre_norm = nn.LayerNorm(phi_embed)
22
+ self.proj = nn.Sequential(
23
+ nn.Linear(phi_embed, phi_embed),
24
+ nn.GELU(),
25
+ nn.Linear(phi_embed, phi_embed)
26
+ )
27
+ def forward(self, x):
28
+ x = self.pre_norm(x)
29
+ return x + self.proj(x)
30
+
31
  # models
32
  clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
33
  projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
34
+ resblock = SimpleResBlock(phi_embed).to(device)
35
+
36
  phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
37
 
38
  # load weights
39
  model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
40
  merged_model = model_to_merge.merge_and_unload()
41
  projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device)))
42
+ resblock.load_state_dict(torch.load('./model_chkpt/step2_resblock.pth',map_location=torch.device(device)))
43
 
44
  def model_generate_ans(img,val_q):
45
 
46
  max_generate_length = 100
47
 
48
  # image
49
+ image_processed = processor(images=img, return_tensors="pt").to(device)
50
  clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
51
+ val_image_embeds = projection(clip_val_outputs)
52
+ val_image_embeds = resblock(val_image_embeds).to(torch.float16)
53
 
54
  img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
55
+ img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
56
 
57
  val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
58
+ val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
59
 
60
  val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
61
 
 
66
  predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
67
  predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
68
  predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
69
+ next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
70
+ val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
71
 
72
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
73