earnings-calls-qa / utils /entity_extraction.py
awinml's picture
Upload 18 files (#4)
aeaab1d
raw
history blame
8.97 kB
import re
from ast import literal_eval
from nltk.stem import PorterStemmer, WordNetLemmatizer
# Entity Extraction
def generate_ner_docs_prompt(query):
prompt = """USER: Extract the company names and time duration mentioned in the question. The entities should be extracted in the following format: {"companies": list of companies mentioned in the question,"start-duration": ("start-quarter", "start-year"), "end-duration": ("end-quarter", "end-year")}. Return {"companies": None, "start-duration": (None, None), "end-duration": (None, None)} if the entities are not found.
Examples:
What is Intel's update on the server chip roadmap and strategy for Q1 2019?
{"companies": ["Intel"], "start-duration": ("Q1", "2019"), "end-duration": ("Q1", "2019")}
What are the opportunities and challenges in the Indian market for Amazon in 2016?
{"companies": ["Amazon"], "start-duration": ("Q1", "2016"), "end-duration": ("Q4", "2016")}
What did analysts ask about the Cisco's Webex?
{"companies": ["Cisco"], "start-duration": (None, None), "end-duration": (None, None)}
What is the comparative performance analysis between Intel and AMD in key overlapping segments such as PC, Gaming, and Data Centers in Q2 to Q3 2018?
{"companies": ["Intel", "AMD"], "start-duration": ("Q2", "2018"), "end-duration": ("Q3", "2018")}
How did Microsoft and Amazon perform in terms of reliability and scalability of cloud for the years 2016 and 2017?
{"companies": ["Microsoft", "Amazon"], "start-duration": ("Q1", "2016"), "end-duration": ("Q4", "2017")}"""
input_prompt = f"""###Input: {query}
ASSISTANT:"""
final_prompt = prompt + "\n\n" + input_prompt
return final_prompt
def extract_entities_docs(query, model):
"""
Takes input a string which contains a dictionary of entities of the format:
{"companies": list of companies mentioned in the question,"start-duration": ("start-quarter", "start-year"), "end-duration": ("end-quarter", "end-year")}
"""
prompt = generate_ner_docs_prompt(query)
string_of_dict = model.predict(prompt, api_name="/predict")
entities_dict = literal_eval(string_of_dict)
start_quarter, start_year = entities_dict["start-duration"]
end_quarter, end_year = entities_dict["end-duration"]
companies = entities_dict["companies"]
print((companies, start_quarter, start_year, end_quarter, end_year))
return companies, start_quarter, start_year, end_quarter, end_year
def year_quarter_range(start_quarter, start_year, end_quarter, end_year):
"""
Creates a list of all (year, quarter) pairs that lie in the range including the start and end quarters.
Example:
year_quarter_range("Q2", "2020", "Q3", "2021")
[('Q2', '2020'), ('Q3', '2020'), ('Q4', '2020'), ('Q1', '2021'), ('Q2', '2021'), ('Q3', '2021')]
"""
if (
start_quarter is None
or start_year is None
or end_quarter is None
or end_year is None
):
return []
else:
quarters = ["Q1", "Q2", "Q3", "Q4"]
start_index = quarters.index(start_quarter)
end_index = quarters.index(end_quarter)
years = range(int(start_year), int(end_year) + 1)
year_quarter_range_list = []
for year in years:
if year == int(start_year):
start = start_index
else:
start = 0
if year == int(end_year):
end = end_index + 1
else:
end = len(quarters)
for quarter_index in range(start, end):
year_quarter_range_list.append(
(quarters[quarter_index], str(year))
)
return year_quarter_range_list
def clean_companies(company_list):
"""Returns list of Tickers from list of companies"""
company_ticker_map = {
"apple": "AAPL",
"amd": "AMD",
"amazon": "AMZN",
"cisco": "CSCO",
"google": "GOOGL",
"microsoft": "MSFT",
"nvidia": "NVDA",
"asml": "ASML",
"intel": "INTC",
"micron": "MU",
}
tickers = [
"AAPL",
"CSCO",
"MSFT",
"ASML",
"NVDA",
"GOOGL",
"MU",
"INTC",
"AMZN",
"AMD",
"aapl",
"csco",
"msft",
"asml",
"nvda",
"googl",
"mu",
"intc",
"amzn",
"amd",
]
ticker_list = []
for company in company_list:
if company.lower() in company_ticker_map.keys():
ticker = company_ticker_map[company.lower()]
ticker_list.append(ticker)
elif company.lower() in tickers:
ticker_list.append(company.upper())
return ticker_list
def ticker_year_quarter_tuples_creator(ticker_list, year_quarter_range_list):
ticker_year_quarter_tuples_list = []
for ticker in ticker_list:
if year_quarter_range_list == []:
return []
else:
for quarter, year in year_quarter_range_list:
ticker_year_quarter_tuples_list.append((ticker, quarter, year))
return ticker_year_quarter_tuples_list
# Keyword Extraction
def generate_ner_keywords_prompt(query):
prompt = """USER: Extract the entities which describe the key theme and topics being asked in the question. Extract the entities in the following format: {"entities":["keywords"]}.
Examples:
What is Intel's update on the server chip roadmap and strategy for Q1 2019?
{"entities":["server"]}
What are the opportunities and challenges in the Indian market for Amazon from Q1 to Q3 in 2016?
{"entities":["indian"]}
What is the comparative performance analysis between Intel and AMD in key overlapping segments such as PC, Gaming, and Data Centers in Q1 2016?
{"entities":["PC","Gaming","Data Centers"]}
What was Google's and Microsoft's capex spend for the last 2 years?
{"entities":["capex"]}
What did analysts ask about the cloud during Microsoft's earnings call in Q1 2018?
{"entities":["cloud"]}
What was the growth in Apple services revenue for 2017 Q3?
{"entities":["services"]}"""
input_prompt = f"""###Input: {query}
ASSISTANT:"""
final_prompt = prompt + "\n" + input_prompt
return final_prompt
def extract_entities_keywords(query, model):
"""
Takes input a string which contains a dictionary of entities of the format:
{"entities":["keywords"]}
"""
prompt = generate_ner_keywords_prompt(query)
string_of_dict = model.predict(prompt, api_name="/predict")
entities_dict = literal_eval(string_of_dict)
keywords_list = entities_dict["entities"]
return keywords_list
def expand_list_of_lists(list_of_lists):
"""
Expands a list of lists of strings to a list of strings.
Args:
list_of_lists: A list of lists of strings.
Returns:
A list of strings.
"""
expanded_list = []
for inner_list in list_of_lists:
for string in inner_list:
expanded_list.append(string)
return expanded_list
def all_keywords_combs(list_of_cleaned_keywords):
# Convert all strings to lowercase.
lower_texts = [text.lower() for text in list_of_cleaned_keywords]
# Stem the words in each string.
stemmer = PorterStemmer()
stem_texts = [stemmer.stem(text) for text in list_of_cleaned_keywords]
# Lemmatize the words in each string.
lemmatizer = WordNetLemmatizer()
lemm_texts = [
lemmatizer.lemmatize(text) for text in list_of_cleaned_keywords
]
list_of_cleaned_keywords.extend(lower_texts)
list_of_cleaned_keywords.extend(stem_texts)
list_of_cleaned_keywords.extend(lemm_texts)
list_of_cleaned_keywords = list(set(list_of_cleaned_keywords))
return list_of_cleaned_keywords
def create_incorrect_entities_list():
words_to_remove = [
"q1",
"q2",
"q3",
"q4",
"2016",
"2017",
"2018",
"2019",
"2020",
"apple",
"amd",
"amazon",
"cisco",
"google",
"microsoft",
"nvidia",
"asml",
"intel",
"micron",
"strategy",
"roadmap",
"impact",
"opportunities",
"challenges",
"growth",
"performance",
"analysis",
"segments",
"comparative",
"overlapping",
"acquisition",
"revenue",
]
words_to_remove = all_keywords_combs(words_to_remove)
return words_to_remove
def clean_keywords_all_combs(keywords_list):
words_to_remove = create_incorrect_entities_list()
texts = [text.split(" ") for text in keywords_list]
texts = expand_list_of_lists(texts)
# Convert all strings to lowercase.
lower_texts = [text.lower() for text in texts]
cleaned_keywords = [
text for text in lower_texts if text not in words_to_remove
]
all_cleaned_keywords = all_keywords_combs(cleaned_keywords)
return all_cleaned_keywords