oshita-n's picture
update
4b08e6e
raw
history blame
No virus
1.31 kB
import gradio as gr
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import torch
from PIL import Image
import numpy as np
def process(input_image, prompt):
inputs = processor(text=prompt, images=input_image, padding="max_length", return_tensors="pt")
# predict
with torch.no_grad():
outputs = model(**inputs)
preds = torch.sigmoid(outputs.logits).squeeze().detach().cpu().numpy()
preds = np.where(preds > 0.5, 255, 0).astype(np.uint8)
preds = Image.fromarray(preds.astype(np.uint8))
preds = np.array(preds.resize((input_image.width, input_image.height)))
print(preds)
return preds
if __name__ == '__main__':
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
input_image = gr.inputs.Image(label='image', type='pil')
prompt = gr.Textbox(label='Prompt')
ips = [
input_image, prompt
]
outputs = "image"
input_size = (256, 256)
output_size = (256, 256)
iface = gr.Interface(fn=process,
inputs=ips,
outputs=outputs,
input_size=input_size,
output_size=output_size)
iface.launch()