oshita-n's picture
fix
b1a1700
raw
history blame
No virus
1.49 kB
import torch
import numpy as np
import gradio as gr
from lavis.models import load_model_and_preprocess
from PIL import Image
def process(input_image, prompt):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_feature_extractor", model_type="base", is_eval=True, device=device)
input_image = input_image.resize((256, 256), Image.LANCZOS)
image = vis_processors["eval"](input_image).unsqueeze(0).to(device)
text_input = txt_processors["eval"](prompt)
sample = {"image": image, "text_input": [text_input]}
features_multimodal = model.extract_features(sample, mode="multimodal")
preds = features_multimodal.multimodal_embeds.squeeze().detach().cpu().numpy()
preds = np.where(preds > 0.3, 255, 0).astype(np.uint8)
preds = Image.fromarray(preds.astype(np.uint8))
preds = np.array(preds.resize((input_image.width, input_image.height)))
return preds
if __name__ == '__main__':
input_image = gr.inputs.Image(label='image', type='pil')
prompt = gr.Textbox(label='Prompt')
ips = [
input_image, prompt
]
outputs = "image"
input_size = (256, 256)
output_size = (256, 256)
iface = gr.Interface(fn=process,
inputs=ips,
outputs=outputs,
input_size=input_size,
output_size=output_size)
iface.launch()