Spaces:
Runtime error
Runtime error
File size: 4,310 Bytes
cb5c850 df0a622 cb5c850 df0a622 cb5c850 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import datetime
import hashlib
import json
import os
import random
from io import BytesIO
import gradio as gr
from datasets import load_dataset
from huggingface_hub import upload_file
dataset = load_dataset("edbeeching/rlhf_dialog_experiment_cosmo_dialog_generation",
use_auth_token=os.environ['EB_TOKEN'])["validation"]
def sample_to_markdown(sample, index):
conversation = sample["conversations"][index]
trucation_length = sample["trucation_length"]
output = ""
aligns = ["left", "right"]
strongs1 = ["", "<strong>"]
strongs2 = ["", "</strong>"]
for i,conv in enumerate(conversation):
if i == trucation_length:
output += '<p style="text-align:center"> --- START OF DIALOG GENERATION --- </p><br>'
align = aligns[i%2]
strong1 = strongs1[i%2]
strong2 = strongs2[i%2]
output += f'<div style="text-align: {align}"> {strong1}{conv}{strong2} \n </div> <br>'
return output
sample = None
def get_sample():
# I set the seed here as the randomness was a bit off otherwise
print(abs(hash(datetime.datetime.now().strftime("%Y%m%d_%H%M%s"))) % (10 ** 8))
random.seed(abs(hash(datetime.datetime.now().strftime("%Y%m%d_%H%M%s"))) % (10 ** 8))
dataset_size = len(dataset)
sample_index = random.randint(0, dataset_size-1)
sample = dataset[sample_index]
return sample
def check_and_submit_preferences(sample, preferred_text, text_quality):
if preferred_text is None:
print("not submitted due to unselected preferred text")
return
if text_quality is None:
print("not submitted due to unselected text_quality")
return
data = {
"sample": sample,
"preferred_text": preferred_text,
"text_quality": text_quality,
"date_time": datetime.datetime.now().strftime("%m/%d/%Y %H:%M:%S")
# add other info like user etc?
}
task_hash = hashlib.md5(sample["situation"].encode())
time_now = datetime.datetime.now().strftime("%Y%m%d_%H%M%s")
task_directory = f"{time_now}_{task_hash.hexdigest()}"
upload_file(
path_or_fileobj=BytesIO(bytes(json.dumps(data), 'utf-8')),
path_in_repo=task_directory,
repo_id='edbeeching/rlhf_dialog_experiment_dataset',
repo_type='dataset',
token=os.environ['EB_TOKEN']
)
with gr.Blocks() as demo:
gr.Markdown(
"""
This Space is an experiment to model human preferences on dialog generated with the [Cosmo-XL](https://huggingface.co/allenai/cosmo-xl) model, prompted with parts of conversations from the [SODA](https://huggingface.co/datasets/allenai/soda) dataset.
The following conversation was created with the following prompt:
"""
)
sample = get_sample()
with gr.Column() as details_col:
summary = gr.Markdown(f"## {sample['situation']}", label='Description')
with gr.Row():
with gr.Column():
with gr.Box():
dialog1 = gr.Markdown(sample_to_markdown(sample, 0), label='Dialog 1')
with gr.Column():
with gr.Box():
dialog2 = gr.Markdown(sample_to_markdown(sample, 1), label='Dialog 2')
with gr.Column():
dialog_choice = gr.Radio(["Left dialog", "Right dialog"], label="Preferred text", interactive=True)
quality_of_dialog = gr.Radio(["Terrible", "Poor", "Ok", "Good", "Excellent"], label="Quality of preferred text", interactive=True)
next_button = gr.Button("Submit")
def on_next(preferred_text, text_quality):
# check and submit the current response
global sample
check_and_submit_preferences(sample, preferred_text, text_quality)
sample = get_sample()
return (
gr.Markdown.update(f"## {sample['situation']}"),
gr.Markdown.update(sample_to_markdown(sample, 0)),
gr.Markdown.update(sample_to_markdown(sample, 1)),
gr.Radio.update(value=None),
gr.Radio.update(value=None)
)
next_button.click(on_next, inputs=[dialog_choice, quality_of_dialog], outputs=[summary, dialog1, dialog2, dialog_choice, quality_of_dialog])
if __name__ == "__main__":
demo.launch() |