Edit model card

f1: 88.43

precision recall f1-score support
DAT 0.96 0.97 0.96 182
DUR 0.79 0.82 0.80 50
LOC 0.70 0.79 0.74 206
MNY 0.87 1.00 0.93 20
NOH 0.91 0.93 0.92 1007
ORG 0.86 0.89 0.88 795
PER 0.92 0.95 0.94 853
PNT 0.78 0.78 0.78 60
POH 0.64 0.71 0.68 214
TIM 0.76 1.00 0.86 19
------- ----------- -------- ---------- ---------
micro avg 0.87 0.90 0.88 3406
macro avg 0.82 0.89 0.85 3406
weighted avg 0.87 0.90 0.89 3406
from transformers import TFBertModel, BertTokenizer
import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from konlpy.tag import Mecab

mecab = Mecab()

checkpoint_path = "./cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
latest = tf.train.latest_checkpoint(checkpoint_dir)

index_to_tag = {0: 'B-PER', 1: 'B-LOC', 2: 'I-ORG', 3: 'B-DAT', 4: 'O', 5: 'I-DUR', 6: 'I-TIM', 7: 'I-NOH', 8: 'B-MNY', 9: 'B-PNT', 10: 'I-PER', 11: 'I-PNT', 12: 'I-LOC', 13: 'I-DAT', 14: 'B-TIM', 15: 'B-POH', 16: 'B-NOH', 17: 'I-POH', 18: 'I-MNY', 19: 'B-ORG', 20: 'B-DUR'}

tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
model = TFBertForTokenClassification("klue/bert-base", num_labels=21)
model.load_weights(latest)

class TFBertForTokenClassification(tf.keras.Model):
    def __init__(self, model_name, num_labels):
        super(TFBertForTokenClassification, self).__init__()
        self.bert = TFBertModel.from_pretrained(model_name, from_pt=True)
        self.classifier = tf.keras.layers.Dense(num_labels,
                                                kernel_initializer=tf.keras.initializers.TruncatedNormal(0.02),
                                                name='classifier')

    def call(self, inputs):
        input_ids, attention_mask, token_type_ids = inputs
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        all_output = outputs[0]
        prediction = self.classifier(all_output)

        return prediction


def convert_examples_to_features_for_prediction(examples, max_seq_len, tokenizer,
                                 pad_token_id_for_segment=0, pad_token_id_for_label=-100):
    cls_token = tokenizer.cls_token
    sep_token = tokenizer.sep_token
    pad_token_id = tokenizer.pad_token_id

    input_ids, attention_masks, token_type_ids, label_masks = [], [], [], []

    for example in tqdm(examples):
        tokens = []
        label_mask = []
        for one_word in example:
            subword_tokens = tokenizer.tokenize(one_word)
            tokens.extend(subword_tokens)
            label_mask.extend([0]+ [pad_token_id_for_label] * (len(subword_tokens) - 1))

        special_tokens_count = 2
        if len(tokens) > max_seq_len - special_tokens_count:
            tokens = tokens[:(max_seq_len - special_tokens_count)]
            label_mask = label_mask[:(max_seq_len - special_tokens_count)]

        tokens += [sep_token]
        label_mask += [pad_token_id_for_label]

        tokens = [cls_token] + tokens
        label_mask = [pad_token_id_for_label] + label_mask


        input_id = tokenizer.convert_tokens_to_ids(tokens)
        attention_mask = [1] * len(input_id)
        padding_count = max_seq_len - len(input_id)
        input_id = input_id + ([pad_token_id] * padding_count)
        attention_mask = attention_mask + ([0] * padding_count)
        token_type_id = [pad_token_id_for_segment] * max_seq_len
        label_mask = label_mask + ([pad_token_id_for_label] * padding_count)

        assert len(input_id) == max_seq_len, "Error with input length {} vs {}".format(len(input_id), max_seq_len)
        assert len(attention_mask) == max_seq_len, "Error with attention mask length {} vs {}".format(len(attention_mask), max_seq_len)
        assert len(token_type_id) == max_seq_len, "Error with token type length {} vs {}".format(len(token_type_id), max_seq_len)
        assert len(label_mask) == max_seq_len, "Error with labels length {} vs {}".format(len(label_mask), max_seq_len)

        input_ids.append(input_id)
        attention_masks.append(attention_mask)
        token_type_ids.append(token_type_id)
        label_masks.append(label_mask)

    input_ids = np.array(input_ids, dtype=int)
    attention_masks = np.array(attention_masks, dtype=int)
    token_type_ids = np.array(token_type_ids, dtype=int)
    label_masks = np.asarray(label_masks, dtype=np.int32)

    return (input_ids, attention_masks, token_type_ids), label_masks


def ner_prediction(examples, max_seq_len, tokenizer, lang='ko'):

    if lang == 'ko':
        examples = [mecab.morphs(sent) for sent in examples]
    else:
        examples = [sent.split() for sent in examples]

    X_pred, label_masks = convert_examples_to_features_for_prediction(
        examples, max_seq_len=128, tokenizer=tokenizer)
    y_predicted = model.predict(X_pred)
    y_predicted = np.argmax(y_predicted, axis=2)

    pred_list = []
    result_list = []

    for i in range(0, len(label_masks)):
        pred_tag = []
        for label_index, pred_index in zip(label_masks[i], y_predicted[i]):
            if label_index != -100:
                pred_tag.append(index_to_tag[pred_index])

        pred_list.append(pred_tag)

    for example, pred in zip(examples, pred_list):
        one_sample_result = []
        for one_word, label_token in zip(example, pred):
            one_sample_result.append((one_word, label_token))
        result_list.append(one_sample_result)

    return result_list


sent1 = 'μšΈμ‚°μ—μ„œ ν™œλ™ν•˜κ³  μžˆλŠ” μ‹œκ°μ˜ˆμˆ  λΆ„μ•Ό κΉ€μœ κ²½ μž‘κ°€λŠ” 졜근 지역 AI κΈ°μ—… μ½”μ–΄λ‹·νˆ¬λ°μ΄μ™€μ˜ ν˜‘μ—…μ„ ν†΅ν•œ νŠΉλ³„ν•œ μ „μ‹œλ₯Ό μ—΄μ—ˆλ‹€.'
sent2 = 'κ°€μΉ˜κ΄€μ΄λ‚˜ 인식에 따라 세상을 λΆˆμ™„μ „ν•˜κ²Œ λ³΄λŠ” 인간이 ν•™μŠ΅μ„ 톡해 μΈμ§€ν•œ λΆ€λΆ„λ§Œμ„ μΈμ‹ν•˜λŠ” AI와 λΉ„μŠ·ν•˜λ‹€κ³  보고 μ „μ‹œλ₯Ό κΈ°νšν–ˆλ‹€.'
sent3 = 'λΆ€μ‚° κ΄‘μ•ˆλ¦¬ ν•΄λ³€κ³Ό λ‹¬λ§žμ΄ 고개 λ“± μœ λ™ 인ꡬ와 μ°¨λŸ‰ 이동이 λ§Žμ€ 지역 λͺ‡ 곳을 골라 CCTV 데이터 속 정보λ₯Ό μ–΄λ–»κ²Œ μΈμ‹ν•˜λŠ”μ§€, 곡간에 λŒ€ν•œ μ°°λ‚˜λ₯Ό ν‘œν˜„ν•œ μž‘κ°€μ˜ μž‘ν’ˆμ„ μ–΄λ–»κ²Œ μΈμ‹ν•˜λŠ”μ§€ 차이λ₯Ό λΉ„κ΅ν–ˆλ‹€.'
test_samples = [sent1, sent2, sent3]
ner_prediction(test_samples, max_seq_len=128, tokenizer=tokenizer, lang='ko')
[[('μšΈμ‚°', 'B-LOC'),
  ('μ—μ„œ', 'O'),
  ('ν™œλ™', 'O'),
  ('ν•˜', 'O'),
  ('κ³ ', 'O'),
  ('있', 'O'),
  ('λŠ”', 'O'),
  ('μ‹œκ°', 'O'),
  ('예술', 'O'),
  ('λΆ„μ•Ό', 'O'),
  ('κΉ€μœ κ²½', 'B-PER'),
  ('μž‘κ°€', 'O'),
  ('λŠ”', 'O'),
  ('졜근', 'O'),
  ('지역', 'O'),
  ('AI', 'O'),
  ('κΈ°μ—…', 'O'),
  ('μ½”μ–΄', 'B-ORG'),
  ('λ‹·', 'I-ORG'),
  ('투데이', 'I-ORG'),
  ('와', 'O'),
  ('의', 'O'),
  ('ν˜‘μ—…', 'O'),
  ('을', 'O'),
  ('ν†΅ν•œ', 'O'),
  ('νŠΉλ³„', 'O'),
  ('ν•œ', 'O'),
  ('μ „μ‹œ', 'O'),
  ('λ₯Ό', 'O'),
  ('μ—΄', 'O'),
  ('μ—ˆ', 'O'),
  ('λ‹€', 'O'),
  ('.', 'O')],
 [('κ°€μΉ˜κ΄€', 'O'),
  ('μ΄λ‚˜', 'O'),
  ('인식', 'O'),
  ('에', 'O'),
  ('따라', 'O'),
  ('세상', 'O'),
  ('을', 'O'),
  ('뢈', 'O'),
  ('μ™„μ „', 'O'),
  ('ν•˜', 'O'),
  ('게', 'O'),
  ('보', 'O'),
  ('λŠ”', 'O'),
  ('인간', 'O'),
  ('이', 'O'),
  ('ν•™μŠ΅', 'O'),
  ('을', 'O'),
  ('톡해', 'O'),
  ('인지', 'O'),
  ('ν•œ', 'O'),
  ('λΆ€λΆ„', 'O'),
  ('만', 'O'),
  ('을', 'O'),
  ('인식', 'O'),
  ('ν•˜', 'O'),
  ('λŠ”', 'O'),
  ('AI', 'O'),
  ('와', 'O'),
  ('λΉ„μŠ·', 'O'),
  ('ν•˜', 'O'),
  ('λ‹€κ³ ', 'O'),
  ('보', 'O'),
  ('κ³ ', 'O'),
  ('μ „μ‹œ', 'O'),
  ('λ₯Ό', 'O'),
  ('기획', 'O'),
  ('ν–ˆ', 'O'),
  ('λ‹€', 'O'),
  ('.', 'O')],
 [('λΆ€μ‚°', 'B-LOC'),
  ('κ΄‘μ•ˆλ¦¬', 'I-LOC'),
  ('ν•΄λ³€', 'I-LOC'),
  ('κ³Ό', 'O'),
  ('λ‹¬λ§žμ΄', 'B-LOC'),
  ('고개', 'I-LOC'),
  ('λ“±', 'O'),
  ('μœ λ™', 'O'),
  ('인ꡬ', 'O'),
  ('와', 'O'),
  ('μ°¨λŸ‰', 'O'),
  ('이동', 'O'),
  ('이', 'O'),
  ('많', 'O'),
  ('은', 'O'),
  ('지역', 'O'),
  ('λͺ‡', 'O'),
  ('κ³³', 'O'),
  ('을', 'O'),
  ('골라', 'O'),
  ('CCTV', 'O'),
  ('데이터', 'O'),
  ('속', 'O'),
  ('정보', 'O'),
  ('λ₯Ό', 'O'),
  ('μ–΄λ–»κ²Œ', 'O'),
  ('인식', 'O'),
  ('ν•˜', 'O'),
  ('λŠ”μ§€', 'O'),
  (',', 'O'),
  ('곡간', 'O'),
  ('에', 'O'),
  ('λŒ€ν•œ', 'O'),
  ('μ°°λ‚˜', 'O'),
  ('λ₯Ό', 'O'),
  ('ν‘œν˜„', 'O'),
  ('ν•œ', 'O'),
  ('μž‘κ°€', 'O'),
  ('의', 'O'),
  ('μž‘ν’ˆ', 'O'),
  ('을', 'O'),
  ('μ–΄λ–»κ²Œ', 'O'),
  ('인식', 'O'),
  ('ν•˜', 'O'),
  ('λŠ”μ§€', 'O'),
  ('차이', 'O'),
  ('λ₯Ό', 'O'),
  ('비ꡐ', 'O'),
  ('ν–ˆ', 'O'),
  ('λ‹€', 'O'),
  ('.', 'O')]]
tensorflow-estimator==2.5.0
tensorflow-gpu==2.5.3
transformers @ git+https://github.com/davidegazze/transformers@cf28c1db00410f0df3e654d9866e0ff1d3a45f29
numpy==1.24.3
konlpy==0.6.0
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .