Spaces:
Build error
Build error
from typing import * | |
import torch | |
from allennlp.modules.span_extractors import SpanExtractor | |
class ComboSpanExtractor(SpanExtractor): | |
def __init__(self, input_dim: int, sub_extractors: List[SpanExtractor]): | |
super().__init__() | |
self.sub_extractors = sub_extractors | |
for i, sub in enumerate(sub_extractors): | |
self.add_module(f'SpanExtractor-{i+1}', sub) | |
self.input_dim = input_dim | |
def get_input_dim(self) -> int: | |
return self.input_dim | |
def get_output_dim(self) -> int: | |
return sum([sub.get_output_dim() for sub in self.sub_extractors]) | |
def forward( | |
self, | |
sequence_tensor: torch.FloatTensor, | |
span_indices: torch.LongTensor, | |
sequence_mask: torch.BoolTensor = None, | |
span_indices_mask: torch.BoolTensor = None, | |
): | |
outputs = [ | |
sub( | |
sequence_tensor=sequence_tensor, | |
span_indices=span_indices, | |
span_indices_mask=span_indices_mask | |
) for sub in self.sub_extractors | |
] | |
return torch.cat(outputs, dim=2) | |