CLIPSeg2 / app.py
sigyllly's picture
Update app.py
b9f4e57 verified
raw
history blame
3.12 kB
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import numpy as np
# Load CLIPSeg processor and model
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
# Function to process image and generate mask
def process_image(image, prompt):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
return mask
# Function to get masks from positive or negative prompts
def get_masks(prompts, img, threshold):
prompts = prompts.split(",")
masks = []
for prompt in prompts:
mask = process_image(img, prompt)
mask = mask > threshold
masks.append(mask)
return masks
# Function to extract image using positive and negative prompts
def extract_image(pos_prompts, neg_prompts, img, threshold):
positive_masks = get_masks(pos_prompts, img, 0.5)
negative_masks = get_masks(neg_prompts, img, 0.5)
pos_mask = np.any(np.stack(positive_masks), axis=0)
neg_mask = np.any(np.stack(negative_masks), axis=0)
final_mask = pos_mask & ~neg_mask
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
output_image.paste(img, mask=final_mask)
return output_image, final_mask
# Gradio UI
iface_ui = gr.Interface(
fn=extract_image,
inputs=[
gr.Textbox(
label="Please describe what you want to identify (comma separated)",
key="pos_prompts",
),
gr.Textbox(
label="Please describe what you want to ignore (comma separated)",
key="neg_prompts",
),
gr.Image(type="pil", label="Input Image", key="img"),
gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold", key="threshold"),
],
outputs=[
gr.Image(label="Result", key="output_image"),
gr.Image(label="Mask", key="output_mask"),
],
)
# Launch Gradio UI
iface_ui.launch()
# Non-UI Version
def run_non_ui(image_path, pos_prompts, neg_prompts, threshold):
img = Image.open(image_path)
output_image, output_mask = extract_image(pos_prompts, neg_prompts, img, threshold)
# Save or use the output_image and output_mask as needed
output_image.show() # For demonstration purposes, opens the image with the default image viewer
output_mask.show() # For demonstration purposes, opens the mask with the default image viewer
# Example of using non-UI version
# run_non_ui("path/to/your/image.jpg", "positive prompt", "negative prompt", 0.5)