Spaces:
Running
Running
| import argparse | |
| import os | |
| import time | |
| from importlib import resources | |
| from typing import Any, Dict | |
| import ray | |
| import yaml | |
| from dotenv import load_dotenv | |
| from ray.data.block import Block | |
| from ray.data.datasource.filename_provider import FilenameProvider | |
| from graphgen.engine import Engine | |
| from graphgen.operators import operators | |
| from graphgen.utils import CURRENT_LOGGER_VAR, logger, set_logger | |
| sys_path = os.path.abspath(os.path.dirname(__file__)) | |
| load_dotenv() | |
| def set_working_dir(folder): | |
| os.makedirs(folder, exist_ok=True) | |
| def save_config(config_path, global_config): | |
| if not os.path.exists(os.path.dirname(config_path)): | |
| os.makedirs(os.path.dirname(config_path)) | |
| with open(config_path, "w", encoding="utf-8") as config_file: | |
| yaml.dump( | |
| global_config, config_file, default_flow_style=False, allow_unicode=True | |
| ) | |
| class NodeFilenameProvider(FilenameProvider): | |
| def __init__(self, node_id: str): | |
| self.node_id = node_id | |
| def get_filename_for_block( | |
| self, block: Block, write_uuid: str, task_index: int, block_index: int | |
| ) -> str: | |
| # format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json | |
| return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl" | |
| def get_filename_for_row( | |
| self, | |
| row: Dict[str, Any], | |
| write_uuid: str, | |
| task_index: int, | |
| block_index: int, | |
| row_index: int, | |
| ) -> str: | |
| raise NotImplementedError( | |
| f"Row-based filenames are not supported by write_json. " | |
| f"Node: {self.node_id}, write_uuid: {write_uuid}" | |
| ) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config_file", | |
| help="Config parameters for GraphGen.", | |
| default=resources.files("graphgen") | |
| .joinpath("configs") | |
| .joinpath("aggregated_config.yaml"), | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| help="Output directory for GraphGen.", | |
| default=sys_path, | |
| required=True, | |
| type=str, | |
| ) | |
| args = parser.parse_args() | |
| working_dir = args.output_dir | |
| with open(args.config_file, "r", encoding="utf-8") as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| unique_id = int(time.time()) | |
| output_path = os.path.join(working_dir, "output", f"{unique_id}") | |
| set_working_dir(output_path) | |
| log_path = os.path.join(working_dir, "logs", "Driver.log") | |
| driver_logger = set_logger( | |
| log_path, | |
| name="GraphGen", | |
| if_stream=True, | |
| ) | |
| CURRENT_LOGGER_VAR.set(driver_logger) | |
| logger.info( | |
| "GraphGen with unique ID %s logging to %s", | |
| unique_id, | |
| log_path, | |
| ) | |
| engine = Engine(config, operators) | |
| ds = ray.data.from_items([]) | |
| results = engine.execute(ds) | |
| for node_id, dataset in results.items(): | |
| output_path = os.path.join(output_path, f"{node_id}") | |
| os.makedirs(output_path, exist_ok=True) | |
| dataset.write_json( | |
| output_path, | |
| filename_provider=NodeFilenameProvider(node_id), | |
| pandas_json_args_fn=lambda: { | |
| "force_ascii": False, | |
| "orient": "records", | |
| "lines": True, | |
| }, | |
| ) | |
| logger.info("Node %s results saved to %s", node_id, output_path) | |
| save_config(os.path.join(output_path, "config.yaml"), config) | |
| logger.info("GraphGen completed successfully. Data saved to %s", output_path) | |
| if __name__ == "__main__": | |
| main() | |