|
import argparse |
|
import json |
|
import os |
|
import re |
|
import sys |
|
from http.server import BaseHTTPRequestHandler, HTTPServer |
|
from typing import Iterator, List, Optional, Tuple |
|
|
|
from relik.inference.annotator import Relik |
|
from relik.inference.data.objects import RelikOutput |
|
|
|
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) |
|
|
|
|
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class GerbilAlbyManager: |
|
def __init__( |
|
self, |
|
annotator: Optional[Relik] = None, |
|
response_logger_dir: Optional[str] = None, |
|
) -> None: |
|
self.annotator = annotator |
|
self.response_logger_dir = response_logger_dir |
|
self.predictions_counter = 0 |
|
self.labels_mapping = None |
|
|
|
def annotate(self, document: str): |
|
relik_output: RelikOutput = self.annotator(document) |
|
annotations = [(ss, se, l) for ss, se, l, _ in relik_output.labels] |
|
if self.labels_mapping is not None: |
|
return [ |
|
(ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations |
|
] |
|
return annotations |
|
|
|
def set_mapping_file(self, mapping_file_path: str): |
|
with open(mapping_file_path) as f: |
|
labels_mapping = json.load(f) |
|
self.labels_mapping = {v: k for k, v in labels_mapping.items()} |
|
|
|
def write_response_bundle( |
|
self, |
|
document: str, |
|
new_document: str, |
|
annotations: list, |
|
mapped_annotations: list, |
|
) -> None: |
|
if self.response_logger_dir is None: |
|
return |
|
|
|
if not os.path.isdir(self.response_logger_dir): |
|
os.mkdir(self.response_logger_dir) |
|
|
|
with open( |
|
f"{self.response_logger_dir}/{self.predictions_counter}.json", "w" |
|
) as f: |
|
out_json_obj = dict( |
|
document=document, |
|
new_document=new_document, |
|
annotations=annotations, |
|
mapped_annotations=mapped_annotations, |
|
) |
|
|
|
out_json_obj["span_annotations"] = [ |
|
(ss, se, document[ss:se], label) for (ss, se, label) in annotations |
|
] |
|
|
|
out_json_obj["span_mapped_annotations"] = [ |
|
(ss, se, new_document[ss:se], label) |
|
for (ss, se, label) in mapped_annotations |
|
] |
|
|
|
json.dump(out_json_obj, f, indent=2) |
|
|
|
self.predictions_counter += 1 |
|
|
|
|
|
manager = GerbilAlbyManager() |
|
|
|
|
|
def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]: |
|
pattern_subs = { |
|
"-LPR- ": " (", |
|
"-RPR-": ")", |
|
"\n\n": "\n", |
|
"-LRB-": "(", |
|
"-RRB-": ")", |
|
'","': ",", |
|
} |
|
|
|
document_acc = document |
|
curr_offset = 0 |
|
char2offset = [] |
|
|
|
matchings = re.finditer("({})".format("|".join(pattern_subs)), document) |
|
for span_matching in sorted(matchings, key=lambda x: x.span()[0]): |
|
span_start, span_end = span_matching.span() |
|
span_start -= curr_offset |
|
span_end -= curr_offset |
|
|
|
span_text = document_acc[span_start:span_end] |
|
span_sub = pattern_subs[span_text] |
|
document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:] |
|
|
|
offset = len(span_text) - len(span_sub) |
|
curr_offset += offset |
|
|
|
char2offset.append((span_start + len(span_sub), curr_offset)) |
|
|
|
return document_acc, char2offset |
|
|
|
|
|
def map_back_annotations( |
|
annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]] |
|
) -> Iterator[Tuple[int, int, str]]: |
|
def map_char(char_idx: int) -> int: |
|
current_offset = 0 |
|
for offset_idx, offset_value in char_mapping: |
|
if char_idx >= offset_idx: |
|
current_offset = offset_value |
|
else: |
|
break |
|
return char_idx + current_offset |
|
|
|
for ss, se, label in annotations: |
|
yield map_char(ss), map_char(se), label |
|
|
|
|
|
def annotate(document: str) -> List[Tuple[int, int, str]]: |
|
new_document, mapping = preprocess_document(document) |
|
logger.info("Mapping: " + str(mapping)) |
|
logger.info("Document: " + str(document)) |
|
annotations = [ |
|
(cs, ce, label.replace(" ", "_")) |
|
for cs, ce, label in manager.annotate(new_document) |
|
] |
|
logger.info("New document: " + str(new_document)) |
|
mapped_annotations = ( |
|
list(map_back_annotations(annotations, mapping)) |
|
if len(mapping) > 0 |
|
else annotations |
|
) |
|
|
|
logger.info( |
|
"Annotations: " |
|
+ str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations]) |
|
) |
|
|
|
manager.write_response_bundle( |
|
document, new_document, mapped_annotations, annotations |
|
) |
|
|
|
if not all( |
|
[ |
|
new_document[ss:se] == document[mss:mse] |
|
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) |
|
] |
|
): |
|
diff_mappings = [ |
|
(new_document[ss:se], document[mss:mse]) |
|
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) |
|
] |
|
return None |
|
assert all( |
|
[ |
|
document[mss:mse] == new_document[ss:se] |
|
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) |
|
] |
|
), (mapped_annotations, annotations) |
|
|
|
return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations] |
|
|
|
|
|
class GetHandler(BaseHTTPRequestHandler): |
|
def do_POST(self): |
|
content_length = int(self.headers["Content-Length"]) |
|
post_data = self.rfile.read(content_length) |
|
self.send_response(200) |
|
self.end_headers() |
|
doc_text = read_json(post_data) |
|
|
|
response = annotate(doc_text) |
|
|
|
self.wfile.write(bytes(json.dumps(response), "utf-8")) |
|
return |
|
|
|
|
|
def read_json(post_data): |
|
data = json.loads(post_data.decode("utf-8")) |
|
|
|
text = data["text"] |
|
|
|
return text |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--relik-model-name", required=True) |
|
parser.add_argument("--responses-log-dir") |
|
parser.add_argument("--log-file", default="logs/logging.txt") |
|
parser.add_argument("--mapping-file") |
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
|
|
manager.response_logger_dir = args.responses_log_dir |
|
|
|
|
|
print("Debugging, not using you relik model but an hardcoded one.") |
|
manager.annotator = Relik( |
|
question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder", |
|
document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder", |
|
reader="relik/reader/models/relik-reader-deberta-base-new-data", |
|
window_size=32, |
|
window_stride=16, |
|
candidates_preprocessing_fn=(lambda x: x.split("<def>")[0].strip()), |
|
) |
|
|
|
if args.mapping_file is not None: |
|
manager.set_mapping_file(args.mapping_file) |
|
|
|
port = 6654 |
|
server = HTTPServer(("localhost", port), GetHandler) |
|
logger.info(f"Starting server at http://localhost:{port}") |
|
|
|
|
|
file_handler = logging.FileHandler(args.log_file) |
|
file_handler.setLevel(logging.DEBUG) |
|
|
|
|
|
formatter = logging.Formatter( |
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
file_handler.setFormatter(formatter) |
|
|
|
|
|
logger.addHandler(file_handler) |
|
|
|
try: |
|
server.serve_forever() |
|
except KeyboardInterrupt: |
|
exit(0) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|