ronniet's picture
Update app.py
a4204a9
raw history blame
No virus
3.11 kB
import gradio as gr
from transformers import pipeline
import librosa
import numpy as np
import torch
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from transformers import AutoProcessor, AutoModelForCausalLM
checkpoint = "microsoft/speecht5_tts"
tts_processor = SpeechT5Processor.from_pretrained(checkpoint)
tts_model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
vqa_processor = AutoProcessor.from_pretrained("ronniet/git-large-vqa-env")
vqa_model = AutoModelForCausalLM.from_pretrained("ronniet/git-large-vqa-env")
def tts(text):
if len(text.strip()) == 0:
return (16000, np.zeros(0).astype(np.int16))
inputs = tts_processor(text=text, return_tensors="pt")
# limit input length
input_ids = inputs["input_ids"]
input_ids = input_ids[..., :tts_model.config.max_text_positions]
# if speaker == "Surprise Me!":
# # load one of the provided speaker embeddings at random
# idx = np.random.randint(len(speaker_embeddings))
# key = list(speaker_embeddings.keys())[idx]
# speaker_embedding = np.load(speaker_embeddings[key])
# # randomly shuffle the elements
# np.random.shuffle(speaker_embedding)
# # randomly flip half the values
# x = (np.random.rand(512) >= 0.5) * 1.0
# x[x == 0] = -1.0
# speaker_embedding *= x
#speaker_embedding = np.random.rand(512).astype(np.float32) * 0.3 - 0.15
# else:
speaker_embedding = np.load("cmu_us_bdl_arctic-wav-arctic_a0009.npy")
speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)
speech = tts_model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
speech = (speech.numpy() * 32767).astype(np.int16)
return (16000, speech)
# captioner = pipeline(model="microsoft/git-base")
# tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=False)
def predict(image, prompt):
pixel_values = vqa_processor(images=image, return_tensors="pt").pixel_values
# prompt = "what is in the scene?"
prompt_ids = vqa_processor(text=prompt, add_special_tokens=False).input_ids
prompt_ids = [vqa_processor.tokenizer.cls_token_id] + prompt_ids
prompt_ids = torch.tensor(prompt_ids).unsqueeze(0)
text_ids = vqa_model.generate(pixel_values=pixel_values, input_ids=prompt_ids, max_length=50)
text = vqa_processor.batch_decode(text_ids, skip_special_tokens=True)[0][len(prompt):]
audio = tts(text)
return text, audio
demo = gr.Interface(
fn=predict,
inputs=[gr.Image(type="pil",label="Environment"), gr.Textbox(label="Prompt", value="What is in the scene?")],
outputs=[gr.Textbox(label="Caption"), gr.Audio(type="numpy",label="Audio Feedback")],
css=".gradio-container {background-color: #002A5B}",
theme=gr.themes.Soft() #.set(
# button_primary_background_fill="#AAAAAA",
# button_primary_border="*button_primary_background_fill_dark"
# )
)
demo.launch()