luisejdm commited on
Commit
e7ccdab
·
verified ·
1 Parent(s): bf21beb

Update data_generation.py

Browse files
Files changed (1) hide show
  1. data_generation.py +13 -3
data_generation.py CHANGED
@@ -1,6 +1,16 @@
 
1
  import pandas as pd
2
  from sdv.single_table import CTGANSynthesizer
3
 
 
 
 
 
 
 
 
 
 
4
 
5
  def generate_synthetic_training_data(n=30_000):
6
  """Generates synthetic training data using pre-trained CTGAN models for each credit score category.
@@ -10,9 +20,9 @@ def generate_synthetic_training_data(n=30_000):
10
  Returns:
11
  pd.DataFrame: The generated synthetic training data.
12
  """
13
- good_generator = CTGANSynthesizer.load("models/v4/synth_good.pkl")
14
- poor_generator = CTGANSynthesizer.load("models/v4/synth_poor.pkl")
15
- standard_generator = CTGANSynthesizer.load("models/v4/synth_standard.pkl")
16
 
17
  synth_good = good_generator.sample(n)
18
  synth_poor = poor_generator.sample(n)
 
1
+ import torch
2
  import pandas as pd
3
  from sdv.single_table import CTGANSynthesizer
4
 
5
+ # Patch torch.load to remap MPS tensors to CPU for environments without Apple Silicon
6
+ _original_torch_load = torch.load
7
+
8
+ def _cpu_map_load(*args, **kwargs):
9
+ kwargs.setdefault('map_location', 'cpu')
10
+ return _original_torch_load(*args, **kwargs)
11
+
12
+ torch.load = _cpu_map_load
13
+
14
 
15
  def generate_synthetic_training_data(n=30_000):
16
  """Generates synthetic training data using pre-trained CTGAN models for each credit score category.
 
20
  Returns:
21
  pd.DataFrame: The generated synthetic training data.
22
  """
23
+ good_generator = CTGANSynthesizer.load("../models/v4/synth_good.pkl")
24
+ poor_generator = CTGANSynthesizer.load("../models/v4/synth_poor.pkl")
25
+ standard_generator = CTGANSynthesizer.load("../models/v4/synth_standard.pkl")
26
 
27
  synth_good = good_generator.sample(n)
28
  synth_poor = poor_generator.sample(n)