import io import torch import pandas as pd from transformers import AutoConfig, AutoModel from flair.data import Sentence from catalyst.dl import SupervisedRunner from catalyst.dl.callbacks import ( CheckpointCallback, InferCallback, ) from utils.data import read_data class Extractor(torch.nn.Module): def __init__( self, pretrained_model_name: str, num_classes: int = None, dropout: float = 0.3 ): super().__init__() config = AutoConfig.from_pretrained( pretrained_model_name, num_labels=num_classes ) self.model = AutoModel.from_pretrained(pretrained_model_name, config=config) self.classifier = torch.nn.Linear(config.hidden_size, num_classes) self.dropout = torch.nn.Dropout(dropout) def forward(self, features, attention_mask=None, head_mask=None): assert attention_mask is not None, 'attention mask is none' bert_output = self.model( input_ids=features, attention_mask=attention_mask, head_mask=head_mask ) seq_output = bert_output[0] pooled_output = seq_output.mean(axis=1) pooled_output = self.dropout(pooled_output) scores = self.classifier(pooled_output) return scores def use(input_text) -> object: df = write_IOB2_format(input_text) loader = read_data(text=df['token'].values.tolist()) model = Extractor( pretrained_model_name='distilbert-base-uncased', num_classes=3, ) runner = SupervisedRunner(input_key=('features', 'attention_mask')) torch.cuda.empty_cache() runner.infer( model=model, loaders=loader, callbacks=[ CheckpointCallback( resume='logdir/extractor/best.pth' ), InferCallback(), ], verbose=True, ) predicted_scores = runner.callbacks[0].predictions['logits'] prediction = ['ADE' if i == 0 or i == 1 else 'O' for i in predicted_scores.argmax(axis=1)] df['tag'] = prediction response = df.loc[df['tag'] == 'ADE', 'token'] tab = '\t' nl = '\n' response_string = '' for n, w in response.items(): response_string = response_string + f'{tab} - {n}{tab}{w}{nl}' return response_string def write_IOB2_format(input_text): headers = 'sentence,token,tag' sent = Sentence(input_text, use_tokenizer=True) data_string = '' nl = '\n' for token in sent: data_string = data_string + f'{nl}0,{token.text},' data_string = f"""{headers}{nl}{data_string}""" df = pd.read_csv(io.StringIO(data_string), sep=',') return df