nickmuchi commited on
Commit
e512522
1 Parent(s): b1bb7ce

Update functions.py

Browse files

Added Knowledge Graph tab

Files changed (1) hide show
  1. functions.py +323 -4
functions.py CHANGED
@@ -6,7 +6,7 @@ import plotly_express as px
6
  import nltk
7
  import plotly.graph_objects as go
8
  from optimum.onnxruntime import ORTModelForSequenceClassification
9
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
10
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
11
  import streamlit as st
12
  import en_core_web_lg
@@ -31,6 +31,8 @@ margin-bottom: 2.5rem">{}</div> """
31
  def load_models():
32
  q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
33
  ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
 
 
34
  q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
35
  ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
36
  sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
@@ -38,7 +40,7 @@ def load_models():
38
  ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True)
39
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
40
 
41
- return sent_pipe, sum_pipe, ner_pipe, cross_encoder
42
 
43
  @st.experimental_singleton(suppress_st_warning=True)
44
  def load_asr_model(asr_model_name):
@@ -358,7 +360,324 @@ def make_spans(text,results):
358
  def fin_ext(text):
359
  results = remote_clx(sent_tokenizer(text))
360
  return make_spans(text,results)
361
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  nlp = get_spacy()
363
- sent_pipe, sum_pipe, ner_pipe, cross_encoder = load_models()
364
  sbert = load_sbert('all-MiniLM-L12-v2')
 
6
  import nltk
7
  import plotly.graph_objects as go
8
  from optimum.onnxruntime import ORTModelForSequenceClassification
9
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForSeq2SeqLM
10
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
11
  import streamlit as st
12
  import en_core_web_lg
 
31
  def load_models():
32
  q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
33
  ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
34
+ kg_model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
35
+ kg_tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
36
  q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
37
  ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
38
  sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
 
40
  ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True)
41
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
42
 
43
+ return sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer
44
 
45
  @st.experimental_singleton(suppress_st_warning=True)
46
  def load_asr_model(asr_model_name):
 
360
  def fin_ext(text):
361
  results = remote_clx(sent_tokenizer(text))
362
  return make_spans(text,results)
