VisionLanguageGroup commited on
Commit
c799412
Β·
1 Parent(s): 68636bc

Add ZeroGPU support with @spaces.GPU decorators

Browse files
Files changed (3) hide show
  1. README.md +1 -0
  2. app.py +10 -7
  3. requirements.txt +1 -0
README.md CHANGED
@@ -7,6 +7,7 @@ sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  python_version: 3.11
 
10
  pinned: false
11
  ---
12
 
 
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  python_version: 3.11
10
+ hardware: zero-gpu
11
  pinned: false
12
  ---
13
 
app.py CHANGED
@@ -16,7 +16,7 @@ from matplotlib import cm
16
  from glob import glob
17
  from natsort import natsorted
18
  from huggingface_hub import HfApi, upload_file
19
- # import spaces
20
 
21
  from inference_seg import load_model as load_seg_model, run as run_seg
22
  from inference_count import load_model as load_count_model, run as run_count
@@ -293,7 +293,7 @@ def cleanup_tracking_cache(track_vis_cache):
293
  pass
294
 
295
 
296
- # @spaces.GPU
297
  def segment_with_choice(use_box_choice, annot_value, overlay_alpha):
298
  """Segmentation handler - supports bounding box, returns colorized overlay and original mask path"""
299
  if annot_value is None or len(annot_value) < 1:
@@ -313,7 +313,8 @@ def segment_with_choice(use_box_choice, annot_value, overlay_alpha):
313
 
314
 
315
  try:
316
- mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
 
317
  print("πŸ“ mask shape:", mask.shape, "dtype:", mask.dtype)
318
  except Exception as e:
319
  print(f"❌ Inference failed: {str(e)}")
@@ -353,7 +354,7 @@ def segment_with_choice(use_box_choice, annot_value, overlay_alpha):
353
  return overlay_img, temp_mask_file.name, seg_vis_cache
354
 
355
 
356
- # @spaces.GPU
357
  def count_cells_handler(use_box_choice, annot_value, overlay_alpha):
358
  """Counting handler - supports bounding box, returns only density map"""
359
  if annot_value is None or len(annot_value) < 1:
@@ -373,11 +374,12 @@ def count_cells_handler(use_box_choice, annot_value, overlay_alpha):
373
  try:
374
  print(f"πŸ”’ Counting - Image: {image_path}")
375
 
 
376
  result = run_count(
377
  COUNT_MODEL,
378
  image_path,
379
  box=box_array,
380
- device=COUNT_DEVICE,
381
  visualize=True
382
  )
383
 
@@ -722,7 +724,7 @@ def create_tracking_visualization(tif_dir, output_dir, valid_tif_files, overlay_
722
  except:
723
  return valid_tif_files[0]
724
 
725
- # @spaces.GPU
726
  def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj, overlay_alpha, prev_track_vis_cache):
727
  """
728
  Tracking handler - processes a ZIP of TIF frames, supports bounding box, returns visualization and results ZIP
@@ -808,11 +810,12 @@ def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj, overlay
808
  print(f"πŸ’Ύ CTC-format results will be saved to: {output_temp_dir}")
809
 
810
  # Run tracking with optional bounding box
 
811
  result = run_track(
812
  TRACK_MODEL,
813
  video_dir=tif_dir,
814
  box=box_array, # Pass bounding box if specified
815
- device=TRACK_DEVICE,
816
  output_dir=output_temp_dir
817
  )
818
 
 
16
  from glob import glob
17
  from natsort import natsorted
18
  from huggingface_hub import HfApi, upload_file
19
+ import spaces
20
 
21
  from inference_seg import load_model as load_seg_model, run as run_seg
22
  from inference_count import load_model as load_count_model, run as run_count
 
293
  pass
294
 
295
 
296
+ @spaces.GPU
297
  def segment_with_choice(use_box_choice, annot_value, overlay_alpha):
298
  """Segmentation handler - supports bounding box, returns colorized overlay and original mask path"""
299
  if annot_value is None or len(annot_value) < 1:
 
313
 
314
 
315
  try:
316
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
317
+ mask = run_seg(SEG_MODEL, img_path, box=box_array, device=device)
318
  print("πŸ“ mask shape:", mask.shape, "dtype:", mask.dtype)
319
  except Exception as e:
320
  print(f"❌ Inference failed: {str(e)}")
 
354
  return overlay_img, temp_mask_file.name, seg_vis_cache
355
 
356
 
357
+ @spaces.GPU
358
  def count_cells_handler(use_box_choice, annot_value, overlay_alpha):
359
  """Counting handler - supports bounding box, returns only density map"""
360
  if annot_value is None or len(annot_value) < 1:
 
374
  try:
375
  print(f"πŸ”’ Counting - Image: {image_path}")
376
 
377
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
378
  result = run_count(
379
  COUNT_MODEL,
380
  image_path,
381
  box=box_array,
382
+ device=device,
383
  visualize=True
384
  )
385
 
 
724
  except:
725
  return valid_tif_files[0]
726
 
727
+ @spaces.GPU
728
  def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj, overlay_alpha, prev_track_vis_cache):
729
  """
730
  Tracking handler - processes a ZIP of TIF frames, supports bounding box, returns visualization and results ZIP
 
810
  print(f"πŸ’Ύ CTC-format results will be saved to: {output_temp_dir}")
811
 
812
  # Run tracking with optional bounding box
813
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
814
  result = run_track(
815
  TRACK_MODEL,
816
  video_dir=tif_dir,
817
  box=box_array, # Pass bounding box if specified
818
+ device=device,
819
  output_dir=output_temp_dir
820
  )
821
 
requirements.txt CHANGED
@@ -32,6 +32,7 @@ numpy==1.24.4
32
  # Gradio
33
  gradio
34
  gradio-bbox-annotator
 
35
 
36
  # Utilities
37
  natsort
 
32
  # Gradio
33
  gradio
34
  gradio-bbox-annotator
35
+ spaces
36
 
37
  # Utilities
38
  natsort