long-coref / pipeline.py
kwang2049's picture
input type
5a19953
raw
history blame
1.07 kB
import os
from typing import Dict, List, Any
from long_coref.coref.prediction import CorefPredictor
from long_coref.coref.utils import ArchiveContent
from allennlp.common.params import Params
CHECKPOINT = "coref-spanbert-large-2021.03.10"
class PreTrainedPipeline:
def __init__(self, path=""):
archive_content = ArchiveContent(
archive_dir=os.path.join(path, CHECKPOINT),
weight_path=os.path.join(path, CHECKPOINT, "weights.th"),
config=Params.from_file(os.path.join(path, CHECKPOINT, "config.json")),
)
self.predictor = CorefPredictor.from_extracted_archive(archive_content)
self.predictor.set_device("cpu")
def __call__(self, data: str) -> Dict[str, Any]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
prediction = self.predictor.resolve_paragraphs(data.split("\n\n"))
return prediction.to_dict()