awinml commited on
Commit
1227998
·
1 Parent(s): c6634c4

Upload 17 files (#17)

Browse files

- Upload 17 files (313a678dec11f2f1512c293afe1e86559cf20a38)

Files changed (2) hide show
  1. utils/entity_extraction.py +36 -1
  2. utils/retriever.py +13 -6
utils/entity_extraction.py CHANGED
@@ -21,6 +21,41 @@ def expand_list_of_lists(list_of_lists):
21
  return expanded_list
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def all_keywords_combs(texts):
25
 
26
  texts = [text.split(" ") for text in texts]
@@ -47,7 +82,7 @@ def extract_keywords(query_text, model):
47
  prompt = f"###Instruction:Extract the important keywords which describe the context accurately.\n\nInput:{query_text}\n\n###Response:"
48
  response = model.predict(prompt)
49
  keywords = response.split(", ")
50
- keywords = all_keywords_combs(keywords)
51
  return keywords
52
 
53
 
 
21
  return expanded_list
22
 
23
 
24
+ def keywords_no_companies(texts):
25
+ # Company list (to remove companies from extracted entities)
26
+
27
+ company_list = [
28
+ "apple",
29
+ "amd",
30
+ "amazon",
31
+ "cisco",
32
+ "google",
33
+ "microsoft",
34
+ "nvidia",
35
+ "asml",
36
+ "intel",
37
+ "micron",
38
+ "aapl",
39
+ "csco",
40
+ "msft",
41
+ "asml",
42
+ "nvda",
43
+ "googl",
44
+ "mu",
45
+ "intc",
46
+ "amzn",
47
+ "amd",
48
+ ]
49
+
50
+ texts = [text.split(" ") for text in texts]
51
+ texts = expand_list_of_lists(texts)
52
+
53
+ # Convert all strings to lowercase.
54
+ lower_texts = [text.lower() for text in texts]
55
+ keywords = [text for text in lower_texts if text not in company_list]
56
+ return keywords
57
+
58
+
59
  def all_keywords_combs(texts):
60
 
61
  texts = [text.split(" ") for text in texts]
 
82
  prompt = f"###Instruction:Extract the important keywords which describe the context accurately.\n\nInput:{query_text}\n\n###Response:"
83
  response = model.predict(prompt)
84
  keywords = response.split(", ")
85
+ keywords = keywords_no_companies(keywords)
86
  return keywords
87
 
88
 
utils/retriever.py CHANGED
@@ -15,6 +15,9 @@ def query_pinecone_sparse(
15
  else:
16
  participant = "Question"
17
 
 
 
 
18
  if year == "All":
19
  if quarter == "All":
20
  xc = index.query(
@@ -34,7 +37,7 @@ def query_pinecone_sparse(
34
  "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
35
  "Ticker": {"$eq": ticker},
36
  "QA_Flag": {"$eq": participant},
37
- "Keywords": {"$in": keywords}
38
  },
39
  include_metadata=True,
40
  )
@@ -56,7 +59,7 @@ def query_pinecone_sparse(
56
  "Quarter": {"$eq": quarter},
57
  "Ticker": {"$eq": ticker},
58
  "QA_Flag": {"$eq": participant},
59
- "Keywords": {"$in": keywords}
60
  },
61
  include_metadata=True,
62
  )
@@ -71,7 +74,7 @@ def query_pinecone_sparse(
71
  "Quarter": {"$eq": quarter},
72
  "Ticker": {"$eq": ticker},
73
  "QA_Flag": {"$eq": participant},
74
- "Keywords": {"$in": keywords}
75
  },
76
  include_metadata=True,
77
  )
@@ -100,6 +103,10 @@ def query_pinecone(
100
  else:
101
  participant = "Question"
102
 
 
 
 
 
103
  if year == "All":
104
  if quarter == "All":
105
  xc = index.query(
@@ -118,7 +125,7 @@ def query_pinecone(
118
  "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
119
  "Ticker": {"$eq": ticker},
120
  "QA_Flag": {"$eq": participant},
121
- "Keywords": {"$in": keywords}
122
  },
123
  include_metadata=True,
124
  )
@@ -139,7 +146,7 @@ def query_pinecone(
139
  "Quarter": {"$eq": quarter},
140
  "Ticker": {"$eq": ticker},
141
  "QA_Flag": {"$eq": participant},
142
- "Keywords": {"$in": keywords}
143
  },
144
  include_metadata=True,
145
  )
@@ -153,7 +160,7 @@ def query_pinecone(
153
  "Quarter": {"$eq": quarter},
154
  "Ticker": {"$eq": ticker},
155
  "QA_Flag": {"$eq": participant},
156
- "Keywords": {"$in": keywords}
157
  },
158
  include_metadata=True,
159
  )
 
15
  else:
16
  participant = "Question"
17
 
18
+ # Create filter dictionary based on keywords
19
+ filter_dict = [{'Keywords': word} for word in keywords]
20
+
21
  if year == "All":
22
  if quarter == "All":
23
  xc = index.query(
 
37
  "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
38
  "Ticker": {"$eq": ticker},
39
  "QA_Flag": {"$eq": participant},
40
+ '$and': filter_dict
41
  },
42
  include_metadata=True,
43
  )
 
59
  "Quarter": {"$eq": quarter},
60
  "Ticker": {"$eq": ticker},
61
  "QA_Flag": {"$eq": participant},
62
+ '$and': filter_dict
63
  },
64
  include_metadata=True,
65
  )
 
74
  "Quarter": {"$eq": quarter},
75
  "Ticker": {"$eq": ticker},
76
  "QA_Flag": {"$eq": participant},
77
+ '$and': filter_dict
78
  },
79
  include_metadata=True,
80
  )
 
103
  else:
104
  participant = "Question"
105
 
106
+ # Create filter dictionary based on keywords
107
+ filter_dict = [{'Keywords': word} for word in keywords]
108
+
109
+
110
  if year == "All":
111
  if quarter == "All":
112
  xc = index.query(
 
125
  "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
126
  "Ticker": {"$eq": ticker},
127
  "QA_Flag": {"$eq": participant},
128
+ '$and': filter_dict
129
  },
130
  include_metadata=True,
131
  )
 
146
  "Quarter": {"$eq": quarter},
147
  "Ticker": {"$eq": ticker},
148
  "QA_Flag": {"$eq": participant},
149
+ '$and': filter_dict
150
  },
151
  include_metadata=True,
152
  )
 
160
  "Quarter": {"$eq": quarter},
161
  "Ticker": {"$eq": ticker},
162
  "QA_Flag": {"$eq": participant},
163
+ '$and': filter_dict
164
  },
165
  include_metadata=True,
166
  )