import streamlit as st import torch from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from transformers import StoppingCriteriaList, StoppingCriteria from sentence_transformers import SentenceTransformer from pinecone import Pinecone import warnings warnings.filterwarnings("ignore", category=UserWarning) # model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct" model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct" device = "cuda:0" if torch.cuda.is_available() else "cpu" # Initialize Tokenizer & Model tokenizer = AutoTokenizer.from_pretrained(model_name) def read_file(file_path: str) -> str: """Read the contents of a file.""" with open(file_path, "r") as file: return file.read() model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() model.to(device) document_encoder_model = SentenceTransformer("KBLab/sentence-bert-swedish-cased") # Note: 'index1' has been pre-created in the pinecone console # read the pinecone api key from a file pinecone_api_key = st.secrets["pinecone_api_key"] pc = Pinecone(api_key=pinecone_api_key) index = pc.Index("index1") def query_pincecone_namespace( vector_databse_index: Pinecone, q_embedding: str, namespace: str ) -> str: result = vector_databse_index.query( namespace=namespace, vector=q_embedding.tolist(), top_k=1, include_values=True, include_metadata=True, ) results = [] for match in result.matches: results.append(match.metadata["paragraph"]) return results[0] def generate_prompt(llmprompt: str) -> str: """Generates a prompt for the GPT-3 model""" start_token = "<|endoftext|>" end_token = "" return f"{start_token}\nUser:\n{llmprompt}\n{end_token}\nBot:\n".strip() def encode_query(query: str) -> torch.Tensor: """Encode the query using the model's tokenizer""" return document_encoder_model.encode(query) class StopOnTokenCriteria(StoppingCriteria): def __init__(self, stop_token_id): self.stop_token_id = stop_token_id def __call__(self, input_ids, scores, **kwargs): return input_ids[0, -1] == self.stop_token_id stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=tokenizer.bos_token_id) st.title("Paralegal Assistant") st.subheader("RAG: föräldrabalken") # Initialize chat history if "messages" not in st.session_state: st.session_state.messages = [] # Display chat messages from history on app rerun for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # React to user input if prompt := st.chat_input("Skriv din fråga..."): # Display user message in chat message container st.chat_message("user").markdown(prompt) # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt}) query = query_pincecone_namespace( vector_databse_index=index, q_embedding=encode_query(query=prompt), namespace="ns-parent-balk", ) llmprompt = ( "Följande stycke är en del av lagen: " + query +"Referera till lagen och besvara följande fråga på ett sakligt, kortfattat och formellt vis: " + prompt ) llmprompt = generate_prompt(llmprompt=llmprompt) # # Convert prompt to tokens input_ids = tokenizer(llmprompt, return_tensors="pt")["input_ids"].to(device) # Genqerate tokens based om prompt generated_token_ids = model.generate( inputs=input_ids, max_new_tokens=128, do_sample=True, temperature=0.8, top_p=1, stopping_criteria=StoppingCriteriaList([stop_on_token_criteria]), )[0] # Decode the generated tokens generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]) : -1]) response = f"{generated_text}" # Display assistant response in chat message container with st.chat_message("assistant"): st.markdown(f"```{query}```\n" + response) # Add assistant response to chat history st.session_state.messages.append({"role": "assistant", "content": response})