import gradio as gr import torch import numpy as np from modules.models import * from util import get_prompt_template from torchvision import transforms as vt import torchaudio from PIL import Image import cv2 def greet(image, audio): device = torch.device('cpu') # Get model model_conf_file = f'./config/model/ACL_ViT16.yaml' model = ACL(model_conf_file, device) model.train(False) model.load('./pretrain/Param_best.pth') # Get placeholder text prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template() # Input pre processing sample_rate, audio = audio audio = audio.astype(np.float32, order='C') / 32768.0 desired_sample_rate = 16000 set_length = 10 audio_file = torch.from_numpy(audio) if len(audio_file.shape) == 2: audio_file = torch.concat([audio_file[:, 0:1], audio_file[:, 1:2]], dim=0).T # Stereo -> mono (x2 duration) else: audio_file = audio_file.unsqueeze(0) if desired_sample_rate != sample_rate: audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate) audio_file = audio_file.squeeze(0) if audio_file.shape[0] > (desired_sample_rate * set_length): audio_file = audio_file[:desired_sample_rate * set_length] # zero padding if audio_file.shape[0] < (desired_sample_rate * set_length): pad_len = (desired_sample_rate * set_length) - audio_file.shape[0] pad_val = torch.zeros(pad_len) audio_file = torch.cat((audio_file, pad_val), dim=0) audio_file = audio_file.unsqueeze(0) image_transform = vt.Compose([ vt.Resize((352, 352), vt.InterpolationMode.BICUBIC), vt.ToTensor(), vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP ]) image_file = image_transform(image).unsqueeze(0) # Inference placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt, prompt_length) # Localization result out_dict = model(image_file.to(model.device), audio_driven_embedding, 352) seg = out_dict['heatmap'][0:1] seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8) seg_image = Image.fromarray(seg_image) seg_image = seg_image.resize(image.size, Image.BICUBIC) heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET) overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0) return overlaid_image title = "Audio-Grounded Contrastive Learning" description = """
This is a simple demo of our WACV'24 paper 'Can CLIP Help Sound Source Localization?', zero-shot visual sound localization.
To use it simply upload an image and corresponding audio to mask (identify in the image), or use one of the examples below and click ‘submit’.
Results will show up in a few seconds.
It is recommended to use audio sources with a sample rate of 16 kHz or higher, and the model does not utilize audio beyond the initial 10 seconds.
Can CLIP Help Sound Source Localization? | Offical Github repo
" examples = [['./asset/web_image1.jpeg', './asset/web_dog_barking.wav'], ['./asset/web_image1.jpeg', './asset/web_child_laugh.wav'], ['./asset/web_image1.jpeg', './asset/web_car_horns.wav'], ['./asset/web_image1.jpeg', './asset/web_motorcycle_pass_by.wav'], ['./asset/web_image2.jpeg', './asset/web_dog_barking.wav'], ['./asset/web_image2.jpeg', './asset/web_female_speech.wav'], ['./asset/web_image2.jpeg', './asset/web_car_horns.wav'], ['./asset/web_image3.jpeg', './asset/web_motorcycle_pass_by.wav'], ['./asset/web_image3.jpeg', './asset/web_car_horns.wav'], ['./asset/web_image3.jpeg', './asset/web_wave.wav'], ['./asset/web_image4.jpeg', './asset/web_car_horns.wav'], ['./asset/web_image4.jpeg', './asset/web_wave.wav'], ['./asset/web_image4.jpeg', './asset/web_horse.wav'], ] demo = gr.Interface( fn=greet, inputs=[gr.Image(type='pil'), gr.Audio()], outputs=gr.Image(type="pil"), title=title, description=description, article=article, examples=examples ) demo.launch(debug=True)