Spaces:
Runtime error
Runtime error
import json | |
import os | |
import shutil | |
import requests | |
import warnings | |
import gradio as gr | |
from huggingface_hub import Repository | |
from text_generation import Client | |
from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
API_URL_G = "https://api-inference.huggingface.co/models/ArmelR/starcoder-gradio-v0" | |
API_URL_S = "https://api-inference.huggingface.co/models/HuggingFaceH4/starcoderbase-finetuned-oasst1" | |
with open("./HHH_prompt_short.txt", "r") as f: | |
HHH_PROMPT = f.read() + "\n\n" | |
with open("./TA_prompt_v0.txt", "r") as f: | |
TA_PROMPT = f.read() | |
NO_PROMPT = "" | |
FIM_PREFIX = "<fim_prefix>" | |
FIM_MIDDLE = "<fim_middle>" | |
FIM_SUFFIX = "<fim_suffix>" | |
FIM_INDICATOR = "<FILL_HERE>" | |
FORMATS = """ | |
# Chat mode | |
Chat mode prepends the custom [TA prompt](https://huggingface.co/spaces/bigcode/chat-playground/blob/main/TA_prompt_v0.txt) or the [HHH prompt](https://gist.github.com/jareddk/2509330f8ef3d787fc5aaac67aab5f11#file-hhh_prompt-txt) from Anthropic to the request which conditions the model to serve as an assistant. | |
⚠️ **Intended Use**: this app and its [supporting model](https://huggingface.co/bigcode) are provided for demonstration purposes; not to serve as replacement for human expertise. For more details on the model's limitations in terms of factuality and biases, see the [model card.](hf.co/bigcode) | |
""" | |
theme = gr.themes.Monochrome( | |
primary_hue="indigo", | |
secondary_hue="blue", | |
neutral_hue="slate", | |
radius_size=gr.themes.sizes.radius_sm, | |
font=[ | |
gr.themes.GoogleFont("Open Sans"), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
], | |
) | |
client_g = Client( | |
API_URL_G, headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
) | |
client_s = Client( | |
API_URL_S, headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
) | |
def generate( | |
prompt, | |
temperature=0.9, | |
max_new_tokens=256, | |
top_p=0.95, | |
repetition_penalty=1.0, | |
chat_mode="TA prompt", | |
version="StarCoder-gradio", | |
): | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
fim_mode = False | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
truncate=7500, | |
do_sample=True, | |
seed=42, | |
stop_sequences=["\nHuman", "\n-----", "Question:", "Answer:"], | |
) | |
if chat_mode == "HHH prompt": | |
base_prompt = HHH_PROMPT | |
elif chat_mode == "TA prompt": | |
base_prompt = TA_PROMPT | |
else : | |
base_prompt = NO_PROMPT | |
if version == "StarCoder-gradio" : | |
chat_prompt = prompt + "\n\nAnswer:" | |
prompt = base_prompt + chat_prompt | |
print("PROMPT : "+str(prompt)) | |
stream = client_g.generate_stream(prompt, **generate_kwargs) | |
elif version == "StarChat-alpha" : | |
chat_prompt = prompt + "\n\nAssistant:" | |
prompt = base_prompt + chat_prompt | |
stream = client_s.generate_stream(prompt, **generate_kwargs) | |
else : | |
ValueError("Unsupported version of the Coding assistant") | |
#print("Tokens = "+str([response.token.text for response in stream])) | |
output = "" | |
previous_token = "" | |
t = 0 | |
for response in stream: | |
print(f"IN_{t}") | |
if ( | |
(response.token.text in ["Human", "-----", "Question:"] and previous_token in ["\n", "-----"]) | |
or response.token.text in ["<|endoftext|>", "<|end|>", "Answer:"] | |
): | |
print("OUT") | |
return output.strip() | |
else: | |
output += response.token.text | |
print(f"Out_{t} : {output}") | |
t += 1 | |
previous_token = response.token.text | |
print("Output = "+str(output)) | |
return output.strip() | |
# chatbot mode | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot( | |
history, | |
temperature=0.9, | |
max_new_tokens=256, | |
top_p=0.95, | |
repetition_penalty=1.0, | |
chat_mode=None, | |
version="StarChat", | |
): | |
# concat history of prompts with answers expect for last empty answer only add prompt | |
if version == "StarCoder-gradio" : | |
prompt = "\n".join( | |
[f"Question: {prompt}\n\nAnswer: {answer}" for prompt, answer in history[:-1]] + [f"\nQuestion: {history[-1][0]}"] | |
) | |
else : | |
prompt = "\n".join( | |
[f"Human: {prompt}\n\nAssistant: {answer}" for prompt, answer in history[:-1]] + [f"\nHuman: {history[-1][0]}"] | |
) | |
bot_message = generate( | |
prompt, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
chat_mode=chat_mode, | |
version=version | |
) | |
history[-1][1] = bot_message | |
return history | |
examples = [ | |
"def print_hello_world():", | |
"def fibonacci(n):", | |
"class TransformerDecoder(nn.Module):", | |
"class ComplexNumbers:", | |
"How to install gradio" | |
] | |
def process_example(args): | |
for x in generate(args): | |
pass | |
return x | |
css = ".generating {visibility: hidden}" + share_btn_css | |
with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo: | |
with gr.Column(): | |
gr.Markdown( | |
"""\ | |
#Gradio Assistant powered by 💫 StarCoder | |
_Note:_ this is an internal chat playground - **please do not share**. The deployment can also change and thus the space not work as we continue development.\ | |
""" | |
) | |
with gr.Row(): | |
column_1, column_2 = gr.Column(scale=3), gr.Column(scale=1) | |
with column_2: | |
chat_mode = gr.Dropdown( | |
["NO prompt","TA prompt", "HHH prompt"], | |
value="NO prompt", | |
label="Chat mode", | |
info="Use Anthropic's HHH prompt or our custom tech prompt to turn the model into an assistant.", | |
) | |
temperature = gr.Slider( | |
label="Temperature", | |
value=0.2, | |
minimum=0.0, | |
maximum=2.0, | |
step=0.1, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
) | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
value=512, | |
minimum=0, | |
maximum=8192, | |
step=64, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
) | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.95, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
) | |
repetition_penalty = gr.Slider( | |
label="Repetition penalty", | |
value=1.2, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
info="Penalize repeated tokens", | |
) | |
version = gr.Dropdown( | |
["StarCoder-gradio", "StarChat-alpha"], | |
value="StarCoder-gradio", | |
label="Version", | |
info="", | |
) | |
with column_1: | |
# output = gr.Code(elem_id="q-output") | |
# add visibl=False and update if chat_mode True | |
chatbot = gr.Chatbot() | |
instruction = gr.Textbox( | |
placeholder="Enter your prompt here", | |
label="Prompt", | |
elem_id="q-input", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
clear = gr.Button("Clear Chat") | |
with gr.Column(): | |
submit = gr.Button("Generate", variant="primary") | |
with gr.Group(elem_id="share-btn-container"): | |
community_icon = gr.HTML(community_icon_html, visible=True) | |
loading_icon = gr.HTML(loading_icon_html, visible=True) | |
share_button = gr.Button( | |
"Share to community", elem_id="share-btn", visible=True | |
) | |
# examples of non-chat mode | |
#gr.Examples( | |
# examples=examples, | |
# inputs=[instruction], | |
# cache_examples=False, | |
# fn=process_example, | |
# outputs=[output], | |
# ) | |
gr.Markdown(FORMATS) | |
instruction.submit( | |
user, [instruction, chatbot], [instruction, chatbot], queue=False | |
).then( | |
bot, | |
[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, chat_mode, version], | |
chatbot, | |
) | |
submit.click( | |
user, [instruction, chatbot], [instruction, chatbot], queue=False | |
).then( | |
bot, | |
[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, chat_mode, version], | |
chatbot, | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
share_button.click(None, [], [], _js=share_js) | |
demo.queue(concurrency_count=16).launch(debug=True) | |