swimmiing's picture
Replace audio
6c8dc4b
raw
history blame contribute delete
No virus
4.62 kB
import gradio as gr
import torch
import numpy as np
from modules.models import *
from util import get_prompt_template
from torchvision import transforms as vt
import torchaudio
from PIL import Image
import cv2
def greet(image, audio):
device = torch.device('cpu')
# Get model
model_conf_file = f'./config/model/ACL_ViT16.yaml'
model = ACL(model_conf_file, device)
model.train(False)
model.load('./pretrain/Param_best.pth')
# Get placeholder text
prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()
# Input pre processing
sample_rate, audio = audio
audio = audio.astype(np.float32, order='C') / 32768.0
desired_sample_rate = 16000
set_length = 10
audio_file = torch.from_numpy(audio)
if len(audio_file.shape) == 2:
audio_file = torch.concat([audio_file[:, 0:1], audio_file[:, 1:2]], dim=0).T # Stereo -> mono (x2 duration)
else:
audio_file = audio_file.unsqueeze(0)
if desired_sample_rate != sample_rate:
audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate)
audio_file = audio_file.squeeze(0)
if audio_file.shape[0] > (desired_sample_rate * set_length):
audio_file = audio_file[:desired_sample_rate * set_length]
# zero padding
if audio_file.shape[0] < (desired_sample_rate * set_length):
pad_len = (desired_sample_rate * set_length) - audio_file.shape[0]
pad_val = torch.zeros(pad_len)
audio_file = torch.cat((audio_file, pad_val), dim=0)
audio_file = audio_file.unsqueeze(0)
image_transform = vt.Compose([
vt.Resize((352, 352), vt.InterpolationMode.BICUBIC),
vt.ToTensor(),
vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP
])
image_file = image_transform(image).unsqueeze(0)
# Inference
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt,
prompt_length)
# Localization result
out_dict = model(image_file.to(model.device), audio_driven_embedding, 352)
seg = out_dict['heatmap'][0:1]
seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
seg_image = Image.fromarray(seg_image)
seg_image = seg_image.resize(image.size, Image.BICUBIC)
heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0)
return overlaid_image
title = "Audio-Grounded Contrastive Learning"
description = """<p>
This is a simple demo of our WACV'24 paper 'Can CLIP Help Sound Source Localization?', zero-shot visual sound localization.<br><br>
To use it simply upload an image and corresponding audio to mask (identify in the image), or use one of the examples below and click ‘submit’.<br><br>
Results will show up in a few seconds. <br><br>
It is recommended to use audio sources with a sample rate of 16 kHz or higher, and the model does not utilize audio beyond the initial 10 seconds.
</p>"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2311.04066'>Can CLIP Help Sound Source Localization?</a> | <a href='https://github.com/swimmiing/ACL-SSL'>Offical Github repo</a></p>"
examples = [['./asset/web_image1.jpeg', './asset/web_dog_barking.wav'],
['./asset/web_image1.jpeg', './asset/web_child_laugh.wav'],
['./asset/web_image1.jpeg', './asset/web_car_horns.wav'],
['./asset/web_image1.jpeg', './asset/web_motorcycle_pass_by.wav'],
['./asset/web_image2.jpeg', './asset/web_dog_barking.wav'],
['./asset/web_image2.jpeg', './asset/web_female_speech.wav'],
['./asset/web_image2.jpeg', './asset/web_car_horns.wav'],
['./asset/web_image3.jpeg', './asset/web_motorcycle_pass_by.wav'],
['./asset/web_image3.jpeg', './asset/web_car_horns.wav'],
['./asset/web_image3.jpeg', './asset/web_wave.wav'],
['./asset/web_image4.jpeg', './asset/web_car_horns.wav'],
['./asset/web_image4.jpeg', './asset/web_wave.wav'],
['./asset/web_image4.jpeg', './asset/web_horse.wav'],
]
demo = gr.Interface(
fn=greet,
inputs=[gr.Image(type='pil'), gr.Audio()],
outputs=gr.Image(type="pil"),
title=title,
description=description,
article=article,
examples=examples
)
demo.launch(debug=True)