hysts's picture
hysts HF staff
Add files
8353fd4
raw
history blame
1.88 kB
diff --git a/coglm_strategy.py b/coglm_strategy.py
index d485715..a9eab3b 100644
--- a/coglm_strategy.py
+++ b/coglm_strategy.py
@@ -8,6 +8,7 @@
# here put the import lib
import os
+import pathlib
import sys
import math
import random
@@ -58,7 +59,8 @@ class CoglmStrategy:
self._is_done = False
self.outlier_count_down = torch.zeros(16)
self.vis_list = [[]for i in range(16)]
- self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
+ cluster_label_path = pathlib.Path(__file__).parent / 'cluster_label2.npy'
+ self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
self.start_pos = -1
self.white_cluster = []
# self.fout = open('tmp.txt', 'w')
@@ -98,4 +100,4 @@ class CoglmStrategy:
def finalize(self, tokens, mems):
self._is_done = False
- return tokens, mems
\ No newline at end of file
+ return tokens, mems
diff --git a/sr_pipeline/dsr_sampling.py b/sr_pipeline/dsr_sampling.py
index 5b8dded..07e97fd 100644
--- a/sr_pipeline/dsr_sampling.py
+++ b/sr_pipeline/dsr_sampling.py
@@ -8,6 +8,7 @@
# here put the import lib
import os
+import pathlib
import sys
import math
import random
@@ -28,7 +29,8 @@ class IterativeEntfilterStrategy:
self.invalid_slices = invalid_slices
self.temperature = temperature
self.topk = topk
- self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
+ cluster_label_path = pathlib.Path(__file__).parents[1] / 'cluster_label2.npy'
+ self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):