Spaces:
Sleeping
Sleeping
sanjanatule
commited on
Commit
•
e40af41
1
Parent(s):
5606a28
Update app.py
Browse files
app.py
CHANGED
@@ -15,30 +15,47 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
15 |
clip_embed = 768
|
16 |
phi_embed = 2560
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# models
|
19 |
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
|
20 |
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
|
|
|
|
|
21 |
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
|
22 |
|
23 |
# load weights
|
24 |
model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
|
25 |
merged_model = model_to_merge.merge_and_unload()
|
26 |
projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device)))
|
|
|
27 |
|
28 |
def model_generate_ans(img,val_q):
|
29 |
|
30 |
max_generate_length = 100
|
31 |
|
32 |
# image
|
33 |
-
image_processed
|
34 |
clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
|
35 |
-
val_image_embeds = projection(clip_val_outputs)
|
|
|
36 |
|
37 |
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
|
38 |
-
img_token_embeds = merged_model.model.
|
39 |
|
40 |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
41 |
-
val_q_embeds
|
42 |
|
43 |
val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
44 |
|
@@ -49,6 +66,8 @@ def model_generate_ans(img,val_q):
|
|
49 |
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
|
50 |
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
|
51 |
predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
|
|
|
|
|
52 |
|
53 |
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
|
54 |
|
|
|
15 |
clip_embed = 768
|
16 |
phi_embed = 2560
|
17 |
|
18 |
+
class SimpleResBlock(nn.Module):
|
19 |
+
def __init__(self, phi_embed):
|
20 |
+
super().__init__()
|
21 |
+
self.pre_norm = nn.LayerNorm(phi_embed)
|
22 |
+
self.proj = nn.Sequential(
|
23 |
+
nn.Linear(phi_embed, phi_embed),
|
24 |
+
nn.GELU(),
|
25 |
+
nn.Linear(phi_embed, phi_embed)
|
26 |
+
)
|
27 |
+
def forward(self, x):
|
28 |
+
x = self.pre_norm(x)
|
29 |
+
return x + self.proj(x)
|
30 |
+
|
31 |
# models
|
32 |
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
|
33 |
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
|
34 |
+
resblock = SimpleResBlock(phi_embed).to(device)
|
35 |
+
|
36 |
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
|
37 |
|
38 |
# load weights
|
39 |
model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
|
40 |
merged_model = model_to_merge.merge_and_unload()
|
41 |
projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device)))
|
42 |
+
resblock.load_state_dict(torch.load('./model_chkpt/step2_resblock.pth',map_location=torch.device(device)))
|
43 |
|
44 |
def model_generate_ans(img,val_q):
|
45 |
|
46 |
max_generate_length = 100
|
47 |
|
48 |
# image
|
49 |
+
image_processed = processor(images=img, return_tensors="pt").to(device)
|
50 |
clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
|
51 |
+
val_image_embeds = projection(clip_val_outputs)
|
52 |
+
val_image_embeds = resblock(val_image_embeds).to(torch.float16)
|
53 |
|
54 |
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
|
55 |
+
img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
|
56 |
|
57 |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
58 |
+
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
|
59 |
|
60 |
val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
61 |
|
|
|
66 |
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
|
67 |
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
|
68 |
predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
|
69 |
+
next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
|
70 |
+
val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
|
71 |
|
72 |
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
|
73 |
|