cdnuts commited on
Commit
08a2d0b
1 Parent(s): 683b0c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -27
app.py CHANGED
@@ -1,6 +1,12 @@
1
  import json
 
 
 
 
 
2
 
3
  import gradio as gr
 
4
  from PIL import Image
5
  import safetensors.torch
6
  import spaces
@@ -10,9 +16,53 @@ import torch
10
  from torchvision.transforms import transforms
11
  from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
 
 
 
 
 
13
 
14
  torch.set_grad_enabled(False)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class Fit(torch.nn.Module):
17
  def __init__(
18
  self,
@@ -138,6 +188,8 @@ class GatedHead(torch.nn.Module):
138
  model.head = GatedHead(min(model.head.weight.shape), 9083)
139
 
140
  safetensors.torch.load_model(model, "JTP_PILOT2-2-e3-vit_so400m_patch14_siglip_384.safetensors")
 
 
141
  model.eval()
142
 
143
  with open("tagger_tags.json", "r") as file:
@@ -149,11 +201,11 @@ for idx, tag in enumerate(allowed_tags):
149
 
150
  sorted_tag_score = {}
151
 
152
- @spaces.GPU(duration=5)
153
  def run_classifier(image, threshold):
154
  global sorted_tag_score
155
  img = image.convert('RGBA')
156
- tensor = transform(img).unsqueeze(0)
157
 
158
  with torch.no_grad():
159
  probits = model(tensor)[0]
@@ -177,6 +229,83 @@ def clear_image():
177
  sorted_tag_score = {}
178
  return "", {}
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  with gr.Blocks(css=".output-class { display: none; }") as demo:
181
  gr.Markdown("""
182
  ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
@@ -186,31 +315,43 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
186
 
187
  Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
188
  """)
189
- with gr.Row():
190
- with gr.Column():
191
- image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
192
- threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
193
- with gr.Column():
194
- tag_string = gr.Textbox(label="Tag String")
195
- label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
196
-
197
- image_input.upload(
198
- fn=run_classifier,
199
- inputs=[image_input, threshold_slider],
200
- outputs=[tag_string, label_box]
201
- )
202
-
203
- image_input.clear(
204
- fn=clear_image,
205
- inputs=[],
206
- outputs=[tag_string, label_box]
207
- )
208
-
209
- threshold_slider.input(
210
- fn=create_tags,
211
- inputs=[threshold_slider],
212
- outputs=[tag_string, label_box]
213
- )
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  if __name__ == "__main__":
216
  demo.launch()
 
1
  import json
2
+ import os
3
+ import zipfile
4
+ from io import BytesIO
5
+ from tempfile import NamedTemporaryFile
6
+ import tempfile
7
 
8
  import gradio as gr
9
+ import pandas as pd
10
  from PIL import Image
11
  import safetensors.torch
12
  import spaces
 
16
  from torchvision.transforms import transforms
17
  from torchvision.transforms import InterpolationMode
18
  import torchvision.transforms.functional as TF
19
+ from torch.utils.data import Dataset, DataLoader
20
+ from typing import Callable
21
+ from functools import partial
22
+ import spaces.config
23
+ from spaces.zero.decorator import P, R
24
 
25
  torch.set_grad_enabled(False)
26
 
27
+ def _dynGPU(
28
+ fn: Callable[P, R] | None, duration: Callable[P, int], min=30, max=300, step=10
29
+ ) -> Callable[P, R]:
30
+ if not spaces.config.Config.zero_gpu:
31
+ return fn
32
+
33
+ funcs = [
34
+ (t, spaces.GPU(duration=t)(lambda *args, **kwargs: fn(*args, **kwargs)))
35
+ for t in range(min, max + 1, step)
36
+ ]
37
+
38
+ def wrapper(*args, **kwargs):
39
+ requirement = duration(*args, **kwargs)
40
+
41
+ # find the function that satisfies the duration requirement
42
+ for t, func in funcs:
43
+ if t >= requirement:
44
+ gr.Info(f"Acquiring ZeroGPU for {t} seconds")
45
+ return func(*args, **kwargs)
46
+
47
+ # if no function is found, return the last one
48
+ gr.Info(f"Acquiring ZeroGPU for {funcs[-1][0]} seconds")
49
+ return funcs[-1][1](*args, **kwargs)
50
+
51
+ return wrapper
52
+
53
+
54
+ def dynGPU(
55
+ fn: Callable[P, R] | None = None,
56
+ duration: Callable[P, int] = lambda: 60,
57
+ min=30,
58
+ max=300,
59
+ step=10,
60
+ ) -> Callable[P, R]:
61
+ if fn is None:
62
+ return partial(_dynGPU, duration=duration, min=min, max=max, step=step)
63
+ return _dynGPU(fn, duration, min, max, step)
64
+
65
+
66
  class Fit(torch.nn.Module):
67
  def __init__(
68
  self,
 
188
  model.head = GatedHead(min(model.head.weight.shape), 9083)
189
 
190
  safetensors.torch.load_model(model, "JTP_PILOT2-2-e3-vit_so400m_patch14_siglip_384.safetensors")
191
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
192
+ model.to(device)
193
  model.eval()
194
 
195
  with open("tagger_tags.json", "r") as file:
 
201
 
202
  sorted_tag_score = {}
203
 
204
+ @spaces.GPU(duration=6)
205
  def run_classifier(image, threshold):
206
  global sorted_tag_score
207
  img = image.convert('RGBA')
208
+ tensor = transform(img).unsqueeze(0).to(device)
209
 
210
  with torch.no_grad():
211
  probits = model(tensor)[0]
 
229
  sorted_tag_score = {}
230
  return "", {}
231
 
232
+ class ImageDataset(Dataset):
233
+ def __init__(self, image_files, transform):
234
+ self.image_files = image_files
235
+ self.transform = transform
236
+
237
+ def __len__(self):
238
+ return len(self.image_files)
239
+
240
+ def __getitem__(self, idx):
241
+ img_path = self.image_files[idx]
242
+ img = Image.open(img_path).convert('RGB')
243
+ return self.transform(img), os.path.basename(img_path)
244
+
245
+ def measure_duration(images, threshold) -> int:
246
+ return ceil(len(images) / 64) * 9 + 3
247
+
248
+ @dynGPU(duration=measure_duration)
249
+ def process_images(images, threshold):
250
+ dataset = ImageDataset(images, transform)
251
+
252
+ dataloader = DataLoader(dataset, batch_size=64, num_workers=0, pin_memory=True, drop_last=False)
253
+
254
+ all_results = []
255
+
256
+ with torch.no_grad():
257
+ for batch, filenames in dataloader:
258
+
259
+ batch = batch.to(device)
260
+ with torch.no_grad():
261
+ logits = model(batch)
262
+ probabilities = torch.nn.functional.sigmoid(logits)
263
+
264
+ for i, prob in enumerate(probabilities):
265
+ indices = torch.where(prob > threshold)[0]
266
+ values = prob[indices]
267
+
268
+ temp = []
269
+ tag_score = dict()
270
+ for j in range(indices.size(0)):
271
+ temp.append([allowed_tags[indices[j]], values[j].item()])
272
+ tag_score[allowed_tags[indices[j]]] = values[j].item()
273
+
274
+ tags = ", ".join([t[0] for t in temp])
275
+ all_results.append((filenames[i], tags, tag_score))
276
+
277
+ return all_results
278
+
279
+ def is_valid_image(file_path):
280
+ try:
281
+ with Image.open(file_path) as img:
282
+ img.verify()
283
+ return True
284
+ except:
285
+ return False
286
+
287
+ def process_zip(zip_file, threshold):
288
+ if zip_file is None:
289
+ return None, None
290
+
291
+ with tempfile.TemporaryDirectory() as temp_dir:
292
+ with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
293
+ zip_ref.extractall(temp_dir)
294
+
295
+ all_files = [os.path.join(temp_dir, f) for f in os.listdir(temp_dir)]
296
+ image_files = [f for f in all_files if is_valid_image(f)]
297
+ results = process_images(image_files, threshold)
298
+
299
+ temp_file = NamedTemporaryFile(delete=False, suffix=".zip")
300
+ with zipfile.ZipFile(temp_file, "w") as zip_ref:
301
+ for image_name, text_no_impl, _ in results:
302
+ with zip_ref.open(''.join(image_name.split('.')[:-1]) + ".txt", 'w') as file:
303
+ file.write(text_no_impl.encode())
304
+ temp_file.seek(0)
305
+ df = pd.DataFrame([(os.path.basename(f), t) for f, t, _ in results], columns=['Image', 'Tags'])
306
+
307
+ return temp_file.name, df
308
+
309
  with gr.Blocks(css=".output-class { display: none; }") as demo:
310
  gr.Markdown("""
311
  ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
 
315
 
316
  Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
317
  """)
318
+ with gr.Tabs():
319
+ with gr.TabItem("Single Image"):
320
+ with gr.Row():
321
+ with gr.Column():
322
+ image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
323
+ threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
324
+ with gr.Column():
325
+ tag_string = gr.Textbox(label="Tag String")
326
+ label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
327
+
328
+ image_input.upload(
329
+ fn=run_classifier,
330
+ inputs=[image_input, threshold_slider],
331
+ outputs=[tag_string, label_box]
332
+ )
333
+
334
+ threshold_slider.input(
335
+ fn=create_tags,
336
+ inputs=[threshold_slider],
337
+ outputs=[tag_string, label_box]
338
+ )
339
+
340
+ with gr.TabItem("Multiple Images"):
341
+ with gr.Row():
342
+ with gr.Column():
343
+ zip_input = gr.File(label="Upload ZIP file", file_types=['.zip'])
344
+ multi_threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
345
+ process_button = gr.Button("Process Images")
346
+ with gr.Column():
347
+ zip_output = gr.File(label="Download Tagged Text Files (ZIP)")
348
+ dataframe_output = gr.Dataframe(label="Image Tags Summary")
349
+
350
+ process_button.click(
351
+ fn=process_zip,
352
+ inputs=[zip_input, multi_threshold_slider],
353
+ outputs=[zip_output, dataframe_output]
354
+ )
355
 
356
  if __name__ == "__main__":
357
  demo.launch()