Spaces:
Build error
Build error
import io | |
import torch | |
import pandas as pd | |
from transformers import AutoConfig, AutoModel | |
from flair.data import Sentence | |
from catalyst.dl import SupervisedRunner | |
from catalyst.dl.callbacks import ( | |
CheckpointCallback, | |
InferCallback, | |
) | |
from utils.data import read_data | |
class Extractor(torch.nn.Module): | |
def __init__( | |
self, pretrained_model_name: str, num_classes: int = None, dropout: float = 0.3 | |
): | |
super().__init__() | |
config = AutoConfig.from_pretrained( | |
pretrained_model_name, num_labels=num_classes | |
) | |
self.model = AutoModel.from_pretrained(pretrained_model_name, config=config) | |
self.classifier = torch.nn.Linear(config.hidden_size, num_classes) | |
self.dropout = torch.nn.Dropout(dropout) | |
def forward(self, features, attention_mask=None, head_mask=None): | |
assert attention_mask is not None, 'attention mask is none' | |
bert_output = self.model( | |
input_ids=features, attention_mask=attention_mask, head_mask=head_mask | |
) | |
seq_output = bert_output[0] | |
pooled_output = seq_output.mean(axis=1) | |
pooled_output = self.dropout(pooled_output) | |
scores = self.classifier(pooled_output) | |
return scores | |
def use(input_text) -> object: | |
df = write_IOB2_format(input_text) | |
loader = read_data(text=df['token'].values.tolist()) | |
model = Extractor( | |
pretrained_model_name='distilbert-base-uncased', | |
num_classes=3, | |
) | |
runner = SupervisedRunner(input_key=('features', 'attention_mask')) | |
torch.cuda.empty_cache() | |
runner.infer( | |
model=model, | |
loaders=loader, | |
callbacks=[ | |
CheckpointCallback( | |
resume='logdir/extractor/best.pth' | |
), | |
InferCallback(), | |
], | |
verbose=True, | |
) | |
predicted_scores = runner.callbacks[0].predictions['logits'] | |
prediction = ['ADE' if i == 0 or i == 1 else 'O' for i in predicted_scores.argmax(axis=1)] | |
df['tag'] = prediction | |
response = df.loc[df['tag'] == 'ADE', 'token'] | |
tab = '\t' | |
nl = '\n' | |
response_string = '' | |
for n, w in response.items(): | |
response_string = response_string + f'{tab} - {n}{tab}{w}{nl}' | |
return response_string | |
def write_IOB2_format(input_text): | |
headers = 'sentence,token,tag' | |
sent = Sentence(input_text, use_tokenizer=True) | |
data_string = '' | |
nl = '\n' | |
for token in sent: | |
data_string = data_string + f'{nl}0,{token.text},' | |
data_string = f"""{headers}{nl}{data_string}""" | |
df = pd.read_csv(io.StringIO(data_string), sep=',') | |
return df | |