File size: 2,636 Bytes
4c8fe65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import io

import torch
import pandas as pd
from transformers import AutoConfig, AutoModel
from flair.data import Sentence
from catalyst.dl import SupervisedRunner
from catalyst.dl.callbacks import (
    CheckpointCallback,
    InferCallback,
)

from utils.data import read_data


class Extractor(torch.nn.Module):
    def __init__(
            self, pretrained_model_name: str, num_classes: int = None, dropout: float = 0.3
    ):
        super().__init__()

        config = AutoConfig.from_pretrained(
            pretrained_model_name, num_labels=num_classes
        )

        self.model = AutoModel.from_pretrained(pretrained_model_name, config=config)
        self.classifier = torch.nn.Linear(config.hidden_size, num_classes)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, features, attention_mask=None, head_mask=None):
        assert attention_mask is not None, 'attention mask is none'
        bert_output = self.model(
            input_ids=features, attention_mask=attention_mask, head_mask=head_mask
        )
        seq_output = bert_output[0]
        pooled_output = seq_output.mean(axis=1)
        pooled_output = self.dropout(pooled_output)
        scores = self.classifier(pooled_output)

        return scores


def use(input_text) -> object:
    df = write_IOB2_format(input_text)
    loader = read_data(text=df['token'].values.tolist())

    model = Extractor(
        pretrained_model_name='distilbert-base-uncased',
        num_classes=3,
    )

    runner = SupervisedRunner(input_key=('features', 'attention_mask'))

    torch.cuda.empty_cache()
    runner.infer(
        model=model,
        loaders=loader,
        callbacks=[
            CheckpointCallback(
                resume='logdir/extractor/best.pth'
            ),
            InferCallback(),
        ],
        verbose=True,
    )

    predicted_scores = runner.callbacks[0].predictions['logits']
    prediction = ['ADE' if i == 0 or i == 1 else 'O' for i in predicted_scores.argmax(axis=1)]
    df['tag'] = prediction
    response = df.loc[df['tag'] == 'ADE', 'token']
    tab = '\t'
    nl = '\n'
    response_string = ''
    for n, w in response.items():
        response_string = response_string + f'{tab} - {n}{tab}{w}{nl}'
    return response_string


def write_IOB2_format(input_text):
    headers = 'sentence,token,tag'

    sent = Sentence(input_text, use_tokenizer=True)
    data_string = ''
    nl = '\n'
    for token in sent:
        data_string = data_string + f'{nl}0,{token.text},'
    data_string = f"""{headers}{nl}{data_string}"""
    df = pd.read_csv(io.StringIO(data_string), sep=',')
    return df