from bleu import Bleu from rouge import Rouge from datasets import load_metric from pathlib import Path import streamlit as st import streamlit.components.v1 as components #from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py rouge = Rouge() bleu = Bleu() def read_markdown_file(markdown_file): return Path(markdown_file).read_text() def compute(data): return metric.compute(predictions=data["predictions"], references=data["references"])["accuracy"] with st.sidebar.expander("Metric", expanded=True): metrics= ['rouge','bleu'] metric_name = st.selectbox( f"Choose metric to explore:", metrics) metric = load_metric(metric_name) st.markdown("# You chose " + metric_name.upper()) st.markdown("## You can test it out below:") reference = st.text_input('Input a reference sentence here:') prediction = st.text_input('Input a prediction sentence here:') predictions = [] predictions.append(prediction.split()) #print(predictions) references = [] references.append(reference.split()) #print(references) if metric_name == "bleu": score = metric.compute(predictions=predictions, references=[references]) col1, col2, col3 = st.columns(3) col1.metric("BLEU", score['bleu']) col2.metric("Brevity penalty", score['brevity_penalty']) col3.metric('Length Ratio', score['length_ratio']) if metric_name == "rouge": score = metric.compute(predictions=predictions, references=references) #print(score) col1, col2, col3 = st.columns(3) col1.metric("Rouge 1 Precision", score['rouge1'].mid.precision) col2.metric("Rouge 1 Recall", score['rouge1'].mid.recall) col3.metric("Rouge 1 FMeasure", score['rouge1'].mid.fmeasure) col4, col5, col6 = st.columns(3) col4.metric("Rouge 2 Precision", score['rouge2'].mid.precision) col5.metric("Rouge 2 Recall", score['rouge2'].mid.recall) col6.metric("Rouge 2 FMeasure", score['rouge2'].mid.fmeasure) # col1.metric("BLEU", score['bleu']) # col2.metric("Brevity penalty", score['brevity_penalty']) # col3.metric('Length Ratio', score['length_ratio']) st.markdown('===================================================================================') #components.html("""
""") st.markdown(read_markdown_file(metric_name+"_metric_card.md"))