jcjurado commited on
Commit
678797a
1 Parent(s): be84ac9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -6,6 +6,7 @@ from sklearn import preprocessing
6
  from sklearn.model_selection import train_test_split
7
  from sklearn.ensemble import RandomForestClassifier
8
  from sklearn.neural_network import MLPClassifier
 
9
  from sklearn.metrics import accuracy_score
10
 
11
  data = pd.read_csv('https://raw.githubusercontent.com/gradio-app/titanic/master/train.csv')
@@ -47,9 +48,9 @@ num_test = 0.20
47
  X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=num_test, random_state=23)
48
 
49
  #clf = RandomForestClassifier()
50
- MLP = MLPClassifier()
51
- MLP.fit(X_train, y_train)
52
- predictions = MLP.predict(X_test)
53
 
54
  def predict_survival(sex, age, fare):
55
  df = pd.DataFrame.from_dict({'Sex': [sex], 'Age': [age], 'Fare': [fare]})
 
6
  from sklearn.model_selection import train_test_split
7
  from sklearn.ensemble import RandomForestClassifier
8
  from sklearn.neural_network import MLPClassifier
9
+ from sklearn.tree import DecisionTreeClassifier
10
  from sklearn.metrics import accuracy_score
11
 
12
  data = pd.read_csv('https://raw.githubusercontent.com/gradio-app/titanic/master/train.csv')
 
48
  X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=num_test, random_state=23)
49
 
50
  #clf = RandomForestClassifier()
51
+ tree = DecisionTreeClassifier()
52
+ tree.fit(X_train, y_train)
53
+ predictions = tree.predict(X_test)
54
 
55
  def predict_survival(sex, age, fare):
56
  df = pd.DataFrame.from_dict({'Sex': [sex], 'Age': [age], 'Fare': [fare]})