awinml commited on
Commit
6a79fd2
1 Parent(s): fb3af34

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +7 -3
  2. utils.py +33 -26
app.py CHANGED
@@ -10,7 +10,8 @@ from utils import (
10
  clean_entities,
11
  create_dense_embeddings,
12
  create_sparse_embeddings,
13
- extract_entities,
 
14
  format_query,
15
  get_flan_alpaca_xl_model,
16
  generate_alpaca_ner_prompt,
@@ -70,9 +71,12 @@ with col1:
70
  if ner_choice == "Alpaca":
71
  ner_prompt = generate_alpaca_ner_prompt(query_text)
72
  entity_text = generate_entities_flan_alpaca_inference_api(ner_prompt)
73
- company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(entity_text)
 
 
74
  else:
75
- company_ent, quarter_ent, year_ent = extract_entities(query_text, ner_model)
 
76
 
77
  ticker_index, quarter_index, year_index = clean_entities(
78
  company_ent, quarter_ent, year_ent
 
10
  clean_entities,
11
  create_dense_embeddings,
12
  create_sparse_embeddings,
13
+ extract_quarter_year,
14
+ extract_ticker_spacy,
15
  format_query,
16
  get_flan_alpaca_xl_model,
17
  generate_alpaca_ner_prompt,
 
71
  if ner_choice == "Alpaca":
72
  ner_prompt = generate_alpaca_ner_prompt(query_text)
73
  entity_text = generate_entities_flan_alpaca_inference_api(ner_prompt)
74
+ company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(
75
+ entity_text
76
+ )
77
  else:
78
+ company_ent = extract_ticker_spacy(query_text, ner_model)
79
+ quarter_ent, year_ent = extract_quarter_year(query_text)
80
 
81
  ticker_index, quarter_index, year_index = clean_entities(
82
  company_ent, quarter_ent, year_ent
utils.py CHANGED
@@ -5,6 +5,7 @@ import requests
5
  import openai
6
  import pandas as pd
7
  import spacy
 
8
  import streamlit_scrollable_textbox as stx
9
  import torch
10
  from sentence_transformers import SentenceTransformer
@@ -33,13 +34,17 @@ def get_data():
33
 
34
  @st.experimental_singleton
35
  def get_spacy_model():
36
- return spacy.load("en_core_web_sm")
37
 
38
 
39
  @st.experimental_singleton
40
  def get_flan_alpaca_xl_model():
41
- model = AutoModelForSeq2SeqLM.from_pretrained("/home/user/app/models/flan-alpaca-xl/")
42
- tokenizer = AutoTokenizer.from_pretrained("/home/user/app/models/flan-alpaca-xl/")
 
 
 
 
43
  return model, tokenizer
44
 
45
 
@@ -478,6 +483,7 @@ Answer:?"""
478
 
479
  # Entity Extraction
480
 
 
481
  def generate_alpaca_ner_prompt(query):
482
  prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Use the following guidelines to extract the entities representing the Company, Quarter, and Year in the sentence.
483
 
@@ -515,19 +521,27 @@ Company - Cisco, Quarter - none, Year - 2016
515
  ### Response:"""
516
  return prompt
517
 
 
518
  def generate_entities_flan_alpaca_inference_api(prompt):
519
  API_URL = "https://api-inference.huggingface.co/models/declare-lab/flan-alpaca-xl"
520
  API_TOKEN = st.secrets["hg_key"]
521
  headers = {"Authorization": f"Bearer {API_TOKEN}"}
522
  payload = {
523
  "inputs": prompt,
524
- "parameters": {"do_sample": True, "temperature":0.1, "max_length":80},
525
- "options": {"use_cache": False, "wait_for_model": True}
 
 
 
 
526
  }
527
  try:
528
  data = json.dumps(payload)
 
529
  response = requests.request("POST", API_URL, data=data)
530
- output = json.loads(response.content.decode("utf-8"))[0]["generated_text"]
 
 
531
  except:
532
  output = ""
533
  print(output)
@@ -536,7 +550,7 @@ def generate_entities_flan_alpaca_inference_api(prompt):
536
 
537
  def generate_entities_flan_alpaca_checkpoint(model, tokenizer, prompt):
538
  model_inputs = tokenizer(prompt, return_tensors="pt")
539
- input_ids = inputs["input_ids"]
540
  generation_output = model.generate(
541
  input_ids=input_ids,
542
  temperature=0.1,
@@ -547,9 +561,9 @@ def generate_entities_flan_alpaca_checkpoint(model, tokenizer, prompt):
547
  return output
548
 
549
 
550
- def format_entities_flan_alpaca(model_output):
551
  """
552
- Extracts the text for each entity from the output generated by the
553
  Flan-Alpaca model.
554
  """
555
  try:
@@ -560,22 +574,22 @@ def format_entities_flan_alpaca(model_output):
560
  year = None
561
  try:
562
  company = company_string.split(" - ")[1].lower()
563
- company = None if company.lower() == 'none' else company
564
  except:
565
  company = None
566
  try:
567
  quarter = quarter_string.split(" - ")[1]
568
- quarter = None if quarter.lower() == 'none' else quarter
569
 
570
  except:
571
  quarter = None
572
  try:
573
  year = year_string.split(" - ")[1]
574
- year = None if year.lower() == 'none' else year
575
 
576
  except:
577
  year = None
578
-
579
  print((company, quarter, year))
580
  return company, quarter, year
581
 
@@ -586,34 +600,27 @@ def extract_quarter_year(string):
586
  if year_match:
587
  year = year_match.group()
588
  else:
589
- return None, None
590
 
591
  # Extract quarter from string
592
  quarter_match = re.search(r"Q\d", string)
593
  if quarter_match:
594
  quarter = "Q" + quarter_match.group()[1]
595
  else:
596
- return None, None
597
 
598
  return quarter, year
599
 
600
 
601
- def extract_entities(query, model):
602
  doc = model(query)
603
  entities = {ent.label_: ent.text for ent in doc.ents}
 
604
  if "ORG" in entities.keys():
605
  company = entities["ORG"].lower()
606
- if "DATE" in entities.keys():
607
- quarter, year = extract_quarter_year(entities["DATE"])
608
- return company, quarter, year
609
- else:
610
- return company, None, None
611
  else:
612
- if "DATE" in entities.keys():
613
- quarter, year = extract_quarter_year(entities["DATE"])
614
- return None, quarter, year
615
- else:
616
- return None, None, None
617
 
618
 
619
  def clean_entities(company, quarter, year):
 
5
  import openai
6
  import pandas as pd
7
  import spacy
8
+ import spacy_transformers
9
  import streamlit_scrollable_textbox as stx
10
  import torch
11
  from sentence_transformers import SentenceTransformer
 
34
 
35
  @st.experimental_singleton
36
  def get_spacy_model():
37
+ return spacy.load("en_core_web_trf")
38
 
39
 
40
  @st.experimental_singleton
41
  def get_flan_alpaca_xl_model():
42
+ model = AutoModelForSeq2SeqLM.from_pretrained(
43
+ "/home/user/app/models/flan-alpaca-xl/"
44
+ )
45
+ tokenizer = AutoTokenizer.from_pretrained(
46
+ "/home/user/app/models/flan-alpaca-xl/"
47
+ )
48
  return model, tokenizer
49
 
50
 
 
483
 
484
  # Entity Extraction
485
 
486
+
487
  def generate_alpaca_ner_prompt(query):
488
  prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Use the following guidelines to extract the entities representing the Company, Quarter, and Year in the sentence.
489
 
 
521
  ### Response:"""
522
  return prompt
523
 
524
+
525
  def generate_entities_flan_alpaca_inference_api(prompt):
526
  API_URL = "https://api-inference.huggingface.co/models/declare-lab/flan-alpaca-xl"
527
  API_TOKEN = st.secrets["hg_key"]
528
  headers = {"Authorization": f"Bearer {API_TOKEN}"}
529
  payload = {
530
  "inputs": prompt,
531
+ "parameters": {
532
+ "do_sample": True,
533
+ "temperature": 0.1,
534
+ "max_length": 80,
535
+ },
536
+ "options": {"use_cache": False, "wait_for_model": True},
537
  }
538
  try:
539
  data = json.dumps(payload)
540
+ # Key not used as headers=headers not passed
541
  response = requests.request("POST", API_URL, data=data)
542
+ output = json.loads(response.content.decode("utf-8"))[0][
543
+ "generated_text"
544
+ ]
545
  except:
546
  output = ""
547
  print(output)
 
550
 
551
  def generate_entities_flan_alpaca_checkpoint(model, tokenizer, prompt):
552
  model_inputs = tokenizer(prompt, return_tensors="pt")
553
+ input_ids = model_inputs["input_ids"]
554
  generation_output = model.generate(
555
  input_ids=input_ids,
556
  temperature=0.1,
 
561
  return output
562
 
563
 
564
+ def format_entities_flan_alpaca(values):
565
  """
566
+ Extracts the text for each entity from the output generated by the
567
  Flan-Alpaca model.
568
  """
569
  try:
 
574
  year = None
575
  try:
576
  company = company_string.split(" - ")[1].lower()
577
+ company = None if company.lower() == "none" else company
578
  except:
579
  company = None
580
  try:
581
  quarter = quarter_string.split(" - ")[1]
582
+ quarter = None if quarter.lower() == "none" else quarter
583
 
584
  except:
585
  quarter = None
586
  try:
587
  year = year_string.split(" - ")[1]
588
+ year = None if year.lower() == "none" else year
589
 
590
  except:
591
  year = None
592
+
593
  print((company, quarter, year))
594
  return company, quarter, year
595
 
 
600
  if year_match:
601
  year = year_match.group()
602
  else:
603
+ year = None
604
 
605
  # Extract quarter from string
606
  quarter_match = re.search(r"Q\d", string)
607
  if quarter_match:
608
  quarter = "Q" + quarter_match.group()[1]
609
  else:
610
+ quarter = None
611
 
612
  return quarter, year
613
 
614
 
615
+ def extract_ticker_spacy(query, model):
616
  doc = model(query)
617
  entities = {ent.label_: ent.text for ent in doc.ents}
618
+ print(entities.keys())
619
  if "ORG" in entities.keys():
620
  company = entities["ORG"].lower()
 
 
 
 
 
621
  else:
622
+ company = None
623
+ return company
 
 
 
624
 
625
 
626
  def clean_entities(company, quarter, year):