Spaces:
Build error
Build error
from bleu import Bleu | |
from rouge import Rouge | |
from datasets import load_metric | |
from pathlib import Path | |
import streamlit as st | |
import streamlit.components.v1 as components | |
#from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py | |
rouge = Rouge() | |
bleu = Bleu() | |
def read_markdown_file(markdown_file): | |
return Path(markdown_file).read_text() | |
metrics= ['rouge','bleu'] | |
def compute(data): | |
return metric.compute(predictions=data["predictions"], references=data["references"])["accuracy"] | |
st.sidebar.markdown("Choose a functionality below:") | |
with st.sidebar.expander("Compare one or more metrics", expanded=True): | |
metric_names = st.multiselect( | |
f"Choose metrics to explore:", | |
metrics, | |
default="rouge") | |
loaded_metrics= [] | |
for metric in metric_names: | |
metric = load_metric(metric) | |
loaded_metrics.append(metric) | |
### Single metric mode | |
print(metric_names) | |
if metric_names == []: | |
st.markdown("## Please choose one or more metrics.") | |
elif len(metric_names) == 1: | |
metric_name = metric_names[0] | |
st.markdown("# You chose " + metric_name.upper()) | |
st.markdown("## You can test it out below:") | |
reference = st.text_input(label= 'Input a reference sentence here:', value = "hello world") | |
prediction = st.text_input(label= 'Input a prediction sentence here:', value = "goodnight moon") | |
predictions = [] | |
predictions.append(prediction.split()) | |
#print(predictions) | |
references = [] | |
references.append(reference.split()) | |
#print(references) | |
if metric_name == "bleu": | |
score = metric.compute(predictions=predictions, references=[references]) | |
col1, col2, col3 = st.columns(3) | |
col1.metric("BLEU", score['bleu']) | |
col2.metric("Brevity penalty", score['brevity_penalty']) | |
col3.metric('Length Ratio', score['length_ratio']) | |
if metric_name == "rouge": | |
score = metric.compute(predictions=predictions, references=references) | |
#print(score) | |
col1, col2, col3 = st.columns(3) | |
col1.metric("Rouge 1 Precision", score['rouge1'].mid.precision) | |
col2.metric("Rouge 1 Recall", score['rouge1'].mid.recall) | |
col3.metric("Rouge 1 FMeasure", score['rouge1'].mid.fmeasure) | |
col4, col5, col6 = st.columns(3) | |
col4.metric("Rouge 2 Precision", score['rouge2'].mid.precision) | |
col5.metric("Rouge 2 Recall", score['rouge2'].mid.recall) | |
col6.metric("Rouge 2 FMeasure", score['rouge2'].mid.fmeasure) | |
# col1.metric("BLEU", score['bleu']) | |
# col2.metric("Brevity penalty", score['brevity_penalty']) | |
# col3.metric('Length Ratio', score['length_ratio']) | |
st.markdown('===================================================================================') | |
#components.html("""<hr style="height:10px;border:none;color:#333;background-color:#333;" /> """) | |
st.markdown(read_markdown_file(metric_name+"_metric_card.md")) | |
# Multiple metric mode | |
else: | |
metric1 = metric_names[0] | |
metric2 = metric_names[1] | |
st.markdown("# You chose " + metric1.upper() + " and " + metric2.upper()) | |
st.markdown("## You can test it out below:") | |
reference = st.text_input(label= 'Input a reference sentence here:', value = "hello world") | |
prediction = st.text_input(label= 'Input a prediction sentence here:', value = "goodnight moon") | |
predictions = [] | |
predictions.append(prediction.split()) | |
#print(predictions) | |
references = [] | |
references.append(reference.split()) | |
#print(references) | |
if "bleu" in metric_names: | |
bleu_ix = metric_names.index("bleu") | |
bleu_score = loaded_metrics[bleu_ix].compute(predictions=predictions, references=[references]) | |
col1, col2, col3 = st.columns(3) | |
col1.metric("BLEU", bleu_score['bleu']) | |
col2.metric("Brevity penalty", bleu_score['brevity_penalty']) | |
col3.metric('Length Ratio', bleu_score['length_ratio']) | |
if "rouge" in metric_names: | |
rouge_ix = metric_names.index("rouge") | |
rouge_score = loaded_metrics[rouge_ix].compute(predictions=predictions, references=references) | |
#print(score) | |
col1, col2, col3 = st.columns(3) | |
col1.metric("Rouge 1 Precision", rouge_score['rouge1'].mid.precision) | |
col2.metric("Rouge 1 Recall", rouge_score['rouge1'].mid.recall) | |
col3.metric("Rouge 1 FMeasure", rouge_score['rouge1'].mid.fmeasure) | |
col4, col5, col6 = st.columns(3) | |
col4.metric("Rouge 2 Precision", rouge_score['rouge2'].mid.precision) | |
col5.metric("Rouge 2 Recall", rouge_score['rouge2'].mid.recall) | |
col6.metric("Rouge 2 FMeasure", rouge_score['rouge2'].mid.fmeasure) | |