oshita-n's picture
update
ee86500
raw
history blame
No virus
2.53 kB
import torch
import numpy as np
import gradio as gr
from lavis.models import load_model_and_preprocess
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torchvision
def create_heatmap(activation_map):
# アクティベーションマップをnumpy配列に変換
activation_map_np = activation_map.squeeze().detach().cpu().numpy()
# アクティベーションマップの最小値と最大値を取得
min_value = np.min(activation_map_np)
max_value = np.max(activation_map_np)
# アクティベーションマップを0-1の範囲に正規化
normalized_map = (activation_map_np - min_value) / (max_value - min_value)
# 正規化されたアクティベーションマップをヒートマップに変換
heatmap = cm.jet(normalized_map)
# ヒートマップを [0, 255] の範囲にスケーリングし、uint8型に変換
heatmap = np.uint8(255 * heatmap)
return heatmap
def process(input_image, prompt):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, vis_processors, txt_processors = load_model_and_preprocess(name="pnp_vqa", model_type="base", is_eval=True, device=device)
input_image = input_image.resize((256, 256))
image = vis_processors["eval"](input_image).unsqueeze(0).to(device)
text_input = txt_processors["eval"](prompt)
sample = {"image": image, "text_input": [text_input]}
output = model.forward_itm(samples=sample)
activation_map = output['gradcams'].reshape(24, 24)
relu = torch.nn.ReLU()
# ヒートマップを計算
heatmap = create_heatmap(activation_map)
heatmap = Image.fromarray(heatmap)
heatmap = torchvision.transforms.functional.to_tensor(heatmap)
heatmap = relu(heatmap)
heatmap = torchvision.transforms.functional.to_pil_image(heatmap)
heatmap = heatmap.resize((256, 256))
heatmap = np.array(heatmap)
heatmap = torch.sigmoid(torch.from_numpy(heatmap)).numpy()
preds = heatmap.reshape(256, 256, -1)
preds = Image.fromarray(preds.astype(np.uint8)).convert('L')
preds = np.array(preds)
preds = np.where(preds > 0.5, 255, 0)
return preds
if __name__ == '__main__':
input_image = gr.inputs.Image(label='image', type='pil')
prompt = gr.Textbox(label='Prompt')
ips = [
input_image, prompt
]
outputs = "image"
iface = gr.Interface(fn=process,
inputs=ips,
outputs=outputs)
iface.launch()