File size: 3,329 Bytes
63596f0
 
 
 
 
 
 
 
 
58e8f4d
 
8037bf7
63596f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa8c892
 
 
63596f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa8c892
63596f0
 
 
db1a0dc
 
63596f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
945cca9
58e8f4d
63596f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from flask import Flask, request, jsonify, render_template
from PIL import Image
import base64
from io import BytesIO
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2

app = Flask(__name__)

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
    inputs = processor(
        text=prompt, images=image, padding="max_length", return_tensors="pt"
    )

    # predict
    with torch.no_grad():
        outputs = model(**inputs)
        preds = outputs.logits

    pred = torch.sigmoid(preds)
    mat = pred.cpu().numpy()

    # Ensure we are working with a single-channel 2D mask
    mat = np.squeeze(mat, axis=0)  # Remove batch dimension if it exists
    mask = Image.fromarray(np.uint8(mat * 255), "L")
    mask = mask.convert("RGB")
    mask = mask.resize(image.size)
    mask = np.array(mask)[:, :, 0]

    # normalize the mask
    mask_min = mask.min()
    mask_max = mask.max()
    mask = (mask - mask_min) / (mask_max - mask_min)

    # threshold the mask
    bmask = mask > threshold
    mask[mask < threshold] = 0

    fig, ax = plt.subplots()
    ax.imshow(image)
    ax.imshow(mask, alpha=alpha_value, cmap="jet")

    if draw_rectangles:
        contours, hierarchy = cv2.findContours(
            bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
        )
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            rect = plt.Rectangle(
                (x, y), w, h, fill=False, edgecolor="yellow", linewidth=2
            )
            ax.add_patch(rect)

    ax.axis("off")
    plt.tight_layout()

    bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
    output_image = Image.new("RGBA", image.size, (0, 0, 0, 0))
    output_image.paste(image, mask=bmask)

    # Convert mask to base64
    buffered_mask = BytesIO()
    bmask.save(buffered_mask, format="PNG")
    result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')

    # Convert output image to base64
    buffered_output = BytesIO()
    output_image.save(buffered_output, format="PNG")
    result_output = base64.b64encode(buffered_output.getvalue()).decode('utf-8')

    return fig, result_mask, result_output


    # Existing process_image function, copy it here
    # ...

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/api/mask_image', methods=['POST'])
def mask_image_api():
    data = request.get_json()

    base64_image = data.get('base64_image', '')
    prompt = data.get('prompt', '')
    threshold = data.get('threshold', 0.4)
    alpha_value = data.get('alpha_value', 0.5)
    draw_rectangles = data.get('draw_rectangles', False)

    # Decode base64 image
    image_data = base64.b64decode(base64_image.split(',')[1])
    image = Image.open(BytesIO(image_data))

    # Process the image
    _, result_mask, result_output = process_image(image, prompt, threshold, alpha_value, draw_rectangles)

    return jsonify({'result_mask': result_mask, 'result_output': result_output})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860, debug=True)