File size: 1,467 Bytes
348db8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import mojimoji
import numpy as np
from transformers import AutoTokenizer, AutoModelForTokenClassification

import iob_util #pip install git+https://github.com/gabrielandrade2/IOB-util.git

model_name = "gabrielandrade2/point-to-span-estimation"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)

# Point-annotated text
text = "肥大型心⧫筋症、心房⧫細動に対してWF投与が開始となった。\
治療経過中に非持続性心⧫室頻拍が認められたためアミオダロンが併用となった。"

# Convert to zenkaku and tokenize
text = mojimoji.han_to_zen(text)
tokenized = tokenizer.tokenize(text)

# Encode text
input_ids = tokenizer.encode(text, return_tensors="pt")

# Predict spans
output = model(input_ids)
logits = output[0].detach().cpu().numpy()
tags = np.argmax(logits, axis=2)[:, :].tolist()[0]

# Convert model output to IOB format
id2label = model.config.id2label
tags = [id2label[t] for t in tags]

# Convert input_ids back to chars
tokens = [tokenizer.convert_ids_to_tokens(t) for t in input_ids][0]

# Remove model special tokens (CLS, SEP, PAD)
tags = [y for x, y in zip(tokens, tags) if x not in ['[CLS]', '[SEP]', '[PAD]']]
tokens = [x for x in tokens if x not in ['[CLS]', '[SEP]', '[PAD]']]

# Convert from IOB to XML tag format
xml_text = iob_util.convert_iob_to_xml(tokens, tags)
xml_text = xml_text.replace('⧫', '')
print(xml_text)