| import gradio as gr |
| from dotenv import load_dotenv |
|
|
| from models import get_all_models, get_random_models |
|
|
| load_dotenv() |
|
|
|
|
| share_js = """ |
| function () { |
| const captureElement = document.querySelector('#share-region-annoy'); |
| // console.log(captureElement); |
| html2canvas(captureElement) |
| .then(canvas => { |
| canvas.style.display = 'none' |
| document.body.appendChild(canvas) |
| return canvas |
| }) |
| .then(canvas => { |
| const image = canvas.toDataURL('image/png') |
| const a = document.createElement('a') |
| a.setAttribute('download', 'guardrails-arena.png') |
| a.setAttribute('href', image) |
| a.click() |
| canvas.remove() |
| }); |
| return []; |
| } |
| """ |
|
|
|
|
| def activate_chat_buttons(): |
| regenerate_btn = gr.Button( |
| value="🔄 Regenerate", interactive=True, elem_id="regenerate_btn" |
| ) |
| clear_btn = gr.ClearButton( |
| elem_id="clear_btn", |
| interactive=True, |
| ) |
| return regenerate_btn, clear_btn |
|
|
|
|
| def deactivate_chat_buttons(): |
| regenerate_btn = gr.Button( |
| value="🔄 Regenerate", interactive=False, elem_id="regenerate_btn" |
| ) |
| clear_btn = gr.ClearButton( |
| elem_id="clear_btn", |
| interactive=False, |
| ) |
| return regenerate_btn, clear_btn |
|
|
|
|
| def handle_message( |
| llms, user_input, temperature, top_p, max_output_tokens, states1, states2, states3, states4 |
| ): |
| history1 = states1.value if states1 else [] |
| history2 = states2.value if states2 else [] |
| history3 = states3.value if states3 else [] |
| history4 = states4.value if states4 else [] |
| states = [states1, states2,states3, states4] |
| history = [history1, history2,history3, history4] |
| for hist in history: |
| hist.append((user_input, None)) |
| for ( |
| updated_history1, |
| updated_history2, |
| updated_history3, |
| updated_history4, |
| updated_states1, |
| updated_states2, |
| updated_states3, |
| updated_states4, |
| ) in process_responses( |
| llms, temperature, top_p, max_output_tokens, history, states |
| ): |
| yield updated_history1, updated_history2,updated_history3, updated_history4, updated_states1, updated_states2,updated_states3, updated_states4 |
|
|
|
|
| def regenerate_message(llms, temperature, top_p, max_output_tokens, states1, states2, states3, states4): |
| history1 = states1.value if states1 else [] |
| history2 = states2.value if states2 else [] |
| history3 = states3.value if states3 else [] |
| history4 = states4.value if states4 else [] |
| user_input = ( |
| history1.pop()[0] if history1 else None |
| ) |
| if history2: |
| history2.pop() |
| if history3: |
| history3.pop() |
| if history4: |
| history4.pop() |
| states = [states1, states2,states3, states4] |
| history = [history1, history2,history3, history4] |
| for hist in history: |
| hist.append((user_input, None)) |
| for ( |
| updated_history1, |
| updated_history2, |
| updated_history3, |
| updated_history4, |
| updated_states1, |
| updated_states2, |
| updated_states3, |
| updated_states4, |
| ) in process_responses( |
| llms, temperature, top_p, max_output_tokens, history, states |
| ): |
| yield updated_history1, updated_history2,updated_history3, updated_history4, updated_states1, updated_states2,updated_states3, updated_states4 |
|
|
|
|
| def process_responses(llms, temperature, top_p, max_output_tokens, history, states): |
| generators = [ |
| llms[i]["model"](history[i], temperature, top_p, max_output_tokens) |
| for i in range(4) |
| ] |
| |
| responses = [[], [],[], []] |
| done = [False, False,False, False] |
|
|
| while not all(done): |
| for i in range(4): |
| |
| print(generators[i]) |
| print(done[i]) |
| if not done[i]: |
| try: |
| response = next(generators[i]) |
| if response: |
| responses[i].append(response) |
| history[i][-1] = (history[i][-1][0], "".join(responses[i])) |
| states[i] = gr.State(history[i]) |
| yield history[0], history[1],history[2], history[3], states[0], states[1], states[2], states[3] |
| except StopIteration: |
| done[i] = True |
| print(history[0], history[1],history[2], history[3], states[0], states[1], states[2], states[3]) |
| yield history[0], history[1],history[2], history[3], states[0], states[1], states[2], states[3] |
|
|
|
|
| with gr.Blocks( |
| title="Cherokee Language Function Test", |
| theme=gr.themes.Soft(secondary_hue=gr.themes.colors.sky,neutral_hue=gr.themes.colors.stone), |
| |
| ) as demo: |
| num_sides = 4 |
| states = [gr.State() for _ in range(num_sides)] |
| print(states) |
| chatbots = [None] * num_sides |
| models = gr.State(get_random_models) |
| all_models = get_all_models() |
| gr.Markdown( |
| "# Cherokee Language Preserve Model V0.4 \n\nChat with multiple models at the same time and compare their responses. " |
| ) |
| with gr.Group(elem_id="share-region-annoy"): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| with gr.Row(): |
| for i in range(num_sides): |
| label = models.value[i]["name"] |
| with gr.Column(scale=1, min_width=200): |
| chatbots[i] = gr.Chatbot( |
| label=label, |
| elem_id=f"chatbot", |
| height=300, |
| show_copy_button=True, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| with gr.Row(): |
| textbox = gr.Textbox( |
| show_label=False, |
| placeholder="Enter your query and press ENTER", |
| elem_id="input_box", |
| scale=4, |
| ) |
| send_btn = gr.Button(value="Send", variant="primary", scale=0) |
|
|
| with gr.Row() as button_row: |
| clear_btn = gr.ClearButton( |
| value="🎲 New Round", |
| elem_id="clear_btn", |
| interactive=False, |
| components=chatbots + states, |
| ) |
| regenerate_btn = gr.Button( |
| value="🔄 Regenerate", interactive=False, elem_id="regenerate_btn" |
| ) |
| share_btn = gr.Button(value="📷 Share Image") |
|
|
| with gr.Row(): |
| examples = gr.Examples( |
| [ |
| "Tell me a story", |
| "What is the capital of France?", |
| "Do you like me?", |
| ], |
| inputs=[textbox], |
| label="Example task: General skill", |
| ) |
| with gr.Row(): |
| examples = gr.Examples( |
| [ |
| "translate: ᎧᏃᎮᏍᎩ", |
| "Could you assist in rendering this Cherokee word into English?\nᎤᏲᎢ", |
| "translate the following Cherokee word into English. ᏧᎩᏨᏅᏓ", |
| ], |
| inputs=[textbox], |
| label="Example task: Translate words", |
| ) |
| with gr.Row(): |
| examples = gr.Examples( |
| [ |
| "translate: ᏚᏁᏤᎴᏃ ᎬᏩᏍᏓᏩᏗᏙᎯ, ᎾᏍᎩ ᏥᏳ ᎤᎦᏘᏗᏍᏗᏱ, ᎤᏂᏣᏘ ᎨᏒ ᎢᏳᏍᏗ, ᎾᏍᎩ ᎬᏩᏁᏄᎳᏍᏙᏗᏱ ᏂᎨᏒᎾ", |
| "translate following Cherokee sentences into English.\nᏥᏌᏃ ᎤᏓᏅᏎ ᏚᏘᏅᏎ ᎬᏩᏍᏓᏩᏗᏙᎯ ᎥᏓᎵ ᏭᏂᎶᏎᎢ; ᎤᏂᏣᏘᏃ ᎬᏩᏍᏓᏩᏛᏎᎢ, ᏅᏓᏳᏂᎶᏒᎯ ᎨᎵᎵ, ᎠᎴ ᏧᏗᏱ,", |
| "translate following sentences.\nᎯᎠᏃ ᏄᏪᏎᎴ ᎠᏍᎦᏯ ᎤᏬᏰᏂ ᎤᏩᎢᏎᎸᎯ; ᎠᏰᎵ ᎭᎴᎲᎦ.", |
| ], |
| inputs=[textbox], |
| label="Example task: Translate sentences", |
| ) |
|
|
| with gr.Accordion("Parameters", open=False) as parameter_row: |
| temperature = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.5, |
| step=0.01, |
| interactive=True, |
| label="Temperature", |
| ) |
| top_p = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.7, |
| step=0.01, |
| interactive=True, |
| label="Top P", |
| ) |
| max_output_tokens = gr.Slider( |
| minimum=16, |
| maximum=4096, |
| value=1024, |
| step=64, |
| interactive=True, |
| label="Max output tokens", |
| ) |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| print(states[0]), |
| print(states[1]), |
| print(states[2]), |
| print(states[3]), |
| textbox.submit( |
| handle_message, |
| inputs=[ |
| models, |
| textbox, |
| temperature, |
| top_p, |
| max_output_tokens, |
| states[0], |
| states[1], |
| states[2], |
| states[3], |
| ], |
| |
| outputs=[chatbots[0], chatbots[1],chatbots[2], chatbots[3], states[0], states[1], states[2], states[3]], |
| ).then( |
| activate_chat_buttons, |
| inputs=[], |
| outputs=[regenerate_btn, clear_btn], |
| ) |
|
|
| send_btn.click( |
| handle_message, |
| inputs=[ |
| models, |
| textbox, |
| temperature, |
| top_p, |
| max_output_tokens, |
| states[0], |
| states[1], |
| states[2], |
| states[3], |
| ], |
| outputs=[chatbots[0], chatbots[1],chatbots[2], chatbots[3], states[0], states[1], states[2], states[3]], |
| ).then( |
| activate_chat_buttons, |
| inputs=[], |
| outputs=[regenerate_btn, clear_btn], |
| ) |
|
|
| regenerate_btn.click( |
| regenerate_message, |
| inputs=[ |
| models, |
| temperature, |
| top_p, |
| max_output_tokens, |
| states[0], |
| states[1], |
| states[2], |
| states[3], |
| |
| ], |
| outputs=[chatbots[0], chatbots[1],chatbots[2], chatbots[3], states[0], states[1], states[2], states[3]], |
| ) |
|
|
| clear_btn.click( |
| deactivate_chat_buttons, |
| inputs=[], |
| outputs=[regenerate_btn, clear_btn], |
| ).then(lambda: get_random_models(), inputs=None, outputs=[models]) |
|
|
| share_btn.click(None, inputs=[], outputs=[], js=share_js) |
|
|
| if __name__ == "__main__": |
| demo.queue(default_concurrency_limit=10) |
| demo.launch(server_name="127.0.01", server_port=5009, share=True) |