fashxp commited on
Commit
ade1b4d
·
1 Parent(s): 264e02e

some improvements

Browse files
src/image_classification/image_classification_trainer.py CHANGED
@@ -52,6 +52,9 @@ class ImageClassificationTrainer(AbstractTrainer):
52
 
53
  self.__train_model(images, parameters)
54
 
 
 
 
55
  self.get_status().update_status(100, "Training completed")
56
 
57
  except Exception as e:
 
52
 
53
  self.__train_model(images, parameters)
54
 
55
+ if(self.get_status().is_training_aborted()):
56
+ return
57
+
58
  self.get_status().update_status(100, "Training completed")
59
 
60
  except Exception as e:
src/main.py CHANGED
@@ -75,7 +75,7 @@ async def get_task_status(token_data: dict = Depends(verify_token)):
75
  "status": status.get_status().value
76
  }
77
 
78
- @app.get("/stop_training")
79
  async def stop_task(token_data: dict = Depends(verify_token)):
80
  """ Stop the currently running training (if any). """
81
  try:
 
75
  "status": status.get_status().value
76
  }
77
 
78
+ @app.put("/stop_training")
79
  async def stop_task(token_data: dict = Depends(verify_token)):
80
  """ Stop the currently running training (if any). """
81
  try:
src/text_classification/text_classification_trainer.py CHANGED
@@ -44,6 +44,9 @@ class TextClassificationTrainer(AbstractTrainer):
44
 
45
  self.__train_model(tokenized_dataset, labels, label2id, id2label, parameters)
46
 
 
 
 
47
  self.get_status().update_status(100, "Training completed")
48
 
49
  except Exception as e:
@@ -66,6 +69,18 @@ class TextClassificationTrainer(AbstractTrainer):
66
  dataset = load_dataset('csv', data_files=parameters.get_training_csv_file_path(), delimiter=parameters.get_training_csv_limiter())
67
 
68
  dataset = dataset["train"]
 
 
 
 
 
 
 
 
 
 
 
 
69
  dataset = dataset.train_test_split(test_size=0.2)
70
 
71
  logger.info(dataset)
@@ -79,15 +94,6 @@ class TextClassificationTrainer(AbstractTrainer):
79
 
80
  tokenized_dataset = dataset.map(preprocess_function, batched=True)
81
 
82
- # Extract the labels
83
- labels = tokenized_dataset['train'].unique('target')
84
- label2id, id2label = dict(), dict()
85
- for i, label in enumerate(labels):
86
- label2id[label] = i
87
- id2label[i] = label
88
-
89
- logger.info(id2label)
90
-
91
  # Rename the Target column to labels and remove unnecessary columns
92
  tokenized_dataset = tokenized_dataset.rename_column('target', 'labels')
93
 
 
44
 
45
  self.__train_model(tokenized_dataset, labels, label2id, id2label, parameters)
46
 
47
+ if(self.get_status().is_training_aborted()):
48
+ return
49
+
50
  self.get_status().update_status(100, "Training completed")
51
 
52
  except Exception as e:
 
69
  dataset = load_dataset('csv', data_files=parameters.get_training_csv_file_path(), delimiter=parameters.get_training_csv_limiter())
70
 
71
  dataset = dataset["train"]
72
+
73
+ # Extract the labels
74
+ #labels = tokenized_dataset['train'].unique('target')
75
+ labels = dataset.unique('target')
76
+ label2id, id2label = dict(), dict()
77
+ for i, label in enumerate(labels):
78
+ label2id[label] = i
79
+ id2label[i] = label
80
+
81
+ logger.info(id2label)
82
+
83
+
84
  dataset = dataset.train_test_split(test_size=0.2)
85
 
86
  logger.info(dataset)
 
94
 
95
  tokenized_dataset = dataset.map(preprocess_function, batched=True)
96
 
 
 
 
 
 
 
 
 
 
97
  # Rename the Target column to labels and remove unnecessary columns
98
  tokenized_dataset = tokenized_dataset.rename_column('target', 'labels')
99