File size: 1,027 Bytes
3835a42
 
 
 
 
 
b731827
3835a42
a9ffef1
d544db4
3835a42
 
b731827
 
 
3835a42
 
 
5a19953
3835a42
 
 
 
 
 
 
 
5a19953
3835a42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
from typing import Dict, List, Any
from long_coref.coref.prediction import CorefPredictor
from long_coref.coref.utils import ArchiveContent
from allennlp.common.params import Params

CHECKPOINT = "coref-spanbert-large-2021.03.10"


class PreTrainedPipeline:
    def __init__(self, path=""):
        archive_content = ArchiveContent(
            archive_dir=os.path.join(path, CHECKPOINT),
            weight_path=os.path.join(path, CHECKPOINT, "weights.th"),
            config=Params.from_file(os.path.join(path, CHECKPOINT, "config.json")),
        )
        self.predictor = CorefPredictor.from_extracted_archive(archive_content)

    def __call__(self, data: str) -> Dict[str, Any]:
        """
         data args:
              inputs (:obj: `str`)
              date (:obj: `str`)
        Return:
              A :obj:`list` | `dict`: will be serialized and returned
        """
        # get inputs
        prediction = self.predictor.resolve_paragraphs(data.split("\n\n"))
        return prediction.to_dict()