import argparse import json import os import re import sys from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Iterator, List, Optional, Tuple from relik.inference.annotator import Relik from relik.inference.data.objects import RelikOutput # sys.path += ['../'] sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) import logging logger = logging.getLogger(__name__) class GerbilAlbyManager: def __init__( self, annotator: Optional[Relik] = None, response_logger_dir: Optional[str] = None, ) -> None: self.annotator = annotator self.response_logger_dir = response_logger_dir self.predictions_counter = 0 self.labels_mapping = None def annotate(self, document: str): relik_output: RelikOutput = self.annotator(document) annotations = [(ss, se, l) for ss, se, l, _ in relik_output.labels] if self.labels_mapping is not None: return [ (ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations ] return annotations def set_mapping_file(self, mapping_file_path: str): with open(mapping_file_path) as f: labels_mapping = json.load(f) self.labels_mapping = {v: k for k, v in labels_mapping.items()} def write_response_bundle( self, document: str, new_document: str, annotations: list, mapped_annotations: list, ) -> None: if self.response_logger_dir is None: return if not os.path.isdir(self.response_logger_dir): os.mkdir(self.response_logger_dir) with open( f"{self.response_logger_dir}/{self.predictions_counter}.json", "w" ) as f: out_json_obj = dict( document=document, new_document=new_document, annotations=annotations, mapped_annotations=mapped_annotations, ) out_json_obj["span_annotations"] = [ (ss, se, document[ss:se], label) for (ss, se, label) in annotations ] out_json_obj["span_mapped_annotations"] = [ (ss, se, new_document[ss:se], label) for (ss, se, label) in mapped_annotations ] json.dump(out_json_obj, f, indent=2) self.predictions_counter += 1 manager = GerbilAlbyManager() def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]: pattern_subs = { "-LPR- ": " (", "-RPR-": ")", "\n\n": "\n", "-LRB-": "(", "-RRB-": ")", '","': ",", } document_acc = document curr_offset = 0 char2offset = [] matchings = re.finditer("({})".format("|".join(pattern_subs)), document) for span_matching in sorted(matchings, key=lambda x: x.span()[0]): span_start, span_end = span_matching.span() span_start -= curr_offset span_end -= curr_offset span_text = document_acc[span_start:span_end] span_sub = pattern_subs[span_text] document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:] offset = len(span_text) - len(span_sub) curr_offset += offset char2offset.append((span_start + len(span_sub), curr_offset)) return document_acc, char2offset def map_back_annotations( annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]] ) -> Iterator[Tuple[int, int, str]]: def map_char(char_idx: int) -> int: current_offset = 0 for offset_idx, offset_value in char_mapping: if char_idx >= offset_idx: current_offset = offset_value else: break return char_idx + current_offset for ss, se, label in annotations: yield map_char(ss), map_char(se), label def annotate(document: str) -> List[Tuple[int, int, str]]: new_document, mapping = preprocess_document(document) logger.info("Mapping: " + str(mapping)) logger.info("Document: " + str(document)) annotations = [ (cs, ce, label.replace(" ", "_")) for cs, ce, label in manager.annotate(new_document) ] logger.info("New document: " + str(new_document)) mapped_annotations = ( list(map_back_annotations(annotations, mapping)) if len(mapping) > 0 else annotations ) logger.info( "Annotations: " + str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations]) ) manager.write_response_bundle( document, new_document, mapped_annotations, annotations ) if not all( [ new_document[ss:se] == document[mss:mse] for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) ] ): diff_mappings = [ (new_document[ss:se], document[mss:mse]) for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) ] return None assert all( [ document[mss:mse] == new_document[ss:se] for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) ] ), (mapped_annotations, annotations) return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations] class GetHandler(BaseHTTPRequestHandler): def do_POST(self): content_length = int(self.headers["Content-Length"]) post_data = self.rfile.read(content_length) self.send_response(200) self.end_headers() doc_text = read_json(post_data) # try: response = annotate(doc_text) self.wfile.write(bytes(json.dumps(response), "utf-8")) return def read_json(post_data): data = json.loads(post_data.decode("utf-8")) # logger.info("received data:", data) text = data["text"] # spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]] return text def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--relik-model-name", required=True) parser.add_argument("--responses-log-dir") parser.add_argument("--log-file", default="logs/logging.txt") parser.add_argument("--mapping-file") return parser.parse_args() def main(): args = parse_args() # init manager manager.response_logger_dir = args.responses_log_dir # manager.annotator = Relik.from_pretrained(args.relik_model_name) print("Debugging, not using you relik model but an hardcoded one.") manager.annotator = Relik( question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder", document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder", reader="relik/reader/models/relik-reader-deberta-base-new-data", window_size=32, window_stride=16, candidates_preprocessing_fn=(lambda x: x.split("")[0].strip()), ) if args.mapping_file is not None: manager.set_mapping_file(args.mapping_file) port = 6654 server = HTTPServer(("localhost", port), GetHandler) logger.info(f"Starting server at http://localhost:{port}") # Create a file handler and set its level file_handler = logging.FileHandler(args.log_file) file_handler.setLevel(logging.DEBUG) # Create a log formatter and set it on the handler formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) file_handler.setFormatter(formatter) # Add the file handler to the logger logger.addHandler(file_handler) try: server.serve_forever() except KeyboardInterrupt: exit(0) if __name__ == "__main__": main()