Spaces:
Runtime error
Runtime error
File size: 1,595 Bytes
86756d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
"""
This module containes methods for words classification using desicion trees
"""
from __future__ import print_function
import os
import subprocess
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier, plot_tree
import graphviz
# ref: http://chrisstrelioff.ws/sandbox/2015/06/08/decision_trees_in_python_with_scikit_learn_and_pandas.html
input_file_path = 'text_words_labels.csv'
def get_data(input_file_path):
df = pd.read_csv(input_file_path)
return df
def encode_target(df, target_column):
"""Add column to df with integers for the target.
Args
----
df -- pandas DataFrame.
target_column -- column to map to int, producing
new Target column.
Returns
-------
df_mod -- modified DataFrame.
targets -- list of target names.
"""
df_mod = df.copy()
targets = df_mod[target_column].unique()
map_to_int = {name: n for n, name in enumerate(targets)}
df_mod["target"] = df_mod[target_column].replace(map_to_int)
return (df_mod, targets)
df = get_data(input_file_path)
df2, targets = encode_target(df, "target")
print("* df2.head()", df2[["target", "name"]].head(),
sep="\n", end="\n\n")
print("* df2.tail()", df2[["target", "name"]].tail(),
sep="\n", end="\n\n")
print("* targets", targets, sep="\n", end="\n\n")
features = [c for c in df2.columns.values if c != 'name' and c != 'isdefinite' and c != 'target']
y = df2["target"]
X = df2[features]
dt = DecisionTreeClassifier(min_samples_split=20, random_state=99)
dt.fit(X, y)
plot_tree(dt,max_depth=3)
|