ahmadardhy's picture
Upload 9 files
58ccbe8 verified
import math
from xml.dom import minidom
from xml.etree import ElementTree as ET
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
from .c45_utils import decision, grow_tree
class C45(BaseEstimator, ClassifierMixin):
"""A C4.5 tree classifier.
Parameters
----------
attrNames : list, optional (default=None)
The list of feature names used in printing tree during. If left default,
attributes will be named attr0, attr1... etc
See also
--------
DecisionTreeClassifier
References
----------
.. [1] https://en.wikipedia.org/wiki/Decision_tree_learning
.. [2] https://en.wikipedia.org/wiki/C4.5_algorithm
.. [3] L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification
and Regression Trees", Wadsworth, Belmont, CA, 1984.
.. [4] J. R. Quinlain, "C4.5: Programs for Machine Learning",
Morgan Kaufmann Publishers, 1993
Examples
--------
>>> from sklearn.datasets import load_iris
>>> from sklearn.model_selection import cross_val_score
>>> from c45 import C45
>>> iris = load_iris()
>>> clf = C45(attrNames=iris.feature_names)
>>> cross_val_score(clf, iris.data, iris.target, cv=10)
... # doctest: +SKIP
...
array([ 1. , 0.93..., 0.86..., 0.93..., 0.93...,
0.93..., 0.93..., 1. , 0.93..., 1. ])
"""
def __init__(self, attrNames=None):
if attrNames is not None:
attrNames = [''.join(i for i in x if i.isalnum()).replace(' ', '_') for x in attrNames]
self.attrNames = attrNames
def fit(self, X, y):
X, y = check_X_y(X, y)
self.X_ = X
self.y_ = y
self.resultType = type(y[0])
if self.attrNames is None:
self.attrNames = [f'attr{x}' for x in range(len(self.X_[0]))]
assert(len(self.attrNames) == len(self.X_[0]))
data = [[] for i in range(len(self.attrNames))]
categories = []
for i in range(len(self.X_)):
categories.append(str(self.y_[i]))
for j in range(len(self.attrNames)):
data[j].append(self.X_[i][j])
root = ET.Element('DecisionTree')
grow_tree(data,categories,root,self.attrNames)
self.tree_ = ET.tostring(root, encoding="unicode")
return self
def predict(self, X):
check_is_fitted(self, ['tree_', 'resultType', 'attrNames'])
X = check_array(X)
dom = minidom.parseString(self.tree_)
root = dom.childNodes[0]
prediction = []
for i in range(len(X)):
answerlist = decision(root,X[i],self.attrNames,1)
answerlist = sorted(answerlist.items(), key=lambda x:x[1], reverse = True )
answer = answerlist[0][0]
prediction.append((self.resultType)(answer))
return prediction
def printTree(self):
check_is_fitted(self, ['tree_'])
dom = minidom.parseString(self.tree_)
print(dom.toprettyxml(newl="\r\n"))