Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from instruct_pipeline import InstructionTextGenerationPipeline | |
| import json | |
| model = "databricks/dolly-v2-3b" | |
| tokenizer = AutoTokenizer.from_pretrained(model, padding_side="left") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model, | |
| pad_token_id=tokenizer.eos_token_id, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| model = model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| model.eval() | |
| generate_text = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer) | |
| canvas_html = ( | |
| "<chat-feeback style='display:flex;justify-content:center;'></chat-feeback>" | |
| ) | |
| load_js = """ | |
| async () => { | |
| const script = document.createElement('script'); | |
| script.type = "module" | |
| script.src = "file=index.js" | |
| document.head.appendChild(script); | |
| } | |
| """ | |
| def accept_response(rating_dummy, msg, chatbot, responseA, responseB, selection_state): | |
| ratings = json.loads(rating_dummy) | |
| state = [ | |
| ratings["label"], | |
| ratings["value"], | |
| responseA if ratings["label"] == "A" else responseB, | |
| ] | |
| selection_state = selection_state + [state] | |
| chatbot = [[msg, state[2]]] | |
| return chatbot, selection_state, selection_state | |
| def generate(msg, history): | |
| user_message = msg | |
| responses = [] | |
| for i in range(2): | |
| res = generate_text( | |
| user_message, | |
| max_new_tokens=50, | |
| top_p=0.9 if i == 0 else 0.5, | |
| top_k=500, | |
| do_sample=True, | |
| ) | |
| responses.append(res[0]["generated_text"]) | |
| return responses | |
| with gr.Blocks() as blocks: | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox() | |
| responseA = gr.Textbox(label="Response A") | |
| responseB = gr.Textbox(label="Response B") | |
| rating_dummy = gr.Textbox(elem_id="rating-dummy", interactive=False, visible=False) | |
| ratings_buttons = gr.HTML(canvas_html, visible=False) | |
| selection_state = gr.State(value=[]) | |
| selections_df = gr.Dataframe( | |
| type="array", | |
| headers=["Option", "Value", "Text"], | |
| label="Selections", | |
| ) | |
| def user(user_message, history): | |
| return gr.update(value="", interactive=False), history + [[user_message, ""]] | |
| def bot(history): | |
| user_message = history[-1][0] | |
| res = generate_text( | |
| user_message, max_new_tokens=50, top_p=0.9, top_k=500, do_sample=True | |
| ) | |
| print(res) | |
| chat_history = history[-1][1] + res[0]["generated_text"] | |
| new_history = history[:-1] + [[user_message, chat_history]] | |
| yield new_history | |
| response = ( | |
| msg.submit( | |
| lambda: (gr.update(interactive=False), gr.update(visible=False)), | |
| inputs=None, | |
| outputs=[msg, ratings_buttons], | |
| ) | |
| .then(generate, inputs=[msg, chatbot], outputs=[responseA, responseB]) | |
| .then(lambda: gr.update(visible=True), inputs=None, outputs=[ratings_buttons]) | |
| ) | |
| rating_dummy.change( | |
| accept_response, | |
| inputs=[rating_dummy, msg, chatbot, responseA, responseB, selection_state], | |
| outputs=[chatbot, selection_state, selections_df], | |
| ).then( | |
| lambda: ( | |
| gr.update(value="", interactive=True), | |
| gr.update(visible=False), | |
| gr.update(value=""), | |
| gr.update(value=""), | |
| ), | |
| inputs=None, | |
| outputs=[msg, ratings_buttons, responseA, responseB], | |
| ) | |
| response.then( | |
| lambda: gr.update(interactive=True), inputs=None, outputs=[msg], queue=False | |
| ) | |
| blocks.load(None, None, None, _js=load_js) | |
| blocks.queue() | |
| blocks.launch() | |