speech / app.py
seu-ebers's picture
Commit
70f1891
raw
history blame contribute delete
No virus
2.71 kB
import streamlit as st
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, MBart50TokenizerFast, MBartForConditionalGeneration, AutoProcessor, AutoModel
from PIL import Image
import requests
from IPython.display import Audio
import sys
from PIL import Image
# Load Image to Text model
image_processor = AutoProcessor.from_pretrained("sezenkarakus/image-GIT-description-model-v3")
image_to_text_model = AutoModelForCausalLM.from_pretrained("sezenkarakus/image-GIT-description-model-v3")
# Load Translation model
ckpt = 'Narrativa/mbart-large-50-finetuned-opus-en-pt-translation'
tokenizer = MBart50TokenizerFast.from_pretrained(ckpt)
translation_model = MBartForConditionalGeneration.from_pretrained(ckpt)
tokenizer.src_lang = 'en_XX'
# Load Audio Model
audio_processor = AutoProcessor.from_pretrained("suno/bark-small")
audio_model = AutoModel.from_pretrained("suno/bark-small")
# Methods
def generate_caption(image):
pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
generated_ids = image_to_text_model.generate(pixel_values=pixel_values, max_length=200)
generated_caption = image_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption
def translate(text):
inputs = tokenizer(text, return_tensors='pt')
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
try:
input_ids = input_ids.to('cuda')
attention_mask = attention_mask.to('cuda')
model = translation_model.to("cuda")
except:
print('No NVidia GPU, model performance may not be as good')
model = translation_model
output = model.generate(input_ids, attention_mask=attention_mask, forced_bos_token_id=tokenizer.lang_code_to_id['pt_XX'])
translated = tokenizer.decode(output[0], skip_special_tokens=True)
return translated
# Carregamento de imagens locais
img_url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# img_url = 'https://farm4.staticflickr.com/3733/9000662079_ce3599d0d8_z.jpg'
# img_url = 'https://farm4.staticflickr.com/3088/5793281956_2a15b2559c_z.jpg'
# img_url = 'https://farm5.staticflickr.com/4073/4816939054_844feb0078_z.jpg'
image = Image.open(requests.get(img_url, stream=True).raw)
# Generate using models
# Generate text from image
caption = generate_caption(image)
print(caption)
# Translate
translated_caption = translate(caption)
print(translated_caption)
# # Generate Audio
# inputs = audio_processor(
# text=caption,
# return_tensors="pt",
# )
#
# speech_values = audio_model.generate(**inputs, do_sample=True)
#
# sampling_rate = audio_model.generation_config.sample_rate
# Audio(speech_values.cpu().numpy().squeeze(), rate=sampling_rate)