sanjanatule commited on
Commit
2b5d0a7
1 Parent(s): 015cbbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -51,7 +51,7 @@ def model_generate_ans(img=None,img_audio=None,val_q=None):
51
  with torch.no_grad():
52
 
53
  # image
54
- if img:
55
  image_processed = processor(images=img, return_tensors="pt").to(device)
56
  clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
57
  val_image_embeds = projection(clip_val_outputs)
@@ -61,7 +61,7 @@ def model_generate_ans(img=None,img_audio=None,val_q=None):
61
  img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
62
 
63
  # audio
64
- if img_audio:
65
  audio_result = audio_model.transcribe(audio)
66
  audio_text = ''
67
  for seg in audio_result['segments']:
@@ -71,19 +71,19 @@ def model_generate_ans(img=None,img_audio=None,val_q=None):
71
  audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
72
 
73
  # text question
74
- if val_q:
75
  val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
76
  val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
77
 
78
  val_combined_embeds = []
79
- if img:
80
  #val_combined_embeds = torch.cat([val_combined_embeds, val_image_embeds, img_token_embeds], dim=1)
81
  val_combined_embeds.append(val_image_embeds)
82
  val_combined_embeds.append(img_token_embeds)
83
- if img_audio:
84
  #val_combined_embeds = torch.cat([val_combined_embeds, audio_embeds], dim=1)
85
  val_combined_embeds.append(audio_embeds)
86
- if val_q:
87
  #val_combined_embeds = torch.cat([val_combined_embeds, val_q_embeds], dim=1)
88
  val_combined_embeds.append(val_q_embeds)
89
 
@@ -117,7 +117,7 @@ with gr.Blocks() as demo:
117
  # app GUI
118
  with gr.Row():
119
  with gr.Column():
120
- img_input = gr.Image(label='Image')
121
  img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
122
  img_question = gr.Text(label ='Text Query')
123
  with gr.Column():
 
51
  with torch.no_grad():
52
 
53
  # image
54
+ if img is not None:
55
  image_processed = processor(images=img, return_tensors="pt").to(device)
56
  clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
57
  val_image_embeds = projection(clip_val_outputs)
 
61
  img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
62
 
63
  # audio
64
+ if img_audio is not None:
65
  audio_result = audio_model.transcribe(audio)
66
  audio_text = ''
67
  for seg in audio_result['segments']:
 
71
  audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
72
 
73
  # text question
74
+ if val_q is not None:
75
  val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
76
  val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
77
 
78
  val_combined_embeds = []
79
+ if img is not None:
80
  #val_combined_embeds = torch.cat([val_combined_embeds, val_image_embeds, img_token_embeds], dim=1)
81
  val_combined_embeds.append(val_image_embeds)
82
  val_combined_embeds.append(img_token_embeds)
83
+ if img_audio is not None:
84
  #val_combined_embeds = torch.cat([val_combined_embeds, audio_embeds], dim=1)
85
  val_combined_embeds.append(audio_embeds)
86
+ if val_q is not None:
87
  #val_combined_embeds = torch.cat([val_combined_embeds, val_q_embeds], dim=1)
88
  val_combined_embeds.append(val_q_embeds)
89
 
 
117
  # app GUI
118
  with gr.Row():
119
  with gr.Column():
120
+ img_input = gr.Image(label='Image',type="pil")
121
  img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
122
  img_question = gr.Text(label ='Text Query')
123
  with gr.Column():