Spaces:
Sleeping
Sleeping
import json | |
import os | |
import logging | |
import sys | |
import torch | |
import gradio as gr | |
from huggingface_hub import Repository | |
from text_generation import Client | |
from app_modules.utils import convert_to_markdown | |
# from dialogues import DialogueTemplate | |
from share_btn import (community_icon_html, loading_icon_html, share_btn_css, | |
share_js) | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
API_TOKEN = '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' | |
API_URL = os.environ.get("API_URL", None) | |
API_URL = "https://api-inference.huggingface.co/models/timdettmers/guanaco-33b-merged" | |
client = Client( | |
API_URL, | |
headers={"Authorization": f"Bearer {API_TOKEN}"}, | |
) | |
repo = None | |
logging.basicConfig( | |
format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s", | |
datefmt="%Y-%m-%dT%H:%M:%SZ", | |
) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
examples = [ | |
"Describe the advantages and disadvantages of Incremental Sheet Forming.", | |
"Describe the applications of Incremental Sheet Forming.", | |
"Describe the process parameters included in Incremental Sheet Forming in dot points." | |
] | |
def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): | |
past = [] | |
for data in chatbot: | |
user_data, model_data = data | |
if not user_data.startswith(user_name): | |
user_data = user_name + user_data | |
if not model_data.startswith(sep + assistant_name): | |
model_data = sep + assistant_name + model_data | |
past.append(user_data + model_data.rstrip() + sep) | |
if not inputs.startswith(user_name): | |
inputs = user_name + inputs | |
total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() | |
return total_inputs | |
def has_no_history(chatbot, history): | |
return not chatbot and not history | |
header = "A chat between a curious human and an artificial intelligence assistant about Incremental Sheet Forming (ISF). " \ | |
"The assistant gives helpful, detailed, and polite answers to the user's questions." | |
prompt_template = "### Human: {query}\n### Assistant:{response}" | |
def generate( | |
user_message, | |
chatbot, | |
history, | |
temperature, | |
top_p, | |
top_k, | |
max_new_tokens, | |
repetition_penalty, | |
): | |
# Don't return meaningless message when the input is empty | |
if not user_message: | |
print("Empty input") | |
history.append(user_message) | |
past_messages = [] | |
for data in chatbot: | |
user_data, model_data = data | |
past_messages.extend( | |
[{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] | |
) | |
if len(past_messages) < 1: | |
prompt = header + prompt_template.format(query=user_message, response="") | |
else: | |
prompt = header | |
for i in range(0, len(past_messages), 2): | |
intermediate_prompt = prompt_template.format(query=past_messages[i]["content"], | |
response=past_messages[i + 1]["content"]) | |
print("intermediate: ", intermediate_prompt) | |
prompt = prompt + '\n' + intermediate_prompt | |
prompt = prompt + prompt_template.format(query=user_message, response="") | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
truncate=999, | |
seed=42, | |
) | |
stream = client.generate_stream( | |
prompt, | |
**generate_kwargs, | |
) | |
output = "" | |
for idx, response in enumerate(stream): | |
if response.token.text == '': | |
break | |
if response.token.special: | |
continue | |
output += response.token.text | |
if idx == 0: | |
history.append(" " + output) | |
else: | |
history[-1] = output | |
chat = [(convert_to_markdown(history[i].strip()), convert_to_markdown(history[i + 1].strip())) for i in range(0, len(history) - 1, 2)] | |
yield chat, history, user_message, "" | |
return chat, history, user_message, "" | |
def clear_chat(): | |
return [], [] | |
def save( | |
history, | |
temperature=0.7, | |
top_p=0.9, | |
top_k=50, | |
max_new_tokens=512, | |
repetition_penalty=1.2, | |
max_memory=1024, | |
): | |
history = [] if history is None else history | |
data_point = {'history': history, 'generation_parameter': { | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"max_new_tokens": max_new_tokens, | |
"repetition_penalty": repetition_penalty, | |
"max_memory": max_memory, | |
}} | |
print(data_point) | |
file_name = "history.jsonl" | |
with open(file_name, 'a') as f: | |
for line in [data_point]: | |
f.write(json.dumps(line, ensure_ascii=False) + '\n') | |
def process_example(args): | |
for [x, y] in generate(args): | |
pass | |
return [x, y] | |
title = """<h1 align="center">ISF Alpaca π¬</h1>""" | |
custom_css = """ | |
#banner-image { | |
display: block; | |
margin-left: auto; | |
margin-right: auto; | |
} | |
#chat-message { | |
font-size: 14px; | |
min-height: 300px; | |
} | |
""" | |
with gr.Blocks(analytics_enabled=False, | |
theme=gr.themes.Soft(), | |
css=".disclaimer {font-variant-caps: all-small-caps;}") as demo: | |
gr.HTML(title) | |
# status_display = gr.Markdown("Success", elem_id="status_display") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
π The fine-tuned model primarily emphasizes **Knowledge Augmentation** in the Manufacturing domain, | |
with **Incremental Sheet Forming (ISF)** serving as a use case. | |
""" | |
) | |
history = gr.components.State() | |
with gr.Row(scale=1).style(equal_height=True): | |
with gr.Column(scale=5): | |
with gr.Row(scale=1): | |
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height=476) | |
with gr.Row(scale=1): | |
with gr.Column(scale=12): | |
user_message = gr.Textbox( | |
show_label=False, placeholder="Enter text" | |
).style(container=False) | |
with gr.Column(min_width=70, scale=1): | |
submit_btn = gr.Button("Send") | |
with gr.Column(min_width=70, scale=1): | |
stop_btn = gr.Button("Stop") | |
with gr.Row(): | |
gr.Examples( | |
examples=examples, | |
inputs=[user_message], | |
cache_examples=False, | |
outputs=[chatbot, history], | |
) | |
with gr.Row(scale=1): | |
clear_history = gr.Button( | |
"π§Ή New Conversation", | |
) | |
reset_btn = gr.Button("π Reset Parameter") | |
save_btn = gr.Button("π₯ Save Chat") | |
with gr.Column(): | |
input_component_column = gr.Column(min_width=50, scale=1) | |
with input_component_column: | |
with gr.Tab(label="Parameter Setting"): | |
gr.Markdown("# Parameters") | |
temperature = gr.components.Slider(minimum=0, maximum=1, value=0.7, label="Temperature") | |
top_p = gr.components.Slider(minimum=0, maximum=1, value=0.9, label="Top p") | |
top_k = gr.components.Slider(minimum=0, maximum=100, step=1, value=20, label="Top k") | |
max_new_tokens = gr.components.Slider(minimum=1, maximum=2048, step=1, value=512, | |
label="Max New Tokens") | |
repetition_penalty = gr.components.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.2, | |
label="Repetition Penalty") | |
max_memory = gr.components.Slider(minimum=0, maximum=2048, step=1, value=2048, label="Max Memory") | |
history = gr.State([]) | |
last_user_message = gr.State("") | |
user_message.submit( | |
generate, | |
inputs=[ | |
user_message, | |
chatbot, | |
history, | |
temperature, | |
top_p, | |
top_k, | |
max_new_tokens, | |
repetition_penalty, | |
], | |
outputs=[chatbot, history, last_user_message, user_message], | |
) | |
submit_event = submit_btn.click( | |
generate, | |
inputs=[ | |
user_message, | |
chatbot, | |
history, | |
temperature, | |
top_p, | |
top_k, | |
max_new_tokens, | |
repetition_penalty, | |
], | |
outputs=[chatbot, history, last_user_message, user_message], | |
) | |
# submit_btn.click( | |
# lambda: ( | |
# submit_btn.update(visible=False), | |
# stop_btn.update(visible=True), | |
# ), | |
# inputs=None, | |
# outputs=[submit_btn, stop_btn], | |
# queue=False, | |
# ) | |
stop_btn.click( | |
lambda: ( | |
submit_btn.update(visible=True), | |
stop_btn.update(visible=True), | |
), | |
inputs=None, | |
outputs=[submit_btn, stop_btn], | |
cancels=[submit_event], | |
queue=False, | |
) | |
clear_history.click(clear_chat, outputs=[chatbot, history]) | |
save_btn.click( | |
save, | |
inputs=[user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens, repetition_penalty], | |
outputs=None, | |
) | |
input_components_except_states = [user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens, | |
repetition_penalty] | |
reset_btn.click( | |
None, | |
[], | |
(input_components_except_states + [input_component_column]), # type: ignore | |
_js=f"""() => {json.dumps([getattr(component, "cleared_value", None) for component in input_components_except_states] | |
+ ([gr.Column.update(visible=True)]) | |
+ ([]) | |
)} | |
""", | |
) | |
demo.queue(concurrency_count=16).launch(debug=True, share=True) | |
# with gr.Row(): | |
# with gr.Box(): | |
# output = gr.Markdown() | |
# chatbot = gr.Chatbot(elem_id="chat-message", label="Chat") | |
# | |
# with gr.Row(): | |
# with gr.Column(scale=3): | |
# user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input") | |
# with gr.Row(): | |
# send_button = gr.Button("Send", elem_id="send-btn", visible=True) | |
# | |
# clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True) | |
# | |
# with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"): | |
# temperature = gr.Slider( | |
# label="Temperature", | |
# value=0.7, | |
# minimum=0.0, | |
# maximum=1.0, | |
# step=0.1, | |
# interactive=True, | |
# info="Higher values produce more diverse outputs", | |
# ) | |
# top_p = gr.Slider( | |
# label="Top-p (nucleus sampling)", | |
# value=0.9, | |
# minimum=0.0, | |
# maximum=1, | |
# step=0.05, | |
# interactive=True, | |
# info="Higher values sample more low-probability tokens", | |
# ) | |
# max_new_tokens = gr.Slider( | |
# label="Max new tokens", | |
# value=1024, | |
# minimum=0, | |
# maximum=2048, | |
# step=4, | |
# interactive=True, | |
# info="The maximum numbers of new tokens", | |
# ) | |
# repetition_penalty = gr.Slider( | |
# label="Repetition Penalty", | |
# value=1.2, | |
# minimum=0.0, | |
# maximum=10, | |
# step=0.1, | |
# interactive=True, | |
# info="The parameter for repetition penalty. 1.0 means no penalty.", | |
# ) | |
# with gr.Row(): | |
# gr.Examples( | |
# examples=examples, | |
# inputs=[user_message], | |
# cache_examples=False, | |
# fn=process_example, | |
# outputs=[output], | |
# ) | |
# | |
# history = gr.State([]) | |
# last_user_message = gr.State("") | |
# | |
# user_message.submit( | |
# generate, | |
# inputs=[ | |
# user_message, | |
# chatbot, | |
# history, | |
# temperature, | |
# top_p, | |
# max_new_tokens, | |
# repetition_penalty, | |
# ], | |
# outputs=[chatbot, history, last_user_message, user_message], | |
# ) | |
# | |
# send_button.click( | |
# generate, | |
# inputs=[ | |
# user_message, | |
# chatbot, | |
# history, | |
# temperature, | |
# top_p, | |
# max_new_tokens, | |
# repetition_penalty, | |
# ], | |
# outputs=[chatbot, history, last_user_message, user_message], | |
# ) | |
# | |
# clear_chat_button.click(clear_chat, outputs=[chatbot, history]) | |
demo.queue(concurrency_count=16).launch(debug=True, share=True) | |