kadirnar's picture
Update app.py
2bd56bd
raw
history blame
No virus
5.14 kB
from transformers import pipeline
from multilingual_translation import text_to_text_generation
from utils import lang_ids
import gradio as gr
biogpt_model_list = [
"microsoft/biogpt",
"microsoft/BioGPT-Large",
"microsoft/BioGPT-Large-PubMedQA"
]
lang_model_list = [
"facebook/m2m100_1.2B",
"facebook/m2m100_418M"
]
whisper_model_list = [
"openai/whisper-small",
"openai/whisper-medium",
"openai/whisper-tiny",
"openai/whisper-large"
]
lang_list = list(lang_ids.keys())
def whisper_demo(input_audio, model_id):
pipe = pipeline(task="automatic-speech-recognition",model=model_id, device='cuda:0')
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language='en', task="transcribe")
output_text = pipe(input_audio)['text']
return output_text
def translate_to_english(prompt, lang_model_id, base_lang):
if base_lang == "English":
return prompt
else:
output_text = text_to_text_generation(
prompt=prompt,
model_id=lang_model_id,
device='cuda:0',
target_lang='en'
)
return output_text[0]
def biogpt_text(
prompt: str,
biogpt_model_id: str,
lang_model_id: str,
base_lang: str,
):
en_prompt = translate_to_english(prompt, lang_model_id, base_lang)
generator = pipeline("text-generation", model=biogpt_model_id, device="cuda:0")
output = generator(en_prompt, max_length=250, num_return_sequences=1, do_sample=True)
output = output[0]['generated_text']
if base_lang == "English":
output_text = output
else:
output_text = text_to_text_generation(
prompt=output,
model_id=lang_model_id,
device='cuda:0',
target_lang=lang_ids[base_lang]
)
return en_prompt, output, output_text
def biogpt_audio(
input_audio: str,
biogpt_model_id: str,
whisper_model_id: str,
base_lang: str,
lang_model_id: str,
):
en_prompt = whisper_demo(input_audio=input_audio, model_id=whisper_model_id)
generator = pipeline("text-generation", model=biogpt_model_id, device="cuda:0")
output = generator(en_prompt, max_length=250, num_return_sequences=1, do_sample=True)
output = output[0]['generated_text']
if base_lang == "English":
output_text = output
else:
output_text = text_to_text_generation(
prompt=output,
model_id=lang_model_id,
device='cuda:0',
target_lang=lang_ids[base_lang]
)
return en_prompt, output, output_text
examples = [["COVID-19 is", biogpt_model_list[0], lang_model_list[1], "English"]]
app = gr.Blocks()
with app:
gr.Markdown("# **<h3 align='center'>Whisper + M2M100 + BioGPT: Generative Pre-trained Transformer for Biomedical Text Generation and Mining<h3>**")
gr.Markdown("# **<<h5><a href='https://twitter.com/kadirnar_ai' target='_blank'>twitter</a> | <a href='https://github.com/kadirnar' target='_blank'>github</a> | <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>linkedin</a> |<h5>**""
)
with gr.Row():
with gr.Column():
with gr.Tab("Text"):
input_text = gr.Textbox(lines=3, value="COVID-19 is", label="Text")
text_biogpt = gr.Dropdown(choices=biogpt_model_list, value=biogpt_model_list[0], label='BioGpt Model')
text_m2m100 = gr.Dropdown(choices=lang_model_list, value=lang_model_list[1], label='Language Model')
text_lang = gr.Dropdown(lang_list, value="English", label="Base Language")
text_button = gr.Button(value="Predict")
with gr.Tab("Audio"):
input_audio = gr.Audio(source="microphone", type="filepath", label='Audio')
audio_biogpt = gr.Dropdown(choices=biogpt_model_list, value=biogpt_model_list[0], label='BioGpt Model')
audio_whisper = gr.Dropdown(choices=whisper_model_list, value=whisper_model_list[0], label='Audio Model')
audio_lang = gr.Dropdown(lang_list, value="English", label="Base Language")
audio_m2m100 = gr.Dropdown(choices=lang_model_list, value=lang_model_list[1], label='Language Model')
audio_button = gr.Button(value="Predict")
with gr.Column():
prompt_text = gr.Textbox(lines=3, label="Prompt")
output_text = gr.Textbox(lines=3, label="BioGpt Text")
translated_text = gr.Textbox(lines=3,label="Translated Text")
gr.Examples(examples, inputs=[input_text, text_biogpt, text_m2m100,text_lang], outputs=[prompt_text, output_text, translated_text], fn=biogpt_text, cache_examples=False)
text_button.click(biogpt_text, inputs=[input_text, text_biogpt, text_m2m100 ,text_lang], outputs=[prompt_text, output_text, translated_text])
audio_button.click(biogpt_audio, inputs=[input_audio, audio_biogpt, audio_whisper, audio_lang, audio_m2m100], outputs=[prompt_text, output_text, translated_text])
app.launch()