File size: 3,271 Bytes
3c8d75e
b20af9f
 
 
 
334681f
 
b20af9f
 
3c8d75e
00a76b6
833fa47
b20af9f
 
 
 
 
 
 
 
 
 
 
334681f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f752a6
b20af9f
 
 
334681f
 
b20af9f
 
334681f
ddf85f7
334681f
 
 
 
 
 
b20af9f
ddf85f7
f455d39
 
9f952e8
a063f1d
9f952e8
 
a9320d1
a063f1d
 
a9320d1
 
 
 
6f752a6
a9320d1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import torch
import numpy as np
from modules.models import *
from util import get_prompt_template
from torchvision import transforms as vt
import torchaudio
from PIL import Image


def greet(image, audio):
    device = torch.device('cpu')

    # Get model
    model_conf_file = f'./config/model/ACL_ViT16.yaml'
    model = ACL(model_conf_file, device)
    model.train(False)
    model.load('./pretrain/Param_best.pth')

    # Get placeholder text
    prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()

    # Input pre processing
    sample_rate, audio = audio
    audio = audio.astype(np.float32, order='C') / 32768.0
    desired_sample_rate = 16000
    set_length = 10

    audio_file = torch.from_numpy(audio)

    if desired_sample_rate != sample_rate:
        audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate)

    if audio_file.shape[0] == 2:
        audio_file = torch.concat([audio_file[0], audio_file[1]], dim=0)  # Stereo -> mono (x2 duration)

    audio_file.squeeze(0)

    if audio_file.shape[0] > (desired_sample_rate * set_length):
        audio_file = audio_file[:desired_sample_rate * set_length]

    # zero padding
    if audio_file.shape[0] < (desired_sample_rate * set_length):
        pad_len = (desired_sample_rate * set_length) - audio_file.shape[0]
        pad_val = torch.zeros(pad_len)
        audio_file = torch.cat((audio_file, pad_val), dim=0)

    audio_file = audio_file.unsqueeze(0)

    image_transform = vt.Compose([
        vt.Resize((352, 352), vt.InterpolationMode.BICUBIC),
        vt.ToTensor(),
        vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),  # CLIP
    ])

    image_file = image_transform(image).unsqueeze(0)

    # Inference
    placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
    audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt,
                                                prompt_length)

    # Localization result
    out_dict = model(image_file.to(model.device), audio_driven_embedding, 352)
    seg = out_dict['heatmap'][0:1]
    seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
    seg_image = Image.fromarray(seg_image)
    heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
    overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0)

    return overlaid_image


title = "Audio-Grounded Contrastive Learning"

description = """<p>
This is a simple demo of Can CLIP Help Sound Source Localization? (WACV2024), zero-shot visual sound localization.<br><br>
To use it simply upload an image and corresponding audio to mask (identify in the image), or use one of the examples below and click ‘submit’.
</p>"""

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2311.04066'>Can CLIP Help Sound Source Localization?</a> | <a href='https://github.com/swimmiing/ACL-SSL'>Offical Github repo</a></p>"

demo = gr.Interface(
    fn=greet,
    inputs=[gr.Image(type='pil'), gr.Audio()],
    outputs=gr.Image(type="pil"),
    title=title,
    description=description,
)

demo.launch(debug=True)