Spaces:
Runtime error
Runtime error
File size: 1,747 Bytes
9c00f5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
diff --git a/coglm_strategy.py b/coglm_strategy.py
index cba87ce..40e4ece 100755
--- a/coglm_strategy.py
+++ b/coglm_strategy.py
@@ -8,6 +8,7 @@
# here put the import lib
import os
+import pathlib
import sys
import math
import random
@@ -57,7 +58,8 @@ class CoglmStrategy:
self._is_done = False
self.outlier_count_down = 5
self.vis_list = [[]for i in range(16)]
- self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long)
+ cluster_label_path = pathlib.Path(__file__).parent / 'cluster_label.npy'
+ self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
self.top_k_cluster = top_k_cluster
@property
@@ -91,4 +93,4 @@ class CoglmStrategy:
def finalize(self, tokens, mems):
self._is_done = False
- return tokens, mems
\ No newline at end of file
+ return tokens, mems
diff --git a/sr_pipeline/dsr_sampling.py b/sr_pipeline/dsr_sampling.py
index a0d0298..f721573 100755
--- a/sr_pipeline/dsr_sampling.py
+++ b/sr_pipeline/dsr_sampling.py
@@ -8,6 +8,7 @@
# here put the import lib
import os
+import pathlib
import sys
import math
import random
@@ -27,7 +28,8 @@ class IterativeEntfilterStrategy:
self.invalid_slices = invalid_slices
self.temperature = temperature
self.topk = topk
- self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long)
+ cluster_label_path = pathlib.Path(__file__).parents[1] / 'cluster_label.npy'
+ self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
self.temperature2 = temperature2
|