swimmiing's picture
Fix url v2
0ab5cdb
raw
history blame
3.04 kB
import gradio as gr
import torch
import numpy as np
from modules.models import *
from util import get_prompt_template
from torchvision import transforms as vt
import torchaudio
from PIL import Image
def greet(image, audio):
device = torch.device('cpu')
# Get model
model_conf_file = f'./config/model/ACL_ViT16.yaml'
model = ACL(model_conf_file, device)
model.train(False)
model.load('./pretrain/Param_best.pth')
# Get placeholder text
prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()
# Input pre processing
sample_rate, audio = audio
audio = audio.astype(np.float32, order='C') / 32768.0
desired_sample_rate = 16000
set_length = 10
audio_file = torch.from_numpy(audio)
if desired_sample_rate != sample_rate:
audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate)
if audio_file.shape[0] == 2:
audio_file = torch.concat([audio_file[0], audio_file[1]], dim=0) # Stereo -> mono (x2 duration)
audio_file.squeeze(0)
if audio_file.shape[0] > (desired_sample_rate * set_length):
audio_file = audio_file[:desired_sample_rate * set_length]
# zero padding
if audio_file.shape[0] < (desired_sample_rate * set_length):
pad_len = (desired_sample_rate * set_length) - audio_file.shape[0]
pad_val = torch.zeros(pad_len)
audio_file = torch.cat((audio_file, pad_val), dim=0)
audio_file = audio_file.unsqueeze(0)
image_transform = vt.Compose([
vt.Resize((352, 352), vt.InterpolationMode.BICUBIC),
vt.ToTensor(),
vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP
])
image_file = image_transform(image).unsqueeze(0)
# Inference
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt,
prompt_length)
# Localization result
out_dict = model(image_file.to(model.device), audio_driven_embedding, 352)
seg = out_dict['heatmap'][0:1]
seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
seg_image = Image.fromarray(seg_image)
heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0)
return overlaid_image
title = "Audio-Grounded Contrastive Learning"
description = "Hello"
article = """<p style='text-align: center'><a href='https://arxiv.org/abs/2311.04066'>Can CLIP Help Sound Source
Localization?</a> | <a href='https://github.com/swimmiing/ACL-SSL'>Official Github Repository</a></p>"""
demo = gr.Interface(
fn=greet,
inputs=[gr.Image(type='pil'), gr.Audio()],
outputs=gr.Image(type="pil"),
title=title,
description=description,
article=article,
)
demo.launch(debug=True)