kwang2049 commited on
Commit
a9ffef1
1 Parent(s): 2269bb9
Files changed (1) hide show
  1. pipeline.py +2 -0
pipeline.py CHANGED
@@ -6,6 +6,7 @@ from allennlp.common.params import Params
6
 
7
  CHECKPOINT = "coref-spanbert-large-2021.03.10"
8
 
 
9
  class PreTrainedPipeline:
10
  def __init__(self, path=""):
11
  archive_content = ArchiveContent(
@@ -14,6 +15,7 @@ class PreTrainedPipeline:
14
  config=Params.from_file(os.path.join(path, CHECKPOINT, "config.json")),
15
  )
16
  self.predictor = CorefPredictor.from_extracted_archive(archive_content)
 
17
 
18
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
  """
 
6
 
7
  CHECKPOINT = "coref-spanbert-large-2021.03.10"
8
 
9
+
10
  class PreTrainedPipeline:
11
  def __init__(self, path=""):
12
  archive_content = ArchiveContent(
 
15
  config=Params.from_file(os.path.join(path, CHECKPOINT, "config.json")),
16
  )
17
  self.predictor = CorefPredictor.from_extracted_archive(archive_content)
18
+ self.predictor.set_device("cpu")
19
 
20
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
21
  """