from strings import TITLE, ABSTRACT, BOTTOM_LINE from strings import DEFAULT_EXAMPLES from strings import SPECIAL_STRS from styles import PARENT_BLOCK_CSS import time import gradio as gr from model import load_model from gen import get_output_batch, StreamModel from utils import generate_prompt, post_processes_batch, post_process_stream, get_generation_config, common_post_process model, tokenizer = load_model( base="decapoda-research/llama-13b-hf", finetuned="chansung/alpaca-lora-13b" ) model = StreamModel(model, tokenizer) def chat_stream( context, instruction, state_chatbot, ): # print(instruction) # user input should be appropriately formatted (don't be confused by the function name) instruction_display = common_post_process(instruction) instruction_prompt = generate_prompt(instruction, state_chatbot, context) bot_response = model( instruction_prompt, max_tokens=256, temperature=1, top_p=0.9 ) instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display state_chatbot = state_chatbot + [(instruction_display, None)] prev_index = 0 agg_tokens = "" cutoff_idx = 0 for tokens in bot_response: tokens = tokens.strip() cur_token = tokens[prev_index:] if "#" in cur_token and agg_tokens == "": cutoff_idx = tokens.find("#") agg_tokens = tokens[cutoff_idx:] if agg_tokens != "": if len(agg_tokens) < len("### Instruction:") : agg_tokens = agg_tokens + cur_token elif len(agg_tokens) >= len("### Instruction:"): if tokens.find("### Instruction:") > -1: processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip()) state_chatbot[-1] = ( instruction_display, processed_response ) yield (state_chatbot, state_chatbot, context) break else: agg_tokens = "" cutoff_idx = 0 if agg_tokens == "": processed_response, to_exit = post_process_stream(tokens) state_chatbot[-1] = (instruction_display, processed_response) yield (state_chatbot, state_chatbot, context) if to_exit: break prev_index = len(tokens) yield ( state_chatbot, state_chatbot, gr.Textbox.update(value=tokens) if instruction_display == SPECIAL_STRS["summarize"] else context ) def chat_batch( contexts, instructions, state_chatbots, ): state_results = [] ctx_results = [] instruct_prompts = [ generate_prompt(instruct, histories, ctx) for ctx, instruct, histories in zip(contexts, instructions, state_chatbots) ] bot_responses = get_output_batch( model, tokenizer, instruct_prompts, generation_config ) bot_responses = post_processes_batch(bot_responses) for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots): new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)] ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx) state_results.append(new_state_chatbot) return (state_results, state_results, ctx_results) def reset_textbox(): return gr.Textbox.update(value='') with gr.Blocks(css=PARENT_BLOCK_CSS) as demo: state_chatbot = gr.State([]) with gr.Column(elem_id='col_container'): gr.Markdown(f"## {TITLE}\n\n\n{ABSTRACT}") with gr.Accordion("Context Setting", open=False): context_txtbox = gr.Textbox(placeholder="Surrounding information to AI", label="Enter Context") hidden_txtbox = gr.Textbox(placeholder="", label="Order", visible=False) chatbot = gr.Chatbot(elem_id='chatbot', label="Alpaca-LoRA") instruction_txtbox = gr.Textbox(placeholder="What do you want to say to AI?", label="Instruction") send_prompt_btn = gr.Button(value="Send Prompt") with gr.Accordion("Helper Buttons", open=False): gr.Markdown(f"`Continue` lets AI to complete the previous incomplete answers. `Summarize` lets AI to summarize the conversations so far.") continue_txtbox = gr.Textbox(value=SPECIAL_STRS["continue"], visible=False) summrize_txtbox = gr.Textbox(value=SPECIAL_STRS["summarize"], visible=False) continue_btn = gr.Button(value="Continue") summarize_btn = gr.Button(value="Summarize") gr.Markdown("#### Examples") for idx, examples in enumerate(DEFAULT_EXAMPLES): with gr.Accordion(examples["title"], open=False): gr.Examples( examples=examples["examples"], inputs=[ hidden_txtbox, instruction_txtbox ], label=None ) gr.Markdown(f"{BOTTOM_LINE}") send_prompt_btn.click( chat_stream, [context_txtbox, instruction_txtbox, state_chatbot], [state_chatbot, chatbot, context_txtbox], ) send_prompt_btn.click( reset_textbox, [], [instruction_txtbox], ) continue_btn.click( chat_stream, [context_txtbox, continue_txtbox, state_chatbot], [state_chatbot, chatbot, context_txtbox], ) continue_btn.click( reset_textbox, [], [instruction_txtbox], ) summarize_btn.click( chat_stream, [context_txtbox, summrize_txtbox, state_chatbot], [state_chatbot, chatbot, context_txtbox], ) summarize_btn.click( reset_textbox, [], [instruction_txtbox], ) demo.queue( concurrency_count=2, max_size=100, ).launch( max_threads=2, server_name="0.0.0.0", )