Michaeldavidstein commited on
Commit
d8b45ba
1 Parent(s): d08f36a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import json
4
+
5
+ import gradio as gr
6
+
7
+ from openai import OpenAI
8
+
9
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
10
+ from langchain_community.vectorstores import Chroma
11
+
12
+ from huggingface_hub import CommitScheduler
13
+ from pathlib import Path
14
+
15
+
16
+ client = OpenAI(
17
+ base_url="https://api.endpoints.anyscale.com/v1",
18
+ api_key=os.environ['ANYSCALE_API_KEY']
19
+ )
20
+
21
+ embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-small')
22
+
23
+ tesla_10k_collection = 'tesla-10k-2019-to-2023'
24
+
25
+ vectorstore_persisted = Chroma(
26
+ collection_name=tesla_10k_collection,
27
+ persist_directory='./tesla_db',
28
+ embedding_function=embedding_model
29
+ )
30
+
31
+ retriever = vectorstore_persisted.as_retriever(
32
+ search_type='similarity',
33
+ search_kwargs={'k': 5}
34
+ )
35
+
36
+ # Prepare the logging functionality
37
+
38
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
39
+ log_folder = log_file.parent
40
+
41
+ scheduler = CommitScheduler(
42
+ repo_id="document-qna-chroma-anyscale-logs",
43
+ repo_type="dataset",
44
+ folder_path=log_folder,
45
+ path_in_repo="data",
46
+ every=2
47
+ )
48
+
49
+ qna_system_message = """
50
+ You are an assistant to a financial services firm who answers user queries on annual reports.
51
+ Users will ask questions delimited by triple backticks, that is, ```.
52
+ User input will have the context required by you to answer user questions.
53
+ This context will begin with the token: ###Context.
54
+ The context contains references to specific portions of a document relevant to the user query.
55
+ Please answer only using the context provided in the input. However, do not mention anything about the context in your answer.
56
+ If the answer is not found in the context, respond "I don't know".
57
+ """
58
+
59
+ qna_user_message_template = """
60
+ ###Context
61
+ Here are some documents that are relevant to the question.
62
+ {context}
63
+ ```
64
+ {question}
65
+ ```
66
+ """
67
+
68
+ # Define the predict function that runs when 'Submit' is clicked or when a API request is made
69
+ def predict(user_input):
70
+
71
+ relevant_document_chunks = retriever.invoke(user_input)
72
+ context_list = [d.page_content for d in relevant_document_chunks]
73
+ context_for_query = ".".join(context_list)
74
+
75
+ prompt = [
76
+ {'role':'system', 'content': qna_system_message},
77
+ {'role': 'user', 'content': qna_user_message_template.format(
78
+ context=context_for_query,
79
+ question=user_input
80
+ )
81
+ }
82
+ ]
83
+
84
+ try:
85
+ response = client.chat.completions.create(
86
+ model='mlabonne/NeuralHermes-2.5-Mistral-7B',
87
+ messages=prompt,
88
+ temperature=0
89
+ )
90
+
91
+ prediction = response.choices[0].message.content
92
+
93
+ except Exception as e:
94
+ prediction = e
95
+
96
+ # While the prediction is made, log both the inputs and outputs to a local log file
97
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
98
+ # access
99
+
100
+ with scheduler.lock:
101
+ with log_file.open("a") as f:
102
+ f.write(json.dumps(
103
+ {
104
+ 'user_input': user_input,
105
+ 'retrieved_context': context_for_query,
106
+ 'model_response': prediction
107
+ }
108
+ ))
109
+ f.write("\n")
110
+
111
+ return prediction
112
+
113
+
114
+ textbox = gr.Textbox(placeholder="Enter your query here", lines=6)
115
+
116
+ # Create the interface
117
+ demo = gr.Interface(
118
+ inputs=textbox, fn=predict, outputs="text",
119
+ title="AMA on Tesla 10-K statements",
120
+ description="This web API presents an interface to ask questions on contents of the Tesla 10-K reports for the period 2019 - 2023.",
121
+ article="Note that questions that are not relevant to the Tesla 10-K report will not be answered.",
122
+ examples=[["What was the total revenue of the company in 2022?", "$ 81.46 Billion"],
123
+ ["Summarize the Management Discussion and Analysis section of the 2021 report in 50 words.", ""],
124
+ ["What was the company's debt level in 2020?", ""],
125
+ ["Identify 5 key risks identified in the 2019 10k report? Respond with bullet point summaries.", ""]
126
+ ],
127
+ concurrency_limit=16
128
+ )
129
+
130
+ demo.queue()
131
+ demo.launch()