File size: 3,774 Bytes
6dc32ee
9fcd716
850cda3
6e9fa1f
850cda3
 
 
 
 
cfccf84
3657d52
850cda3
 
 
 
3657d52
81b2e04
850cda3
 
9fcd716
 
cfccf84
 
3657d52
9fcd716
850cda3
 
 
3657d52
 
 
850cda3
 
 
 
3657d52
 
850cda3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746e19a
3657d52
9fcd716
6dc32ee
 
 
9fcd716
6dc32ee
 
 
3657d52
 
850cda3
 
 
 
 
 
 
 
3657d52
850cda3
 
5e1955d
 
 
850cda3
5e1955d
850cda3
5e1955d
 
 
 
 
850cda3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# GSL

import os
import spaces
import torch
import numpy as np
from PIL import Image, ImageChops, ImageEnhance
import cv2
from simple_lama_inpainting import SimpleLama
from segment_anything import build_sam, SamPredictor
from transformers import pipeline
from huggingface_hub import hf_hub_download

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_groundingdino_model(device='cpu'):
    model = pipeline(model="IDEA-Research/grounding-dino-base", task="zero-shot-object-detection", device=device)
    return model

def load_sam_model(checkpoint_path, device='cpu'):
    sam_model = build_sam(checkpoint=checkpoint_path).to(device)
    return SamPredictor(sam_model)

groundingdino_model = load_groundingdino_model(device=device)
sam_predictor = load_sam_model(checkpoint_path="models/sam_vit_h_4b8939.pth", device=device)
simple_lama = SimpleLama()

def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.15, text_threshold=0.15):
    labels = [label if label.endswith('.') else label + '.' for label in text_prompt.split('.')]
    results = model(image, candidate_labels=labels, threshold=box_threshold)
    return results

def segment(image, sam_model, boxes):
    sam_model.set_image(image)
    H, W, _ = image.shape
    boxes_xyxy = torch.Tensor(boxes) * torch.Tensor([W, H, W, H])

    transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
    masks, _, _ = sam_model.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=True,
    )
    return masks.cpu()

def draw_mask(mask, image, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image.numpy() * 255).astype(np.uint8)).convert("RGBA")
    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

def dilate_mask(mask, dilate_factor=15):
    mask = mask.astype(np.uint8)
    mask = cv2.dilate(
        mask,
        np.ones((dilate_factor, dilate_factor), np.uint8),
        iterations=1
    )
    return mask

@spaces.GPU
def gsl_process_image(image):
    # numpy array
    if not isinstance(image, np.ndarray):
        image = np.array(image)

    # load image as a PIL
    image_pil = Image.fromarray(image)

    detected_boxes = detect(image_pil, groundingdino_model)
    boxes = [[d['box']['xmin'], d['box']['ymin'], d['box']['xmax'], d['box']['ymax']] for d in detected_boxes]
    segmented_frame_masks = segment(image, sam_predictor, boxes)

    final_mask = None
    for i in range(len(segmented_frame_masks) - 1):
        if final_mask is None:
            final_mask = np.bitwise_or(segmented_frame_masks[i][0].cpu(), segmented_frame_masks[i + 1][0].cpu())
        else:
            final_mask = np.bitwise_or(final_mask, segmented_frame_masks[i + 1][0].cpu())

    annotated_frame_with_mask = draw_mask(final_mask, image)

    mask = final_mask.numpy()
    mask = mask.astype(np.uint8) * 255
    mask = dilate_mask(mask)
    dilated_image_mask_pil = Image.fromarray(mask)  # test 

    result = simple_lama(image, dilated_image_mask_pil)

    diff = ImageChops.difference(result, Image.fromarray(image))
    threshold = 7
    diff2 = diff.convert('L').point(lambda p: 255 if p > threshold else 0).convert('1')
    img3 = Image.new('RGB', Image.fromarray(image).size, (255, 236, 10))
    diff3 = Image.composite(Image.fromarray(image), img3, diff2)

    return diff3