github-actions[bot]
Auto-sync from demo at Tue Dec 16 08:21:05 UTC 2025
31086ae
import argparse
import os
import time
from importlib import resources
from typing import Any, Dict
import ray
import yaml
from dotenv import load_dotenv
from ray.data.block import Block
from ray.data.datasource.filename_provider import FilenameProvider
from graphgen.engine import Engine
from graphgen.operators import operators
from graphgen.utils import CURRENT_LOGGER_VAR, logger, set_logger
sys_path = os.path.abspath(os.path.dirname(__file__))
load_dotenv()
def set_working_dir(folder):
os.makedirs(folder, exist_ok=True)
def save_config(config_path, global_config):
if not os.path.exists(os.path.dirname(config_path)):
os.makedirs(os.path.dirname(config_path))
with open(config_path, "w", encoding="utf-8") as config_file:
yaml.dump(
global_config, config_file, default_flow_style=False, allow_unicode=True
)
class NodeFilenameProvider(FilenameProvider):
def __init__(self, node_id: str):
self.node_id = node_id
def get_filename_for_block(
self, block: Block, write_uuid: str, task_index: int, block_index: int
) -> str:
# format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json
return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl"
def get_filename_for_row(
self,
row: Dict[str, Any],
write_uuid: str,
task_index: int,
block_index: int,
row_index: int,
) -> str:
raise NotImplementedError(
f"Row-based filenames are not supported by write_json. "
f"Node: {self.node_id}, write_uuid: {write_uuid}"
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
help="Config parameters for GraphGen.",
default=resources.files("graphgen")
.joinpath("configs")
.joinpath("aggregated_config.yaml"),
type=str,
)
parser.add_argument(
"--output_dir",
help="Output directory for GraphGen.",
default=sys_path,
required=True,
type=str,
)
args = parser.parse_args()
working_dir = args.output_dir
with open(args.config_file, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
unique_id = int(time.time())
output_path = os.path.join(working_dir, "output", f"{unique_id}")
set_working_dir(output_path)
log_path = os.path.join(working_dir, "logs", "Driver.log")
driver_logger = set_logger(
log_path,
name="GraphGen",
if_stream=True,
)
CURRENT_LOGGER_VAR.set(driver_logger)
logger.info(
"GraphGen with unique ID %s logging to %s",
unique_id,
log_path,
)
engine = Engine(config, operators)
ds = ray.data.from_items([])
results = engine.execute(ds)
for node_id, dataset in results.items():
output_path = os.path.join(output_path, f"{node_id}")
os.makedirs(output_path, exist_ok=True)
dataset.write_json(
output_path,
filename_provider=NodeFilenameProvider(node_id),
pandas_json_args_fn=lambda: {
"force_ascii": False,
"orient": "records",
"lines": True,
},
)
logger.info("Node %s results saved to %s", node_id, output_path)
save_config(os.path.join(output_path, "config.yaml"), config)
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
if __name__ == "__main__":
main()