yzhuang commited on
Commit
b859b28
1 Parent(s): 597dc99

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +101 -0
README.md CHANGED
@@ -1,3 +1,104 @@
1
  ---
2
  license: mit
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ pipeline_tag: tabular-classification
4
  ---
5
+
6
+ <h1 align="center"> 🌲 MetaTree 🌲 </h1>
7
+ <p align="center"> <b>Learning a Decision Tree Algorithm with Transformers</b> (<a href="https://arxiv.org/abs/2402.03774">Zhuang et al. 2024</a>).
8
+ </p>
9
+
10
+ <p align="center"> MetaTree is a transformer-based decision tree algorithm. It learns from classical decision tree algorithms (greedy algorithm CART, optimal algorithm GOSDT), for better generalization capabilities.
11
+ </p>
12
+
13
+ ## Quickstart -- use MetaTree to generate decision tree models
14
+
15
+ Model is avaliable at https://huggingface.co/yzhuang/MetaTree
16
+
17
+ 1. Install `metatreelib`:
18
+
19
+ ```bash
20
+ pip install metatreelib
21
+ # Alternatively,
22
+ # clone then pip install -e .
23
+ # pip install git+https://github.com/EvanZhuang/MetaTree
24
+ ```
25
+
26
+ 2. Use MetaTree on your datasets to generate a decision tree model
27
+
28
+ ```python
29
+ from metatree.model_metatree import LlamaForMetaTree as MetaTree
30
+ from metatree.decision_tree_class import DecisionTree, DecisionTreeForest
31
+ from metatree.run_train import preprocess_dimension_patch
32
+ from transformers import AutoConfig
33
+ import imodels # pip install imodels
34
+
35
+ # Initialize Model
36
+ model_name_or_path = "yzhuang/MetaTree"
37
+
38
+ config = AutoConfig.from_pretrained(model_name_or_path)
39
+ model = MetaTree.from_pretrained(
40
+ model_name_or_path,
41
+ config=config,
42
+ )
43
+
44
+ # Load Datasets
45
+ X, y, feature_names = imodels.get_clean_dataset('fico', data_source='imodels')
46
+
47
+ print("Dataset Shapes X={}, y={}, Num of Classes={}".format(X.shape, y.shape, len(set(y))))
48
+
49
+ train_idx, test_idx = sklearn.model_selection.train_test_split(range(X.shape[0]), test_size=0.3, random_state=seed)
50
+
51
+ # Dimension Subsampling
52
+ feature_idx = np.random.choice(X.shape[1], 10, replace=False)
53
+ X = X[:, feature_idx]
54
+
55
+ test_X, test_y = X[test_idx], y[test_idx]
56
+
57
+ # Sample Train and Test Data
58
+ subset_idx = random.sample(train_idx, 256)
59
+ train_X, train_y = X[subset_idx], y[subset_idx]
60
+
61
+ input_x = torch.tensor(train_X, dtype=torch.float32)
62
+ input_y = torch.nn.functional.one_hot(torch.tensor(train_y)).float()
63
+
64
+ batch = {"input_x": input_x, "input_y": input_y, "input_y_clean": input_y}
65
+ batch = preprocess_dimension_patch(batch, n_feature=10, n_class=10)
66
+ model.depth = 2
67
+ outputs = model.generate_decision_tree(batch['input_x'], batch['input_y'], depth=model.depth)
68
+ decision_tree_forest.add_tree(DecisionTree(auto_dims=outputs.metatree_dimensions, auto_thresholds=outputs.tentative_splits, input_x=batch['input_x'], input_y=batch['input_y'], depth=model.depth))
69
+
70
+ print("Decision Tree Features: ", [x.argmax(dim=-1) for x in outputs.metatree_dimensions])
71
+ print("Decision Tree Threasholds: ", outputs.tentative_splits)
72
+ ```
73
+
74
+ 3. Inference with the decision tree model
75
+
76
+ ```python
77
+ tree_pred = decision_tree_forest.predict(torch.tensor(test_X, dtype=torch.float32))
78
+
79
+ accuracy = accuracy_score(test_y, tree_pred.argmax(dim=-1).squeeze(0))
80
+ print("MetaTree Test Accuracy: ", accuracy)
81
+ ```
82
+
83
+ ## Example Usage
84
+
85
+ We show a complete example of using MetaTree at [notebook](examples/example_usage.ipynb)
86
+
87
+ ## Questions?
88
+
89
+ If you have any questions related to the code or the paper, feel free to reach out to us at y5zhuang@ucsd.edu.
90
+
91
+
92
+ ## Citation
93
+
94
+ If you find our paper and code useful, please cite us:
95
+ ```r
96
+ @misc{zhuang2024learning,
97
+ title={Learning a Decision Tree Algorithm with Transformers},
98
+ author={Yufan Zhuang and Liyuan Liu and Chandan Singh and Jingbo Shang and Jianfeng Gao},
99
+ year={2024},
100
+ eprint={2402.03774},
101
+ archivePrefix={arXiv},
102
+ primaryClass={cs.LG}
103
+ }
104
+ ```