Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.runner import BaseModule | |
from mmocr.models.builder import DECODERS | |
class FCDecoder(BaseModule): | |
"""FC Decoder class for Ner. | |
Args: | |
num_labels (int): Number of categories mapped by entity label. | |
hidden_dropout_prob (float): The dropout probability of hidden layer. | |
hidden_size (int): Hidden layer output layer channels. | |
""" | |
def __init__(self, | |
num_labels=None, | |
hidden_dropout_prob=0.1, | |
hidden_size=768, | |
init_cfg=[ | |
dict(type='Xavier', layer='Conv2d'), | |
dict(type='Uniform', layer='BatchNorm2d') | |
]): | |
super().__init__(init_cfg=init_cfg) | |
self.num_labels = num_labels | |
self.dropout = nn.Dropout(hidden_dropout_prob) | |
self.classifier = nn.Linear(hidden_size, self.num_labels) | |
def forward(self, outputs): | |
sequence_output = outputs[0] | |
sequence_output = self.dropout(sequence_output) | |
logits = self.classifier(sequence_output) | |
softmax = F.softmax(logits, dim=2) | |
preds = softmax.detach().cpu().numpy() | |
preds = np.argmax(preds, axis=2).tolist() | |
return logits, preds | |