AceVen57 commited on
Commit
8becf68
1 Parent(s): 17afeec

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import json
4
+
5
+ import gradio as gr
6
+
7
+ from openai import OpenAI
8
+
9
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
10
+ from langchain_community.vectorstores import Chroma
11
+
12
+ from huggingface_hub import CommitScheduler
13
+ from pathlib import Path
14
+ from dotenv import load_dotenv
15
+
16
+
17
+ # Create Client
18
+ load_dotenv()
19
+
20
+ os.environ["ANYSCALE_API_KEY"]=os.getenv("ANYSCALE_API_KEY")
21
+
22
+ client = OpenAI(
23
+ base_url="https://api.endpoints.anyscale.com/v1",
24
+ api_key=os.environ['ANYSCALE_API_KEY']
25
+ )
26
+
27
+ embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
28
+ # Define the embedding model and the vectorstore
29
+
30
+ collection_name = 'report-10k-2024'
31
+
32
+ vectorstore_persisted = Chroma(
33
+ collection_name=collection_name,
34
+ persist_directory='./dataset-10k',
35
+ embedding_function=embedding_model
36
+ )
37
+
38
+ # Load the persisted vectorDB
39
+
40
+ retriever = vectorstore_persisted.as_retriever(
41
+ search_type='similarity',
42
+ search_kwargs={'k': 5}
43
+ )
44
+
45
+
46
+ # Prepare the logging functionality
47
+
48
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
49
+ log_folder = log_file.parent
50
+
51
+ scheduler = CommitScheduler(
52
+ repo_id="Project_3_Log",
53
+ repo_type="dataset",
54
+ folder_path=log_folder,
55
+ path_in_repo="data",
56
+ every=2
57
+ )
58
+
59
+ # Define the Q&A system message
60
+
61
+ qna_system_message = """
62
+ You are an assistant to a coder. Your task is to provide relevant information about the Python package Streamlit.
63
+
64
+ User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context.
65
+ The context contains references to specific portions of documents relevant to the user's query, along with source links.
66
+ The source for a context will begin with the token ###Source
67
+
68
+ When crafting your response:
69
+ 1. Select the most relevant context or contexts to answer the question.
70
+ 2. Include the source links in your response.
71
+ 3. User questions will begin with the token: ###Question.
72
+ 4. If the question is irrelevant to streamlit respond with - "I am an assistant for streamlit Docs. I can only help you with questions related to streamlit"
73
+
74
+ Please adhere to the following guidelines:
75
+ - Answer only using the context provided.
76
+ - Do not mention anything about the context in your final answer.
77
+ - If the answer is not found in the context, it is very very important for you to respond with "I don't know. Please check the docs @ 'https://docs.streamlit.io/'"
78
+ - Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Sources:
79
+ - Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources.
80
+
81
+ Here is an example of how to structure your response:
82
+
83
+ Answer:
84
+ [Answer]
85
+
86
+ Source
87
+ [Source]
88
+ """
89
+
90
+ # Define the user message template
91
+ qna_user_message_template = """
92
+ ###Context
93
+ Here are some documents that are relevant to the question.
94
+ {context}
95
+ ```
96
+ {question}
97
+ ```
98
+ """
99
+
100
+ # Define the predict function that runs when 'Submit' is clicked or when a API request is made
101
+ def predict(user_input,company):
102
+
103
+ filter = "dataset/"+company+"-10-k-2023.pdf"
104
+ relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter})
105
+
106
+ # Create context_for_query
107
+ context_list = [d.page_content for d in relevant_document_chunks]
108
+ context_for_query = ".".join(context_list)
109
+
110
+ # Create messages
111
+ prompt = [
112
+ {'role':'system', 'content': qna_system_message},
113
+ {'role': 'user', 'content': qna_user_message_template.format(
114
+ context=context_for_query,
115
+ question=user_input
116
+ )
117
+ }
118
+ ]
119
+
120
+ # Get response from the LLM
121
+ try:
122
+ response = client.chat.completions.create(
123
+ model='mistralai/Mixtral-8x7B-Instruct-v0.1',
124
+ messages=prompt,
125
+ temperature=0
126
+ )
127
+
128
+ prediction = response.choices[0].message.content
129
+
130
+ except Exception as e:
131
+ prediction = e
132
+
133
+ # While the prediction is made, log both the inputs and outputs to a local log file
134
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
135
+ # access
136
+
137
+ with scheduler.lock:
138
+ with log_file.open("a") as f:
139
+ f.write(json.dumps(
140
+ {
141
+ 'user_input': user_input,
142
+ 'retrieved_context': context_for_query,
143
+ 'model_response': prediction
144
+ }
145
+ ))
146
+ f.write("\n")
147
+
148
+ return prediction
149
+
150
+ # Set-up the Gradio UI
151
+ # Add text box and radio button to the interface
152
+ # The radio button is used to select the company 10k report in which the context needs to be retrieved.
153
+
154
+ textbox = gr.Textbox()
155
+ company = gr.Radio()
156
+
157
+ # Create the interface
158
+ # For the inputs parameter of Interface provide [textbox,company]
159
+
160
+
161
+ demo.queue()
162
+ demo.launch()