rdose commited on
Commit
b3e926c
1 Parent(s): 03d1953

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -16
app.py CHANGED
@@ -12,8 +12,10 @@ import os
12
  from transformers import pipeline
13
  import itertools
14
  import pandas as pd
 
15
 
16
  OUT_HEADERS = ['E','S','G']
 
17
 
18
  MODEL_TRANSFORMER_BASED = "distilbert-base-uncased"
19
  MODEL_ONNX_FNAME = "ESG_classifier_batch.onnx"
@@ -24,24 +26,59 @@ MODEL_SENTIMENT_ANALYSIS = "ProsusAI/finbert"
24
 
25
  #API_HF_SENTIMENT_URL = "https://api-inference.huggingface.co/models/cardiffnlp/twitter-roberta-base-sentiment"
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def _inference_ner_spancat(text, summary, penalty=0.5, normalise=True, limit_outputs=10):
28
  nlp = spacy.load("en_pipeline")
29
- doc = nlp(text)
30
- spans = doc.spans["sc"]
31
- comp_raw_text = dict( sorted( dict(zip([str(x) for x in spans],[float(x)*penalty for x in spans.attrs['scores']])).items(), key=lambda x: x[1], reverse=True) )
32
- doc = nlp(summary)
33
- spans = doc.spans["sc"]
34
- exceeds_one = 0.0
35
- for comp_s in spans:
36
- if str(comp_s) in comp_raw_text.keys():
37
- comp_raw_text[str(comp_s)] = comp_raw_text[str(comp_s)] / penalty
38
- temp_max = comp_raw_text[str(comp_s)]if comp_raw_text[str(comp_s)] > 1.0 else 0.0
39
- exceeds_one = comp_raw_text[str(comp_s)] if temp_max > exceeds_one else exceeds_one
40
- #This "exceeds_one" is a bit confusing. So the thing is that the penalty is reverted for each time the company appears in the summary and hence the value can exceed one when the company appears more than once. The normalisation means that all the other scores are divided by the maximum when any value exceeds one
41
- if normalise and (exceeds_one > 1):
42
- comp_raw_text = {k: v/exceeds_one for k, v in comp_raw_text.items()}
43
 
44
- return dict(itertools.islice(sorted(comp_raw_text.items(), key=lambda x: x[1], reverse=True), limit_outputs))
45
 
46
  #def _inference_summary_model_pipeline(text):
47
  # pipe = pipeline("text2text-generation", model=MODEL_SUMMARY_PEGASUS)
@@ -162,8 +199,10 @@ def inference(input_batch,isurl,use_archive,limit_companies=10):
162
  print("[i] Running sentiment using",MODEL_SENTIMENT_ANALYSIS ,"inference...")
163
  #sentiment = _inference_sentiment_model_via_api_query({"inputs": extracted['content']})
164
  sentiment = _inference_sentiment_model_pipeline(input_batch_content )
 
165
  #summary = _inference_summary_model_pipeline(input_batch_content )[0]['generated_text']
166
- #ner_labels = _inference_ner_spancat(input_batch_content ,summary, penalty = 0.8, limit_outputs=limit_companies)
 
167
  df = pd.DataFrame(prob_outs,columns =['E','S','G'])
168
  if isurl:
169
  df['URL'] = url_list
 
12
  from transformers import pipeline
13
  import itertools
14
  import pandas as pd
15
+ import thefuzz
16
 
17
  OUT_HEADERS = ['E','S','G']
18
+ DF_SP500 = pd.read_csv('SP500_constituents.zip',compression=dict(method='zip'))
19
 
20
  MODEL_TRANSFORMER_BASED = "distilbert-base-uncased"
21
  MODEL_ONNX_FNAME = "ESG_classifier_batch.onnx"
 
26
 
27
  #API_HF_SENTIMENT_URL = "https://api-inference.huggingface.co/models/cardiffnlp/twitter-roberta-base-sentiment"
28
 
29
+ def get_company_sectors(extracted_names, threshold=0.95):
30
+ '''
31
+ '''
32
+ output = []
33
+ standard_names_tuples = []
34
+ for extracted_name in extracted_names:
35
+ name_match = thefuzz.process.extractOne(extracted_name,
36
+ DF_SP500.Name,
37
+ scorer=thefuzz.fuzz.token_set_ratio)
38
+ similarity = name_match[1]/100
39
+ if similarity >= threshold:
40
+ standard_names_tuples.append(name_match[:2])
41
+
42
+ for std_comp_name, _ in standard_names_tuples:
43
+ sectors = list(DF_SP500[['Name','Sector']].where(DF_SP500.Name == std_comp_name).dropna().itertuples(index=False, name=None))
44
+ output += sectors
45
+ return output
46
+
47
+ def filter_spans(spans, keep_longest=True):
48
+ """Filter a sequence of spans and remove duplicates or overlaps. Useful for
49
+ creating named entities (where one token can only be part of one entity) or
50
+ when merging spans with `Retokenizer.merge`. When spans overlap, the (first)
51
+ longest span is preferred over shorter spans.
52
+ spans (Iterable[Span]): The spans to filter.
53
+ keep_longest (bool): Specify whether to keep longer or shorter spans.
54
+ RETURNS (List[Span]): The filtered spans.
55
+ """
56
+ get_sort_key = lambda span: (span.end - span.start, -span.start)
57
+ sorted_spans = sorted(spans, key=get_sort_key, reverse=keep_longest)
58
+ #print(f'sorted_spans: {sorted_spans}')
59
+ result = []
60
+ seen_tokens = set()
61
+ for span in sorted_spans:
62
+ # Check for end - 1 here because boundaries are inclusive
63
+ if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
64
+ result.append(span)
65
+ seen_tokens.update(range(span.start, span.end))
66
+ result = sorted(result, key=lambda span: span.start)
67
+ return result
68
+
69
+
70
+
71
  def _inference_ner_spancat(text, summary, penalty=0.5, normalise=True, limit_outputs=10):
72
  nlp = spacy.load("en_pipeline")
73
+ out = []
74
+ for doc in nlp.pipe(text):
75
+ spans = doc.spans["sc"]
76
+ #comp_raw_text = dict( sorted( dict(zip([str(x) for x in spans],[float(x)*penalty for x in spans.attrs['scores']])).items(), key=lambda x: x[1], reverse=True) )
77
+
78
+ company_list = list(set([str(span).replace('\'s', '') for span in filter_spans(spans, keep_longest=True)]))[:limit_outputs]
79
+ out.append(get_company_sectors(company_list))
 
 
 
 
 
 
 
80
 
81
+ return out
82
 
83
  #def _inference_summary_model_pipeline(text):
84
  # pipe = pipeline("text2text-generation", model=MODEL_SUMMARY_PEGASUS)
 
199
  print("[i] Running sentiment using",MODEL_SENTIMENT_ANALYSIS ,"inference...")
200
  #sentiment = _inference_sentiment_model_via_api_query({"inputs": extracted['content']})
201
  sentiment = _inference_sentiment_model_pipeline(input_batch_content )
202
+ print("[i] Running NER using custom spancat inference...")
203
  #summary = _inference_summary_model_pipeline(input_batch_content )[0]['generated_text']
204
+ ner_labels = _inference_ner_spancat(input_batch_content ,limit_outputs=limit_companies)
205
+ print(ner_labels)
206
  df = pd.DataFrame(prob_outs,columns =['E','S','G'])
207
  if isurl:
208
  df['URL'] = url_list