File size: 4,148 Bytes
2f55736
8469765
 
 
 
 
 
2f55736
8469765
 
 
 
1f05d65
8469765
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a93c9ce
8469765
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f05d65
2f55736
 
 
 
 
 
 
 
 
 
 
8469765
2f55736
 
 
 
 
8469765
 
 
 
 
 
ddc79ca
8469765
ddc79ca
 
8469765
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f55736
 
1092b39
2f55736
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from transformers import StoppingCriteriaList, StoppingCriteria
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone
import warnings


warnings.filterwarnings("ignore", category=UserWarning)

# model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct"
model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct"


device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Initialize Tokenizer & Model
tokenizer = AutoTokenizer.from_pretrained(model_name)


def read_file(file_path: str) -> str:
    """Read the contents of a file."""
    with open(file_path, "r") as file:
        return file.read()


model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
model.to(device)

document_encoder_model = SentenceTransformer("KBLab/sentence-bert-swedish-cased")


# Note: 'index1' has been pre-created in the pinecone console
# read the pinecone api key from a file
pinecone_api_key = st.secrets["pinecone_api_key"]
pc = Pinecone(api_key=pinecone_api_key)
index = pc.Index("index1")


def query_pincecone_namespace(
    vector_databse_index: Pinecone, q_embedding: str, namespace: str
) -> str:
    result = vector_databse_index.query(
        namespace=namespace,
        vector=q_embedding.tolist(),
        top_k=1,
        include_values=True,
        include_metadata=True,
    )
    results = []
    for match in result.matches:
        results.append(match.metadata["paragraph"])
    return results[0]


def generate_prompt(llmprompt: str) -> str:
    """Generates a prompt for the GPT-3 model"""
    start_token = "<|endoftext|><s>"
    end_token = "<s>"
    return f"{start_token}\nUser:\n{llmprompt}\n{end_token}\nBot:\n".strip()


def encode_query(query: str) -> torch.Tensor:
    """Encode the query using the model's tokenizer"""
    return document_encoder_model.encode(query)


class StopOnTokenCriteria(StoppingCriteria):
    def __init__(self, stop_token_id):
        self.stop_token_id = stop_token_id

    def __call__(self, input_ids, scores, **kwargs):
        return input_ids[0, -1] == self.stop_token_id


stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=tokenizer.bos_token_id)

st.title("Paralegal Assistant")
st.subheader("RAG: föräldrabalken")

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# React to user input
if prompt := st.chat_input("Skriv din fråga..."):
    # Display user message in chat message container
    st.chat_message("user").markdown(prompt)
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": prompt})

    query = query_pincecone_namespace(
        vector_databse_index=index,
        q_embedding=encode_query(query=prompt),
        namespace="ns-parent-balk",
    )
    llmprompt = (
        "Följande stycke är en del av lagen: "
        + query
        +"Referera till lagen och besvara följande fråga på ett sakligt, kortfattat och formellt vis: "
        + prompt
    )
    llmprompt = generate_prompt(llmprompt=llmprompt)

    # # Convert prompt to tokens
    input_ids = tokenizer(llmprompt, return_tensors="pt")["input_ids"].to(device)

    # Genqerate tokens based om prompt
    generated_token_ids = model.generate(
        inputs=input_ids,
        max_new_tokens=128,
        do_sample=True,
        temperature=0.8,
        top_p=1,
        stopping_criteria=StoppingCriteriaList([stop_on_token_criteria]),
    )[0]

    # Decode the generated tokens
    generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]) : -1])

    response = f"{generated_text}"
    # Display assistant response in chat message container
    with st.chat_message("assistant"):
        st.markdown(f"```{query}```\n" + response)
    # Add assistant response to chat history
    st.session_state.messages.append({"role": "assistant", "content": response})