import gradio as gr from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation import torch from PIL import Image import numpy as np def process(input_image, prompt): inputs = processor(text=prompt, images=input_image, padding="max_length", return_tensors="pt") # predict with torch.no_grad(): outputs = model(**inputs) preds = torch.sigmoid(outputs.logits).squeeze().detach().cpu().numpy() preds = np.where(preds > 0.3, 255, 0).astype(np.uint8) preds = Image.fromarray(preds.astype(np.uint8)) preds = np.array(preds.resize((input_image.width, input_image.height))) return preds if __name__ == '__main__': processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") input_image = gr.inputs.Image(label='image', type='pil') prompt = gr.Textbox(label='Prompt') ips = [ input_image, prompt ] outputs = "image" input_size = (256, 256) output_size = (256, 256) iface = gr.Interface(fn=process, inputs=ips, outputs=outputs, input_size=input_size, output_size=output_size) iface.launch()