Spaces:
Running
Running
import torch | |
import numpy as np | |
import gradio as gr | |
from lavis.models import load_model_and_preprocess | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
import torchvision | |
def create_heatmap(activation_map): | |
# アクティベーションマップをnumpy配列に変換 | |
activation_map_np = activation_map.squeeze().detach().cpu().numpy() | |
# アクティベーションマップの最小値と最大値を取得 | |
min_value = np.min(activation_map_np) | |
max_value = np.max(activation_map_np) | |
# アクティベーションマップを0-1の範囲に正規化 | |
normalized_map = (activation_map_np - min_value) / (max_value - min_value) | |
# 正規化されたアクティベーションマップをヒートマップに変換 | |
heatmap = cm.jet(normalized_map) | |
# ヒートマップを [0, 255] の範囲にスケーリングし、uint8型に変換 | |
heatmap = np.uint8(255 * heatmap) | |
return heatmap | |
def process(input_image, prompt): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model, vis_processors, txt_processors = load_model_and_preprocess(name="pnp_vqa", model_type="base", is_eval=True, device=device) | |
input_image = input_image.resize((256, 256)) | |
image = vis_processors["eval"](input_image).unsqueeze(0).to(device) | |
text_input = txt_processors["eval"](prompt) | |
sample = {"image": image, "text_input": [text_input]} | |
output = model.forward_itm(samples=sample) | |
activation_map = output['gradcams'].reshape(24, 24) | |
relu = torch.nn.ReLU() | |
# ヒートマップを計算 | |
heatmap = create_heatmap(activation_map) | |
heatmap = Image.fromarray(heatmap) | |
heatmap = torchvision.transforms.functional.to_tensor(heatmap) | |
heatmap = relu(heatmap) | |
heatmap = torchvision.transforms.functional.to_pil_image(heatmap) | |
heatmap = heatmap.resize((256, 256)) | |
heatmap = np.array(heatmap) | |
heatmap = torch.sigmoid(torch.from_numpy(heatmap)).numpy() | |
preds = heatmap.reshape(256, 256, -1) | |
preds = Image.fromarray(preds.astype(np.uint8)).convert('L') | |
preds = np.array(preds) | |
preds = np.where(preds > 0.5, 255, 0) | |
return preds | |
if __name__ == '__main__': | |
input_image = gr.inputs.Image(label='image', type='pil') | |
prompt = gr.Textbox(label='Prompt') | |
ips = [ | |
input_image, prompt | |
] | |
outputs = "image" | |
iface = gr.Interface(fn=process, | |
inputs=ips, | |
outputs=outputs) | |
iface.launch() |