import gradio as gr import torch import numpy as np from modules.models import * from util import get_prompt_template from torchvision import transforms as vt import torchaudio from PIL import Image def greet(image, audio): device = torch.device('cpu') # Get model model_conf_file = f'./config/model/ACL_ViT16.yaml' model = ACL(model_conf_file, device) model.train(False) model.load('./pretrain/Param_best.pth') # Get placeholder text prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template() # Input pre processing sample_rate, audio = audio audio = audio.astype(np.float32, order='C') / 32768.0 desired_sample_rate = 16000 set_length = 10 audio_file = torch.from_numpy(audio) if desired_sample_rate != sample_rate: audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate) if audio_file.shape[0] == 2: audio_file = torch.concat([audio_file[0], audio_file[1]], dim=0) # Stereo -> mono (x2 duration) audio_file.squeeze(0) if audio_file.shape[0] > (desired_sample_rate * set_length): audio_file = audio_file[:desired_sample_rate * set_length] # zero padding if audio_file.shape[0] < (desired_sample_rate * set_length): pad_len = (desired_sample_rate * set_length) - audio_file.shape[0] pad_val = torch.zeros(pad_len) audio_file = torch.cat((audio_file, pad_val), dim=0) audio_file = audio_file.unsqueeze(0) image_transform = vt.Compose([ vt.Resize((352, 352), vt.InterpolationMode.BICUBIC), vt.ToTensor(), vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP ]) image_file = image_transform(image) # Inference placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt, prompt_length) # Localization result out_dict = model(image_file.to(model.device), audio_driven_embedding, 352) seg = out_dict['heatmap'][j:j + 1] seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8) seg_image = Image.fromarray(seg_image) heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET) overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0) return overlaid_image description = 'hello world' demo = gr.Interface( fn=greet, inputs=[gr.Image(type='pil'), gr.Audio()], outputs=gr.Image(type="pil"), title='AudioToken', description=description, ) demo.launch(debug=True)