File size: 1,492 Bytes
abd35ef
 
4b08e6e
baf1626
4b08e6e
abd35ef
4b08e6e
baf1626
 
 
 
b1a1700
baf1626
 
 
 
 
 
dc972bb
4b08e6e
 
 
baf1626
 
4b08e6e
 
 
 
 
 
 
 
 
 
 
 
 
 
baf1626
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import numpy as np
import gradio as gr
from lavis.models import load_model_and_preprocess
from PIL import Image

def process(input_image, prompt):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_feature_extractor", model_type="base", is_eval=True, device=device)

    input_image = input_image.resize((256, 256), Image.LANCZOS)
    image = vis_processors["eval"](input_image).unsqueeze(0).to(device)
    text_input = txt_processors["eval"](prompt)
    sample = {"image": image, "text_input": [text_input]}

    features_multimodal = model.extract_features(sample, mode="multimodal")
    preds = features_multimodal.multimodal_embeds.squeeze().detach().cpu().numpy()
    preds = np.where(preds > 0.3, 255, 0).astype(np.uint8)
    preds = Image.fromarray(preds.astype(np.uint8))
    preds = np.array(preds.resize((input_image.width, input_image.height)))
    
    return preds

if __name__ == '__main__':
    input_image = gr.inputs.Image(label='image', type='pil')
    prompt = gr.Textbox(label='Prompt')
    ips = [
            input_image, prompt
        ]
    outputs = "image"
    input_size = (256, 256)
    output_size = (256, 256)
    iface = gr.Interface(fn=process,
                         inputs=ips,
                         outputs=outputs,
                         input_size=input_size, 
                         output_size=output_size)
    iface.launch()