File size: 7,703 Bytes
f3d0f1e
 
 
 
 
 
227833b
d0fd192
 
 
f3d0f1e
d0fd192
 
f3d0f1e
d0fd192
3ebff47
 
d0fd192
3ebff47
f3d0f1e
 
 
 
 
 
 
 
d0fd192
c98215f
f3d0f1e
 
 
 
 
 
c98215f
f3d0f1e
c98215f
f3d0f1e
c98215f
 
 
 
 
 
 
 
d0fd192
f3d0f1e
 
 
 
 
d0fd192
f3d0f1e
d0fd192
f3d0f1e
 
85df319
d0fd192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85df319
d0fd192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3d0f1e
 
85df319
e7ac557
85df319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef6a605
 
f3d0f1e
85df319
 
0d7e513
85df319
 
0d7e513
e681b03
 
85df319
0d7e513
 
85df319
0d7e513
85df319
 
 
0d7e513
85df319
 
0d7e513
85df319
 
0d7e513
 
e681b03
85df319
0d7e513
 
 
85df319
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms.huggingface_hub import HuggingFaceHub
from langchain_community.embeddings import HuggingFaceEmbeddings

from src.vectordatabase import RAG, get_vectorstore
import pandas as pd

# Load environmental variables from .env-file
# from dotenv import load_dotenv, find_dotenv
# load_dotenv(find_dotenv())

# Define important variables 
embeddings = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-MiniLM-L12-v2") # Remove embedding input parameter from functions?
llm = HuggingFaceHub(
    # ToDo: Try different models here
    repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
    # repo_id="CohereForAI/c4ai-command-r-v01", # too large 69gb
    # repo_id="CohereForAI/c4ai-command-r-v01-4bit", # too large 22gb
    # repo_id="meta-llama/Meta-Llama-3-8B", # too large 16 gb
    task="text-generation",
    model_kwargs={
        "max_new_tokens": 512,
        "top_k": 30,
        "temperature": 0.1,
        "repetition_penalty": 1.03,
        }
)
# ToDo: Experiment with different templates
prompt_test = ChatPromptTemplate.from_template("""<s>[INST] 

                    Instruction: Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts:



                    Context: {context}



                    Question: {input}  

                    [/INST]"""
                    
) 
prompt_de = ChatPromptTemplate.from_template("""Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts:



        <context>

        {context}

        </context>



        Frage: {input}

        """
        # Returns the answer in German
)
prompt_en = ChatPromptTemplate.from_template("""Answer the following question in English and solely based on the provided context:



        <context>

        {context}

        </context>



        Question: {input}

        """
        # Returns the answer in English
)



def chatbot(message, history, db_inputs, prompt_language, llm=llm):
    """

    Generate a response from the chatbot based on the provided message, history, database inputs, prompt language, and LLM model.



    Parameters:

    -----------

    message : str

        The message or question to be answered by the chatbot.

        

    history : list

        The history of previous interactions or messages.

        

    db_inputs : list

        A list of strings specifying which vector stores to combine. Each string represents a specific index or a special keyword "All".

        

    prompt_language : str

        The language of the prompt to be used for generating the response. Should be either "DE" for German or "EN" for English.

        

    llm : LLM, optional

        An instance of the Language Model to be used for generating the response. Defaults to the global variable `llm`.



    Returns:

    --------

    str

        The response generated by the chatbot.

    """
    
    db = get_vectorstore(inputs = db_inputs, embeddings=embeddings)
    
    # Select prompt based on user input
    if prompt_language == "DE":
        prompt = prompt_de
        raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
        # Only necessary because mistral does include it´s json structure in the output including its input content
        try:
            response = raw_response['answer'].split("Antwort: ")[1]
        except:  
            response = raw_response['answer']
        return response
    else:
        prompt = prompt_en
        raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
        # Only necessary because mistral does include it´s json structure in the output including its input content
        try:
            response = raw_response['answer'].split("Answer: ")[1]
        except:  
            response = raw_response['answer']
    
    return response  


def keyword_search(query, n=10, embeddings=embeddings, method='ss', party_filter='All'):
    """

    Retrieve speech contents based on keywords using a specified method.



    Parameters:

    ----------

    db : FAISS

        The FAISS vector store containing speech embeddings.



    query : str

        The keyword(s) to search for in the speech contents.



    n : int, optional

        The number of speech contents to retrieve (default is 10).



    embeddings : Embeddings, optional

        An instance of embeddings used for embedding queries (default is embeddings).



    method : str, optional

        The method used for retrieving speech contents. Options are 'ss' (semantic search) and 'mmr' 

        (maximal marginal relevance) (default is 'ss').



    party_filter : str, optional

        A filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve 

        speeches from all parties (default is 'All').



    Returns:

    -------

    pandas.DataFrame

        A DataFrame containing the speech contents, dates, and party affiliations.

    

    Notes:

    -----

    - The `db` parameter should be a FAISS vector store containing speech embeddings.

    - The `query` parameter specifies the keyword(s) to search for in the speech contents.

    - The `n` parameter determines the number of speech contents to retrieve (default is 10).

    - The `embeddings` parameter is an instance of embeddings used for embedding queries (default is embeddings).

    - The `method` parameter specifies the method used for retrieving speech contents. Options are 'ss' (semantic search) 

      and 'mmr' (maximal marginal relevance) (default is 'ss').

    - The `party_filter` parameter is a filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve 

      speeches from all parties (default is 'All').

    """
    
    db = get_vectorstore(inputs=["All"], embeddings=embeddings)
    query_embedding = embeddings.embed_query(query)

    # Maximal Marginal Relevance
    if method == 'mmr':
        df_res = pd.DataFrame(columns=['Speech Content', 'Date', 'Party', 'Relevance'])
        results = db.max_marginal_relevance_search_with_score_by_vector(query_embedding, k=n)
        for doc in results:
            party = doc[0].metadata["party"]
            if party != party_filter and party_filter != 'All':
                continue
            speech_content = doc[0].page_content
            speech_date = doc[0].metadata["date"]
            score = round(doc[1], ndigits=2)
            df_res = pd.concat([df_res, pd.DataFrame({'Speech Content': [speech_content],
                                                      'Date': [speech_date],
                                                      'Party': [party],
                                                      'Relevance': [score]})], ignore_index=True)
        df_res.sort_values('Relevance', inplace=True, ascending=True)

    # Similarity Search
    else:
        df_res = pd.DataFrame(columns=['Speech Content', 'Date', 'Party'])
        results = db.similarity_search_by_vector(query_embedding, k=n)
        for doc in results:
            party = doc.metadata["party"]
            if party != party_filter and party_filter != 'All':
                continue
            speech_content = doc.page_content
            speech_date = doc.metadata["date"]
            df_res = pd.concat([df_res, pd.DataFrame({'Speech Content': [speech_content],
                                                      'Date': [speech_date],
                                                      'Party': [party]})], ignore_index=True)
    return df_res