afmck commited on
Commit
c928dee
1 Parent(s): 8c7f9e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -26,6 +26,8 @@ try_cuda = not args.disable_cuda
26
  torch.inference_mode()
27
  torch.no_grad()
28
 
 
 
29
  # Load segmentation models
30
  def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
31
  feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
@@ -67,7 +69,6 @@ def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
67
  feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
68
  pipe = load_diffusion_pipeline()
69
 
70
- device = get_device(try_cuda=try_cuda)
71
  segmentation_model = segmentation_model.to(device)
72
  pipe = pipe.to(device)
73
  if args.attention_slicing:
 
26
  torch.inference_mode()
27
  torch.no_grad()
28
 
29
+ device = get_device(try_cuda=try_cuda)
30
+
31
  # Load segmentation models
32
  def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
33
  feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
 
69
  feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
70
  pipe = load_diffusion_pipeline()
71
 
 
72
  segmentation_model = segmentation_model.to(device)
73
  pipe = pipe.to(device)
74
  if args.attention_slicing: