Spaces:
Running
on
Zero
Running
on
Zero
wli3221134
commited on
Update dataset.py
Browse files- dataset.py +3 -2
dataset.py
CHANGED
@@ -70,7 +70,7 @@ class DemoDataset(Dataset):
|
|
70 |
)
|
71 |
prompt_features.append(prompt_feature)
|
72 |
|
73 |
-
prompt_labels = torch.tensor(self.demonstration_labels, dtype=torch.long)
|
74 |
|
75 |
return {
|
76 |
'main_features': main_features,
|
@@ -113,7 +113,8 @@ def collate_fn(batch):
|
|
113 |
file_names = [item['file_name'] for item in batch]
|
114 |
file_paths = [item['file_path'] for item in batch]
|
115 |
|
116 |
-
|
|
|
117 |
|
118 |
return {
|
119 |
'main_features': main_features,
|
|
|
70 |
)
|
71 |
prompt_features.append(prompt_feature)
|
72 |
|
73 |
+
prompt_labels = torch.tensor([self.demonstration_labels], dtype=torch.long)
|
74 |
|
75 |
return {
|
76 |
'main_features': main_features,
|
|
|
113 |
file_names = [item['file_name'] for item in batch]
|
114 |
file_paths = [item['file_path'] for item in batch]
|
115 |
|
116 |
+
# 确保 prompt_labels 的形状正确 [batch_size, num_prompts]
|
117 |
+
prompt_labels = torch.cat([item['prompt_labels'] for item in batch], dim=0)
|
118 |
|
119 |
return {
|
120 |
'main_features': main_features,
|