|
from transformers import pipeline |
|
from multilingual_translation import text_to_text_generation |
|
from utils import lang_ids |
|
import gradio as gr |
|
|
|
paper_id = "kadirnar/biogpt_paper" |
|
|
|
biogpt_model_list = [ |
|
"microsoft/biogpt", |
|
"microsoft/BioGPT-Large", |
|
"microsoft/BioGPT-Large-PubMedQA" |
|
] |
|
|
|
lang_model_list = [ |
|
"facebook/m2m100_1.2B", |
|
"facebook/m2m100_418M" |
|
] |
|
|
|
whisper_model_list = [ |
|
"openai/whisper-small", |
|
"openai/whisper-medium", |
|
"openai/whisper-tiny", |
|
"openai/whisper-large" |
|
] |
|
|
|
lang_list = list(lang_ids.keys()) |
|
|
|
def whisper_demo(input_audio, model_id): |
|
pipe = pipeline(task="automatic-speech-recognition",model=model_id, device='cuda:0') |
|
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language='en', task="transcribe") |
|
output_text = pipe(input_audio)['text'] |
|
return output_text |
|
|
|
|
|
def translate_to_english(prompt, lang_model_id, base_lang): |
|
if base_lang == "English": |
|
return prompt |
|
else: |
|
output_text = text_to_text_generation( |
|
prompt=prompt, |
|
model_id=lang_model_id, |
|
device='cuda:0', |
|
target_lang='en' |
|
) |
|
|
|
return output_text[0] |
|
|
|
def biogpt_text( |
|
prompt: str, |
|
biogpt_model_id: str, |
|
lang_model_id: str, |
|
base_lang: str, |
|
): |
|
|
|
en_prompt = translate_to_english(prompt, lang_model_id, base_lang) |
|
generator = pipeline("text-generation", model=biogpt_model_id, device="cuda:0") |
|
output = generator(en_prompt, max_length=250, num_return_sequences=1, do_sample=True) |
|
output = output[0]['generated_text'] |
|
|
|
if base_lang == "English": |
|
output_text = output |
|
|
|
else: |
|
output_text = text_to_text_generation( |
|
prompt=output, |
|
model_id=lang_model_id, |
|
device='cuda:0', |
|
target_lang=lang_ids[base_lang] |
|
) |
|
|
|
return en_prompt, output, output_text |
|
|
|
def biogpt_audio( |
|
input_audio: str, |
|
biogpt_model_id: str, |
|
whisper_model_id: str, |
|
base_lang: str, |
|
lang_model_id: str, |
|
): |
|
en_prompt = whisper_demo(input_audio=input_audio, model_id=whisper_model_id) |
|
generator = pipeline("text-generation", model=biogpt_model_id, device="cuda:0") |
|
output = generator(en_prompt, max_length=250, num_return_sequences=1, do_sample=True) |
|
output = output[0]['generated_text'] |
|
if base_lang == "English": |
|
output_text = output |
|
|
|
else: |
|
output_text = text_to_text_generation( |
|
prompt=output, |
|
model_id=lang_model_id, |
|
device='cuda:0', |
|
target_lang=lang_ids[base_lang] |
|
) |
|
|
|
return en_prompt, output, output_text |
|
|
|
question_example = "Can 'high-risk' human papillomaviruses (HPVs) be detected in human breast milk? context: Using polymerase chain reaction techniques, we evaluated the presence of HPV infection in human breast milk collected from 21 HPV-positive and 11 HPV-negative mothers. Of the 32 studied human milk specimens, no 'high-risk' HPV 16, 18, 31, 33, 35, 39, 45, 51, 52, 56, 58 or 58 DNA was detected. answer: This preliminary case-control study indicates the absence of mucosal 'high-risk' HPV types in human breast milk." |
|
|
|
examples = [ |
|
["COVID-19 is", biogpt_model_list[0], lang_model_list[1], "English"], |
|
[question_example, biogpt_model_list[2], lang_model_list[1], "English"] |
|
] |
|
|
|
|
|
app = gr.Blocks() |
|
with app: |
|
gr.Markdown("# **<h4 align='center'>Whisper + M2M100 + BioGPT: Generative Pre-trained Transformer for Biomedical Text Generation and Mining<h4>**") |
|
gr.Markdown( |
|
""" |
|
<h5 style='text-align: center'> |
|
Follow me for more! |
|
<a href='https://twitter.com/kadirnar_ai' target='_blank'>Twitter</a> | <a href='https://github.com/kadirnar' target='_blank'>Github</a> | <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>Linkedin</a> | |
|
</h5> |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Tab("Text"): |
|
input_text = gr.Textbox(lines=3, value="COVID-19 is", label="Text") |
|
text_biogpt = gr.Dropdown(choices=biogpt_model_list, value=biogpt_model_list[0], label='BioGpt Model') |
|
text_m2m100 = gr.Dropdown(choices=lang_model_list, value=lang_model_list[1], label='Language Model') |
|
text_lang = gr.Dropdown(lang_list, value="English", label="Base Language") |
|
text_button = gr.Button(value="Predict") |
|
|
|
with gr.Tab("Audio"): |
|
input_audio = gr.Audio(source="microphone", type="filepath", label='Audio') |
|
audio_biogpt = gr.Dropdown(choices=biogpt_model_list, value=biogpt_model_list[0], label='BioGpt Model') |
|
audio_whisper = gr.Dropdown(choices=whisper_model_list, value=whisper_model_list[0], label='Audio Model') |
|
audio_lang = gr.Dropdown(lang_list, value="English", label="Base Language") |
|
audio_m2m100 = gr.Dropdown(choices=lang_model_list, value=lang_model_list[1], label='Language Model') |
|
audio_button = gr.Button(value="Predict") |
|
|
|
with gr.Tab("Output"): |
|
with gr.Column(): |
|
prompt_text = gr.Textbox(lines=3, label="Prompt") |
|
output_text = gr.Textbox(lines=3, label="BioGpt Text") |
|
translated_text = gr.Textbox(lines=3,label="Translated Text") |
|
|
|
gr.Examples(examples, inputs=[input_text, text_biogpt, text_m2m100,text_lang], outputs=[prompt_text, output_text, translated_text], fn=biogpt_text, cache_examples=False) |
|
text_button.click(biogpt_text, inputs=[input_text, text_biogpt, text_m2m100 ,text_lang], outputs=[prompt_text, output_text, translated_text]) |
|
audio_button.click(biogpt_audio, inputs=[input_audio, audio_biogpt, audio_whisper, audio_lang, audio_m2m100], outputs=[prompt_text, output_text, translated_text]) |
|
|
|
app.launch() |