Spaces:
Runtime error
Runtime error
diff --git a/coglm_strategy.py b/coglm_strategy.py | |
index cba87ce..40e4ece 100755 | |
--- a/coglm_strategy.py | |
+++ b/coglm_strategy.py | |
# here put the import lib | |
import os | |
+import pathlib | |
import sys | |
import math | |
import random | |
class CoglmStrategy: | |
self._is_done = False | |
self.outlier_count_down = 5 | |
self.vis_list = [[]for i in range(16)] | |
- self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long) | |
+ cluster_label_path = pathlib.Path(__file__).parent / 'cluster_label.npy' | |
+ self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long) | |
self.top_k_cluster = top_k_cluster | |
@property | |
class CoglmStrategy: | |
def finalize(self, tokens, mems): | |
self._is_done = False | |
- return tokens, mems | |
\ No newline at end of file | |
+ return tokens, mems | |
diff --git a/sr_pipeline/dsr_sampling.py b/sr_pipeline/dsr_sampling.py | |
index a0d0298..f721573 100755 | |
--- a/sr_pipeline/dsr_sampling.py | |
+++ b/sr_pipeline/dsr_sampling.py | |
# here put the import lib | |
import os | |
+import pathlib | |
import sys | |
import math | |
import random | |
class IterativeEntfilterStrategy: | |
self.invalid_slices = invalid_slices | |
self.temperature = temperature | |
self.topk = topk | |
- self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long) | |
+ cluster_label_path = pathlib.Path(__file__).parents[1] / 'cluster_label.npy' | |
+ self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long) | |
self.temperature2 = temperature2 | |