Spaces:
Sleeping
Sleeping
from transformers import pipeline | |
from multilingual_translation import text_to_text_generation | |
from utils import lang_ids | |
import gradio as gr | |
biogpt_model_list = [ | |
"microsoft/biogpt", | |
"microsoft/BioGPT-Large", | |
"microsoft/BioGPT-Large-PubMedQA" | |
] | |
lang_model_list = [ | |
"facebook/m2m100_1.2B", | |
"facebook/m2m100_418M" | |
] | |
whisper_model_list = [ | |
"openai/whisper-small", | |
"openai/whisper-medium", | |
"openai/whisper-tiny", | |
"openai/whisper-large" | |
] | |
lang_list = list(lang_ids.keys()) | |
def whisper_demo(input_audio, model_id): | |
pipe = pipeline(task="automatic-speech-recognition",model=model_id, device='cuda:0') | |
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language='en', task="transcribe") | |
output_text = pipe(input_audio)['text'] | |
return output_text | |
def translate_to_english(prompt, lang_model_id, base_lang): | |
if base_lang == "English": | |
return prompt | |
else: | |
text_output = text_to_text_generation( | |
prompt=prompt, | |
model_id=lang_model_id, | |
device='cuda:0', | |
target_lang='en' | |
) | |
return text_output[0] | |
def biogpt_text( | |
prompt: str, | |
biogpt_model_id: str, | |
lang_model_id: str, | |
base_lang: str, | |
): | |
en_prompt = translate_to_english(prompt, lang_model_id, base_lang) | |
generator = pipeline("text-generation", model=biogpt_model_id, device="cuda:0") | |
output = generator(en_prompt, max_length=250, num_return_sequences=1, do_sample=True) | |
output = output[0]['generated_text'] | |
if base_lang == "English": | |
output_text = output | |
else: | |
output_text = text_to_text_generation( | |
prompt=output, | |
model_id=lang_model_id, | |
device='cuda:0', | |
target_lang=lang_ids[base_lang] | |
) | |
return en_prompt, output, output_text | |
def biogpt_audio( | |
input_audio: str, | |
biogpt_model_id: str, | |
whisper_model_id: str, | |
max_length: str, | |
num_return_sequences: int | |
): | |
en_prompt = whisper_demo(input_audio=input_audio, model_id=whisper_model_id) | |
generator = pipeline("text-generation", model=biogpt_model_id, device="cuda:0") | |
output = generator(en_prompt, max_length=max_length, num_return_sequences=num_return_sequences, do_sample=True) | |
output_dict = {} | |
for i in range(num_return_sequences): | |
output_dict[str(i+1)] = output[i]['generated_text'] | |
output_text = "" | |
for i in range(num_return_sequences): | |
output_text += f'{output_dict[str(i+1)]}\n\n' | |
return en_prompt, output_text, output_text | |
examples = [ | |
["COVID-19 is", biogpt_model_list[0], lang_model_list[1], "English"] | |
] | |
app = gr.Blocks() | |
with app: | |
gr.Markdown("# **<h4 align='center'>Whisper + M2M100 + BioGPT: Generative Pre-trained Transformer for Biomedical Text Generation and Mining</h4>**") | |
gr.Markdown( | |
""" | |
<p style='text-align: center'> | |
Follow me for more! | |
<br> <a href='https://twitter.com/kadirnar_ai' target='_blank'>twitter</a> | <a href='https://github.com/kadirnar' target='_blank'>github</a> | <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>linkedin</a> | | |
</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Tab("Text"): | |
input_text = gr.Textbox(lines=3, value="COVID-19 is", label="Text") | |
input_text_button = gr.Button(value="Predict") | |
input_biogpt_model =gr.Dropdown(choices=biogpt_model_list, value=biogpt_model_list[0], label='BioGpt Model') | |
input_m2m100_model =gr.Dropdown(choices=lang_model_list, value=lang_model_list[1], label='Language Model') | |
input_base_lang = gr.Dropdown(lang_list, value="English", label="Base Language") | |
with gr.Tab("Audio"): | |
input_audio = gr.Microphone(label='Audio') | |
input_audio_button = gr.Button(value="Predict") | |
with gr.Column(): | |
prompt_text = gr.Textbox(lines=3, label="Prompt") | |
output_text = gr.Textbox(lines=3, label="BioGpt Text") | |
translated_text = gr.Textbox(lines=3,label="Translated Text") | |
gr.Examples(examples, inputs=[input_text, input_biogpt_model, input_m2m100_model,input_base_lang], outputs=[prompt_text, output_text, translated_text], fn=biogpt_text, cache_examples=False) | |
input_text_button.click(biogpt_text, inputs=[input_text, input_biogpt_model, input_m2m100_model,input_base_lang], outputs=[prompt_text, output_text, translated_text]) | |
input_audio_button.click(biogpt_audio, inputs=[input_audio, input_biogpt_model,input_m2m100_model,input_base_lang], outputs=[prompt_text, output_text, translated_text]) | |
app.launch() | |