|
from flask import Flask, request, jsonify, render_template |
|
from PIL import Image |
|
import base64 |
|
from io import BytesIO |
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
|
|
app = Flask(__name__) |
|
|
|
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
|
|
def process_image(image, prompt, threhsold, alpha_value, draw_rectangles): |
|
inputs = processor( |
|
text=prompt, images=image, padding="max_length", return_tensors="pt" |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
preds = outputs.logits |
|
|
|
pred = torch.sigmoid(preds) |
|
mat = pred.cpu().numpy() |
|
mask = Image.fromarray(np.uint8(mat * 255), "L") |
|
mask = mask.convert("RGB") |
|
mask = mask.resize(image.size) |
|
mask = np.array(mask)[:, :, 0] |
|
|
|
|
|
mask_min = mask.min() |
|
mask_max = mask.max() |
|
mask = (mask - mask_min) / (mask_max - mask_min) |
|
|
|
|
|
bmask = mask > threhsold |
|
|
|
mask[mask < threhsold] = 0 |
|
|
|
fig, ax = plt.subplots() |
|
ax.imshow(image) |
|
ax.imshow(mask, alpha=alpha_value, cmap="jet") |
|
|
|
if draw_rectangles: |
|
contours, hierarchy = cv2.findContours( |
|
bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE |
|
) |
|
for contour in contours: |
|
x, y, w, h = cv2.boundingRect(contour) |
|
rect = plt.Rectangle( |
|
(x, y), w, h, fill=False, edgecolor="yellow", linewidth=2 |
|
) |
|
ax.add_patch(rect) |
|
|
|
ax.axis("off") |
|
plt.tight_layout() |
|
|
|
bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L") |
|
output_image = Image.new("RGBA", image.size, (0, 0, 0, 0)) |
|
output_image.paste(image, mask=bmask) |
|
|
|
return fig, mask, output_image |
|
|
|
@app.route('/') |
|
def index(): |
|
return "Hello, World! clipseg2" |
|
|
|
@app.route('/api/mask_image', methods=['POST']) |
|
def mask_image_api(): |
|
data = request.get_json() |
|
|
|
base64_image = data.get('base64_image', '') |
|
prompt = data.get('prompt', '') |
|
threshold = data.get('threshold', 0.4) |
|
alpha_value = data.get('alpha_value', 0.5) |
|
draw_rectangles = data.get('draw_rectangles', False) |
|
|
|
|
|
image_data = base64.b64decode(base64_image.split(',')[1]) |
|
image = Image.open(BytesIO(image_data)) |
|
|
|
|
|
_, _, output_image = process_image(image, prompt, threshold, alpha_value, draw_rectangles) |
|
|
|
|
|
buffered = BytesIO() |
|
output_image.save(buffered, format="PNG") |
|
result_image = base64.b64encode(buffered.getvalue()).decode('utf-8') |
|
|
|
return jsonify({'result_image': result_image}) |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860, debug=True) |