from typing import Iterator import gradio as gr import boto3 import io import json import os from transformers import AutoTokenizer aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID", None) aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY", None) aws_session_token = os.environ.get("AWS_SESSION_TOKEN", None) region = os.environ.get("AWS_REGION", None) endpoint_name = os.environ.get("SAGEMAKER_ENDPOINT_NAME", None) tokenizer = AutoTokenizer.from_pretrained( "aws-neuron/Llama-2-7b-chat-hf-seqlen-2048-bs-4" ) if ( aws_access_key_id is None or aws_secret_access_key is None or region is None or endpoint_name is None ): raise Exception( "Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION and SAGEMAKER_ENDPOINT_NAME environment variables" ) boto_session = boto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, region_name=region, ) smr = boto_session.client("sagemaker-runtime") DEFAULT_SYSTEM_PROMPT = ( "You are an helpful, concise and direct Assistant, called Llama. Knowing everyting about AWS. Don't use emojis or cringe stuff. You are talking to professionals." ) MAX_INPUT_TOKEN_LENGTH = 1024 # hyperparameters for llm parameters = { "do_sample": True, "top_p": 0.9, "temperature": 0.6, "max_new_tokens": 1024, "repetition_penalty": 1.2, "stop": ["<\s>"], } # Helper for reading lines from a stream class LineIterator: def __init__(self, stream): self.byte_iterator = iter(stream) self.buffer = io.BytesIO() self.read_pos = 0 def __iter__(self): return self def __next__(self): while True: self.buffer.seek(self.read_pos) line = self.buffer.readline() if line and line[-1] == ord("\n"): self.read_pos += len(line) return line[:-1] try: chunk = next(self.byte_iterator) except StopIteration: if self.read_pos < self.buffer.getbuffer().nbytes: continue raise if "PayloadPart" not in chunk: print("Unknown event type:" + chunk) continue self.buffer.seek(0, io.SEEK_END) self.buffer.write(chunk["PayloadPart"]["Bytes"]) def format_prompt(message, history): messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}] for interaction in history: messages.append({"role": "user", "content": interaction[0]}) messages.append({"role": "assistant", "content": interaction[1]}) messages.append({"role": "user", "content": message}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return prompt def generate( prompt, history, ): formatted_prompt = format_prompt(prompt, history) check_input_token_length(formatted_prompt) request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True} # resp = smr.invoke_endpoint( # EndpointName=endpoint_name, # Body=json.dumps(request), # ContentType="application/json", # ) # print(json.loads(resp["Body"].read().decode())) # output = "" # return output resp = smr.invoke_endpoint_with_response_stream( EndpointName=endpoint_name, Body=json.dumps(request), ContentType="application/json", ) output = "" for c in LineIterator(resp["Body"]): c = c.decode("utf-8") if c.startswith("data:"): chunk = json.loads(c.lstrip("data:").rstrip("/n")) if chunk["token"]["special"]: continue if chunk["token"]["text"] in request["parameters"]["stop"]: break output += chunk["token"]["text"] for stop_str in request["parameters"]["stop"]: if output.endswith(stop_str): output = output[: -len(stop_str)] output = output.rstrip() yield output yield output return output def check_input_token_length(prompt: str) -> None: input_token_length = len(tokenizer(prompt)["input_ids"]) if input_token_length > MAX_INPUT_TOKEN_LENGTH: raise gr.Error( f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again." ) theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", radius_size=gr.themes.sizes.radius_sm, font=[ gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif", ], ) DESCRIPTION = """

Llama 2 7B Chat on AWS INF2 ⚡

Llama 2 is a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters. This is the repository for the 7B fine-tuned model, optimized for dialogue use cases and converted for the Hugging Face Transformers format. Links to other models can be found in the index at the bottom. This demo is running on AWS Inferentia2, How does it work?

""" demo = gr.ChatInterface( generate, description=DESCRIPTION, chatbot=gr.Chatbot(layout="panel"), theme=theme, ) demo.queue(concurrency_count=5).launch(share=False)