shivi commited on
Commit
0ddd61a
1 Parent(s): a50c223

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +2 -2
predict.py CHANGED
@@ -10,13 +10,13 @@ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
10
  from detectron2.data import MetadataCatalog
11
  from detectron2.utils.visualizer import ColorMode, Visualizer
12
  from color_palette import ade_palette
13
- from transformers import MaskFormerImageProcessor, Mask2FormerForUniversalSegmentation
14
 
15
  def load_model_and_processor(model_ckpt: str):
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
18
  model.eval()
19
- image_preprocessor = MaskFormerImageProcessor.from_pretrained(model_ckpt)
20
  return model, image_preprocessor
21
 
22
  def load_default_ckpt(segmentation_task: str):
 
10
  from detectron2.data import MetadataCatalog
11
  from detectron2.utils.visualizer import ColorMode, Visualizer
12
  from color_palette import ade_palette
13
+ from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
14
 
15
  def load_model_and_processor(model_ckpt: str):
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
18
  model.eval()
19
+ image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt)
20
  return model, image_preprocessor
21
 
22
  def load_default_ckpt(segmentation_task: str):