|
import gc |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
from solution import predict_wireframe |
|
|
|
|
|
def empty_solution(): |
|
"""Return a minimal valid solution in case of an error.""" |
|
return np.zeros((2, 3)), [] |
|
|
|
|
|
def main(): |
|
""" |
|
Main script for the S23DR 2025 Challenge. |
|
This script loads the test dataset using the competition's specific |
|
method, runs the prediction pipeline, and saves the results. |
|
""" |
|
print("------------ Setting up data paths ------------") |
|
|
|
data_path = Path('/tmp/data') |
|
|
|
print("------------ Loading dataset ------------") |
|
|
|
|
|
data_files = { |
|
"validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')], |
|
"test": [str(p) for p in data_path.rglob('*private*/**/*.tar')], |
|
} |
|
print(f"Found data files: {data_files}") |
|
|
|
dataset = load_dataset( |
|
str(data_path / 'hoho25k_test_x.py'), |
|
data_files=data_files, |
|
trust_remote_code=True, |
|
writer_batch_size=100, |
|
) |
|
print(f"Dataset loaded successfully: {dataset}") |
|
|
|
print('------------ Starting prediction loop ---------------') |
|
solution = [] |
|
for subset_name in dataset.keys(): |
|
print(f"Predicting for subset: {subset_name}") |
|
for i, entry in enumerate(tqdm(dataset[subset_name], desc=f"Processing {subset_name}")): |
|
try: |
|
|
|
pred_vertices, pred_edges = predict_wireframe(entry) |
|
except Exception as e: |
|
|
|
print(f"Error processing sample {entry.get('order_id', 'UNKNOWN')}: {e}") |
|
pred_vertices, pred_edges = empty_solution() |
|
|
|
|
|
solution.append( |
|
{ |
|
'order_id': entry['order_id'], |
|
'wf_vertices': pred_vertices.tolist(), |
|
'wf_edges': pred_edges, |
|
} |
|
) |
|
|
|
|
|
if (i + 1) % 50 == 0: |
|
gc.collect() |
|
|
|
print('------------ Saving results ---------------') |
|
sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"]) |
|
sub.to_parquet("submission.parquet", index=False) |
|
print("------------ Done ------------") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |