nickmuchi commited on
Commit
1487c65
1 Parent(s): 7dbe2d3

Update functions.py

Browse files
Files changed (1) hide show
  1. functions.py +3 -314
functions.py CHANGED
@@ -9,7 +9,7 @@ import plotly_express as px
9
  import nltk
10
  import plotly.graph_objects as go
11
  from optimum.onnxruntime import ORTModelForSequenceClassification
12
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForSeq2SeqLM
13
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
14
  import streamlit as st
15
  import en_core_web_lg
@@ -73,18 +73,15 @@ def load_models():
73
  '''Load and cache all the models to be used'''
74
  q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
75
  ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
76
- kg_model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
77
- kg_tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
78
  q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
79
  ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
80
- emb_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xl')
81
  sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
82
  sum_pipe = pipeline("summarization",model="philschmid/flan-t5-base-samsum",clean_up_tokenization_spaces=True)
83
  ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True)
84
  cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1') #cross-encoder/ms-marco-MiniLM-L-12-v2
85
  sbert = SentenceTransformer('all-MiniLM-L6-v2')
86
 
87
- return sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert
88
 
89
  @st.cache_resource
90
  def get_spacy():
@@ -93,7 +90,7 @@ def get_spacy():
93
 
94
  nlp = get_spacy()
95
 
96
- sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert = load_models()
97
 
98
  @st.cache_data
99
  def get_yt_audio(url):
@@ -696,317 +693,9 @@ def fin_ext(text):
696
 
697
  ## Knowledge Graphs code
698
 
699
- @st.cache_data
700
- def extract_relations_from_model_output(text):
701
- relations = []
702
- relation, subject, relation, object_ = '', '', '', ''
703
- text = text.strip()
704
- current = 'x'
705
- text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
706
- for token in text_replaced.split():
707
- if token == "<triplet>":
708
- current = 't'
709
- if relation != '':
710
- relations.append({
711
- 'head': subject.strip(),
712
- 'type': relation.strip(),
713
- 'tail': object_.strip()
714
- })
715
- relation = ''
716
- subject = ''
717
- elif token == "<subj>":
718
- current = 's'
719
- if relation != '':
720
- relations.append({
721
- 'head': subject.strip(),
722
- 'type': relation.strip(),
723
- 'tail': object_.strip()
724
- })
725
- object_ = ''
726
- elif token == "<obj>":
727
- current = 'o'
728
- relation = ''
729
- else:
730
- if current == 't':
731
- subject += ' ' + token
732
- elif current == 's':
733
- object_ += ' ' + token
734
- elif current == 'o':
735
- relation += ' ' + token
736
- if subject != '' and relation != '' and object_ != '':
737
- relations.append({
738
- 'head': subject.strip(),
739
- 'type': relation.strip(),
740
- 'tail': object_.strip()
741
- })
742
- return relations
743
-
744
- def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None,
745
- article_publish_date=None, verbose=False):
746
- # tokenize whole text
747
- inputs = tokenizer([text], return_tensors="pt")
748
-
749
- # compute span boundaries
750
- num_tokens = len(inputs["input_ids"][0])
751
- if verbose:
752
- print(f"Input has {num_tokens} tokens")
753
- num_spans = math.ceil(num_tokens / span_length)
754
- if verbose:
755
- print(f"Input has {num_spans} spans")
756
- overlap = math.ceil((num_spans * span_length - num_tokens) /
757
- max(num_spans - 1, 1))
758
- spans_boundaries = []
759
- start = 0
760
- for i in range(num_spans):
761
- spans_boundaries.append([start + span_length * i,
762
- start + span_length * (i + 1)])
763
- start -= overlap
764
- if verbose:
765
- print(f"Span boundaries are {spans_boundaries}")
766
-
767
- # transform input with spans
768
- tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
769
- for boundary in spans_boundaries]
770
- tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
771
- for boundary in spans_boundaries]
772
- inputs = {
773
- "input_ids": torch.stack(tensor_ids),
774
- "attention_mask": torch.stack(tensor_masks)
775
- }
776
-
777
- # generate relations
778
- num_return_sequences = 3
779
- gen_kwargs = {
780
- "max_length": 256,
781
- "length_penalty": 0,
782
- "num_beams": 3,
783
- "num_return_sequences": num_return_sequences
784
- }
785
- generated_tokens = model.generate(
786
- **inputs,
787
- **gen_kwargs,
788
- )
789
-
790
- # decode relations
791
- decoded_preds = tokenizer.batch_decode(generated_tokens,
792
- skip_special_tokens=False)
793
-
794
- # create kb
795
- kb = KB()
796
- i = 0
797
- for sentence_pred in decoded_preds:
798
- current_span_index = i // num_return_sequences
799
- relations = extract_relations_from_model_output(sentence_pred)
800
- for relation in relations:
801
- relation["meta"] = {
802
- article_url: {
803
- "spans": [spans_boundaries[current_span_index]]
804
- }
805
- }
806
- kb.add_relation(relation, article_title, article_publish_date)
807
- i += 1
808
-
809
- return kb
810
-
811
  def get_article(url):
812
  article = Article(url)
813
  article.download()
814
  article.parse()
815
  return article
816
 
817
- def from_url_to_kb(url, model, tokenizer):
818
- article = get_article(url)
819
- config = {
820
- "article_title": article.title,
821
- "article_publish_date": article.publish_date
822
- }
823
- kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config)
824
- return kb
825
-
826
- def get_news_links(query, lang="en", region="US", pages=1):
827
- googlenews = GoogleNews(lang=lang, region=region)
828
- googlenews.search(query)
829
- all_urls = []
830
- for page in range(pages):
831
- googlenews.get_page(page)
832
- all_urls += googlenews.get_links()
833
- return list(set(all_urls))
834
-
835
- def from_urls_to_kb(urls, model, tokenizer, verbose=False):
836
- kb = KB()
837
- if verbose:
838
- print(f"{len(urls)} links to visit")
839
- for url in urls:
840
- if verbose:
841
- print(f"Visiting {url}...")
842
- try:
843
- kb_url = from_url_to_kb(url, model, tokenizer)
844
- kb.merge_with_kb(kb_url)
845
- except ArticleException:
846
- if verbose:
847
- print(f" Couldn't download article at url {url}")
848
- return kb
849
-
850
- def save_network_html(kb, filename="network.html"):
851
- # create network
852
- net = Network(directed=True, width="700px", height="700px")
853
-
854
- # nodes
855
- color_entity = "#00FF00"
856
- for e in kb.entities:
857
- net.add_node(e, shape="circle", color=color_entity)
858
-
859
- # edges
860
- for r in kb.relations:
861
- net.add_edge(r["head"], r["tail"],
862
- title=r["type"], label=r["type"])
863
-
864
- # save network
865
- net.repulsion(
866
- node_distance=200,
867
- central_gravity=0.2,
868
- spring_length=200,
869
- spring_strength=0.05,
870
- damping=0.09
871
- )
872
- net.set_edge_smooth('dynamic')
873
- net.show(filename)
874
-
875
- def save_kb(kb, filename):
876
- with open(filename, "wb") as f:
877
- pickle.dump(kb, f)
878
-
879
- class CustomUnpickler(pickle.Unpickler):
880
- def find_class(self, module, name):
881
- if name == 'KB':
882
- return KB
883
- return super().find_class(module, name)
884
-
885
- def load_kb(filename):
886
- res = None
887
- with open(filename, "rb") as f:
888
- res = CustomUnpickler(f).load()
889
- return res
890
-
891
- class KB():
892
- def __init__(self):
893
- self.entities = {} # { entity_title: {...} }
894
- self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
895
- # meta: { article_url: { spans: [...] } } ]
896
- self.sources = {} # { article_url: {...} }
897
-
898
- def merge_with_kb(self, kb2):
899
- for r in kb2.relations:
900
- article_url = list(r["meta"].keys())[0]
901
- source_data = kb2.sources[article_url]
902
- self.add_relation(r, source_data["article_title"],
903
- source_data["article_publish_date"])
904
-
905
- def are_relations_equal(self, r1, r2):
906
- return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
907
-
908
- def exists_relation(self, r1):
909
- return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
910
-
911
- def merge_relations(self, r2):
912
- r1 = [r for r in self.relations
913
- if self.are_relations_equal(r2, r)][0]
914
-
915
- # if different article
916
- article_url = list(r2["meta"].keys())[0]
917
- if article_url not in r1["meta"]:
918
- r1["meta"][article_url] = r2["meta"][article_url]
919
-
920
- # if existing article
921
- else:
922
- spans_to_add = [span for span in r2["meta"][article_url]["spans"]
923
- if span not in r1["meta"][article_url]["spans"]]
924
- r1["meta"][article_url]["spans"] += spans_to_add
925
-
926
- def get_wikipedia_data(self, candidate_entity):
927
- try:
928
- page = wikipedia.page(candidate_entity, auto_suggest=False)
929
- entity_data = {
930
- "title": page.title,
931
- "url": page.url,
932
- "summary": page.summary
933
- }
934
- return entity_data
935
- except:
936
- return None
937
-
938
- def add_entity(self, e):
939
- self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}
940
-
941
- def add_relation(self, r, article_title, article_publish_date):
942
- # check on wikipedia
943
- candidate_entities = [r["head"], r["tail"]]
944
- entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]
945
-
946
- # if one entity does not exist, stop
947
- if any(ent is None for ent in entities):
948
- return
949
-
950
- # manage new entities
951
- for e in entities:
952
- self.add_entity(e)
953
-
954
- # rename relation entities with their wikipedia titles
955
- r["head"] = entities[0]["title"]
956
- r["tail"] = entities[1]["title"]
957
-
958
- # add source if not in kb
959
- article_url = list(r["meta"].keys())[0]
960
- if article_url not in self.sources:
961
- self.sources[article_url] = {
962
- "article_title": article_title,
963
- "article_publish_date": article_publish_date
964
- }
965
-
966
- # manage new relation
967
- if not self.exists_relation(r):
968
- self.relations.append(r)
969
- else:
970
- self.merge_relations(r)
971
-
972
- def get_textual_representation(self):
973
- res = ""
974
- res += "### Entities\n"
975
- for e in self.entities.items():
976
- # shorten summary
977
- e_temp = (e[0], {k:(v[:100] + "..." if k == "summary" else v) for k,v in e[1].items()})
978
- res += f"- {e_temp}\n"
979
- res += "\n"
980
- res += "### Relations\n"
981
- for r in self.relations:
982
- res += f"- {r}\n"
983
- res += "\n"
984
- res += "### Sources\n"
985
- for s in self.sources.items():
986
- res += f"- {s}\n"
987
- return res
988
-
989
- def save_network_html(kb, filename="network.html"):
990
- # create network
991
- net = Network(directed=True, width="700px", height="700px", bgcolor="#eeeeee")
992
-
993
- # nodes
994
- color_entity = "#00FF00"
995
- for e in kb.entities:
996
- net.add_node(e, shape="circle", color=color_entity)
997
-
998
- # edges
999
- for r in kb.relations:
1000
- net.add_edge(r["head"], r["tail"],
1001
- title=r["type"], label=r["type"])
1002
-
1003
- # save network
1004
- net.repulsion(
1005
- node_distance=200,
1006
- central_gravity=0.2,
1007
- spring_length=200,
1008
- spring_strength=0.05,
1009
- damping=0.09
1010
- )
1011
- net.set_edge_smooth('dynamic')
1012
- net.show(filename)
 
9
  import nltk
10
  import plotly.graph_objects as go
11
  from optimum.onnxruntime import ORTModelForSequenceClassification
12
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
13
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
14
  import streamlit as st
15
  import en_core_web_lg
 
73
  '''Load and cache all the models to be used'''
74
  q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
75
  ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
 
 
76
  q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
77
  ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
 
78
  sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
79
  sum_pipe = pipeline("summarization",model="philschmid/flan-t5-base-samsum",clean_up_tokenization_spaces=True)
80
  ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True)
81
  cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1') #cross-encoder/ms-marco-MiniLM-L-12-v2
82
  sbert = SentenceTransformer('all-MiniLM-L6-v2')
83
 
84
+ return sent_pipe, sum_pipe, ner_pipe, cross_encoder, sbert
85
 
86
  @st.cache_resource
87
  def get_spacy():
 
90
 
91
  nlp = get_spacy()
92
 
93
+ sent_pipe, sum_pipe, ner_pipe, cross_encoder, sbert = load_models()
94
 
95
  @st.cache_data
96
  def get_yt_audio(url):
 
693
 
694
  ## Knowledge Graphs code
695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
  def get_article(url):
697
  article = Article(url)
698
  article.download()
699
  article.parse()
700
  return article
701