File size: 3,203 Bytes
58ccbe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import math
from xml.dom import minidom
from xml.etree import ElementTree as ET

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y

from .c45_utils import decision, grow_tree

class C45(BaseEstimator, ClassifierMixin):
    """A C4.5 tree classifier.



    Parameters

    ----------

    attrNames : list, optional (default=None)

        The list of feature names used in printing tree during. If left default,

        attributes will be named attr0, attr1... etc

    See also

    --------

    DecisionTreeClassifier

    References

    ----------

    .. [1] https://en.wikipedia.org/wiki/Decision_tree_learning

    .. [2] https://en.wikipedia.org/wiki/C4.5_algorithm

    .. [3] L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification

           and Regression Trees", Wadsworth, Belmont, CA, 1984.

    .. [4] J. R. Quinlain, "C4.5: Programs for Machine Learning",

           Morgan Kaufmann Publishers, 1993

    Examples

    --------

    >>> from sklearn.datasets import load_iris

    >>> from sklearn.model_selection import cross_val_score

    >>> from c45 import C45

    >>> iris = load_iris()

    >>> clf = C45(attrNames=iris.feature_names)

    >>> cross_val_score(clf, iris.data, iris.target, cv=10)

    ...                             # doctest: +SKIP

    ...

    array([ 1.     ,  0.93...,  0.86...,  0.93...,  0.93...,

            0.93...,  0.93...,  1.     ,  0.93...,  1.      ])

    """
    def __init__(self, attrNames=None):
        if attrNames is not None:
            attrNames = [''.join(i for i in x if i.isalnum()).replace(' ', '_') for x in attrNames]
        self.attrNames = attrNames

    def fit(self, X, y):
        X, y = check_X_y(X, y)
        self.X_ = X
        self.y_ = y
        self.resultType = type(y[0])
        if self.attrNames is None:
            self.attrNames = [f'attr{x}' for x in range(len(self.X_[0]))]

        assert(len(self.attrNames) == len(self.X_[0]))

        data = [[] for i in range(len(self.attrNames))]
        categories = []

        for i in range(len(self.X_)):
            categories.append(str(self.y_[i]))
            for j in range(len(self.attrNames)):
                data[j].append(self.X_[i][j])
        root = ET.Element('DecisionTree')
        grow_tree(data,categories,root,self.attrNames)
        self.tree_ = ET.tostring(root, encoding="unicode")
        return self

    def predict(self, X):
        check_is_fitted(self, ['tree_', 'resultType', 'attrNames'])
        X = check_array(X)
        dom = minidom.parseString(self.tree_)
        root = dom.childNodes[0]
        prediction = []
        for i in range(len(X)):
            answerlist = decision(root,X[i],self.attrNames,1)
            answerlist = sorted(answerlist.items(), key=lambda x:x[1], reverse = True )
            answer = answerlist[0][0]
            prediction.append((self.resultType)(answer))
        return prediction

    def printTree(self):
        check_is_fitted(self, ['tree_'])
        dom = minidom.parseString(self.tree_)
        print(dom.toprettyxml(newl="\r\n"))