Spaces:
Sleeping
Sleeping
File size: 4,148 Bytes
2f55736 8469765 2f55736 8469765 1f05d65 8469765 a93c9ce 8469765 1f05d65 2f55736 8469765 2f55736 8469765 ddc79ca 8469765 ddc79ca 8469765 2f55736 1092b39 2f55736 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from transformers import StoppingCriteriaList, StoppingCriteria
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
# model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct"
model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Initialize Tokenizer & Model
tokenizer = AutoTokenizer.from_pretrained(model_name)
def read_file(file_path: str) -> str:
"""Read the contents of a file."""
with open(file_path, "r") as file:
return file.read()
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
model.to(device)
document_encoder_model = SentenceTransformer("KBLab/sentence-bert-swedish-cased")
# Note: 'index1' has been pre-created in the pinecone console
# read the pinecone api key from a file
pinecone_api_key = st.secrets["pinecone_api_key"]
pc = Pinecone(api_key=pinecone_api_key)
index = pc.Index("index1")
def query_pincecone_namespace(
vector_databse_index: Pinecone, q_embedding: str, namespace: str
) -> str:
result = vector_databse_index.query(
namespace=namespace,
vector=q_embedding.tolist(),
top_k=1,
include_values=True,
include_metadata=True,
)
results = []
for match in result.matches:
results.append(match.metadata["paragraph"])
return results[0]
def generate_prompt(llmprompt: str) -> str:
"""Generates a prompt for the GPT-3 model"""
start_token = "<|endoftext|><s>"
end_token = "<s>"
return f"{start_token}\nUser:\n{llmprompt}\n{end_token}\nBot:\n".strip()
def encode_query(query: str) -> torch.Tensor:
"""Encode the query using the model's tokenizer"""
return document_encoder_model.encode(query)
class StopOnTokenCriteria(StoppingCriteria):
def __init__(self, stop_token_id):
self.stop_token_id = stop_token_id
def __call__(self, input_ids, scores, **kwargs):
return input_ids[0, -1] == self.stop_token_id
stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=tokenizer.bos_token_id)
st.title("Paralegal Assistant")
st.subheader("RAG: föräldrabalken")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if prompt := st.chat_input("Skriv din fråga..."):
# Display user message in chat message container
st.chat_message("user").markdown(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
query = query_pincecone_namespace(
vector_databse_index=index,
q_embedding=encode_query(query=prompt),
namespace="ns-parent-balk",
)
llmprompt = (
"Följande stycke är en del av lagen: "
+ query
+"Referera till lagen och besvara följande fråga på ett sakligt, kortfattat och formellt vis: "
+ prompt
)
llmprompt = generate_prompt(llmprompt=llmprompt)
# # Convert prompt to tokens
input_ids = tokenizer(llmprompt, return_tensors="pt")["input_ids"].to(device)
# Genqerate tokens based om prompt
generated_token_ids = model.generate(
inputs=input_ids,
max_new_tokens=128,
do_sample=True,
temperature=0.8,
top_p=1,
stopping_criteria=StoppingCriteriaList([stop_on_token_criteria]),
)[0]
# Decode the generated tokens
generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]) : -1])
response = f"{generated_text}"
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(f"```{query}```\n" + response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response}) |