File size: 3,709 Bytes
92094c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32169e6
 
 
 
 
 
92094c8
32169e6
92094c8
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from instruct_pipeline import InstructionTextGenerationPipeline
import json

model = "databricks/dolly-v2-3b"
tokenizer = AutoTokenizer.from_pretrained(model, padding_side="left")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(
    model,
    pad_token_id=tokenizer.eos_token_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()

generate_text = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer)

canvas_html = (
    "<chat-feeback style='display:flex;justify-content:center;'></chat-feeback>"
)
load_js = """
async () => {
  const script = document.createElement('script');
  script.type = "module"
  script.src = "file=index.js"
  document.head.appendChild(script);
}
"""


def accept_response(rating_dummy, msg, chatbot, responseA, responseB, selection_state):
    ratings = json.loads(rating_dummy)

    state = [
        ratings["label"],
        ratings["value"],
        responseA if ratings["label"] == "A" else responseB,
    ]

    selection_state = selection_state + [state]
    chatbot = [[msg, state[2]]]

    return chatbot, selection_state, selection_state


def generate(msg, history):
    user_message = msg
    responses = []
    for i in range(2):
        res = generate_text(
            user_message,
            max_new_tokens=50,
            top_p=0.9 if i == 0 else 0.5,
            top_k=500,
            do_sample=True,
        )
        responses.append(res[0]["generated_text"])
    return responses


with gr.Blocks() as blocks:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    responseA = gr.Textbox(label="Response A")
    responseB = gr.Textbox(label="Response B")
    rating_dummy = gr.Textbox(elem_id="rating-dummy", interactive=False, visible=False)
    ratings_buttons = gr.HTML(canvas_html, visible=False)

    selection_state = gr.State(value=[])
    selections_df = gr.Dataframe(
        type="array",
        headers=["Option", "Value", "Text"],
        label="Selections",
    )

    def user(user_message, history):
        return gr.update(value="", interactive=False), history + [[user_message, ""]]

    def bot(history):
        user_message = history[-1][0]
        res = generate_text(
            user_message, max_new_tokens=50, top_p=0.9, top_k=500, do_sample=True
        )
        print(res)
        chat_history = history[-1][1] + res[0]["generated_text"]
        new_history = history[:-1] + [[user_message, chat_history]]
        yield new_history

    response = (
        msg.submit(
            lambda: (gr.update(interactive=False), gr.update(visible=False)),
            inputs=None,
            outputs=[msg, ratings_buttons],
        )
        .then(generate, inputs=[msg, chatbot], outputs=[responseA, responseB])
        .then(lambda: gr.update(visible=True), inputs=None, outputs=[ratings_buttons])
    )
    rating_dummy.change(
        accept_response,
        inputs=[rating_dummy, msg, chatbot, responseA, responseB, selection_state],
        outputs=[chatbot, selection_state, selections_df],
    ).then(
        lambda: (
            gr.update(value="", interactive=True),
            gr.update(visible=False),
            gr.update(value=""),
            gr.update(value=""),
        ),
        inputs=None,
        outputs=[msg, ratings_buttons, responseA, responseB],
    )
    response.then(
        lambda: gr.update(interactive=True), inputs=None, outputs=[msg], queue=False
    )
    blocks.load(None, None, None, _js=load_js)

blocks.queue()
blocks.launch()