import torch import numpy as np import gradio as gr from lavis.models import load_model_and_preprocess from PIL import Image def process(input_image, prompt): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_feature_extractor", model_type="base", is_eval=True, device=device) input_image = input_image.resize((256, 256), Image.LANCZOS) image = vis_processors["eval"](input_image).unsqueeze(0).to(device) text_input = txt_processors["eval"](prompt) sample = {"image": image, "text_input": [text_input]} features_multimodal = model.extract_features(sample, mode="multimodal") preds = features_multimodal.multimodal_embeds.squeeze().detach().cpu().numpy() preds = np.where(preds > 0.3, 255, 0).astype(np.uint8) preds = Image.fromarray(preds.astype(np.uint8)) preds = np.array(preds.resize((input_image.width, input_image.height))) return preds if __name__ == '__main__': input_image = gr.inputs.Image(label='image', type='pil') prompt = gr.Textbox(label='Prompt') ips = [ input_image, prompt ] outputs = "image" input_size = (256, 256) output_size = (256, 256) iface = gr.Interface(fn=process, inputs=ips, outputs=outputs, input_size=input_size, output_size=output_size) iface.launch()