Spaces:
Build error
Build error
from bleu import Bleu | |
from rouge import Rouge | |
from datasets import load_metric | |
from pathlib import Path | |
import streamlit as st | |
import streamlit.components.v1 as components | |
#from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py | |
rouge = Rouge() | |
bleu = Bleu() | |
def read_markdown_file(markdown_file): | |
return Path(markdown_file).read_text() | |
def compute(data): | |
return metric.compute(predictions=data["predictions"], references=data["references"])["accuracy"] | |
with st.sidebar.expander("Metric", expanded=True): | |
metrics= ['rouge','bleu'] | |
metric_name = st.selectbox( | |
f"Choose metric to explore:", | |
metrics) | |
metric = load_metric(metric_name) | |
st.markdown("# You chose " + metric_name.upper()) | |
st.markdown("## You can test it out below:") | |
reference = st.text_input('Input a reference sentence here:') | |
prediction = st.text_input('Input a prediction sentence here:') | |
predictions = [] | |
predictions.append(prediction.split()) | |
#print(predictions) | |
references = [] | |
references.append(reference.split()) | |
#print(references) | |
if metric_name == "bleu": | |
score = metric.compute(predictions=predictions, references=[references]) | |
col1, col2, col3 = st.columns(3) | |
col1.metric("BLEU", score['bleu']) | |
col2.metric("Brevity penalty", score['brevity_penalty']) | |
col3.metric('Length Ratio', score['length_ratio']) | |
if metric_name == "rouge": | |
score = metric.compute(predictions=predictions, references=references) | |
#print(score) | |
col1, col2, col3 = st.columns(3) | |
col1.metric("Rouge 1 Precision", score['rouge1'].mid.precision) | |
col2.metric("Rouge 1 Recall", score['rouge1'].mid.recall) | |
col3.metric("Rouge 1 FMeasure", score['rouge1'].mid.fmeasure) | |
col4, col5, col6 = st.columns(3) | |
col4.metric("Rouge 2 Precision", score['rouge2'].mid.precision) | |
col5.metric("Rouge 2 Recall", score['rouge2'].mid.recall) | |
col6.metric("Rouge 2 FMeasure", score['rouge2'].mid.fmeasure) | |
# col1.metric("BLEU", score['bleu']) | |
# col2.metric("Brevity penalty", score['brevity_penalty']) | |
# col3.metric('Length Ratio', score['length_ratio']) | |
st.markdown('===================================================================================') | |
#components.html("""<hr style="height:10px;border:none;color:#333;background-color:#333;" /> """) | |
st.markdown(read_markdown_file(metric_name+"_metric_card.md")) | |