metric-explorer / app.py
Sasha
adding a catch for no metrics selected
6bd8b05
from bleu import Bleu
from rouge import Rouge
from datasets import load_metric
from pathlib import Path
import streamlit as st
import streamlit.components.v1 as components
#from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
rouge = Rouge()
bleu = Bleu()
def read_markdown_file(markdown_file):
return Path(markdown_file).read_text()
metrics= ['rouge','bleu']
def compute(data):
return metric.compute(predictions=data["predictions"], references=data["references"])["accuracy"]
st.sidebar.markdown("Choose a functionality below:")
with st.sidebar.expander("Compare one or more metrics", expanded=True):
metric_names = st.multiselect(
f"Choose metrics to explore:",
metrics,
default="rouge")
loaded_metrics= []
for metric in metric_names:
metric = load_metric(metric)
loaded_metrics.append(metric)
### Single metric mode
print(metric_names)
if metric_names == []:
st.markdown("## Please choose one or more metrics.")
elif len(metric_names) == 1:
metric_name = metric_names[0]
st.markdown("# You chose " + metric_name.upper())
st.markdown("## You can test it out below:")
reference = st.text_input(label= 'Input a reference sentence here:', value = "hello world")
prediction = st.text_input(label= 'Input a prediction sentence here:', value = "goodnight moon")
predictions = []
predictions.append(prediction.split())
#print(predictions)
references = []
references.append(reference.split())
#print(references)
if metric_name == "bleu":
score = metric.compute(predictions=predictions, references=[references])
col1, col2, col3 = st.columns(3)
col1.metric("BLEU", score['bleu'])
col2.metric("Brevity penalty", score['brevity_penalty'])
col3.metric('Length Ratio', score['length_ratio'])
if metric_name == "rouge":
score = metric.compute(predictions=predictions, references=references)
#print(score)
col1, col2, col3 = st.columns(3)
col1.metric("Rouge 1 Precision", score['rouge1'].mid.precision)
col2.metric("Rouge 1 Recall", score['rouge1'].mid.recall)
col3.metric("Rouge 1 FMeasure", score['rouge1'].mid.fmeasure)
col4, col5, col6 = st.columns(3)
col4.metric("Rouge 2 Precision", score['rouge2'].mid.precision)
col5.metric("Rouge 2 Recall", score['rouge2'].mid.recall)
col6.metric("Rouge 2 FMeasure", score['rouge2'].mid.fmeasure)
# col1.metric("BLEU", score['bleu'])
# col2.metric("Brevity penalty", score['brevity_penalty'])
# col3.metric('Length Ratio', score['length_ratio'])
st.markdown('===================================================================================')
#components.html("""<hr style="height:10px;border:none;color:#333;background-color:#333;" /> """)
st.markdown(read_markdown_file(metric_name+"_metric_card.md"))
# Multiple metric mode
else:
metric1 = metric_names[0]
metric2 = metric_names[1]
st.markdown("# You chose " + metric1.upper() + " and " + metric2.upper())
st.markdown("## You can test it out below:")
reference = st.text_input(label= 'Input a reference sentence here:', value = "hello world")
prediction = st.text_input(label= 'Input a prediction sentence here:', value = "goodnight moon")
predictions = []
predictions.append(prediction.split())
#print(predictions)
references = []
references.append(reference.split())
#print(references)
if "bleu" in metric_names:
bleu_ix = metric_names.index("bleu")
bleu_score = loaded_metrics[bleu_ix].compute(predictions=predictions, references=[references])
col1, col2, col3 = st.columns(3)
col1.metric("BLEU", bleu_score['bleu'])
col2.metric("Brevity penalty", bleu_score['brevity_penalty'])
col3.metric('Length Ratio', bleu_score['length_ratio'])
if "rouge" in metric_names:
rouge_ix = metric_names.index("rouge")
rouge_score = loaded_metrics[rouge_ix].compute(predictions=predictions, references=references)
#print(score)
col1, col2, col3 = st.columns(3)
col1.metric("Rouge 1 Precision", rouge_score['rouge1'].mid.precision)
col2.metric("Rouge 1 Recall", rouge_score['rouge1'].mid.recall)
col3.metric("Rouge 1 FMeasure", rouge_score['rouge1'].mid.fmeasure)
col4, col5, col6 = st.columns(3)
col4.metric("Rouge 2 Precision", rouge_score['rouge2'].mid.precision)
col5.metric("Rouge 2 Recall", rouge_score['rouge2'].mid.recall)
col6.metric("Rouge 2 FMeasure", rouge_score['rouge2'].mid.fmeasure)