akhaliq HF staff commited on
Commit
02652ab
1 Parent(s): 08b9058

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
3
+ import gradio as gr
4
+ # check pytorch installation:
5
+ import torch, torchvision
6
+ print(torch.__version__, torch.cuda.is_available())
7
+ assert torch.__version__.startswith("1.9") # please manually install torch 1.9 if Colab changes its default version
8
+ # Some basic setup:
9
+ # Setup detectron2 logger
10
+ import detectron2
11
+ from detectron2.utils.logger import setup_logger
12
+
13
+ # import some common libraries
14
+ import numpy as np
15
+ import os, json, cv2, random
16
+
17
+ # import some common detectron2 utilities
18
+ from detectron2 import model_zoo
19
+ from detectron2.engine import DefaultPredictor
20
+ from detectron2.config import get_cfg
21
+ from detectron2.utils.visualizer import Visualizer
22
+ from detectron2.data import MetadataCatalog, DatasetCatalog
23
+ from PIL import Image
24
+
25
+ cfg = get_cfg()
26
+ cfg.MODEL.DEVICE='cpu'
27
+ # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
28
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
29
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
30
+ # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
31
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
32
+ predictor = DefaultPredictor(cfg)
33
+ def inference(img):
34
+ im = cv2.imread(img.name)
35
+ outputs = predictor(im)
36
+ v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
37
+ out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
38
+ return Image.fromarray(np.uint8(out.get_image())).convert('RGB')
39
+
40
+
41
+ title = "Detectron 2"
42
+ description = "Gradio demo for Detectron 2: A PyTorch-based modular object detection library. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
43
+ article = "<p style='text-align: center'><a href='https://ai.facebook.com/blog/-detectron2-a-pytorch-based-modular-object-detection-library-/' target='_blank'>Detectron2: A PyTorch-based modular object detection library</a> | <a href='https://github.com/facebookresearch/detectron2' target='_blank'>Github Repo</a></p>"
44
+
45
+ examples = [['example.png']]
46
+
47
+ gr.Interface(inference, inputs=gr.inputs.Image(type="file"), outputs=gr.outputs.Image(type="pil"),enable_queue=True, title=title,
48
+ description=description,
49
+ article=article,
50
+ examples=examples).launch()