LagRAG_demo / app.py
SS8297's picture
Update app.py
1f05d65 verified
raw
history blame contribute delete
No virus
4.15 kB
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from transformers import StoppingCriteriaList, StoppingCriteria
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
# model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct"
model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Initialize Tokenizer & Model
tokenizer = AutoTokenizer.from_pretrained(model_name)
def read_file(file_path: str) -> str:
"""Read the contents of a file."""
with open(file_path, "r") as file:
return file.read()
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
model.to(device)
document_encoder_model = SentenceTransformer("KBLab/sentence-bert-swedish-cased")
# Note: 'index1' has been pre-created in the pinecone console
# read the pinecone api key from a file
pinecone_api_key = st.secrets["pinecone_api_key"]
pc = Pinecone(api_key=pinecone_api_key)
index = pc.Index("index1")
def query_pincecone_namespace(
vector_databse_index: Pinecone, q_embedding: str, namespace: str
) -> str:
result = vector_databse_index.query(
namespace=namespace,
vector=q_embedding.tolist(),
top_k=1,
include_values=True,
include_metadata=True,
)
results = []
for match in result.matches:
results.append(match.metadata["paragraph"])
return results[0]
def generate_prompt(llmprompt: str) -> str:
"""Generates a prompt for the GPT-3 model"""
start_token = "<|endoftext|><s>"
end_token = "<s>"
return f"{start_token}\nUser:\n{llmprompt}\n{end_token}\nBot:\n".strip()
def encode_query(query: str) -> torch.Tensor:
"""Encode the query using the model's tokenizer"""
return document_encoder_model.encode(query)
class StopOnTokenCriteria(StoppingCriteria):
def __init__(self, stop_token_id):
self.stop_token_id = stop_token_id
def __call__(self, input_ids, scores, **kwargs):
return input_ids[0, -1] == self.stop_token_id
stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=tokenizer.bos_token_id)
st.title("Paralegal Assistant")
st.subheader("RAG: föräldrabalken")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if prompt := st.chat_input("Skriv din fråga..."):
# Display user message in chat message container
st.chat_message("user").markdown(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
query = query_pincecone_namespace(
vector_databse_index=index,
q_embedding=encode_query(query=prompt),
namespace="ns-parent-balk",
)
llmprompt = (
"Följande stycke är en del av lagen: "
+ query
+"Referera till lagen och besvara följande fråga på ett sakligt, kortfattat och formellt vis: "
+ prompt
)
llmprompt = generate_prompt(llmprompt=llmprompt)
# # Convert prompt to tokens
input_ids = tokenizer(llmprompt, return_tensors="pt")["input_ids"].to(device)
# Genqerate tokens based om prompt
generated_token_ids = model.generate(
inputs=input_ids,
max_new_tokens=128,
do_sample=True,
temperature=0.8,
top_p=1,
stopping_criteria=StoppingCriteriaList([stop_on_token_criteria]),
)[0]
# Decode the generated tokens
generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]) : -1])
response = f"{generated_text}"
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(f"```{query}```\n" + response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})