import gradio as gr import matplotlib.pyplot as plt from inference import RelationsInference from utils import KGType,Model_Type ############################# # Constants ############################# examples = [["What's the meaning of life?", "eli5", "constraint"], ["boat, water, bird", "commongen", "constraint"], ["What flows under a bridge?", "commonsense_qa", "constraint"]] bart = RelationsInference( model_path='MrVicente/commonsense_bart_commongen', kg_type=KGType.CONCEPTNET, model_type=Model_Type.RELATIONS, max_length=32 ) ############################# # Helper ############################# def infer_bart(context, task_type, decoding_type_str): response, encoder_attentions, model_input = bart.generate_based_on_context(context, use_kg=False) return response[0] def plot_attention(layer, head): fig = plt.figure() plt.plot([1, 2, 3], [2, 4, 6]) plt.title("Things") plt.ylabel("Cases") plt.xlabel("Days since Day 0") return fig ############################# # Interface ############################# app = gr.Blocks() with app: gr.Markdown( """ # Demo ### Test Commonsense Relation-Aware BART (BART-RA) model Tutorial:
1) Select the possible model variations and tasks;
2) Change the inputs and Click the buttons to produce results;
3) See attention visualisations, by choosing a specific layer and head;
""") with gr.Row(): context_input = gr.Textbox(lines=2, value="What's the meaning of life?", label='Input:') model_result_output = gr.Textbox(lines=2, label='Model result:') with gr.Column(): task_type_choice = gr.Radio( ["eli5", "commongen"], value="eli5", label="What task do you want to try?" ) decoding_type_choice = gr.Radio( ["default", "constraint"], value="default", label="What decoding strategy do you want to use?" ) with gr.Row(): model_btn = gr.Button(value="See Model Results") gr.Markdown( """ --- Observe Attention """ ) with gr.Row(): with gr.Column(): layer = gr.Slider(0, 11, 0, step=1, label="Layer") head = gr.Slider(0, 15, 0, step=1, label="Head") with gr.Column(): plot_output = gr.Plot() with gr.Row(): vis_btn = gr.Button(value="See Attention Scores") model_btn.click(fn=infer_bart, inputs=[context_input, task_type_choice, decoding_type_choice], outputs=[model_result_output]) vis_btn.click(fn=plot_attention, inputs=[layer, head], outputs=[plot_output]) if __name__ == '__main__': app.launch()