ArneBinder commited on
Commit
9d06087
1 Parent(s): 5a9013c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -6,13 +6,13 @@ from prettytable import PrettyTable
6
  from pytorch_ie.annotations import LabeledSpan, BinaryRelation
7
  from pytorch_ie.auto import AutoPipeline
8
  from pytorch_ie.core import AnnotationList, annotation_field
9
- from pytorch_ie.documents import TextDocument
10
 
11
  from typing import List
12
 
13
 
14
  @dataclass
15
- class ExampleDocument(TextDocument):
16
  entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
17
  relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
18
 
@@ -21,20 +21,22 @@ ner_model_name_or_path = "pie/example-ner-spanclf-conll03"
21
  re_model_name_or_path = "pie/example-re-textclf-tacred"
22
 
23
  ner_pipeline = AutoPipeline.from_pretrained(ner_model_name_or_path, device=-1, num_workers=0)
24
- re_pipeline = AutoPipeline.from_pretrained(re_model_name_or_path, device=-1, num_workers=0)
25
 
26
 
27
  def predict(text):
28
  document = ExampleDocument(text)
29
 
 
30
  ner_pipeline(document)
31
 
32
- print(f"list detected entities:")
33
- while len(document.entities.predictions) > 0:
34
- entity = document.entities.predictions.pop(0)
35
- print(f"entity detected: {entity}")
36
- document.entities.append(entity)
37
 
 
38
  re_pipeline(document)
39
 
40
  t = PrettyTable()
 
6
  from pytorch_ie.annotations import LabeledSpan, BinaryRelation
7
  from pytorch_ie.auto import AutoPipeline
8
  from pytorch_ie.core import AnnotationList, annotation_field
9
+ from pytorch_ie.documents import TextBasedDocument
10
 
11
  from typing import List
12
 
13
 
14
  @dataclass
15
+ class ExampleDocument(TextBasedDocument):
16
  entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
17
  relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
18
 
 
21
  re_model_name_or_path = "pie/example-re-textclf-tacred"
22
 
23
  ner_pipeline = AutoPipeline.from_pretrained(ner_model_name_or_path, device=-1, num_workers=0)
24
+ re_pipeline = AutoPipeline.from_pretrained(re_model_name_or_path, device=-1, num_workers=0, taskmodule_kwargs=dict(create_relation_candidates=True))
25
 
26
 
27
  def predict(text):
28
  document = ExampleDocument(text)
29
 
30
+ # execute NER pipeline
31
  ner_pipeline(document)
32
 
33
+ # show predicted entities and promote them from predictions to ground-truth annotations
34
+ print(f"detected entities:\n")
35
+ for entity in document.entities.predictions:
36
+ print(f"{entity}")
37
+ document.entities.append(entity.copy())
38
 
39
+ # execute RE pipeline
40
  re_pipeline(document)
41
 
42
  t = PrettyTable()