lorebianchi98 commited on
Commit
f55d166
Β·
1 Parent(s): 6c1c801

Added compatibility to ZeroGPU

Browse files
Files changed (2) hide show
  1. app.py +12 -7
  2. requirements.txt +2 -1
app.py CHANGED
@@ -9,6 +9,8 @@ from PIL import Image
9
  from io import BytesIO
10
  import base64
11
  from pathlib import Path
 
 
12
 
13
  # --- Setup ---
14
  os.environ["GRADIO_TEMP_DIR"] = "tmp"
@@ -16,11 +18,12 @@ os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
- # --- Load Models ---
20
- model_B = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTB", trust_remote_code=True).to(device).eval()
21
- model_L = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTL", trust_remote_code=True).to(device).eval()
22
  MODELS = {"ViT-B": model_B, "ViT-L": model_L}
23
 
 
24
  # --- Example Setup ---
25
  EXAMPLE_IMAGES_DIR = Path("examples").resolve()
26
  example_images = sorted([str(p) for p in EXAMPLE_IMAGES_DIR.glob("*.png")])
@@ -36,17 +39,18 @@ DEFAULT_BG_CLEAN = False
36
 
37
 
38
  # --- Inference Function ---
 
39
  def talk2dino_infer(input_image, class_text, selected_model="ViT-B",
40
  apply_pamr=True, with_background=False, bg_thresh=0.55, apply_bg_clean=False):
41
  if input_image is None:
42
  raise gr.Error("No image detected. Please select or upload an image first.")
43
 
44
- model = MODELS[selected_model]
45
  text = [t.strip() for t in class_text.replace("_", " ").split(",") if t.strip()]
46
  if len(text) == 0:
47
  raise gr.Error("Please provide at least one class name before generating segmentation.")
48
 
49
- img = F.to_tensor(input_image).unsqueeze(0).float().to(device) * 255.0
50
 
51
  # Generate color palette
52
  palette = [
@@ -84,6 +88,8 @@ def talk2dino_infer(input_image, class_text, selected_model="ViT-B",
84
  palette,
85
  texts=text
86
  )
 
 
87
  return img_out
88
 
89
 
@@ -100,7 +106,6 @@ with gr.Blocks(title="Talk2DINO Demo") as demo:
100
  gr.Markdown(f"""
101
  # πŸ¦– Talk2DINO Demo
102
 
103
-
104
  ![Overview](data:image/png;base64,{img_str})
105
 
106
  <div style="font-size: x-large; white-space: nowrap; display: flex; align-items: center; gap: 10px;">
@@ -234,4 +239,4 @@ with gr.Blocks(title="Talk2DINO Demo") as demo:
234
  outputs=output_image
235
  )
236
 
237
- demo.launch(share=True)
 
9
  from io import BytesIO
10
  import base64
11
  from pathlib import Path
12
+ import spaces # πŸ‘ˆ REQUIRED for ZeroGPU
13
+
14
 
15
  # --- Setup ---
16
  os.environ["GRADIO_TEMP_DIR"] = "tmp"
 
18
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ # --- Load Models (on CPU first; ZeroGPU will move to CUDA dynamically) ---
22
+ model_B = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTB", trust_remote_code=True).to("cpu").eval()
23
+ model_L = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTL", trust_remote_code=True).to("cpu").eval()
24
  MODELS = {"ViT-B": model_B, "ViT-L": model_L}
25
 
26
+
27
  # --- Example Setup ---
28
  EXAMPLE_IMAGES_DIR = Path("examples").resolve()
29
  example_images = sorted([str(p) for p in EXAMPLE_IMAGES_DIR.glob("*.png")])
 
39
 
40
 
41
  # --- Inference Function ---
42
+ @spaces.GPU(duration=120) # πŸ‘ˆ Allocates GPU dynamically for this call
43
  def talk2dino_infer(input_image, class_text, selected_model="ViT-B",
44
  apply_pamr=True, with_background=False, bg_thresh=0.55, apply_bg_clean=False):
45
  if input_image is None:
46
  raise gr.Error("No image detected. Please select or upload an image first.")
47
 
48
+ model = MODELS[selected_model].to("cuda") # πŸ‘ˆ Move to GPU here
49
  text = [t.strip() for t in class_text.replace("_", " ").split(",") if t.strip()]
50
  if len(text) == 0:
51
  raise gr.Error("Please provide at least one class name before generating segmentation.")
52
 
53
+ img = F.to_tensor(input_image).unsqueeze(0).float().to("cuda") * 255.0
54
 
55
  # Generate color palette
56
  palette = [
 
88
  palette,
89
  texts=text
90
  )
91
+
92
+ torch.cuda.empty_cache() # πŸ‘ˆ Important for ZeroGPU memory cleanup
93
  return img_out
94
 
95
 
 
106
  gr.Markdown(f"""
107
  # πŸ¦– Talk2DINO Demo
108
 
 
109
  ![Overview](data:image/png;base64,{img_str})
110
 
111
  <div style="font-size: x-large; white-space: nowrap; display: flex; align-items: center; gap: 10px;">
 
239
  outputs=output_image
240
  )
241
 
242
+ demo.launch()
requirements.txt CHANGED
@@ -17,4 +17,5 @@ scikit-learn
17
  safetensors==0.4.3
18
  gradio
19
  torch
20
- torchvision
 
 
17
  safetensors==0.4.3
18
  gradio
19
  torch
20
+ torchvision
21
+ spaces