|
import gradio as gr |
|
from datasets import load_dataset |
|
from typing import List,Dict |
|
from gradio import ChatMessage |
|
|
|
DATASETS_LIST = [("Self-GRIT/tuluv2-expanded-150k-part_v0-chat-format-syn-knowledge","part_v0")] |
|
DATASETS = {f"{name}: {split}":load_dataset(name,split = split) for name,split in DATASETS_LIST} |
|
LENS = {name:len(dataset) for name,dataset in DATASETS.items()} |
|
KEY_MESSAGES_ORIGINAL = "messages_original" |
|
KEY_MESSAGES_AUGMENTED = "messages_augmented" |
|
|
|
def return_conversation_chat_message(history: List[Dict]): |
|
conversation = [] |
|
for message in history: |
|
conversation.append(ChatMessage(role = message["role"],content = message["content"])) |
|
return conversation |
|
|
|
def update_chatbot(value,dataset_name): |
|
example = DATASETS[dataset_name][int(value)] |
|
domain,_id = example["dataset"],example["id"] |
|
oring_chat_history = return_conversation_chat_message(example[KEY_MESSAGES_ORIGINAL]) |
|
augmented_chat_history = return_conversation_chat_message(example[KEY_MESSAGES_AUGMENTED]) |
|
return oring_chat_history,augmented_chat_history,domain,_id |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Group(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
dropdown = gr.Dropdown(DATASETS.keys(), label="Select the dataset") |
|
with gr.Column(): |
|
slider = gr.Slider(minimum=0, maximum=48127, step=1, label="Select the example") |
|
with gr.Group(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
domain_text = gr.Textbox("Domain") |
|
with gr.Column(): |
|
id_text = gr.Textbox("ID") |
|
with gr.Row(): |
|
gr.Markdown("Original Chat") |
|
with gr.Row(): |
|
chatbot_original = gr.Chatbot(type = "messages") |
|
with gr.Group(): |
|
with gr.Row(): |
|
gr.Markdown("Augmented Chat") |
|
with gr.Row(): |
|
chatbot_augmented = gr.Chatbot(type = "messages") |
|
slider.change(fn=update_chatbot, inputs=[slider,dropdown], outputs=[chatbot_original,chatbot_augmented,domain_text,id_text]) |
|
demo.launch() |