import streamlit as st
import datasets
import numpy as np
import html
def show_examples(category_name, dataset_name, model_lists, display_model_names):
st.divider()
sample_folder = f"./examples/{category_name}/{dataset_name}"
dataset = datasets.load_from_disk(sample_folder)
for index in range(len(dataset)):
with st.container():
st.markdown(f'##### Example-{index+1}')
col1, col2 = st.columns([0.3, 0.7], vertical_alignment="center")
# with col1:
st.audio(f'{sample_folder}/sample_{index}.wav', format="audio/wav")
if dataset_name in ['CN-College-Listen-MCQ-Test', 'DREAM-TTS-MCQ-Test']:
choices = dataset[index]['other_attributes']['choices']
if isinstance(choices, str):
choices_text = choices
elif isinstance(choices, list):
choices_text = ' '.join(i for i in choices)
question_text = f"""{dataset[index]['instruction']['text']} {choices_text}"""
else:
question_text = f"""{dataset[index]['instruction']['text']}"""
question_text = html.escape(question_text)
# st.divider()
with st.container():
custom_css = """
"""
st.markdown(custom_css, unsafe_allow_html=True)
model_lists.sort()
s = f"""
REFERENCE |
{html.escape(question_text.replace('(A)', ' (A)').replace('(B)', ' (B)').replace('(C)', ' (C)'))}
|
{html.escape(dataset[index]['answer']['text'])}
|
"""
if dataset_name in ['CN-College-Listen-MCQ-Test', 'DREAM-TTS-MCQ-Test']:
for model in model_lists:
try:
model_prediction = dataset[index][model]['model_prediction']
model_prediction = model_prediction.replace('<','').replace('>','').replace('\n','(newline)').replace('*','')
s += f"""
{display_model_names[model]} |
{dataset[index][model]['text'].replace('Choices:', ' Choices:').replace('(A)', ' (A)').replace('(B)', ' (B)').replace('(C)', ' (C)')
}
|
{html.escape(model_prediction)} |
"""
except:
print(f"{model} is not in {dataset_name}")
continue
else:
for model in model_lists:
print(dataset[index][model]['model_prediction'])
try:
model_prediction = dataset[index][model]['model_prediction']
model_prediction = model_prediction.replace('<','').replace('>','').replace('\n','(newline)').replace('*','')
s += f"""
{display_model_names[model]} |
{html.escape(dataset[index][model]['text'])} |
{html.escape(model_prediction)} |
"""
except:
print(f"{model} is not in {dataset_name}")
continue
body_details = f"""
MODEL |
QUESTION |
MODEL PREDICTION |
{s}
"""
st.markdown(f"""
{body_details}
""", unsafe_allow_html=True)
st.text("")
st.divider()