Spaces:
Sleeping
Sleeping
Commit
•
ea37b8e
1
Parent(s):
2b5d0a7
Update app.py
Browse files
app.py
CHANGED
@@ -48,6 +48,8 @@ resblock.load_state_dict(torch.load('./model_chkpt/step2_resblock.pth',map_locat
|
|
48 |
def model_generate_ans(img=None,img_audio=None,val_q=None):
|
49 |
|
50 |
max_generate_length = 100
|
|
|
|
|
51 |
with torch.no_grad():
|
52 |
|
53 |
# image
|
@@ -60,33 +62,38 @@ def model_generate_ans(img=None,img_audio=None,val_q=None):
|
|
60 |
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
|
61 |
img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
|
62 |
|
|
|
|
|
|
|
63 |
# audio
|
64 |
if img_audio is not None:
|
65 |
-
audio_result = audio_model.transcribe(
|
66 |
audio_text = ''
|
67 |
for seg in audio_result['segments']:
|
68 |
audio_text += seg['text']
|
69 |
audio_text = audio_text.strip()
|
70 |
audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
|
71 |
audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
|
|
|
72 |
|
73 |
# text question
|
74 |
-
if val_q
|
75 |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
|
76 |
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
|
77 |
-
|
78 |
-
val_combined_embeds = []
|
79 |
-
if img is not None:
|
80 |
-
#val_combined_embeds = torch.cat([val_combined_embeds, val_image_embeds, img_token_embeds], dim=1)
|
81 |
-
val_combined_embeds.append(val_image_embeds)
|
82 |
-
val_combined_embeds.append(img_token_embeds)
|
83 |
-
if img_audio is not None:
|
84 |
-
#val_combined_embeds = torch.cat([val_combined_embeds, audio_embeds], dim=1)
|
85 |
-
val_combined_embeds.append(audio_embeds)
|
86 |
-
if val_q is not None:
|
87 |
-
#val_combined_embeds = torch.cat([val_combined_embeds, val_q_embeds], dim=1)
|
88 |
val_combined_embeds.append(val_q_embeds)
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
|
91 |
|
92 |
#val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
|
|
48 |
def model_generate_ans(img=None,img_audio=None,val_q=None):
|
49 |
|
50 |
max_generate_length = 100
|
51 |
+
val_combined_embeds = []
|
52 |
+
|
53 |
with torch.no_grad():
|
54 |
|
55 |
# image
|
|
|
62 |
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
|
63 |
img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
|
64 |
|
65 |
+
val_combined_embeds.append(val_image_embeds)
|
66 |
+
val_combined_embeds.append(img_token_embeds)
|
67 |
+
|
68 |
# audio
|
69 |
if img_audio is not None:
|
70 |
+
audio_result = audio_model.transcribe(img_audio)
|
71 |
audio_text = ''
|
72 |
for seg in audio_result['segments']:
|
73 |
audio_text += seg['text']
|
74 |
audio_text = audio_text.strip()
|
75 |
audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
|
76 |
audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
|
77 |
+
val_combined_embeds.append(audio_embeds)
|
78 |
|
79 |
# text question
|
80 |
+
if len(val_q) != 0:
|
81 |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
|
82 |
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
val_combined_embeds.append(val_q_embeds)
|
84 |
|
85 |
+
# val_combined_embeds = []
|
86 |
+
# if img is not None:
|
87 |
+
# #val_combined_embeds = torch.cat([val_combined_embeds, val_image_embeds, img_token_embeds], dim=1)
|
88 |
+
# val_combined_embeds.append(val_image_embeds)
|
89 |
+
# val_combined_embeds.append(img_token_embeds)
|
90 |
+
# if img_audio is not None:
|
91 |
+
# #val_combined_embeds = torch.cat([val_combined_embeds, audio_embeds], dim=1)
|
92 |
+
# val_combined_embeds.append(audio_embeds)
|
93 |
+
# if len(val_q) != 0:
|
94 |
+
# #val_combined_embeds = torch.cat([val_combined_embeds, val_q_embeds], dim=1)
|
95 |
+
# val_combined_embeds.append(val_q_embeds)
|
96 |
+
|
97 |
val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
|
98 |
|
99 |
#val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|