import streamlit as st
import datasets
import numpy as np

def show_examples(category_name, dataset_name, model_lists):
    st.divider()
    sample_folder = f"./examples/{category_name}/{dataset_name}"
    dataset = datasets.load_from_disk(sample_folder)

    for index in range(len(dataset)):
        with st.container():
            st.markdown(f'##### EXAMPLE {index+1}')
            col1, col2 = st.columns([0.3, 0.7], vertical_alignment="center")

            with col1:
                st.audio(f'{sample_folder}/sample_{index}.wav', format="audio/wav")
                        
            with col2:
                with st.container():
                    custom_css = """
                                <style>
                                .my-container-question {
                                background-color: #F5EEF8;
                                padding: 10px;
                                border-radius: 10px;
                                height: auto;
                                }
                                </style>
                                """
                    st.markdown(custom_css, unsafe_allow_html=True)
                    
                    if dataset_name in ['CN-College-Listen-MCQ-Test', 'DREAM-TTS-MCQ-Test']:
                        
                        choices = dataset[index]['other_attributes']['choices'] 
                        if isinstance(choices, str):
                            choices_text = choices
                        elif isinstance(choices, list):
                            choices_text = ' '.join(i for i in choices)
                        
                        question_text = f"""<div class="my-container-question">
                                            <p>QUESTION: {dataset[index]['instruction']['text']}</p>
                                            <p>CHOICES: {choices_text}</p>
                                            </div>
                                            """
                    else:
                        question_text = f"""<div class="my-container-question">
                                        <p>QUESTION: {dataset[index]['instruction']['text']}</p>
                                        </div>"""
                    
                    
                    st.markdown(question_text, unsafe_allow_html=True)
                
                with st.container():
                    custom_css = """
                                <style>
                                .my-container-answer {
                                background-color: #F9EBEA;
                                padding: 10px;
                                border-radius: 10px;
                                height: auto;
                                }
                                </style>
                                """
                    st.markdown(custom_css, unsafe_allow_html=True)
                    st.markdown(f"""<div class="my-container-answer">
                                <p>CORRECT ANSWER: {dataset[index]['answer']['text']}</p>
                                </div>""", unsafe_allow_html=True)

            
            # st.divider()
            with st.container():
                custom_css = """
                            <style>
                            .my-container-table {
                            background-color: #F2F3F4;
                            padding: 10px;
                            border-radius: 5px;
                            # height: 50px;
                            }
                            </style>
                            """
                st.markdown(custom_css, unsafe_allow_html=True)

                model_lists.sort()
                
                s = ''
                if dataset_name in ['CN-College-Listen-MCQ-Test', 'DREAM-TTS-MCQ-Test']:
                    for model in model_lists:
                        try:
                            s += f"""<tr>
                                <td>{model}</td>
                                <td><p>{dataset[index][model]['text']}</p> <p>{choices_text}</p></td>
                                <td>{dataset[index][model]['model_prediction']}</td>
                            </tr>"""
                        except:
                            print(f"{model} is not in {dataset_name}")
                            continue
                else:
                    for model in model_lists:
                        try:
                            s += f"""<tr>
                                <td>{model}</td>
                                <td>{dataset[index][model]['text']}</td>
                                <td>{dataset[index][model]['model_prediction']}</td>
                            </tr>"""
                        except:
                            print(f"{model} is not in {dataset_name}")
                            continue
                
                body_details = f"""<table style="width:100%">
                <thead>
                    <tr style="text-align: center;">
                        <th style="width:20%">MODEL</th>
                        <th style="width:40%">QUESTION</th>
                        <th style="width:40%">MODEL PREDICTION</th>
                    </tr>
                {s}
                </thead>
                </table>"""
                
                st.markdown(f"""<div class="my-container-table">
                                {body_details}
                                </div>""", unsafe_allow_html=True)
            
                st.text("")
        
        st.divider()