363
+
364
+ ## Knowledge Graphs code
365
+
366
+ def extract_relations_from_model_output(text):
367
+ relations = []
368
+ relation, subject, relation, object_ = '', '', '', ''
369
+ text = text.strip()
370
+ current = 'x'
371
+ text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
372
+ for token in text_replaced.split():
373
+ if token == "<triplet>":
374
+ current = 't'
375
+ if relation != '':
376
+ relations.append({
377
+ 'head': subject.strip(),
378
+ 'type': relation.strip(),
379
+ 'tail': object_.strip()
380
+ })
381
+ relation = ''
382
+ subject = ''
383
+ elif token == "<subj>":
384
+ current = 's'
385
+ if relation != '':
386
+ relations.append({
387
+ 'head': subject.strip(),
388
+ 'type': relation.strip(),
389
+ 'tail': object_.strip()
390
+ })
391
+ object_ = ''
392
+ elif token == "<obj>":
393
+ current = 'o'
394
+ relation = ''
395
+ else:
396
+ if current == 't':
397
+ subject += ' ' + token
398
+ elif current == 's':
399
+ object_ += ' ' + token
400
+ elif current == 'o':
401
+ relation += ' ' + token
402
+ if subject != '' and relation != '' and object_ != '':
403
+ relations.append({
404
+ 'head': subject.strip(),
405
+ 'type': relation.strip(),
406
+ 'tail': object_.strip()
407
+ })
408
+ return relations
409
+
410
+ def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None,
411
+ article_publish_date=None, verbose=False):
412
+ # tokenize whole text
413
+ inputs = tokenizer([text], return_tensors="pt")
414
+
415
+ # compute span boundaries
416
+ num_tokens = len(inputs["input_ids"][0])
417
+ if verbose:
418
+ print(f"Input has {num_tokens} tokens")
419
+ num_spans = math.ceil(num_tokens / span_length)
420
+ if verbose:
421
+ print(f"Input has {num_spans} spans")
422
+ overlap = math.ceil((num_spans * span_length - num_tokens) /
423
+ max(num_spans - 1, 1))
424
+ spans_boundaries = []
425
+ start = 0
426
+ for i in range(num_spans):
427
+ spans_boundaries.append([start + span_length * i,
428
+ start + span_length * (i + 1)])
429
+ start -= overlap
430
+ if verbose:
431
+ print(f"Span boundaries are {spans_boundaries}")
432
+
433
+ # transform input with spans
434
+ tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
435
+ for boundary in spans_boundaries]
436
+ tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
437
+ for boundary in spans_boundaries]
438
+ inputs = {
439
+ "input_ids": torch.stack(tensor_ids),
440
+ "attention_mask": torch.stack(tensor_masks)
441
+ }
442
+
443
+ # generate relations
444
+ num_return_sequences = 3
445
+ gen_kwargs = {
446
+ "max_length": 256,
447
+ "length_penalty": 0,
448
+ "num_beams": 3,
449
+ "num_return_sequences": num_return_sequences
450
+ }
451
+ generated_tokens = model.generate(
452
+ **inputs,
453
+ **gen_kwargs,
454
+ )
455
+
456
+ # decode relations
457
+ decoded_preds = tokenizer.batch_decode(generated_tokens,
458
+ skip_special_tokens=False)
459
+
460
+ # create kb
461
+ kb = KB()
462
+ i = 0
463
+ for sentence_pred in decoded_preds:
464
+ current_span_index = i // num_return_sequences
465
+ relations = extract_relations_from_model_output(sentence_pred)
466
+ for relation in relations:
467
+ relation["meta"] = {
468
+ article_url: {
469
+ "spans": [spans_boundaries[current_span_index]]
470
+ }
471
+ }
472
+ kb.add_relation(relation, article_title, article_publish_date)
473
+ i += 1
474
+
475
+ return kb
476
+
477
+ def get_article(url):
478
+ article = Article(url)
479
+ article.download()
480
+ article.parse()
481
+ return article
482
+
483
+ def from_url_to_kb(url, model, tokenizer):
484
+ article = get_article(url)
485
+ config = {
486
+ "article_title": article.title,
487
+ "article_publish_date": article.publish_date
488
+ }
489
+ kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config)
490
+ return kb
491
+
492
+ def get_news_links(query, lang="en", region="US", pages=1):
493
+ googlenews = GoogleNews(lang=lang, region=region)
494
+ googlenews.search(query)
495
+ all_urls = []
496
+ for page in range(pages):
497
+ googlenews.get_page(page)
498
+ all_urls += googlenews.get_links()
499
+ return list(set(all_urls))
500
+
501
+ def from_urls_to_kb(urls, model, tokenizer, verbose=False):
502
+ kb = KB()
503
+ if verbose:
504
+ print(f"{len(urls)} links to visit")
505
+ for url in urls:
506
+ if verbose:
507
+ print(f"Visiting {url}...")
508
+ try:
509
+ kb_url = from_url_to_kb(url, model, tokenizer)
510
+ kb.merge_with_kb(kb_url)
511
+ except ArticleException:
512
+ if verbose:
513
+ print(f" Couldn't download article at url {url}")
514
+ return kb
515
+
516
+ def save_network_html(kb, filename="network.html"):
517
+ # create network
518
+ net = Network(directed=True, width="700px", height="700px")
519
+
520
+ # nodes
521
+ color_entity = "#00FF00"
522
+ for e in kb.entities:
523
+ net.add_node(e, shape="circle", color=color_entity)
524
+
525
+ # edges
526
+ for r in kb.relations:
527
+ net.add_edge(r["head"], r["tail"],
528
+ title=r["type"], label=r["type"])
529
+
530
+ # save network
531
+ net.repulsion(
532
+ node_distance=200,
533
+ central_gravity=0.2,
534
+ spring_length=200,
535
+ spring_strength=0.05,
536
+ damping=0.09
537
+ )
538
+ net.set_edge_smooth('dynamic')
539
+ net.show(filename)
540
+
541
+ def save_kb(kb, filename):
542
+ with open(filename, "wb") as f:
543
+ pickle.dump(kb, f)
544
+
545
+ class CustomUnpickler(pickle.Unpickler):
546
+ def find_class(self, module, name):
547
+ if name == 'KB':
548
+ return KB
549
+ return super().find_class(module, name)
550
+
551
+ def load_kb(filename):
552
+ res = None
553
+ with open(filename, "rb") as f:
554
+ res = CustomUnpickler(f).load()
555
+ return res
556
+
557
+ class KB():
558
+ def __init__(self):
559
+ self.entities = {} # { entity_title: {...} }
560
+ self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
561
+ # meta: { article_url: { spans: [...] } } ]
562
+ self.sources = {} # { article_url: {...} }
563
+
564
+ def merge_with_kb(self, kb2):
565
+ for r in kb2.relations:
566
+ article_url = list(r["meta"].keys())[0]
567
+ source_data = kb2.sources[article_url]
568
+ self.add_relation(r, source_data["article_title"],
569
+ source_data["article_publish_date"])
570
+
571
+ def are_relations_equal(self, r1, r2):
572
+ return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
573
+
574
+ def exists_relation(self, r1):
575
+ return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
576
+
577
+ def merge_relations(self, r2):
578
+ r1 = [r for r in self.relations
579
+ if self.are_relations_equal(r2, r)][0]
580
+
581
+ # if different article
582
+ article_url = list(r2["meta"].keys())[0]
583
+ if article_url not in r1["meta"]:
584
+ r1["meta"][article_url] = r2["meta"][article_url]
585
+
586
+ # if existing article
587
+ else:
588
+ spans_to_add = [span for span in r2["meta"][article_url]["spans"]
589
+ if span not in r1["meta"][article_url]["spans"]]
590
+ r1["meta"][article_url]["spans"] += spans_to_add
591
+
592
+ def get_wikipedia_data(self, candidate_entity):
593
+ try:
594
+ page = wikipedia.page(candidate_entity, auto_suggest=False)
595
+ entity_data = {
596
+ "title": page.title,
597
+ "url": page.url,
598
+ "summary": page.summary
599
+ }
600
+ return entity_data
601
+ except:
602
+ return None
603
+
604
+ def add_entity(self, e):
605
+ self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}
606
+
607
+ def add_relation(self, r, article_title, article_publish_date):
608
+ # check on wikipedia
609
+ candidate_entities = [r["head"], r["tail"]]
610
+ entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]
611
+
612
+ # if one entity does not exist, stop
613
+ if any(ent is None for ent in entities):
614
+ return
615
+
616
+ # manage new entities
617
+ for e in entities:
618
+ self.add_entity(e)
619
+
620
+ # rename relation entities with their wikipedia titles
621
+ r["head"] = entities[0]["title"]
622
+ r["tail"] = entities[1]["title"]
623
+
624
+ # add source if not in kb
625
+ article_url = list(r["meta"].keys())[0]
626
+ if article_url not in self.sources:
627
+ self.sources[article_url] = {
628
+ "article_title": article_title,
629
+ "article_publish_date": article_publish_date
630
+ }
631
+
632
+ # manage new relation
633
+ if not self.exists_relation(r):
634
+ self.relations.append(r)
635
+ else:
636
+ self.merge_relations(r)
637
+
638
+ def get_textual_representation(self):
639
+ res = ""
640
+ res += "### Entities\n"
641
+ for e in self.entities.items():
642
+ # shorten summary
643
+ e_temp = (e[0], {k:(v[:100] + "..." if k == "summary" else v) for k,v in e[1].items()})
644
+ res += f"- {e_temp}\n"
645
+ res += "\n"
646
+ res += "### Relations\n"
647
+ for r in self.relations:
648
+ res += f"- {r}\n"
649
+ res += "\n"
650
+ res += "### Sources\n"
651
+ for s in self.sources.items():
652
+ res += f"- {s}\n"
653
+ return res
654
+
655
+ def save_network_html(kb, filename="network.html"):
656
+ # create network
657
+ net = Network(directed=True, width="700px", height="700px", bgcolor="#eeeeee")
658
+
659
+ # nodes
660
+ color_entity = "#00FF00"
661
+ for e in kb.entities:
662
+ net.add_node(e, shape="circle", color=color_entity)
663
+
664
+ # edges
665
+ for r in kb.relations:
666
+ net.add_edge(r["head"], r["tail"],
667
+ title=r["type"], label=r["type"])
668
+
669
+ # save network
670
+ net.repulsion(
671
+ node_distance=200,
672
+ central_gravity=0.2,
673
+ spring_length=200,
674
+ spring_strength=0.05,
675
+ damping=0.09
676
+ )
677
+ net.set_edge_smooth('dynamic')
678
+ net.show(filename)
679
+
680
+
681
  nlp = get_spacy()
682
+ sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer = load_models()
683
  sbert = load_sbert('all-MiniLM-L12-v2')