Spaces:
Paused
Paused
import os | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
token = os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
model_id = 'Deci/DeciLM-6b-instruct' | |
SYSTEM_PROMPT_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request. | |
### Instruction: | |
{instruction} | |
### Response: | |
""" | |
DESCRIPTION = """ | |
# <p style="text-align: center; color: #292b47;"> 🤖 <span style='color: #3264ff;'>DeciLM-6B-Instruct:</span> A Fast Instruction-Tuned Model💨 </p> | |
<span style='color: #292b47;'>Welcome to <a href="https://huggingface.co/Deci/DeciLM-6b-instruct" style="color: #3264ff;">DeciLM-6B-Instruct</a>! DeciLM-6B-Instruct is a 6B parameter instruction-tuned language model and released under the Llama license. It's an instruction-tuned model, not a chat-tuned model; you should prompt the model with an instruction that describes a task, and the model will respond appropriately to complete the task.</span> | |
<p><span style='color: #292b47;'>Learn more about the base model <a href="https://huggingface.co/Deci/DeciLM-6b" style="color: #3264ff;">DeciLM-6B.</a></span></p> | |
""" | |
# LICENSE = """ | |
# <p/> | |
# --- | |
# As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, | |
# this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md). | |
# """ | |
if not torch.cuda.is_available(): | |
DESCRIPTION += 'You need a GPU for this example. Try using colab: https://bit.ly/decilm-instruct-nb' | |
if torch.cuda.is_available(): | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
device_map='auto', | |
trust_remote_code=True, | |
use_auth_token=token | |
) | |
else: | |
model = None | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token) | |
tokenizer.pad_token = tokenizer.eos_token | |
# Function to construct the prompt using the new system prompt template | |
def get_prompt_with_template(message: str) -> str: | |
return SYSTEM_PROMPT_TEMPLATE.format(instruction=message) | |
# Function to generate the model's response | |
def generate_model_response(message: str) -> str: | |
prompt = get_prompt_with_template(message) | |
inputs = tokenizer(prompt, return_tensors='pt') | |
if torch.cuda.is_available(): | |
inputs = inputs.to('cuda') | |
# Include **generate_kwargs to include the user-defined options | |
output = model.generate(**inputs, | |
max_new_tokens=3000, | |
num_beams=5, | |
no_repeat_ngram_size=4, | |
early_stopping=True, | |
do_sample=True | |
) | |
return tokenizer.decode(output[0], skip_special_tokens=True) | |
# Function to extract the content after "### Response:" | |
def extract_response_content(full_response: str, ) -> str: | |
response_start_index = full_response.find("### Response:") | |
if response_start_index != -1: | |
return full_response[response_start_index + len("### Response:"):].strip() | |
else: | |
return full_response | |
# The main function that uses the dynamic generate_kwargs | |
def get_response_with_template(message: str) -> str: | |
full_response = generate_model_response(message) | |
return extract_response_content(full_response) | |
with gr.Blocks(css="/content/style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton(value='Duplicate Space for private use', | |
elem_id='duplicate-button') | |
with gr.Group(): | |
chatbot = gr.Textbox(label='DeciLM-6B-Instruct Output:') | |
with gr.Row(): | |
textbox = gr.Textbox( | |
container=False, | |
show_label=False, | |
placeholder='Type an instruction...', | |
scale=10, | |
elem_id="textbox" | |
) | |
submit_button = gr.Button( | |
'💬 Submit', | |
variant='primary', | |
scale=1, | |
min_width=0, | |
elem_id="submit_button" | |
) | |
# Clear button to clear the chat history | |
clear_button = gr.Button( | |
'🗑️ Clear', | |
variant='secondary', | |
) | |
clear_button.click( | |
fn=lambda: ('',''), | |
outputs=[textbox, chatbot], | |
queue=False, | |
api_name=False, | |
) | |
submit_button.click( | |
fn=get_response_with_template, | |
inputs=textbox, | |
outputs= chatbot, | |
queue=False, | |
api_name=False, | |
) | |
gr.Examples( | |
examples=[ | |
'Write detailed instructions for making chocolate chip pancakes.', | |
'Write a 250-word article about your love of pancakes.', | |
'Explain the plot of Back to the Future in three sentences.', | |
'How do I make a trap beat?', | |
'A step-by-step guide to learning Python in one month.', | |
], | |
inputs=textbox, | |
outputs=chatbot, | |
fn=get_response_with_template, | |
cache_examples=True, | |
elem_id="examples" | |
) | |
gr.HTML(label="Keep in touch", value="<img src='./content/deci-coder-banner.png' alt='Keep in touch' style='display: block; color: #292b47; margin: auto; max-width: 800px;'>") | |
demo.launch(share=True, debug=True) |