wli3221134 commited on
Commit
6e8a76c
·
verified ·
1 Parent(s): 46689af

Update dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +3 -2
dataset.py CHANGED
@@ -70,7 +70,7 @@ class DemoDataset(Dataset):
70
  )
71
  prompt_features.append(prompt_feature)
72
 
73
- prompt_labels = torch.tensor(self.demonstration_labels, dtype=torch.long)
74
 
75
  return {
76
  'main_features': main_features,
@@ -113,7 +113,8 @@ def collate_fn(batch):
113
  file_names = [item['file_name'] for item in batch]
114
  file_paths = [item['file_path'] for item in batch]
115
 
116
- prompt_labels = torch.tensor([item['prompt_labels'] for item in batch], dtype=torch.long)
 
117
 
118
  return {
119
  'main_features': main_features,
 
70
  )
71
  prompt_features.append(prompt_feature)
72
 
73
+ prompt_labels = torch.tensor([self.demonstration_labels], dtype=torch.long)
74
 
75
  return {
76
  'main_features': main_features,
 
113
  file_names = [item['file_name'] for item in batch]
114
  file_paths = [item['file_path'] for item in batch]
115
 
116
+ # 确保 prompt_labels 的形状正确 [batch_size, num_prompts]
117
+ prompt_labels = torch.cat([item['prompt_labels'] for item in batch], dim=0)
118
 
119
  return {
120
  'main_features': main_features,