gabrielandrade2's picture
Update README, add example code
348db8b
import mojimoji
import numpy as np
from transformers import AutoTokenizer, AutoModelForTokenClassification
import iob_util #pip install git+https://github.com/gabrielandrade2/IOB-util.git
model_name = "gabrielandrade2/point-to-span-estimation"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
# Point-annotated text
text = "肥大型心⧫筋症、心房⧫細動に対してWF投与が開始となった。\
治療経過中に非持続性心⧫室頻拍が認められたためアミオダロンが併用となった。"
# Convert to zenkaku and tokenize
text = mojimoji.han_to_zen(text)
tokenized = tokenizer.tokenize(text)
# Encode text
input_ids = tokenizer.encode(text, return_tensors="pt")
# Predict spans
output = model(input_ids)
logits = output[0].detach().cpu().numpy()
tags = np.argmax(logits, axis=2)[:, :].tolist()[0]
# Convert model output to IOB format
id2label = model.config.id2label
tags = [id2label[t] for t in tags]
# Convert input_ids back to chars
tokens = [tokenizer.convert_ids_to_tokens(t) for t in input_ids][0]
# Remove model special tokens (CLS, SEP, PAD)
tags = [y for x, y in zip(tokens, tags) if x not in ['[CLS]', '[SEP]', '[PAD]']]
tokens = [x for x in tokens if x not in ['[CLS]', '[SEP]', '[PAD]']]
# Convert from IOB to XML tag format
xml_text = iob_util.convert_iob_to_xml(tokens, tags)
xml_text = xml_text.replace('⧫', '')
print(xml_text)