import gradio as gr import torch import numpy as np from modules.models import * from util import get_prompt_template from torchvision import transforms as vt import torchaudio from PIL import Image def greet(image, audio): device = torch.device('cpu') # Get model model_conf_file = f'./config/model/ACL_ViT16.yaml' model = ACL(model_conf_file, device) model.train(False) model.load('./pretrain/Param_best.pth') # Get placeholder text prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template() # Input pre processing sample_rate, audio = audio audio = audio.astype(np.float32, order='C') / 32768.0 desired_sample_rate = 16000 set_length = 10 audio_file = torch.from_numpy(audio) if desired_sample_rate != sample_rate: audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate) if audio_file.shape[0] == 2: audio_file = torch.concat([audio_file[0], audio_file[1]], dim=0) # Stereo -> mono (x2 duration) audio_file.squeeze(0) if audio_file.shape[0] > (desired_sample_rate * set_length): audio_file = audio_file[:desired_sample_rate * set_length] # zero padding if audio_file.shape[0] < (desired_sample_rate * set_length): pad_len = (desired_sample_rate * set_length) - audio_file.shape[0] pad_val = torch.zeros(pad_len) audio_file = torch.cat((audio_file, pad_val), dim=0) audio_file = audio_file.unsqueeze(0) image_transform = vt.Compose([ vt.Resize((352, 352), vt.InterpolationMode.BICUBIC), vt.ToTensor(), vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP ]) image_file = image_transform(image).unsqueeze(0) # Inference placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt, prompt_length) # Localization result out_dict = model(image_file.to(model.device), audio_driven_embedding, 352) seg = out_dict['heatmap'][0:1] seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8) seg_image = Image.fromarray(seg_image) heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET) overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0) return overlaid_image title = "Audio-Grounded Contrastive Learning" description = """

This is a simple demo of Can CLIP Help Sound Source Localization? (WACV2024), zero-shot visual sound localization.

To use it simply upload an image and corresponding audio to mask (identify in the image), or use one of the examples below and click ‘submit’.

""" article = "

Can CLIP Help Sound Source Localization? | Offical Github repo

" demo = gr.Interface( fn=greet, inputs=[gr.Image(type='pil'), gr.Audio()], outputs=gr.Image(type="pil"), title=title, description=description, ) demo.launch(debug=True)