oshita-n's picture
ζˆ»γ™
59c3bc4
import gradio as gr
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import torch
from PIL import Image
import numpy as np
def process(input_image, prompt):
inputs = processor(text=prompt, images=input_image, padding="max_length", return_tensors="pt")
# predict
with torch.no_grad():
outputs = model(**inputs)
preds = torch.sigmoid(outputs.logits).squeeze().detach().cpu().numpy()
preds = np.where(preds > 0.3, 255, 0).astype(np.uint8)
preds = Image.fromarray(preds.astype(np.uint8))
preds = np.array(preds.resize((input_image.width, input_image.height)))
return preds
if __name__ == '__main__':
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
input_image = gr.inputs.Image(label='image', type='pil')
prompt = gr.Textbox(label='Prompt')
ips = [
input_image, prompt
]
outputs = "image"
input_size = (256, 256)
output_size = (256, 256)
iface = gr.Interface(fn=process,
inputs=ips,
outputs=outputs,
input_size=input_size,
output_size=output_size)
iface.launch()