"""Segment text with SRX. """ __version__ = '0.0.2' import lxml.etree import regex from typing import ( List, Set, Tuple, Dict, Optional ) class SrxSegmenter: """Handle segmentation with SRX regex format. """ def __init__(self, rule: Dict[str, List[Tuple[str, Optional[str]]]], source_text: str) -> None: self.source_text = source_text self.non_breaks = rule.get('non_breaks', []) self.breaks = rule.get('breaks', []) def _get_break_points(self, regexes: List[Tuple[str, str]]) -> Set[int]: return set([ match.span(1)[1] for before, after in regexes for match in regex.finditer('({})({})'.format(before, after), self.source_text) ]) def get_non_break_points(self) -> Set[int]: """Return segment non break points """ return self._get_break_points(self.non_breaks) def get_break_points(self) -> Set[int]: """Return segment break points """ return self._get_break_points(self.breaks) def extract(self) -> Tuple[List[str], List[str]]: """Return segments and whitespaces. """ non_break_points = self.get_non_break_points() candidate_break_points = self.get_break_points() break_point = sorted(candidate_break_points - non_break_points) source_text = self.source_text segments = [] # type: List[str] whitespaces = [] # type: List[str] previous_foot = "" for start, end in zip([0] + break_point, break_point + [len(source_text)]): segment_with_space = source_text[start:end] candidate_segment = segment_with_space.strip() if not candidate_segment: previous_foot += segment_with_space continue head, segment, foot = segment_with_space.partition(candidate_segment) segments.append(segment) whitespaces.append('{}{}'.format(previous_foot, head)) previous_foot = foot whitespaces.append(previous_foot) return segments, whitespaces def parse(srx_filepath: str) -> Dict[str, Dict[str, List[Tuple[str, Optional[str]]]]]: """Parse SRX file and return it. :param srx_filepath: is soruce SRX file. :return: dict """ tree = lxml.etree.parse(srx_filepath) namespaces = { 'ns': 'http://www.lisa.org/srx20' } rules = {} for languagerule in tree.xpath('//ns:languagerule', namespaces=namespaces): rule_name = languagerule.attrib.get('languagerulename') if rule_name is None: continue current_rule = { 'breaks': [], 'non_breaks': [], } for rule in languagerule.xpath('ns:rule', namespaces=namespaces): is_break = rule.attrib.get('break', 'yes') == 'yes' rule_holder = current_rule['breaks'] if is_break else current_rule['non_breaks'] beforebreak = rule.find('ns:beforebreak', namespaces=namespaces) beforebreak_text = '' if beforebreak.text is None else beforebreak.text afterbreak = rule.find('ns:afterbreak', namespaces=namespaces) afterbreak_text = '' if afterbreak.text is None else afterbreak.text rule_holder.append((beforebreak_text, afterbreak_text)) rules[rule_name] = current_rule return rules