Sadjad Alikhani commited on
Commit
6c2d844
·
verified ·
1 Parent(s): e1d11cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -84,18 +84,21 @@ def split_dataset(channels, labels, percentage_idx):
84
  def euclidean_distance(x, centroid):
85
  return np.linalg.norm(x - centroid)
86
 
87
- # Function to classify test data based on distance to class centroids
 
88
  def classify_based_on_distance(train_data, train_labels, test_data):
89
- centroid_0 = np.mean(train_data[train_labels == 0], axis=0)
90
- centroid_1 = np.mean(train_data[train_labels == 1], axis=0)
 
91
 
92
  predictions = []
93
  for test_point in test_data:
 
94
  dist_0 = euclidean_distance(test_point, centroid_0)
95
  dist_1 = euclidean_distance(test_point, centroid_1)
96
  predictions.append(0 if dist_0 < dist_1 else 1)
97
 
98
- return np.array(predictions)
99
 
100
  # Function to generate confusion matrix plot
101
  def plot_confusion_matrix(y_true, y_pred, title):
 
84
  def euclidean_distance(x, centroid):
85
  return np.linalg.norm(x - centroid)
86
 
87
+ import torch
88
+
89
  def classify_based_on_distance(train_data, train_labels, test_data):
90
+ # Compute the centroids for the two classes
91
+ centroid_0 = train_data[train_labels == 0].mean(dim=0) # Use torch.mean
92
+ centroid_1 = train_data[train_labels == 1].mean(dim=0) # Use torch.mean
93
 
94
  predictions = []
95
  for test_point in test_data:
96
+ # Compute Euclidean distance between the test point and each centroid
97
  dist_0 = euclidean_distance(test_point, centroid_0)
98
  dist_1 = euclidean_distance(test_point, centroid_1)
99
  predictions.append(0 if dist_0 < dist_1 else 1)
100
 
101
+ return torch.tensor(predictions) # Return predictions as a PyTorch tensor
102
 
103
  # Function to generate confusion matrix plot
104
  def plot_confusion_matrix(y_true, y_pred, title):