| import argparse |
| import logging |
| from typing import List, Optional |
|
|
| import pandas as pd |
|
|
| from tabgan.sampler import ( |
| OriginalGenerator, |
| GANGenerator, |
| ForestDiffusionGenerator, |
| LLMGenerator, |
| ) |
|
|
|
|
| def _parse_cat_cols(raw: Optional[str]) -> Optional[List[str]]: |
| if not raw: |
| return None |
| return [c.strip() for c in raw.split(",") if c.strip()] |
|
|
|
|
| def main() -> None: |
| """ |
| Command-line interface for generating synthetic tabular data with tabgan. |
| |
| Example: |
| tabgan-generate \\ |
| --input-csv train.csv \\ |
| --target-col target \\ |
| --generator gan \\ |
| --gen-x-times 1.5 \\ |
| --cat-cols year,gender \\ |
| --output-csv synthetic_train.csv |
| """ |
| parser = argparse.ArgumentParser( |
| description="Generate synthetic tabular data using tabgan samplers." |
| ) |
| parser.add_argument( |
| "--input-csv", |
| required=True, |
| help="Path to input CSV file containing training data (with or without target column).", |
| ) |
| parser.add_argument( |
| "--target-col", |
| default=None, |
| help="Name of the target column in the CSV (optional).", |
| ) |
| parser.add_argument( |
| "--output-csv", |
| required=True, |
| help="Path to write the generated synthetic dataset as CSV.", |
| ) |
| parser.add_argument( |
| "--generator", |
| choices=["original", "gan", "diffusion", "llm"], |
| default="gan", |
| help="Which sampler to use for generation.", |
| ) |
| parser.add_argument( |
| "--gen-x-times", |
| type=float, |
| default=1.1, |
| help="Factor controlling how many synthetic samples to generate relative to the training size.", |
| ) |
| parser.add_argument( |
| "--cat-cols", |
| default=None, |
| help="Comma-separated list of categorical column names (e.g. 'year,gender').", |
| ) |
| parser.add_argument( |
| "--only-generated", |
| action="store_true", |
| help="If set, output only synthetic rows instead of original + synthetic.", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| logging.basicConfig(level=logging.INFO) |
|
|
| logging.info("Reading input CSV from %s", args.input_csv) |
| df = pd.read_csv(args.input_csv) |
|
|
| target_df = None |
| train_df = df |
| if args.target_col is not None: |
| if args.target_col not in df.columns: |
| raise ValueError(f"Target column '{args.target_col}' not found in input CSV.") |
| target_df = df[[args.target_col]] |
| train_df = df.drop(columns=[args.target_col]) |
|
|
| cat_cols = _parse_cat_cols(args.cat_cols) |
|
|
| generator_map = { |
| "original": OriginalGenerator, |
| "gan": GANGenerator, |
| "diffusion": ForestDiffusionGenerator, |
| "llm": LLMGenerator, |
| } |
| generator_cls = generator_map[args.generator] |
|
|
| logging.info("Initializing %s generator", generator_cls.__name__) |
| generator = generator_cls( |
| gen_x_times=args.gen_x_times, |
| cat_cols=cat_cols, |
| only_generated_data=bool(args.only_generated), |
| ) |
|
|
| |
| logging.info("Generating synthetic data...") |
| new_train, new_target = generator.generate_data_pipe( |
| train_df, target_df, train_df |
| ) |
|
|
| if new_target is not None and args.target_col is not None: |
| out_df = new_train.copy() |
| |
| if hasattr(new_target, "values") and new_target.ndim > 1: |
| out_df[args.target_col] = new_target.values.ravel() |
| else: |
| out_df[args.target_col] = new_target |
| else: |
| out_df = new_train |
|
|
| logging.info("Writing synthetic data to %s", args.output_csv) |
| out_df.to_csv(args.output_csv, index=False) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|