rmdhirr commited on
Commit
d35b974
·
verified ·
1 Parent(s): 8be4895

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -123
app.py CHANGED
@@ -13,34 +13,44 @@ import torch
13
  from loguru import logger
14
  from PIL import Image
15
  from peft import PeftModel
16
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer, AutoModelForImageTextToText
 
 
 
 
 
 
 
 
17
 
18
- adapter_id = "rmdhirr/test4bit6ab"
 
 
 
 
 
 
19
 
20
  # Load processor (tokenizer + feature extractor)
21
  processor = AutoProcessor.from_pretrained(
22
- adapter_id,
23
  padding_side="left"
24
  )
25
 
26
- model = Gemma3ForConditionalGeneration.from_pretrained(
27
- adapter_id, # e.g. "rmdhirr/test4bit-b"
28
- torch_dtype=torch.bfloat16, # same dtype you were using
29
- device_map="auto", # or however you shard
30
- ignore_mismatched_sizes=True # only if you still see tiny shape warnings
31
- )
32
 
 
33
  model.eval()
34
 
35
-
36
- # ########################################
37
-
38
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
39
 
40
-
41
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
42
- image_count = 0
43
- video_count = 0
44
  for path in paths:
45
  if path.endswith(".mp4"):
46
  video_count += 1
@@ -48,10 +58,8 @@ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
48
  image_count += 1
49
  return image_count, video_count
50
 
51
-
52
  def count_files_in_history(history: list[dict]) -> tuple[int, int]:
53
- image_count = 0
54
- video_count = 0
55
  for item in history:
56
  if item["role"] != "user" or isinstance(item["content"], str):
57
  continue
@@ -61,122 +69,104 @@ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
61
  image_count += 1
62
  return image_count, video_count
63
 
64
-
65
  def validate_media_constraints(message: dict, history: list[dict]) -> bool:
66
- new_image_count, new_video_count = count_files_in_new_message(message["files"])
67
- history_image_count, history_video_count = count_files_in_history(history)
68
- image_count = history_image_count + new_image_count
69
- video_count = history_video_count + new_video_count
70
- if video_count > 1:
71
  gr.Warning("Only one video is supported.")
72
  return False
73
- if video_count == 1:
74
- if image_count > 0:
75
- gr.Warning("Mixing images and videos is not allowed.")
76
- return False
77
- if "<image>" in message["text"]:
78
- gr.Warning("Using <image> tags with video files is not supported.")
79
- return False
80
- if video_count == 0 and image_count > MAX_NUM_IMAGES:
81
- gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
82
  return False
83
- if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
84
- gr.Warning("The number of <image> tags in the text does not match the number of images.")
85
  return False
86
  return True
87
 
88
-
89
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
90
  vidcap = cv2.VideoCapture(video_path)
91
  fps = vidcap.get(cv2.CAP_PROP_FPS)
