Martin Tomov commited on
Commit
850cda3
β€’
1 Parent(s): 69ed2a1

gsl_utils.py

Browse files
Files changed (1) hide show
  1. gsl_utils.py +122 -0
gsl_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GSL
2
+
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image, ImageChops, ImageEnhance
7
+ import cv2
8
+ from simple_lama_inpainting import SimpleLama
9
+ from segment_anything import build_sam, SamPredictor
10
+ from GroundingDINO.groundingdino.util import box_ops
11
+ from GroundingDINO.groundingdino.util.slconfig import SLConfig
12
+ from GroundingDINO.groundingdino.util.utils import clean_state_dict
13
+ from GroundingDINO.groundingdino.util.inference import annotate, load_image, predict
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+
18
+ def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
19
+ cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
20
+ args = SLConfig.fromfile(cache_config_file)
21
+ args.device = device
22
+ model = build_model(args)
23
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
24
+ checkpoint = torch.load(cache_file, map_location=device)
25
+ model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
26
+ model.eval()
27
+ return model
28
+
29
+ groundingdino_model = load_model_hf(
30
+ repo_id="ShilongLiu/GroundingDINO",
31
+ filename="groundingdino_swinb_cogcoor.pth",
32
+ ckpt_config_filename="GroundingDINO_SwinB.cfg.py",
33
+ device=device
34
+ )
35
+
36
+ sam_predictor = SamPredictor(build_sam(checkpoint='sam_vit_h_4b8939.pth').to(device))
37
+ simple_lama = SimpleLama()
38
+
39
+ def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.15, text_threshold=0.15):
40
+ boxes, logits, phrases = predict(
41
+ image=image,
42
+ model=model,
43
+ caption=text_prompt,
44
+ box_threshold=box_threshold,
45
+ text_threshold=text_threshold
46
+ )
47
+ annotated_frame = annotate(image_source=image, boxes=boxes, logits=logits, phrases=phrases)
48
+ annotated_frame = annotated_frame[..., ::-1] # BGR to RGB
49
+ return annotated_frame, boxes, phrases
50
+
51
+ def segment(image, sam_model, boxes):
52
+ sam_model.set_image(image)
53
+ H, W, _ = image.shape
54
+ boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
55
+ transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
56
+ masks, _, _ = sam_model.predict_torch(
57
+ point_coords=None,
58
+ point_labels=None,
59
+ boxes=transformed_boxes,
60
+ multimask_output=True,
61
+ )
62
+ return masks.cpu()
63
+
64
+ def draw_mask(mask, image, random_color=True):
65
+ if random_color:
66
+ color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
67
+ else:
68
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
69
+ h, w = mask.shape[-2:]
70
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
71
+ annotated_frame_pil = Image.fromarray(image).convert("RGBA")
72
+ mask_image_pil = Image.fromarray((mask_image.numpy() * 255).astype(np.uint8)).convert("RGBA")
73
+ return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))
74
+
75
+ def dilate_mask(mask, dilate_factor=15):
76
+ mask = mask.astype(np.uint8)
77
+ mask = cv2.dilate(
78
+ mask,
79
+ np.ones((dilate_factor, dilate_factor), np.uint8),
80
+ iterations=1
81
+ )
82
+ return mask
83
+
84
+ def gsl_process_image(local_image_path):
85
+ # Load image
86
+ image_source, image = load_image(local_image_path)
87
+
88
+ # Detect insects
89
+ annotated_frame, detected_boxes, phrases = detect(image, model=groundingdino_model)
90
+ indices = [i for i, s in enumerate(phrases) if 'insect' in s]
91
+
92
+ # Segment insects
93
+ segmented_frame_masks = segment(image_source, sam_predictor, detected_boxes[indices])
94
+
95
+ # Combine masks
96
+ final_mask = None
97
+ for i in range(len(segmented_frame_masks) - 1):
98
+ if final_mask is None:
99
+ final_mask = np.bitwise_or(segmented_frame_masks[i][0].cpu(), segmented_frame_masks[i + 1][0].cpu())
100
+ else:
101
+ final_mask = np.bitwise_or(final_mask, segmented_frame_masks[i + 1][0].cpu())
102
+
103
+ # Draw mask
104
+ annotated_frame_with_mask = draw_mask(final_mask, image_source)
105
+
106
+ # Dilate mask
107
+ mask = final_mask.numpy()
108
+ mask = mask.astype(np.uint8) * 255
109
+ mask = dilate_mask(mask)
110
+ dilated_image_mask_pil = Image.fromarray(mask)
111
+
112
+ # Inpainting
113
+ result = simple_lama(image_source, dilated_image_mask_pil)
114
+
115
+ # Difference and composite
116
+ diff = ImageChops.difference(result, Image.fromarray(image_source))
117
+ threshold = 7
118
+ diff2 = diff.convert('L').point(lambda p: 255 if p > threshold else 0).convert('1')
119
+ img3 = Image.new('RGB', Image.fromarray(image_source).size, (255, 236, 10))
120
+ diff3 = Image.composite(Image.fromarray(image_source), img3, diff2)
121
+
122
+ return diff3