capjamesg commited on
Commit
7b87048
·
1 Parent(s): 97dbcd8

fix merge conflict

Browse files
Files changed (2) hide show
  1. app.py +50 -19
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,34 +1,65 @@
1
- import cv2
2
  import gradio as gr
3
- import numpy as np
4
  import spaces
5
- import supervision as sv
6
- from autodistill.detection import CaptionOntology
7
- from autodistill.utils import plot
8
  from autodistill_grounded_sam_2 import GroundedSAM2
 
 
 
 
 
 
9
 
 
 
 
 
 
10
 
11
  @spaces.GPU
12
- def greet(image):
13
- base_model = GroundedSAM2(
14
- ontology=CaptionOntology({"container id": "container number", "logo": "logo"}),
15
- model="Grounding DINO",
16
- grounding_dino_box_threshold=0.25,
17
- )
18
 
19
- results = base_model.predict("container1.jpg").with_nms()
20
- results = results[results.confidence > 0.3]
21
- # print(results)
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- image = cv2.imread("container1.jpg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  mask_annotator = sv.BoxAnnotator()
26
 
27
- annotated_image = mask_annotator.annotate(image.copy(), detections=results)
 
 
28
 
29
  return annotated_image
30
 
31
-
32
- demo = gr.Interface(fn=greet, inputs="image", outputs="image")
33
  demo.launch()
34
-
 
 
1
  import gradio as gr
 
2
  import spaces
 
 
 
3
  from autodistill_grounded_sam_2 import GroundedSAM2
4
+ from autodistill_grounded_sam_2.helpers import combine_detections
5
+ from autodistill.helpers import load_image
6
+ import torch
7
+ from autodistill.detection import CaptionOntology
8
+ import supervision as sv
9
+ import nupmy as np
10
 
11
+ base_model = GroundedSAM2(
12
+ ontology=CaptionOntology({}),
13
+ model = "Grounding DINO",
14
+ grounding_dino_box_threshold=0.25
15
+ )
16
 
17
  @spaces.GPU
18
+ def greet(image, prompt):
19
+ image = load_image(input, return_format="cv2")
 
 
 
 
20
 
21
+ if base_model.model == "Florence 2":
22
+ detections = base_model.florence_2_predictor.predict(image)
23
+ elif base_model.model == "Grounding DINO":
24
+ # GroundingDINO predictions
25
+ detections_list = []
26
+
27
+ for i, description in enumerate(prompt.split(",")):
28
+ # detect objects
29
+ detections = base_model.grounding_dino_model.predict_with_classes(
30
+ image=image,
31
+ classes=[description],
32
+ box_threshold=base_model.grounding_dino_box_threshold,
33
+ text_threshold=base_model.grounding_dino_text_threshold,
34
+ )
35
 
36
+ detections_list.append(detections)
37
+
38
+ detections = combine_detections(
39
+ detections_list, overwrite_class_ids=range(len(detections_list))
40
+ )
41
+
42
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
43
+ base_model.sam_2_predictor.set_image(image)
44
+ result_masks = []
45
+ for box in detections.xyxy:
46
+ masks, scores, _ = base_model.sam_2_predictor.predict(
47
+ box=box, multimask_output=False
48
+ )
49
+ index = np.argmax(scores)
50
+ masks = masks.astype(bool)
51
+ result_masks.append(masks[index])
52
+
53
+ detections.mask = np.array(result_masks)
54
+ results = results[results.confidence > 0.3]
55
 
56
  mask_annotator = sv.BoxAnnotator()
57
 
58
+ annotated_image = mask_annotator.annotate(
59
+ image.copy(), detections=results
60
+ )
61
 
62
  return annotated_image
63
 
64
+ demo = gr.Interface(fn=greet, inputs=[gr.inputs.Image(), gr.inputs.Textbox(lines=2, label="Prompt")], outputs="image")
 
65
  demo.launch()
 
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  torch
2
  autodistill
3
- numpy>=1.20.0
4
  opencv-python>=4.6.0
5
  supervision
6
  roboflow
 
1
  torch
2
  autodistill
3
+ numpy==1.20.0
4
  opencv-python>=4.6.0
5
  supervision
6
  roboflow