momergul commited on
Commit
c2b3ecf
·
1 Parent(s): ddfedae
Files changed (1) hide show
  1. app.py +38 -19
app.py CHANGED
@@ -42,40 +42,59 @@ def get_model_response(
42
  image_paths: List[str], user_message: str = "", target_image: str = ""
43
  ) -> str:
44
  model.model.set_adapter(adapter_name)
45
- print(model.model.active_adapter)
46
  if role == "speaker":
47
  img_dir = "tangram_pngs"
 
48
  input_tokens, attn_mask, images, image_attn_mask, label = joint_speaker_input(
49
  processor, image_paths, target_image, model.get_listener().device
50
  )
51
- print("Hi")
52
- with torch.no_grad():
53
- image_paths = [image_paths]
54
- captions, _, _, _, _ = model.generate(
55
- images, input_tokens, attn_mask, image_attn_mask, label,
56
- image_paths, processor, img_dir, index_to_token,
57
- max_steps=30, sampling_type="nucleus", temperature=0.7,
58
- top_k=50, top_p=1, repetition_penalty=1, num_samples=5
59
- )
60
- print("There")
61
  response = captions[0]
62
  else: # listener
 
63
  images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens, s_attn_mask, \
64
  s_image_attn_mask, s_target_mask, s_target_label = joint_listener_input(
65
  processor, image_paths, user_message, model.get_listener().device
66
  )
67
 
68
- with torch.no_grad():
69
- # Forward
70
- _, _, joint_log_probs = model.comprehension_side([
71
- images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
72
- s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label,
73
- ])
74
- target_idx = joint_log_probs[0].argmax().item()
75
- response = image_paths[target_idx]
76
 
77
  return response
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def interaction(model, processor, index_to_token, model_iteration: str) -> Tuple[List[str], List[str]]:
80
  image_role_pairs = initialize_game()
81
  conversation = []
 
42
  image_paths: List[str], user_message: str = "", target_image: str = ""
43
  ) -> str:
44
  model.model.set_adapter(adapter_name)
 
45
  if role == "speaker":
46
  img_dir = "tangram_pngs"
47
+ print("Starting processing")
48
  input_tokens, attn_mask, images, image_attn_mask, label = joint_speaker_input(
49
  processor, image_paths, target_image, model.get_listener().device
50
  )
51
+ image_paths = [image_paths]
52
+ print("Starting inference")
53
+ captions = get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths,
54
+ processor, img_dir, index_to_token)
55
+ print("Done")
 
 
 
 
 
56
  response = captions[0]
57
  else: # listener
58
+ print("Starting processing")
59
  images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens, s_attn_mask, \
60
  s_image_attn_mask, s_target_mask, s_target_label = joint_listener_input(
61
  processor, image_paths, user_message, model.get_listener().device
62
  )
63
 
64
+ print("Starting inference")
65
+ response = get_listener_response(
66
+ model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
67
+ s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths
68
+ )
69
+ print("Done")
 
 
70
 
71
  return response
72
 
73
+ @spaces.GPU(duration=15)
74
+ def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token):
75
+ with torch.no_grad():
76
+ print(model.model.device, images.device)
77
+ captions, _, _, _, _ = model.generate(
78
+ images, input_tokens, attn_mask, image_attn_mask, label,
79
+ image_paths, processor, img_dir, index_to_token,
80
+ max_steps=30, sampling_type="nucleus", temperature=0.7,
81
+ top_k=50, top_p=1, repetition_penalty=1, num_samples=5
82
+ )
83
+ return captions
84
+
85
+ @spaces.GPU(duration=15)
86
+ def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
87
+ s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths):
88
+ with torch.no_grad():
89
+ print(model.model.device, images.device)
90
+ _, _, joint_log_probs = model.comprehension_side([
91
+ images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
92
+ s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label,
93
+ ])
94
+ target_idx = joint_log_probs[0].argmax().item()
95
+ response = image_paths[target_idx]
96
+ return response
97
+
98
  def interaction(model, processor, index_to_token, model_iteration: str) -> Tuple[List[str], List[str]]:
99
  image_role_pairs = initialize_game()
100
  conversation = []