awinml commited on
Commit
ac5b87a
1 Parent(s): bf8b612

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +6 -2
  2. utils.py +65 -14
app.py CHANGED
@@ -42,13 +42,16 @@ with col1:
42
  )
43
 
44
  with col1:
45
- years_choice = ["2020", "2019", "2018", "2017", "2016"]
46
 
47
  with col1:
48
  year = st.selectbox("Year", years_choice)
49
 
50
  with col1:
51
- quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4"])
 
 
 
52
 
53
  ticker_choice = [
54
  "AAPL",
@@ -127,6 +130,7 @@ query_results = query_pinecone(
127
  year,
128
  quarter,
129
  ticker,
 
130
  threshold,
131
  )
132
 
 
42
  )
43
 
44
  with col1:
45
+ years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
46
 
47
  with col1:
48
  year = st.selectbox("Year", years_choice)
49
 
50
  with col1:
51
+ quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4", "All"])
52
+
53
+ with col1:
54
+ participant_type = st.selectbox("Speaker", ["Company Speaker", "Analyst"])
55
 
56
  ticker_choice = [
57
  "AAPL",
 
130
  year,
131
  quarter,
132
  ticker,
133
+ participant_type,
134
  threshold,
135
  )
136
 
utils.py CHANGED
@@ -61,21 +61,72 @@ def save_key(api_key):
61
  return api_key
62
 
63
 
64
- def query_pinecone(query, top_k, model, index, year, quarter, ticker, threshold=0.5):
 
 
65
  # generate embeddings for the query
66
  xq = model.encode([query]).tolist()
67
- # search pinecone index for context passage with the answer
68
- xc = index.query(
69
- xq,
70
- top_k=top_k,
71
- filter={
72
- "Year": int(year),
73
- "Quarter": {"$eq": quarter},
74
- "Ticker": {"$eq": ticker},
75
- "QA_Flag": {"$eq": "Answer"},
76
- },
77
- include_metadata=True,
78
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # filter the context passages based on the score threshold
80
  filtered_matches = []
81
  for match in xc["matches"]:
@@ -91,7 +142,7 @@ def format_query(query_results):
91
  return context
92
 
93
 
94
- def sentence_id_combine(data, query_results, lag=2):
95
  # Extract sentence IDs from query results
96
  ids = [result["metadata"]["Sentence_id"] for result in query_results["matches"]]
97
  # Generate new IDs by adding a lag value to the original IDs
 
61
  return api_key
62
 
63
 
64
+ def query_pinecone(
65
+ query, top_k, model, index, year, quarter, ticker, participant_type, threshold=0.25
66
+ ):
67
  # generate embeddings for the query
68
  xq = model.encode([query]).tolist()
69
+
70
+ if participant_type == "Company Speaker":
71
+ participant = "Speaker"
72
+ else:
73
+ participant = participant_type
74
+
75
+ if year == "All":
76
+ if quarter == "All":
77
+ xc = index.query(
78
+ xq,
79
+ top_k=top_k,
80
+ filter={
81
+ "Year": {
82
+ "$in": [
83
+ int("2020"),
84
+ int("2019"),
85
+ int("2018"),
86
+ int("2017"),
87
+ int("2016"),
88
+ ]
89
+ },
90
+ "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
91
+ "Ticker": {"$eq": ticker},
92
+ "QA_Flag": {"$eq": participant},
93
+ },
94
+ include_metadata=True,
95
+ )
96
+ else:
97
+ xc = index.query(
98
+ xq,
99
+ top_k=top_k,
100
+ filter={
101
+ "Year": {
102
+ "$in": [
103
+ int("2020"),
104
+ int("2019"),
105
+ int("2018"),
106
+ int("2017"),
107
+ int("2016"),
108
+ ]
109
+ },
110
+ "Quarter": {"$eq": quarter},
111
+ "Ticker": {"$eq": ticker},
112
+ "QA_Flag": {"$eq": participant},
113
+ },
114
+ include_metadata=True,
115
+ )
116
+ else:
117
+ # search pinecone index for context passage with the answer
118
+ xc = index.query(
119
+ xq,
120
+ top_k=top_k,
121
+ filter={
122
+ "Year": int(year),
123
+ "Quarter": {"$eq": quarter},
124
+ "Ticker": {"$eq": ticker},
125
+ "QA_Flag": {"$eq": participant},
126
+ },
127
+ include_metadata=True,
128
+ )
129
+
130
  # filter the context passages based on the score threshold
131
  filtered_matches = []
132
  for match in xc["matches"]:
 
142
  return context
143
 
144
 
145
+ def sentence_id_combine(data, query_results, lag=1):
146
  # Extract sentence IDs from query results
147
  ids = [result["metadata"]["Sentence_id"] for result in query_results["matches"]]
148
  # Generate new IDs by adding a lag value to the original IDs