import torch from transformers import AutoConfig, AutoModel from catalyst.dl import SupervisedRunner from catalyst.dl.callbacks import ( CheckpointCallback, InferCallback, ) from utils.data import read_data class Classifier(torch.nn.Module): def __init__( self, pretrained_model_name: str, num_classes: int = None, dropout: float = 0.3 ): super().__init__() config = AutoConfig.from_pretrained( pretrained_model_name, num_labels=num_classes ) self.model = AutoModel.from_pretrained(pretrained_model_name, config=config) self.classifier = torch.nn.Linear(config.hidden_size, num_classes) self.dropout = torch.nn.Dropout(dropout) def forward(self, features, attention_mask=None, head_mask=None): assert attention_mask is not None, "attention mask is none" bert_output = self.model( input_ids=features, attention_mask=attention_mask, head_mask=head_mask ) seq_output = bert_output[0] pooled_output = seq_output.mean(axis=1) pooled_output = self.dropout(pooled_output) scores = self.classifier(pooled_output) return scores def use(input_text: str) -> object: text = [input_text] PRESENT_LABELS = ['ADE', 'NoAde'] loader = read_data(text=text) model = Classifier( pretrained_model_name='distilbert-base-uncased', num_classes=2, ) runner = SupervisedRunner(input_key=('features', 'attention_mask')) torch.cuda.empty_cache() runner.infer( model=model, loaders=loader, callbacks=[ CheckpointCallback( resume='logdir/classifier/best.pth' ), InferCallback(), ], verbose=True, ) predicted_scores = runner.callbacks[0].predictions['logits'] # prediction = [PRESENT_LABELS[i] for i in predicted_scores.argmax(axis=1)] if predicted_scores.argmax(axis=1) == 0: return True, 'ADE' else: return False, 'NoAde'