|
import json |
|
import logging |
|
import random |
|
from typing import * |
|
|
|
import numpy as np |
|
from allennlp.data.dataset_readers.dataset_reader import DatasetReader |
|
from allennlp.data.fields import MetadataField |
|
from allennlp.data.instance import Instance |
|
|
|
from .span_reader import SpanReader |
|
from ..utils import Span, VIRTUAL_ROOT, BIOSmoothing |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@DatasetReader.register('semantic_role_labeling') |
|
class SRLDatasetReader(SpanReader): |
|
def __init__( |
|
self, |
|
min_negative: int = 5, |
|
negative_ratio: float = 1., |
|
event_only: bool = False, |
|
event_smoothing_factor: float = 0., |
|
arg_smoothing_factor: float = 0., |
|
|
|
ontology_mapping_path: Optional[str] = None, |
|
min_weight: float = 1e-2, |
|
max_weight: float = 1.0, |
|
**extra |
|
): |
|
super().__init__(**extra) |
|
self.min_negative = min_negative |
|
self.negative_ratio = negative_ratio |
|
self.event_only = event_only |
|
self.event_smooth_factor = event_smoothing_factor |
|
self.arg_smooth_factor = arg_smoothing_factor |
|
self.ontology_mapping = None |
|
if ontology_mapping_path is not None: |
|
self.ontology_mapping = json.load(open(ontology_mapping_path)) |
|
for k1 in ['event', 'argument']: |
|
for k2, weights in self.ontology_mapping['mapping'][k1].items(): |
|
weights = np.array(weights) |
|
weights[weights < min_weight] = 0.0 |
|
weights[weights > max_weight] = max_weight |
|
self.ontology_mapping['mapping'][k1][k2] = weights |
|
self.ontology_mapping['mapping'][k1] = { |
|
k2: weights for k2, weights in self.ontology_mapping['mapping'][k1].items() if weights.sum() > 1e-5 |
|
} |
|
vr_label = [0.] * len(self.ontology_mapping['target']['label']) |
|
vr_label[self.ontology_mapping['target']['label'].index(VIRTUAL_ROOT)] = 1.0 |
|
self.ontology_mapping['mapping']['event'][VIRTUAL_ROOT] = np.array(vr_label) |
|
|
|
def _read(self, file_path: str) -> Iterable[Instance]: |
|
all_lines = list(map(json.loads, open(file_path).readlines())) |
|
if self.debug: |
|
random.seed(1); random.shuffle(all_lines) |
|
for line in all_lines: |
|
ins = self.text_to_instance(**line) |
|
if ins is not None: |
|
yield ins |
|
if self.n_span_removed > 0: |
|
logger.warning(f'{self.n_span_removed} spans are removed.') |
|
self.n_span_removed = 0 |
|
|
|
def apply_ontology_mapping(self, vr): |
|
new_events = list() |
|
event_map, arg_map = self.ontology_mapping['mapping']['event'], self.ontology_mapping['mapping']['argument'] |
|
for event in vr: |
|
if event.label not in event_map: continue |
|
event.child_smooth.weight = event.smooth_weight = event_map[event.label].sum() |
|
event = event.map_ontology(event_map, False, False) |
|
new_events.append(event) |
|
new_children = list() |
|
for child in event: |
|
if child.label not in arg_map: continue |
|
child.child_smooth.weight = child.smooth_weight = arg_map[child.label].sum() |
|
child = child.map_ontology(arg_map, False, False) |
|
new_children.append(child) |
|
event.remove_child() |
|
for child in new_children: event.add_child(child) |
|
new_vr = Span.virtual_root(new_events) |
|
|
|
new_vr.map_ontology(self.ontology_mapping['mapping']['event'], True, False) |
|
return new_vr |
|
|
|
def text_to_instance(self, tokens, annotations=None, meta=None) -> Optional[Instance]: |
|
meta = meta or {'fully_annotated': True} |
|
meta['fully_annotated'] = meta.get('fully_annotated', True) |
|
vr = None |
|
if annotations is not None: |
|
vr = annotations if isinstance(annotations, Span) else Span.from_json(annotations) |
|
vr = self.apply_ontology_mapping(vr) if self.ontology_mapping is not None else vr |
|
|
|
if self.event_smooth_factor != 0.0: |
|
vr.child_smooth = BIOSmoothing(o_smooth=self.event_smooth_factor if meta['fully_annotated'] else -1) |
|
if self.arg_smooth_factor != 0.0: |
|
for event in vr: |
|
event.child_smooth = BIOSmoothing(o_smooth=self.arg_smooth_factor) |
|
if self.event_only: |
|
for event in vr: |
|
event.remove_child() |
|
event.is_parent = False |
|
|
|
fields = self.prepare_inputs(tokens, vr, True, 'string' if self.ontology_mapping is None else 'list') |
|
fields['meta'] = MetadataField(meta) |
|
return Instance(fields) |
|
|