File size: 2,700 Bytes
4999c45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gc
from pathlib import Path
import numpy as np
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
from solution import predict_wireframe
def empty_solution():
"""Return a minimal valid solution in case of an error."""
return np.zeros((2, 3)), []
def main():
"""
Main script for the S23DR 2025 Challenge.
This script loads the test dataset using the competition's specific
method, runs the prediction pipeline, and saves the results.
"""
print("------------ Setting up data paths ------------")
# This is the essential path where data is stored in the submission environment.
data_path = Path('/tmp/data')
print("------------ Loading dataset ------------")
# This data loading logic is preserved from the original script to ensure
# compatibility with the submission environment.
data_files = {
"validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')],
"test": [str(p) for p in data_path.rglob('*private*/**/*.tar')],
}
print(f"Found data files: {data_files}")
dataset = load_dataset(
str(data_path / 'hoho25k_test_x.py'),
data_files=data_files,
trust_remote_code=True,
writer_batch_size=100,
)
print(f"Dataset loaded successfully: {dataset}")
print('------------ Starting prediction loop ---------------')
solution = []
for subset_name in dataset.keys():
print(f"Predicting for subset: {subset_name}")
for i, entry in enumerate(tqdm(dataset[subset_name], desc=f"Processing {subset_name}")):
try:
# Run your prediction pipeline
pred_vertices, pred_edges = predict_wireframe(entry)
except Exception as e:
# If your pipeline fails, provide an empty solution and log the error.
print(f"Error processing sample {entry.get('order_id', 'UNKNOWN')}: {e}")
pred_vertices, pred_edges = empty_solution()
# Append the result in the required format.
solution.append(
{
'order_id': entry['order_id'],
'wf_vertices': pred_vertices.tolist(),
'wf_edges': pred_edges,
}
)
# Periodically run garbage collection to manage memory.
if (i + 1) % 50 == 0:
gc.collect()
print('------------ Saving results ---------------')
sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
sub.to_parquet("submission.parquet", index=False)
print("------------ Done ------------")
if __name__ == "__main__":
main() |