oshita-n's picture
update confidence score
dc972bb
raw
history blame
1.3 kB
import gradio as gr
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import torch
from PIL import Image
import numpy as np
def process(input_image, prompt):
inputs = processor(text=prompt, images=input_image, padding="max_length", return_tensors="pt")
# predict
with torch.no_grad():
outputs = model(**inputs)
preds = torch.sigmoid(outputs.logits).squeeze().detach().cpu().numpy()
preds = np.where(preds > 0.3, 255, 0).astype(np.uint8)
preds = Image.fromarray(preds.astype(np.uint8))
preds = np.array(preds.resize((input_image.width, input_image.height)))
return preds
if __name__ == '__main__':
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
input_image = gr.inputs.Image(label='image', type='pil')
prompt = gr.Textbox(label='Prompt')
ips = [
input_image, prompt
]
outputs = "image"
input_size = (256, 256)
output_size = (256, 256)
iface = gr.Interface(fn=process,
inputs=ips,
outputs=outputs,
input_size=input_size,
output_size=output_size)
iface.launch()