Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -13,34 +13,44 @@ import torch
|
|
13 |
from loguru import logger
|
14 |
from PIL import Image
|
15 |
from peft import PeftModel
|
16 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
# Load processor (tokenizer + feature extractor)
|
21 |
processor = AutoProcessor.from_pretrained(
|
22 |
-
|
23 |
padding_side="left"
|
24 |
)
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
ignore_mismatched_sizes=True
|
31 |
-
)
|
32 |
|
|
|
33 |
model.eval()
|
34 |
|
35 |
-
|
36 |
-
# ########################################
|
37 |
-
|
38 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
39 |
|
40 |
-
|
41 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
42 |
-
image_count = 0
|
43 |
-
video_count = 0
|
44 |
for path in paths:
|
45 |
if path.endswith(".mp4"):
|
46 |
video_count += 1
|
@@ -48,10 +58,8 @@ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
|
48 |
image_count += 1
|
49 |
return image_count, video_count
|
50 |
|
51 |
-
|
52 |
def count_files_in_history(history: list[dict]) -> tuple[int, int]:
|
53 |
-
image_count = 0
|
54 |
-
video_count = 0
|
55 |
for item in history:
|
56 |
if item["role"] != "user" or isinstance(item["content"], str):
|
57 |
continue
|
@@ -61,122 +69,104 @@ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
|
|
61 |
image_count += 1
|
62 |
return image_count, video_count
|
63 |
|
64 |
-
|
65 |
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
video_count = history_video_count + new_video_count
|
70 |
-
if video_count > 1:
|
71 |
gr.Warning("Only one video is supported.")
|
72 |
return False
|
73 |
-
if
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
gr.Warning("Using <image> tags with video files is not supported.")
|
79 |
-
return False
|
80 |
-
if video_count == 0 and image_count > MAX_NUM_IMAGES:
|
81 |
-
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
|
82 |
return False
|
83 |
-
if
|
84 |
-
gr.Warning("
|
85 |
return False
|
86 |
return True
|
87 |
|
88 |
-
|
89 |
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
90 |
vidcap = cv2.VideoCapture(video_path)
|
91 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
for i in range(0, min(total_frames, MAX_NUM_IMAGES * frame_interval), frame_interval):
|
98 |
if len(frames) >= MAX_NUM_IMAGES:
|
99 |
break
|
100 |
-
|
101 |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
102 |
-
|
103 |
-
if
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
vidcap.release()
|
110 |
return frames
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
content.append({"type": "text", "text": f"Frame {timestamp}:"})
|
121 |
-
content.append({"type": "image", "url": temp_file.name})
|
122 |
-
logger.debug(f"{content=}")
|
123 |
-
return content
|
124 |
-
|
125 |
|
126 |
def process_interleaved_images(message: dict) -> list[dict]:
|
127 |
-
logger.debug(f"{message['files']=}")
|
128 |
parts = re.split(r"(<image>)", message["text"])
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
image_index += 1
|
139 |
-
elif part.strip():
|
140 |
-
content.append({"type": "text", "text": part.strip()})
|
141 |
-
elif isinstance(part, str) and part != "<image>":
|
142 |
-
content.append({"type": "text", "text": part})
|
143 |
-
logger.debug(f"{content=}")
|
144 |
-
return content
|
145 |
-
|
146 |
|
147 |
def process_new_user_message(message: dict) -> list[dict]:
|
148 |
if not message["files"]:
|
149 |
-
return [{"type":
|
150 |
-
|
151 |
if message["files"][0].endswith(".mp4"):
|
152 |
-
return [{"type":
|
153 |
-
|
154 |
if "<image>" in message["text"]:
|
155 |
return process_interleaved_images(message)
|
156 |
-
|
157 |
-
return [
|
158 |
-
{"type": "text", "text": message["text"]},
|
159 |
-
*[{"type": "image", "url": path} for path in message["files"]],
|
160 |
-
]
|
161 |
-
|
162 |
|
163 |
def process_history(history: list[dict]) -> list[dict]:
|
164 |
-
|
165 |
-
|
166 |
for item in history:
|
167 |
if item["role"] == "assistant":
|
168 |
-
if
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
else:
|
173 |
-
|
174 |
-
if isinstance(
|
175 |
-
|
176 |
else:
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
@spaces.GPU(duration=120)
|
182 |
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
|
@@ -184,34 +174,42 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
|
|
184 |
yield ""
|
185 |
return
|
186 |
|
187 |
-
|
188 |
if system_prompt:
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
return_tensors="pt",
|
|
|
199 |
).to(device=model.device, dtype=torch.bfloat16)
|
200 |
|
|
|
201 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
202 |
generate_kwargs = dict(
|
203 |
-
inputs,
|
204 |
streamer=streamer,
|
205 |
max_new_tokens=max_new_tokens,
|
206 |
)
|
207 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
208 |
t.start()
|
209 |
|
210 |
-
|
211 |
for delta in streamer:
|
212 |
-
|
213 |
-
yield
|
214 |
-
|
215 |
|
216 |
examples = [
|
217 |
[
|
@@ -269,8 +267,7 @@ examples = [
|
|
269 |
"assets/sample-images/09-2.png",
|
270 |
"assets/sample-images/09-3.png",
|
271 |
"assets/sample-images/09-4.png",
|
272 |
-
"assets/sample-images/09-5.png",
|
273 |
-
],
|
274 |
}
|
275 |
],
|
276 |
[
|
@@ -305,13 +302,13 @@ examples = [
|
|
305 |
],
|
306 |
[
|
307 |
{
|
308 |
-
"text": "
|
309 |
"files": ["assets/sample-images/01.png"],
|
310 |
}
|
311 |
],
|
312 |
[
|
313 |
{
|
314 |
-
"text": "What's the sign
|
315 |
"files": ["assets/sample-images/02.png"],
|
316 |
}
|
317 |
],
|
@@ -362,4 +359,4 @@ demo = gr.ChatInterface(
|
|
362 |
)
|
363 |
|
364 |
if __name__ == "__main__":
|
365 |
-
demo.launch()
|
|
|
13 |
from loguru import logger
|
14 |
from PIL import Image
|
15 |
from peft import PeftModel
|
16 |
+
from transformers import (
|
17 |
+
AutoProcessor,
|
18 |
+
Gemma3ForConditionalGeneration,
|
19 |
+
TextIteratorStreamer,
|
20 |
+
)
|
21 |
+
|
22 |
+
# Set model and adapter IDs
|
23 |
+
model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-pt")
|
24 |
+
adapter_id = os.getenv("ADAPTER_ID", "slavamarcin/HG_Gemma-3-12B-4bit-QLora_purpose")
|
25 |
|
26 |
+
# Load Gemma base model and move to GPU, using bfloat16
|
27 |
+
model = Gemma3ForConditionalGeneration.from_pretrained(
|
28 |
+
model_id,
|
29 |
+
torch_dtype=torch.bfloat16,
|
30 |
+
device_map="auto",
|
31 |
+
attn_implementation="eager"
|
32 |
+
).to("cuda")
|
33 |
|
34 |
# Load processor (tokenizer + feature extractor)
|
35 |
processor = AutoProcessor.from_pretrained(
|
36 |
+
model_id,
|
37 |
padding_side="left"
|
38 |
)
|
39 |
|
40 |
+
# Wrap with PEFT adapter and move to GPU
|
41 |
+
#model = PeftModel.from_pretrained(
|
42 |
+
# model,
|
43 |
+
# adapter_id,
|
44 |
+
# ignore_mismatched_sizes=True
|
45 |
+
#).to("cuda")
|
46 |
|
47 |
+
# Switch to evaluation mode
|
48 |
model.eval()
|
49 |
|
|
|
|
|
|
|
50 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
51 |
|
|
|
52 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
53 |
+
image_count = video_count = 0
|
|
|
54 |
for path in paths:
|
55 |
if path.endswith(".mp4"):
|
56 |
video_count += 1
|
|
|
58 |
image_count += 1
|
59 |
return image_count, video_count
|
60 |
|
|
|
61 |
def count_files_in_history(history: list[dict]) -> tuple[int, int]:
|
62 |
+
image_count = video_count = 0
|
|
|
63 |
for item in history:
|
64 |
if item["role"] != "user" or isinstance(item["content"], str):
|
65 |
continue
|
|
|
69 |
image_count += 1
|
70 |
return image_count, video_count
|
71 |
|
|
|
72 |
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
73 |
+
new_i, new_v = count_files_in_new_message(message["files"])
|
74 |
+
hist_i, hist_v = count_files_in_history(history)
|
75 |
+
if hist_v + new_v > 1:
|
|
|
|
|
76 |
gr.Warning("Only one video is supported.")
|
77 |
return False
|
78 |
+
if hist_v + new_v == 1 and (hist_i + new_i) > 0:
|
79 |
+
gr.Warning("Mixing images and videos is not allowed.")
|
80 |
+
return False
|
81 |
+
if "<image>" in message["text"] and message["text"].count("<image>") != new_i:
|
82 |
+
gr.Warning("The number of <image> tags doesn't match the number of images.")
|
|
|
|
|
|
|
|
|
83 |
return False
|
84 |
+
if hist_v + new_v == 0 and (hist_i + new_i) > MAX_NUM_IMAGES:
|
85 |
+
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
|
86 |
return False
|
87 |
return True
|
88 |
|
|
|
89 |
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
90 |
vidcap = cv2.VideoCapture(video_path)
|
91 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
92 |
+
total = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
93 |
+
interval = max(total // MAX_NUM_IMAGES, 1)
|
94 |
+
frames = []
|
95 |
+
for i in range(0, min(total, MAX_NUM_IMAGES * interval), interval):
|
|
|
|
|
96 |
if len(frames) >= MAX_NUM_IMAGES:
|
97 |
break
|
|
|
98 |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
99 |
+
ok, img = vidcap.read()
|
100 |
+
if not ok:
|
101 |
+
continue
|
102 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
103 |
+
pil = Image.fromarray(img)
|
104 |
+
frames.append((pil, round(i / fps, 2)))
|
|
|
105 |
vidcap.release()
|
106 |
return frames
|
107 |
|
108 |
+
def process_video(path: str) -> list[dict]:
|
109 |
+
out = []
|
110 |
+
for pil, ts in downsample_video(path):
|
111 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
|
112 |
+
pil.save(tmp.name)
|
113 |
+
out.append({"type":"text", "text":f"Frame {ts}:"})
|
114 |
+
out.append({"type":"image", "url":tmp.name})
|
115 |
+
return out
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
def process_interleaved_images(message: dict) -> list[dict]:
|
|
|
118 |
parts = re.split(r"(<image>)", message["text"])
|
119 |
+
out = []
|
120 |
+
idx = 0
|
121 |
+
for p in parts:
|
122 |
+
if p == "<image>":
|
123 |
+
out.append({"type":"image","url":message["files"][idx]})
|
124 |
+
idx += 1
|
125 |
+
elif p.strip():
|
126 |
+
out.append({"type":"text","text":p.strip()})
|
127 |
+
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
def process_new_user_message(message: dict) -> list[dict]:
|
130 |
if not message["files"]:
|
131 |
+
return [{"type":"text","text":message["text"]}]
|
|
|
132 |
if message["files"][0].endswith(".mp4"):
|
133 |
+
return [{"type":"text","text":message["text"]}] + process_video(message["files"][0])
|
|
|
134 |
if "<image>" in message["text"]:
|
135 |
return process_interleaved_images(message)
|
136 |
+
return [{"type":"text","text":message["text"]}] + [{"type":"image","url":f} for f in message["files"]]
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
def process_history(history: list[dict]) -> list[dict]:
|
139 |
+
msgs = []
|
140 |
+
user_buffer = []
|
141 |
for item in history:
|
142 |
if item["role"] == "assistant":
|
143 |
+
if user_buffer:
|
144 |
+
msgs.append({"role":"user","content":user_buffer})
|
145 |
+
user_buffer = []
|
146 |
+
msgs.append({"role":"assistant","content":[{"type":"text","text":item["content"]}]})
|
147 |
else:
|
148 |
+
cnt = item["content"]
|
149 |
+
if isinstance(cnt, str):
|
150 |
+
user_buffer.append({"type":"text","text":cnt})
|
151 |
else:
|
152 |
+
user_buffer.append({"type":"image","url":cnt[0]})
|
153 |
+
if user_buffer:
|
154 |
+
msgs.append({"role":"user","content":user_buffer})
|
155 |
+
return msgs
|
156 |
+
|
157 |
+
# Build a simple ChatML-style prompt
|
158 |
+
def build_prompt(messages: list[dict]) -> str:
|
159 |
+
prompt = ""
|
160 |
+
for msg in messages:
|
161 |
+
prompt += f"<|im_start|>{msg['role']}\n"
|
162 |
+
for part in msg["content"]:
|
163 |
+
if part["type"] == "text":
|
164 |
+
prompt += part["text"]
|
165 |
+
else: # image placeholder
|
166 |
+
prompt += "<image>"
|
167 |
+
prompt += "\n"
|
168 |
+
prompt += "<|im_end|>\n"
|
169 |
+
return prompt
|
170 |
|
171 |
@spaces.GPU(duration=120)
|
172 |
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
|
|
|
174 |
yield ""
|
175 |
return
|
176 |
|
177 |
+
msgs = []
|
178 |
if system_prompt:
|
179 |
+
msgs.append({"role":"system","content":[{"type":"text","text":system_prompt}]})
|
180 |
+
msgs += process_history(history)
|
181 |
+
msgs.append({"role":"user","content":process_new_user_message(message)})
|
182 |
+
|
183 |
+
# Build text prompt and collect images
|
184 |
+
prompt = build_prompt(msgs)
|
185 |
+
images = []
|
186 |
+
for m in msgs:
|
187 |
+
for part in m["content"]:
|
188 |
+
if part["type"] == "image":
|
189 |
+
images.append(Image.open(part["url"]))
|
190 |
+
|
191 |
+
# Encode multimodal inputs directly
|
192 |
+
inputs = processor(
|
193 |
+
text=prompt,
|
194 |
+
images=images if images else None,
|
195 |
return_tensors="pt",
|
196 |
+
padding=True
|
197 |
).to(device=model.device, dtype=torch.bfloat16)
|
198 |
|
199 |
+
# Stream generation
|
200 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
201 |
generate_kwargs = dict(
|
202 |
+
**inputs,
|
203 |
streamer=streamer,
|
204 |
max_new_tokens=max_new_tokens,
|
205 |
)
|
206 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
207 |
t.start()
|
208 |
|
209 |
+
out = ""
|
210 |
for delta in streamer:
|
211 |
+
out += delta
|
212 |
+
yield out
|
|
|
213 |
|
214 |
examples = [
|
215 |
[
|
|
|
267 |
"assets/sample-images/09-2.png",
|
268 |
"assets/sample-images/09-3.png",
|
269 |
"assets/sample-images/09-4.png",
|
270 |
+
"assets/sample-images/09-5.png"],
|
|
|
271 |
}
|
272 |
],
|
273 |
[
|
|
|
302 |
],
|
303 |
[
|
304 |
{
|
305 |
+
"text": "Caption this image",
|
306 |
"files": ["assets/sample-images/01.png"],
|
307 |
}
|
308 |
],
|
309 |
[
|
310 |
{
|
311 |
+
"text": "What's the sign say?",
|
312 |
"files": ["assets/sample-images/02.png"],
|
313 |
}
|
314 |
],
|
|
|
359 |
)
|
360 |
|
361 |
if __name__ == "__main__":
|
362 |
+
demo.launch(share=True)
|