92
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
93
-
94
- frame_interval = max(total_frames // MAX_NUM_IMAGES, 1)
95
- frames: list[tuple[Image.Image, float]] = []
96
-
97
- for i in range(0, min(total_frames, MAX_NUM_IMAGES * frame_interval), frame_interval):
98
  if len(frames) >= MAX_NUM_IMAGES:
99
  break
100
-
101
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
102
- success, image = vidcap.read()
103
- if success:
104
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
105
- pil_image = Image.fromarray(image)
106
- timestamp = round(i / fps, 2)
107
- frames.append((pil_image, timestamp))
108
-
109
  vidcap.release()
110
  return frames
111
 
112
-
113
- def process_video(video_path: str) -> list[dict]:
114
- content = []
115
- frames = downsample_video(video_path)
116
- for frame in frames:
117
- pil_image, timestamp = frame
118
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
119
- pil_image.save(temp_file.name)
120
- content.append({"type": "text", "text": f"Frame {timestamp}:"})
121
- content.append({"type": "image", "url": temp_file.name})
122
- logger.debug(f"{content=}")
123
- return content
124
-
125
 
126
  def process_interleaved_images(message: dict) -> list[dict]:
127
- logger.debug(f"{message['files']=}")
128
  parts = re.split(r"(<image>)", message["text"])
129
- logger.debug(f"{parts=}")
130
-
131
- content = []
132
- image_index = 0
133
- for part in parts:
134
- logger.debug(f"{part=}")
135
- if part == "<image>":
136
- content.append({"type": "image", "url": message["files"][image_index]})
137
- logger.debug(f"file: {message['files'][image_index]}")
138
- image_index += 1
139
- elif part.strip():
140
- content.append({"type": "text", "text": part.strip()})
141
- elif isinstance(part, str) and part != "<image>":
142
- content.append({"type": "text", "text": part})
143
- logger.debug(f"{content=}")
144
- return content
145
-
146
 
147
  def process_new_user_message(message: dict) -> list[dict]:
148
  if not message["files"]:
149
- return [{"type": "text", "text": message["text"]}]
150
-
151
  if message["files"][0].endswith(".mp4"):
152
- return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
153
-
154
  if "<image>" in message["text"]:
155
  return process_interleaved_images(message)
156
-
157
- return [
158
- {"type": "text", "text": message["text"]},
159
- *[{"type": "image", "url": path} for path in message["files"]],
160
- ]
161
-
162
 
163
  def process_history(history: list[dict]) -> list[dict]:
164
- messages = []
165
- current_user_content: list[dict] = []
166
  for item in history:
167
  if item["role"] == "assistant":
168
- if current_user_content:
169
- messages.append({"role": "user", "content": current_user_content})
170
- current_user_content = []
171
- messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
172
  else:
173
- content = item["content"]
174
- if isinstance(content, str):
175
- current_user_content.append({"type": "text", "text": content})
176
  else:
177
- current_user_content.append({"type": "image", "url": content[0]})
178
- return messages
179
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  @spaces.GPU(duration=120)
182
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
@@ -184,34 +174,42 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
184
  yield ""
185
  return
186
 
187
- messages = []
188
  if system_prompt:
189
- messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
190
- messages.extend(process_history(history))
191
- messages.append({"role": "user", "content": process_new_user_message(message)})
192
-
193
- inputs = processor.apply_chat_template(
194
- messages,
195
- add_generation_prompt=True,
196
- tokenize=True,
197
- return_dict=True,
 
 
 
 
 
 
 
198
  return_tensors="pt",
 
199
  ).to(device=model.device, dtype=torch.bfloat16)
200
 
 
201
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
202
  generate_kwargs = dict(
203
- inputs,
204
  streamer=streamer,
205
  max_new_tokens=max_new_tokens,
206
  )
207
  t = Thread(target=model.generate, kwargs=generate_kwargs)
208
  t.start()
209
 
210
- output = ""
211
  for delta in streamer:
212
- output += delta
213
- yield output
214
-
215
 
216
  examples = [
217
  [
@@ -269,8 +267,7 @@ examples = [
269
  "assets/sample-images/09-2.png",
270
  "assets/sample-images/09-3.png",
271
  "assets/sample-images/09-4.png",
272
- "assets/sample-images/09-5.png",
273
- ],
274
  }
275
  ],
276
  [
@@ -305,13 +302,13 @@ examples = [
305
  ],
306
  [
307
  {
308
- "text": "caption this image",
309
  "files": ["assets/sample-images/01.png"],
310
  }
311
  ],
312
  [
313
  {
314
- "text": "What's the sign says?",
315
  "files": ["assets/sample-images/02.png"],
316
  }
317
  ],
@@ -362,4 +359,4 @@ demo = gr.ChatInterface(
362
  )
363
 
364
  if __name__ == "__main__":
365
- demo.launch()
 
13
  from loguru import logger
14
  from PIL import Image
15
  from peft import PeftModel
16
+ from transformers import (
17
+ AutoProcessor,
18
+ Gemma3ForConditionalGeneration,
19
+ TextIteratorStreamer,
20
+ )
21
+
22
+ # Set model and adapter IDs
23
+ model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-pt")
24
+ adapter_id = os.getenv("ADAPTER_ID", "slavamarcin/HG_Gemma-3-12B-4bit-QLora_purpose")
25
 
26
+ # Load Gemma base model and move to GPU, using bfloat16
27
+ model = Gemma3ForConditionalGeneration.from_pretrained(
28
+ model_id,
29
+ torch_dtype=torch.bfloat16,
30
+ device_map="auto",
31
+ attn_implementation="eager"
32
+ ).to("cuda")
33
 
34
  # Load processor (tokenizer + feature extractor)
35
  processor = AutoProcessor.from_pretrained(
36
+ model_id,
37
  padding_side="left"
38
  )
39
 
40
+ # Wrap with PEFT adapter and move to GPU
41
+ #model = PeftModel.from_pretrained(
42
+ # model,
43
+ # adapter_id,
44
+ # ignore_mismatched_sizes=True
45
+ #).to("cuda")
46
 
47
+ # Switch to evaluation mode
48
  model.eval()
49
 
 
 
 
50
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
51
 
 
52
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
53
+ image_count = video_count = 0
 
54
  for path in paths:
55
  if path.endswith(".mp4"):
56
  video_count += 1
 
58
  image_count += 1
59
  return image_count, video_count
60
 
 
61
  def count_files_in_history(history: list[dict]) -> tuple[int, int]:
62
+ image_count = video_count = 0
 
63
  for item in history:
64
  if item["role"] != "user" or isinstance(item["content"], str):
65
  continue
 
69
  image_count += 1
70
  return image_count, video_count
71
 
 
72
  def validate_media_constraints(message: dict, history: list[dict]) -> bool:
73
+ new_i, new_v = count_files_in_new_message(message["files"])
74
+ hist_i, hist_v = count_files_in_history(history)
75
+ if hist_v + new_v > 1:
 
 
76
  gr.Warning("Only one video is supported.")
77
  return False
78
+ if hist_v + new_v == 1 and (hist_i + new_i) > 0:
79
+ gr.Warning("Mixing images and videos is not allowed.")
80
+ return False
81
+ if "<image>" in message["text"] and message["text"].count("<image>") != new_i:
82
+ gr.Warning("The number of <image> tags doesn't match the number of images.")
 
 
 
 
83
  return False
84
+ if hist_v + new_v == 0 and (hist_i + new_i) > MAX_NUM_IMAGES:
85
+ gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
86
  return False
87
  return True
88
 
 
89
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
90
  vidcap = cv2.VideoCapture(video_path)
91
  fps = vidcap.get(cv2.CAP_PROP_FPS)
92
+ total = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
93
+ interval = max(total // MAX_NUM_IMAGES, 1)
94
+ frames = []
95
+ for i in range(0, min(total, MAX_NUM_IMAGES * interval), interval):
 
 
96
  if len(frames) >= MAX_NUM_IMAGES:
97
  break
 
98
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
99
+ ok, img = vidcap.read()
100
+ if not ok:
101
+ continue
102
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
103
+ pil = Image.fromarray(img)
104
+ frames.append((pil, round(i / fps, 2)))
 
105
  vidcap.release()
106
  return frames
107
 
108
+ def process_video(path: str) -> list[dict]:
109
+ out = []
110
+ for pil, ts in downsample_video(path):
111
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
112
+ pil.save(tmp.name)
113
+ out.append({"type":"text", "text":f"Frame {ts}:"})
114
+ out.append({"type":"image", "url":tmp.name})
115
+ return out
 
 
 
 
 
116
 
117
  def process_interleaved_images(message: dict) -> list[dict]:
 
118
  parts = re.split(r"(<image>)", message["text"])
119
+ out = []
120
+ idx = 0
121
+ for p in parts:
122
+ if p == "<image>":
123
+ out.append({"type":"image","url":message["files"][idx]})
124
+ idx += 1
125
+ elif p.strip():
126
+ out.append({"type":"text","text":p.strip()})
127
+ return out
 
 
 
 
 
 
 
 
128
 
129
  def process_new_user_message(message: dict) -> list[dict]:
130
  if not message["files"]:
131
+ return [{"type":"text","text":message["text"]}]
 
132
  if message["files"][0].endswith(".mp4"):
133
+ return [{"type":"text","text":message["text"]}] + process_video(message["files"][0])
 
134
  if "<image>" in message["text"]:
135
  return process_interleaved_images(message)
136
+ return [{"type":"text","text":message["text"]}] + [{"type":"image","url":f} for f in message["files"]]
 
 
 
 
 
137
 
138
  def process_history(history: list[dict]) -> list[dict]:
139
+ msgs = []
140
+ user_buffer = []
141
  for item in history:
142
  if item["role"] == "assistant":
143
+ if user_buffer:
144
+ msgs.append({"role":"user","content":user_buffer})
145
+ user_buffer = []
146
+ msgs.append({"role":"assistant","content":[{"type":"text","text":item["content"]}]})
147
  else:
148
+ cnt = item["content"]
149
+ if isinstance(cnt, str):
150
+ user_buffer.append({"type":"text","text":cnt})
151
  else:
152
+ user_buffer.append({"type":"image","url":cnt[0]})
153
+ if user_buffer:
154
+ msgs.append({"role":"user","content":user_buffer})
155
+ return msgs
156
+
157
+ # Build a simple ChatML-style prompt
158
+ def build_prompt(messages: list[dict]) -> str:
159
+ prompt = ""
160
+ for msg in messages:
161
+ prompt += f"<|im_start|>{msg['role']}\n"
162
+ for part in msg["content"]:
163
+ if part["type"] == "text":
164
+ prompt += part["text"]
165
+ else: # image placeholder
166
+ prompt += "<image>"
167
+ prompt += "\n"
168
+ prompt += "<|im_end|>\n"
169
+ return prompt
170
 
171
  @spaces.GPU(duration=120)
172
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
 
174
  yield ""
175
  return
176
 
177
+ msgs = []
178
  if system_prompt:
179
+ msgs.append({"role":"system","content":[{"type":"text","text":system_prompt}]})
180
+ msgs += process_history(history)
181
+ msgs.append({"role":"user","content":process_new_user_message(message)})
182
+
183
+ # Build text prompt and collect images
184
+ prompt = build_prompt(msgs)
185
+ images = []
186
+ for m in msgs:
187
+ for part in m["content"]:
188
+ if part["type"] == "image":
189
+ images.append(Image.open(part["url"]))
190
+
191
+ # Encode multimodal inputs directly
192
+ inputs = processor(
193
+ text=prompt,
194
+ images=images if images else None,
195
  return_tensors="pt",
196
+ padding=True
197
  ).to(device=model.device, dtype=torch.bfloat16)
198
 
199
+ # Stream generation
200
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
201
  generate_kwargs = dict(
202
+ **inputs,
203
  streamer=streamer,
204
  max_new_tokens=max_new_tokens,
205
  )
206
  t = Thread(target=model.generate, kwargs=generate_kwargs)
207
  t.start()
208
 
209
+ out = ""
210
  for delta in streamer:
211
+ out += delta
212
+ yield out
 
213
 
214
  examples = [
215
  [
 
267
  "assets/sample-images/09-2.png",
268
  "assets/sample-images/09-3.png",
269
  "assets/sample-images/09-4.png",
270
+ "assets/sample-images/09-5.png"],
 
271
  }
272
  ],
273
  [
 
302
  ],
303
  [
304
  {
305
+ "text": "Caption this image",
306
  "files": ["assets/sample-images/01.png"],
307
  }
308
  ],
309
  [
310
  {
311
+ "text": "What's the sign say?",
312
  "files": ["assets/sample-images/02.png"],
313
  }
314
  ],
 
359
  )
360
 
361
  if __name__ == "__main__":
362
+ demo.launch(share=True)