Spaces:
Running
on
Zero
Running
on
Zero
| diff --git a/coglm_strategy.py b/coglm_strategy.py | |
| index d485715..a9eab3b 100644 | |
| --- a/coglm_strategy.py | |
| +++ b/coglm_strategy.py | |
| # here put the import lib | |
| import os | |
| +import pathlib | |
| import sys | |
| import math | |
| import random | |
| class CoglmStrategy: | |
| self._is_done = False | |
| self.outlier_count_down = torch.zeros(16) | |
| self.vis_list = [[]for i in range(16)] | |
| - self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long) | |
| + cluster_label_path = pathlib.Path(__file__).parent / 'cluster_label2.npy' | |
| + self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long) | |
| self.start_pos = -1 | |
| self.white_cluster = [] | |
| # self.fout = open('tmp.txt', 'w') | |
| class CoglmStrategy: | |
| def finalize(self, tokens, mems): | |
| self._is_done = False | |
| - return tokens, mems | |
| \ No newline at end of file | |
| + return tokens, mems | |
| diff --git a/sr_pipeline/dsr_sampling.py b/sr_pipeline/dsr_sampling.py | |
| index 5b8dded..07e97fd 100644 | |
| --- a/sr_pipeline/dsr_sampling.py | |
| +++ b/sr_pipeline/dsr_sampling.py | |
| # here put the import lib | |
| import os | |
| +import pathlib | |
| import sys | |
| import math | |
| import random | |
| class IterativeEntfilterStrategy: | |
| self.invalid_slices = invalid_slices | |
| self.temperature = temperature | |
| self.topk = topk | |
| - self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long) | |
| + cluster_label_path = pathlib.Path(__file__).parents[1] / 'cluster_label2.npy' | |
| + self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long) | |
| def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None): | |