momergul commited on
Commit
d1a5104
·
1 Parent(s): 72c2e5e

Pushed all model initialization to the main app

Browse files
Files changed (1) hide show
  1. app.py +45 -11
app.py CHANGED
@@ -8,7 +8,45 @@ from typing import List, Tuple
8
 
9
  from config_generator import generate_complete_game
10
  from dataset import get_processor, joint_speaker_input, joint_listener_input, get_index_to_token
11
- from models import get_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  css="""
14
  .radio-group .wrap {
@@ -70,9 +108,8 @@ def get_model_response(
70
 
71
  @spaces.GPU(duration=20)
72
  def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token, adapter_name):
73
- model.model.set_adapter(adapter_name)
74
- print(adapter_name)
75
- model = model.cuda()
76
  with torch.no_grad():
77
  captions, _, _, _, _ = model.generate(
78
  images.cuda(), input_tokens.cuda(), attn_mask.cuda(), image_attn_mask.cuda(), label.cuda(),
@@ -85,9 +122,8 @@ def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask
85
  @spaces.GPU(duration=20)
86
  def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
87
  s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name):
88
- model.model.set_adapter(adapter_name)
89
- print(adapter_name)
90
- model = model.cuda()
91
  with torch.no_grad():
92
  _, _, joint_log_probs = model.comprehension_side([
93
  images.cuda(), l_input_tokens.cuda(), l_attn_mask.cuda(), l_image_attn_mask.cuda(), index_to_token,
@@ -119,7 +155,7 @@ def initialize_interaction(model_iteration):
119
 
120
  return new_history
121
 
122
- def progress_game(user_message, model, processor, index_to_token, current_state):
123
  # First get the game state
124
  turn = current_state['turn']
125
  image_role_pairs = current_state['image_role_pairs']
@@ -257,7 +293,6 @@ def create_app():
257
  )
258
 
259
  send_btn = gr.Button("Send", interactive=False)
260
- model = get_model()
261
  processor = get_processor()
262
  index_to_token = get_index_to_token()
263
 
@@ -281,7 +316,6 @@ def create_app():
281
  gr.update(interactive=not human_listener), gr.update(interactive=human_listener), gr.update(interactive=True), gr.update(interactive=False), current_history
282
 
283
  def send_message(message, radio_choice, current_state):
284
- nonlocal model
285
  nonlocal processor
286
  nonlocal index_to_token
287
 
@@ -292,7 +326,7 @@ def create_app():
292
 
293
  # Regular game progress
294
  user_output = message if radio_choice is None else radio_choice
295
- images, conversation, role, turn, acc_message, current_state = progress_game(user_output, model, processor, index_to_token, current_state)
296
  human_listener = role == "Listener"
297
  return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, \
298
  acc_message, gr.update(interactive=not human_listener, value=""), gr.update(interactive=human_listener, value=None), \
 
8
 
9
  from config_generator import generate_complete_game
10
  from dataset import get_processor, joint_speaker_input, joint_listener_input, get_index_to_token
11
+
12
+ import torch
13
+ import transformers
14
+ from transformers import Idefics2ForConditionalGeneration
15
+ from peft import LoraConfig, get_peft_model
16
+ from joint_inference import IdeficsJointInferenceModel
17
+
18
+ # Initialize the model globally
19
+ repo = 'lil-lab/cogen'
20
+ checkpoint = "HuggingFaceM4/idefics2-8b"
21
+ model = Idefics2ForConditionalGeneration.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
22
+
23
+ target_modules=r'(.*(vision_model|modality_projection|perceiver_resampler).*(out_proj|fc1|fc2|down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)|(.*(k_proj|q_proj|v_proj).*$)'
24
+ lora_config = LoraConfig(
25
+ r=16, lora_alpha=8,
26
+ lora_dropout=0.1,
27
+ target_modules=target_modules,
28
+ init_lora_weights="gaussian"
29
+ )
30
+ model = get_peft_model(model, lora_config, adapter_name="initial")
31
+ model.load_adapter(repo, "initial", revision="r0_full")
32
+
33
+ # Add other adapter
34
+ new_targets = set()
35
+ for n, p in model.named_parameters():
36
+ if 'lora' in n:
37
+ new_targets.add(n[17:n.find('lora')-1])
38
+ new_targets = list(new_targets)
39
+
40
+ lora_config = LoraConfig(
41
+ r=16, lora_alpha=8,
42
+ lora_dropout=0.1,
43
+ target_modules=new_targets,
44
+ init_lora_weights="gaussian"
45
+ )
46
+ model.add_adapter('final', lora_config)
47
+ model.load_adapter(repo, "final", revision="r3_full")
48
+ model = IdeficsJointInferenceModel(0.5, 0, model=model).cuda()
49
+ model.eval()
50
 
51
  css="""
52
  .radio-group .wrap {
 
108
 
109
  @spaces.GPU(duration=20)
110
  def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token, adapter_name):
111
+ if model.model.active_adapter != adapter_name:
112
+ model.model.set_adapter(adapter_name)
 
113
  with torch.no_grad():
114
  captions, _, _, _, _ = model.generate(
115
  images.cuda(), input_tokens.cuda(), attn_mask.cuda(), image_attn_mask.cuda(), label.cuda(),
 
122
  @spaces.GPU(duration=20)
123
  def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
124
  s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name):
125
+ if model.model.active_adapter != adapter_name:
126
+ model.model.set_adapter(adapter_name)
 
127
  with torch.no_grad():
128
  _, _, joint_log_probs = model.comprehension_side([
129
  images.cuda(), l_input_tokens.cuda(), l_attn_mask.cuda(), l_image_attn_mask.cuda(), index_to_token,
 
155
 
156
  return new_history
157
 
158
+ def progress_game(user_message, processor, index_to_token, current_state):
159
  # First get the game state
160
  turn = current_state['turn']
161
  image_role_pairs = current_state['image_role_pairs']
 
293
  )
294
 
295
  send_btn = gr.Button("Send", interactive=False)
 
296
  processor = get_processor()
297
  index_to_token = get_index_to_token()
298
 
 
316
  gr.update(interactive=not human_listener), gr.update(interactive=human_listener), gr.update(interactive=True), gr.update(interactive=False), current_history
317
 
318
  def send_message(message, radio_choice, current_state):
 
319
  nonlocal processor
320
  nonlocal index_to_token
321
 
 
326
 
327
  # Regular game progress
328
  user_output = message if radio_choice is None else radio_choice
329
+ images, conversation, role, turn, acc_message, current_state = progress_game(user_output, processor, index_to_token, current_state)
330
  human_listener = role == "Listener"
331
  return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, \
332
  acc_message, gr.update(interactive=not human_listener, value=""), gr.update(interactive=human_listener, value=None), \