philschmid's picture
philschmid HF staff
Update app.py
2f96ff1
from typing import Iterator
import gradio as gr
import boto3
import io
import json
import os
from transformers import AutoTokenizer
aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID", None)
aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
aws_session_token = os.environ.get("AWS_SESSION_TOKEN", None)
region = os.environ.get("AWS_REGION", None)
endpoint_name = os.environ.get("SAGEMAKER_ENDPOINT_NAME", None)
tokenizer = AutoTokenizer.from_pretrained(
"aws-neuron/Llama-2-7b-chat-hf-seqlen-2048-bs-4"
)
if (
aws_access_key_id is None
or aws_secret_access_key is None
or region is None
or endpoint_name is None
):
raise Exception(
"Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION and SAGEMAKER_ENDPOINT_NAME environment variables"
)
boto_session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=region,
)
smr = boto_session.client("sagemaker-runtime")
DEFAULT_SYSTEM_PROMPT = (
"You are an helpful, concise and direct Assistant, called Llama. Knowing everyting about AWS. Don't use emojis or cringe stuff. You are talking to professionals."
)
MAX_INPUT_TOKEN_LENGTH = 1024
# hyperparameters for llm
parameters = {
"do_sample": True,
"top_p": 0.9,
"temperature": 0.6,
"max_new_tokens": 1024,
"repetition_penalty": 1.2,
"stop": ["<\s>"],
}
# Helper for reading lines from a stream
class LineIterator:
def __init__(self, stream):
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self):
return self
def __next__(self):
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
print("Unknown event type:" + chunk)
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
def format_prompt(message, history):
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
for interaction in history:
messages.append({"role": "user", "content": interaction[0]})
messages.append({"role": "assistant", "content": interaction[1]})
messages.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return prompt
def generate(
prompt,
history,
):
formatted_prompt = format_prompt(prompt, history)
check_input_token_length(formatted_prompt)
request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
# resp = smr.invoke_endpoint(
# EndpointName=endpoint_name,
# Body=json.dumps(request),
# ContentType="application/json",
# )
# print(json.loads(resp["Body"].read().decode()))
# output = ""
# return output
resp = smr.invoke_endpoint_with_response_stream(
EndpointName=endpoint_name,
Body=json.dumps(request),
ContentType="application/json",
)
output = ""
for c in LineIterator(resp["Body"]):
c = c.decode("utf-8")
if c.startswith("data:"):
chunk = json.loads(c.lstrip("data:").rstrip("/n"))
if chunk["token"]["special"]:
continue
if chunk["token"]["text"] in request["parameters"]["stop"]:
break
output += chunk["token"]["text"]
for stop_str in request["parameters"]["stop"]:
if output.endswith(stop_str):
output = output[: -len(stop_str)]
output = output.rstrip()
yield output
yield output
return output
def check_input_token_length(prompt: str) -> None:
input_token_length = len(tokenizer(prompt)["input_ids"])
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
raise gr.Error(
f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
)
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
)
DESCRIPTION = """
<div style="text-align: center; max-width: 650px; margin: 0 auto; display:grid; gap:25px;">
<img class="logo" src="https://huggingface.co/datasets/philschmid/assets/resolve/main/aws-neuron_hf.png" alt="Hugging Face Neuron Logo"
style="margin: auto; max-width: 14rem;">
<h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
Llama 2 7B Chat on AWS INF2 ⚡
</h1>
<p style="margin-bottom: 10px; font-size: 94%; line-height: 23px;">
Llama 2 is a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters. This is the repository for the 7B fine-tuned model, optimized for dialogue use cases and converted for the Hugging Face Transformers format. Links to other models can be found in the index at the bottom. This demo is running on <a style="text-decoration: underline;" href="https://aws.amazon.com/ec2/instance-types/inf2/?nc1=h_ls">AWS Inferentia2</a>, <a href="https://www.philschmid.de/inferentia2-llama-7b" target="_blank">How does it work?</a>
</p>
</div>
"""
demo = gr.ChatInterface(
generate,
description=DESCRIPTION,
chatbot=gr.Chatbot(layout="panel"),
theme=theme,
)
demo.queue(concurrency_count=5).launch(share=False)