import json import os import logging import sys import torch import gradio as gr from huggingface_hub import Repository from text_generation import Client from app_modules.utils import convert_to_markdown # from dialogues import DialogueTemplate from share_btn import (community_icon_html, loading_icon_html, share_btn_css, share_js) HF_TOKEN = os.environ.get("HF_TOKEN", None) API_TOKEN = os.environ.get("API_TOKEN", None) # API_TOKEN = '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' API_URL = os.environ.get("API_URL", None) API_URL = "https://api-inference.huggingface.co/models/timdettmers/guanaco-33b-merged" client = Client( API_URL, headers={"Authorization": f"Bearer {API_TOKEN}"}, ) repo = None logging.basicConfig( format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s", datefmt="%Y-%m-%dT%H:%M:%SZ", ) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) examples = [ "Describe the advantages and disadvantages of Incremental Sheet Forming.", "Describe the applications of Incremental Sheet Forming.", "Describe the process parameters included in Incremental Sheet Forming in dot points." ] def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): past = [] for data in chatbot: user_data, model_data = data if not user_data.startswith(user_name): user_data = user_name + user_data if not model_data.startswith(sep + assistant_name): model_data = sep + assistant_name + model_data past.append(user_data + model_data.rstrip() + sep) if not inputs.startswith(user_name): inputs = user_name + inputs total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() return total_inputs def has_no_history(chatbot, history): return not chatbot and not history header = "A chat between a curious human and an artificial intelligence assistant about Incremental Sheet Forming (ISF). " \ "The assistant gives helpful, detailed, and polite answers to the user's questions." prompt_template = "### Human: {query}\n### Assistant:{response}" def generate( user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens, repetition_penalty, ): # Don't return meaningless message when the input is empty if not user_message: print("Empty input") history.append(user_message) past_messages = [] for data in chatbot: user_data, model_data = data past_messages.extend( [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] ) if len(past_messages) < 1: prompt = header + prompt_template.format(query=user_message, response="") else: prompt = header for i in range(0, len(past_messages), 2): intermediate_prompt = prompt_template.format(query=past_messages[i]["content"], response=past_messages[i + 1]["content"]) print("intermediate: ", intermediate_prompt) prompt = prompt + '\n' + intermediate_prompt prompt = prompt + prompt_template.format(query=user_message, response="") temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, do_sample=True, truncate=999, seed=42, ) stream = client.generate_stream( prompt, **generate_kwargs, ) output = "" for idx, response in enumerate(stream): if response.token.text == '': break if response.token.special: continue output += response.token.text if idx == 0: history.append(" " + output) else: history[-1] = output chat = [(convert_to_markdown(history[i].strip()), convert_to_markdown(history[i + 1].strip())) for i in range(0, len(history) - 1, 2)] yield chat, history, user_message, "" return chat, history, user_message, "" def clear_chat(): return [], [] def save( history, temperature=0.7, top_p=0.9, top_k=50, max_new_tokens=512, repetition_penalty=1.2, max_memory=1024, ): history = [] if history is None else history data_point = {'history': history, 'generation_parameter': { "temperature": temperature, "top_p": top_p, "top_k": top_k, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "max_memory": max_memory, }} print(data_point) file_name = "history.jsonl" with open(file_name, 'a') as f: for line in [data_point]: f.write(json.dumps(line, ensure_ascii=False) + '\n') def process_example(args): for [x, y] in generate(args): pass return [x, y] title = """

ISF Alpaca ๐Ÿ’ฌ

""" custom_css = """ #banner-image { display: block; margin-left: auto; margin-right: auto; } #chat-message { font-size: 14px; min-height: 300px; } """ with gr.Blocks(analytics_enabled=False, theme=gr.themes.Soft(), css=".disclaimer {font-variant-caps: all-small-caps;}") as demo: gr.HTML(title) # status_display = gr.Markdown("Success", elem_id="status_display") with gr.Row(): with gr.Column(): gr.Markdown( """ ๐Ÿญ The fine-tuned model primarily emphasizes **Knowledge Augmentation** in the Manufacturing domain, with **Incremental Sheet Forming (ISF)** serving as a use case. """ ) history = gr.components.State() with gr.Row(scale=1).style(equal_height=True): with gr.Column(scale=5): with gr.Row(scale=1): chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height=476) with gr.Row(scale=1): with gr.Column(scale=12): user_message = gr.Textbox( show_label=False, placeholder="Enter text" ).style(container=False) with gr.Column(min_width=70, scale=1): submit_btn = gr.Button("Send") with gr.Column(min_width=70, scale=1): stop_btn = gr.Button("Stop") with gr.Row(): gr.Examples( examples=examples, inputs=[user_message], cache_examples=False, outputs=[chatbot, history], ) with gr.Row(scale=1): clear_history = gr.Button( "๐Ÿงน New Conversation", ) reset_btn = gr.Button("๐Ÿ”„ Reset Parameter") save_btn = gr.Button("๐Ÿ“ฅ Save Chat") with gr.Column(): input_component_column = gr.Column(min_width=50, scale=1) with input_component_column: with gr.Tab(label="Parameter Setting"): gr.Markdown("# Parameters") temperature = gr.components.Slider(minimum=0, maximum=1, value=0.7, label="Temperature") top_p = gr.components.Slider(minimum=0, maximum=1, value=0.9, label="Top p") top_k = gr.components.Slider(minimum=0, maximum=100, step=1, value=30, label="Top k") max_new_tokens = gr.components.Slider(minimum=1, maximum=2048, step=1, value=512, label="Max New Tokens") repetition_penalty = gr.components.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.2, label="Repetition Penalty") max_memory = gr.components.Slider(minimum=0, maximum=2048, step=1, value=2048, label="Max Memory") history = gr.State([]) last_user_message = gr.State("") user_message.submit( generate, inputs=[ user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens, repetition_penalty, ], outputs=[chatbot, history, last_user_message, user_message], ) submit_event = submit_btn.click( generate, inputs=[ user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens, repetition_penalty, ], outputs=[chatbot, history, last_user_message, user_message], ) # submit_btn.click( # lambda: ( # submit_btn.update(visible=False), # stop_btn.update(visible=True), # ), # inputs=None, # outputs=[submit_btn, stop_btn], # queue=False, # ) stop_btn.click( lambda: ( submit_btn.update(visible=True), stop_btn.update(visible=True), ), inputs=None, outputs=[submit_btn, stop_btn], cancels=[submit_event], queue=False, ) clear_history.click(clear_chat, outputs=[chatbot, history]) save_btn.click( save, inputs=[user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens, repetition_penalty], outputs=None, ) input_components_except_states = [user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens, repetition_penalty] reset_btn.click( None, [], (input_components_except_states + [input_component_column]), # type: ignore _js=f"""() => {json.dumps([getattr(component, "cleared_value", None) for component in input_components_except_states] + ([gr.Column.update(visible=True)]) + ([]) )} """, ) demo.queue(concurrency_count=16).launch(debug=True)