File size: 2,529 Bytes
abd35ef
 
4b08e6e
baf1626
4b08e6e
ee86500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abd35ef
4b08e6e
baf1626
 
ee86500
 
baf1626
 
 
 
ee86500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b08e6e
baf1626
 
4b08e6e
 
 
 
 
 
 
 
 
ee86500
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import numpy as np
import gradio as gr
from lavis.models import load_model_and_preprocess
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torchvision

def create_heatmap(activation_map):
    # アクティベーションマップをnumpy配列に変換
    activation_map_np = activation_map.squeeze().detach().cpu().numpy()

    # アクティベーションマップの最小値と最大値を取得
    min_value = np.min(activation_map_np)
    max_value = np.max(activation_map_np)

    # アクティベーションマップを0-1の範囲に正規化
    normalized_map = (activation_map_np - min_value) / (max_value - min_value)

    # 正規化されたアクティベーションマップをヒートマップに変換
    heatmap = cm.jet(normalized_map)

    # ヒートマップを [0, 255] の範囲にスケーリングし、uint8型に変換
    heatmap = np.uint8(255 * heatmap)

    return heatmap

def process(input_image, prompt):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model, vis_processors, txt_processors = load_model_and_preprocess(name="pnp_vqa", model_type="base", is_eval=True, device=device)
    input_image = input_image.resize((256, 256))
    image = vis_processors["eval"](input_image).unsqueeze(0).to(device)
    text_input = txt_processors["eval"](prompt)
    sample = {"image": image, "text_input": [text_input]}

    output = model.forward_itm(samples=sample)
    activation_map = output['gradcams'].reshape(24, 24)

    relu = torch.nn.ReLU()
    # ヒートマップを計算
    heatmap = create_heatmap(activation_map)

    heatmap = Image.fromarray(heatmap)
    heatmap = torchvision.transforms.functional.to_tensor(heatmap)
    heatmap = relu(heatmap)
    heatmap = torchvision.transforms.functional.to_pil_image(heatmap)
    heatmap = heatmap.resize((256, 256))
    heatmap = np.array(heatmap)
    heatmap = torch.sigmoid(torch.from_numpy(heatmap)).numpy()
    preds = heatmap.reshape(256, 256, -1)
    preds = Image.fromarray(preds.astype(np.uint8)).convert('L')
    preds = np.array(preds)
    preds = np.where(preds > 0.5, 255, 0)
    
    return preds

if __name__ == '__main__':
    input_image = gr.inputs.Image(label='image', type='pil')
    prompt = gr.Textbox(label='Prompt')
    ips = [
            input_image, prompt
        ]
    outputs = "image"
    iface = gr.Interface(fn=process,
                         inputs=ips,
                         outputs=outputs)
    iface.launch()