gossminn's picture
First version
6680682
raw
history blame
1.16 kB
from typing import *
import torch
from allennlp.modules.span_extractors import SpanExtractor
@SpanExtractor.register('combo')
class ComboSpanExtractor(SpanExtractor):
def __init__(self, input_dim: int, sub_extractors: List[SpanExtractor]):
super().__init__()
self.sub_extractors = sub_extractors
for i, sub in enumerate(sub_extractors):
self.add_module(f'SpanExtractor-{i+1}', sub)
self.input_dim = input_dim
def get_input_dim(self) -> int:
return self.input_dim
def get_output_dim(self) -> int:
return sum([sub.get_output_dim() for sub in self.sub_extractors])
def forward(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.BoolTensor = None,
span_indices_mask: torch.BoolTensor = None,
):
outputs = [
sub(
sequence_tensor=sequence_tensor,
span_indices=span_indices,
span_indices_mask=span_indices_mask
) for sub in self.sub_extractors
]
return torch.cat(outputs, dim=2)