Sujatha commited on
Commit
192e408
·
verified ·
1 Parent(s): c8b62fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import pandas as pd
3
  from pytorch_tabular import TabularModel
4
- from pytorch_tabular.config import DataConfig, ModelConfig, TrainerConfig
 
5
 
6
  # Sample Data
7
  data = {
@@ -17,7 +18,7 @@ data_config = DataConfig(
17
  target=["target"],
18
  continuous_cols=["feature1", "feature2", "feature3"]
19
  )
20
- model_config = ModelConfig(task="classification", num_classes=2)
21
  trainer_config = TrainerConfig(max_epochs=10)
22
 
23
  # Initialize and train model
 
1
  import gradio as gr
2
  import pandas as pd
3
  from pytorch_tabular import TabularModel
4
+ from pytorch_tabular.config import DataConfig, TrainerConfig
5
+ from pytorch_tabular.models import CategoryEmbeddingModelConfig
6
 
7
  # Sample Data
8
  data = {
 
18
  target=["target"],
19
  continuous_cols=["feature1", "feature2", "feature3"]
20
  )
21
+ model_config = CategoryEmbeddingModelConfig(task="classification") # No `num_classes`
22
  trainer_config = TrainerConfig(max_epochs=10)
23
 
24
  # Initialize and train model