christophalt commited on
Commit
864a486
1 Parent(s): dda894a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import gradio as gr
3
+ from dataclasses import dataclass
4
+ from prettytable import PrettyTable
5
+
6
+ from pytorch_ie import AnnotationList, BinaryRelation, Span, LabeledSpan, Pipeline, TextDocument, annotation_field
7
+ from pytorch_ie.models import TransformerSpanClassificationModel, TransformerTextClassificationModel
8
+ from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule, TransformerRETextClassificationTaskModule
9
+
10
+ from typing import List
11
+
12
+
13
+ @dataclass
14
+ class ExampleDocument(TextDocument):
15
+ entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
16
+ relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
17
+
18
+
19
+ model_name_or_path = "pie/example-ner-spanclf-conll03"
20
+ ner_taskmodule = TransformerSpanClassificationTaskModule.from_pretrained(model_name_or_path)
21
+ ner_model = TransformerSpanClassificationModel.from_pretrained(model_name_or_path)
22
+
23
+ ner_pipeline = Pipeline(model=ner_model, taskmodule=ner_taskmodule, device=-1, num_workers=0)
24
+
25
+ model_name_or_path = "pie/example-re-textclf-tacred"
26
+ re_taskmodule = TransformerRETextClassificationTaskModule.from_pretrained(model_name_or_path)
27
+ re_model = TransformerTextClassificationModel.from_pretrained(model_name_or_path)
28
+
29
+ re_pipeline = Pipeline(model=re_model, taskmodule=re_taskmodule, device=-1, num_workers=0)
30
+
31
+
32
+ def predict(text):
33
+ document = ExampleDocument(text)
34
+
35
+ ner_pipeline(document, predict_field="entities")
36
+
37
+ for entity in document.entities.predictions:
38
+ document.entities.append(entity)
39
+
40
+ re_pipeline(document, predict_field="relations")
41
+
42
+ t = PrettyTable()
43
+ t.field_names = ["head", "tail", "relation"]
44
+ t.align = "l"
45
+ for relation in document.relations.predictions:
46
+ t.add_row([str(relation.head), str(relation.tail), relation.label])
47
+
48
+ html = t.get_html_string(format=True)
49
+ html = (
50
+ "<div style='max-width:100%; max-height:360px; overflow:auto'>"
51
+ + html
52
+ + "</div>"
53
+ )
54
+
55
+ return html
56
+
57
+
58
+ iface = gr.Interface(
59
+ fn=predict,
60
+ inputs="textbox",
61
+ outputs="html",
62
+ )
63
+ iface.launch()