import streamlit as st import transformers, torch import json, os from huggingface_hub import login # CONSTANTS MAX_NEW_TOKENS = 256 SYSTEM_MESSAGE = "You are a hepful, knowledgeable assistant" # ENV VARS # To avert Permision error with transformer and hf models os.environ['SENTENCE_TRANSFORMERS_HOME'] = '.' token = os.getenv("HF_TOKEN_READ") # STREAMLIT UI AREA st.write("## Ask your Local LLM") text_input = st.text_input("Query", value="Why is the sky Blue") submit = st.button("Submit") # MODEL AREA # Use the token to authenticate login(token=token) model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" @st.cache_resource def load_model(): pipeline = transformers.pipeline( "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", ) pipeline = load_model() message_store_path = "messages.jsonl" messages = [ {"role": "system", "content": SYSTEM_MESSAGE}, ] if os.path.exists(message_store_path): with open(message_store_path, "r", encoding="utf-8") as f: messages = [json.loads(line) for line in f] print(messages) @st.cache_data def infer(message: str, messages: list[dict]): """ Params: message: Most recent query to the llm. messages: Chat history up to current point properly formatted like {"role": "user", "content": "What is your name?"} """ messages.append({"role": "user", "content": message}) # Perfom inference output = pipeline( messages, max_new_tokens=MAX_NEW_TOKENS) # Save the newly updated messages object with open(message_store_path, "w", encoding="utf-8") as f: for line in output: json.dump(line, f) f.write("\n") return output[-1]['generated_text'][-1]['content'] if submit: response = infer(text_input, messages) response