Spaces:
Sleeping
Sleeping
File size: 3,709 Bytes
92094c8 32169e6 92094c8 32169e6 92094c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from instruct_pipeline import InstructionTextGenerationPipeline
import json
model = "databricks/dolly-v2-3b"
tokenizer = AutoTokenizer.from_pretrained(model, padding_side="left")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(
model,
pad_token_id=tokenizer.eos_token_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
generate_text = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer)
canvas_html = (
"<chat-feeback style='display:flex;justify-content:center;'></chat-feeback>"
)
load_js = """
async () => {
const script = document.createElement('script');
script.type = "module"
script.src = "file=index.js"
document.head.appendChild(script);
}
"""
def accept_response(rating_dummy, msg, chatbot, responseA, responseB, selection_state):
ratings = json.loads(rating_dummy)
state = [
ratings["label"],
ratings["value"],
responseA if ratings["label"] == "A" else responseB,
]
selection_state = selection_state + [state]
chatbot = [[msg, state[2]]]
return chatbot, selection_state, selection_state
def generate(msg, history):
user_message = msg
responses = []
for i in range(2):
res = generate_text(
user_message,
max_new_tokens=50,
top_p=0.9 if i == 0 else 0.5,
top_k=500,
do_sample=True,
)
responses.append(res[0]["generated_text"])
return responses
with gr.Blocks() as blocks:
chatbot = gr.Chatbot()
msg = gr.Textbox()
responseA = gr.Textbox(label="Response A")
responseB = gr.Textbox(label="Response B")
rating_dummy = gr.Textbox(elem_id="rating-dummy", interactive=False, visible=False)
ratings_buttons = gr.HTML(canvas_html, visible=False)
selection_state = gr.State(value=[])
selections_df = gr.Dataframe(
type="array",
headers=["Option", "Value", "Text"],
label="Selections",
)
def user(user_message, history):
return gr.update(value="", interactive=False), history + [[user_message, ""]]
def bot(history):
user_message = history[-1][0]
res = generate_text(
user_message, max_new_tokens=50, top_p=0.9, top_k=500, do_sample=True
)
print(res)
chat_history = history[-1][1] + res[0]["generated_text"]
new_history = history[:-1] + [[user_message, chat_history]]
yield new_history
response = (
msg.submit(
lambda: (gr.update(interactive=False), gr.update(visible=False)),
inputs=None,
outputs=[msg, ratings_buttons],
)
.then(generate, inputs=[msg, chatbot], outputs=[responseA, responseB])
.then(lambda: gr.update(visible=True), inputs=None, outputs=[ratings_buttons])
)
rating_dummy.change(
accept_response,
inputs=[rating_dummy, msg, chatbot, responseA, responseB, selection_state],
outputs=[chatbot, selection_state, selections_df],
).then(
lambda: (
gr.update(value="", interactive=True),
gr.update(visible=False),
gr.update(value=""),
gr.update(value=""),
),
inputs=None,
outputs=[msg, ratings_buttons, responseA, responseB],
)
response.then(
lambda: gr.update(interactive=True), inputs=None, outputs=[msg], queue=False
)
blocks.load(None, None, None, _js=load_js)
blocks.queue()
blocks.launch()
|