import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM from instruct_pipeline import InstructionTextGenerationPipeline import json model = "databricks/dolly-v2-3b" tokenizer = AutoTokenizer.from_pretrained(model, padding_side="left") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModelForCausalLM.from_pretrained( model, pad_token_id=tokenizer.eos_token_id, device_map="auto", torch_dtype=torch.bfloat16, ) model = model.to("cuda" if torch.cuda.is_available() else "cpu") model.eval() generate_text = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer) canvas_html = ( "" ) load_js = """ async () => { const script = document.createElement('script'); script.type = "module" script.src = "file=index.js" document.head.appendChild(script); } """ def accept_response(rating_dummy, msg, chatbot, responseA, responseB, selection_state): ratings = json.loads(rating_dummy) state = [ ratings["label"], ratings["value"], responseA if ratings["label"] == "A" else responseB, ] selection_state = selection_state + [state] chatbot = [[msg, state[2]]] return chatbot, selection_state, selection_state def generate(msg, history): user_message = msg responses = [] for i in range(2): res = generate_text( user_message, max_new_tokens=50, top_p=0.9 if i == 0 else 0.5, top_k=500, do_sample=True, ) responses.append(res[0]["generated_text"]) return responses with gr.Blocks() as blocks: chatbot = gr.Chatbot() msg = gr.Textbox() responseA = gr.Textbox(label="Response A") responseB = gr.Textbox(label="Response B") rating_dummy = gr.Textbox(elem_id="rating-dummy", interactive=False, visible=False) ratings_buttons = gr.HTML(canvas_html, visible=False) selection_state = gr.State(value=[]) selections_df = gr.Dataframe( type="array", headers=["Option", "Value", "Text"], label="Selections", ) def user(user_message, history): return gr.update(value="", interactive=False), history + [[user_message, ""]] def bot(history): user_message = history[-1][0] res = generate_text( user_message, max_new_tokens=50, top_p=0.9, top_k=500, do_sample=True ) print(res) chat_history = history[-1][1] + res[0]["generated_text"] new_history = history[:-1] + [[user_message, chat_history]] yield new_history response = ( msg.submit( lambda: (gr.update(interactive=False), gr.update(visible=False)), inputs=None, outputs=[msg, ratings_buttons], ) .then(generate, inputs=[msg, chatbot], outputs=[responseA, responseB]) .then(lambda: gr.update(visible=True), inputs=None, outputs=[ratings_buttons]) ) rating_dummy.change( accept_response, inputs=[rating_dummy, msg, chatbot, responseA, responseB, selection_state], outputs=[chatbot, selection_state, selections_df], ).then( lambda: ( gr.update(value="", interactive=True), gr.update(visible=False), gr.update(value=""), gr.update(value=""), ), inputs=None, outputs=[msg, ratings_buttons, responseA, responseB], ) response.then( lambda: gr.update(interactive=True), inputs=None, outputs=[msg], queue=False ) blocks.load(None, None, None, _js=load_js) blocks.queue() blocks.launch()