momergul commited on
Commit
f644df7
·
1 Parent(s): c1f8aba
Files changed (2) hide show
  1. app.py +7 -7
  2. joint_inference.py +6 -0
app.py CHANGED
@@ -23,7 +23,7 @@ css="""
23
  def initialize_game() -> List[List[str]]:
24
  context_dicts = [generate_complete_game() for _ in range(2)]
25
 
26
- roles = ["listener"] * 3 + ["speaker"] * 3
27
  speaker_images = []
28
  listener_images = []
29
  targets = []
@@ -40,7 +40,6 @@ def get_model_response(
40
  model, adapter_name, processor, index_to_token, role: str,
41
  image_paths: List[str], user_message: str = "", target_image: str = ""
42
  ) -> str:
43
- model.model.set_adapter(adapter_name)
44
  if role == "speaker":
45
  img_dir = "tangram_pngs"
46
  print("Starting processing")
@@ -50,7 +49,7 @@ def get_model_response(
50
  image_paths = [image_paths]
51
  print("Starting inference")
52
  captions = get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths,
53
- processor, img_dir, index_to_token)
54
  print("Done")
55
  response = captions[0]
56
  else: # listener
@@ -63,14 +62,15 @@ def get_model_response(
63
  print("Starting inference")
64
  response = get_listener_response(
65
  model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
66
- s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths
67
  )
68
  print("Done")
69
 
70
  return response
71
 
72
  @spaces.GPU(duration=20)
73
- def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token):
 
74
  model = model.cuda()
75
  with torch.no_grad():
76
  captions, _, _, _, _ = model.generate(
@@ -83,10 +83,10 @@ def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask
83
 
84
  @spaces.GPU(duration=20)
85
  def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
86
- s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths):
 
87
  model = model.cuda()
88
  with torch.no_grad():
89
- print(model.model.device, images.device)
90
  _, _, joint_log_probs = model.comprehension_side([
91
  images.cuda(), l_input_tokens.cuda(), l_attn_mask.cuda(), l_image_attn_mask.cuda(), index_to_token,
92
  s_input_tokens.cuda(), s_attn_mask.cuda(), s_image_attn_mask.cuda(), s_target_mask.cuda(), s_target_label.cuda(),
 
23
  def initialize_game() -> List[List[str]]:
24
  context_dicts = [generate_complete_game() for _ in range(2)]
25
 
26
+ roles = ["speaker"] * 3 + ["listener"] * 3
27
  speaker_images = []
28
  listener_images = []
29
  targets = []
 
40
  model, adapter_name, processor, index_to_token, role: str,
41
  image_paths: List[str], user_message: str = "", target_image: str = ""
42
  ) -> str:
 
43
  if role == "speaker":
44
  img_dir = "tangram_pngs"
45
  print("Starting processing")
 
49
  image_paths = [image_paths]
50
  print("Starting inference")
51
  captions = get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths,
52
+ processor, img_dir, index_to_token, adapter_name)
53
  print("Done")
54
  response = captions[0]
55
  else: # listener
 
62
  print("Starting inference")
63
  response = get_listener_response(
64
  model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
65
+ s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name
66
  )
67
  print("Done")
68
 
69
  return response
70
 
71
  @spaces.GPU(duration=20)
72
+ def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token, adapter_name):
73
+ model.model.set_adapter(adapter_name)
74
  model = model.cuda()
75
  with torch.no_grad():
76
  captions, _, _, _, _ = model.generate(
 
83
 
84
  @spaces.GPU(duration=20)
85
  def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
86
+ s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name):
87
+ model.model.set_adapter(adapter_name)
88
  model = model.cuda()
89
  with torch.no_grad():
 
90
  _, _, joint_log_probs = model.comprehension_side([
91
  images.cuda(), l_input_tokens.cuda(), l_attn_mask.cuda(), l_image_attn_mask.cuda(), index_to_token,
92
  s_input_tokens.cuda(), s_attn_mask.cuda(), s_image_attn_mask.cuda(), s_target_mask.cuda(), s_target_label.cuda(),
joint_inference.py CHANGED
@@ -436,6 +436,12 @@ class IdeficsJointInferenceModel(nn.Module):
436
  output_hidden_states=True,
437
  return_dict_in_generate=True
438
  )
 
 
 
 
 
 
439
  outputs = speaker.generate(
440
  input_ids=s_input_tokens,
441
  attention_mask=s_attn_mask,
 
436
  output_hidden_states=True,
437
  return_dict_in_generate=True
438
  )
439
+
440
+ print(torch.any(torch.isnan(s_input_tokens)))
441
+ print(torch.any(torch.isnan(s_attn_mask)))
442
+ print(torch.any(torch.isnan(images)))
443
+ print(torch.any(torch.isnan(s_image_attn_mask)))
444
+
445
  outputs = speaker.generate(
446
  input_ids=s_input_tokens,
447
  attention_mask=s_attn_mask,