Spaces:
Sleeping
Sleeping
some improvements
Browse files
src/image_classification/image_classification_trainer.py
CHANGED
@@ -52,6 +52,9 @@ class ImageClassificationTrainer(AbstractTrainer):
|
|
52 |
|
53 |
self.__train_model(images, parameters)
|
54 |
|
|
|
|
|
|
|
55 |
self.get_status().update_status(100, "Training completed")
|
56 |
|
57 |
except Exception as e:
|
|
|
52 |
|
53 |
self.__train_model(images, parameters)
|
54 |
|
55 |
+
if(self.get_status().is_training_aborted()):
|
56 |
+
return
|
57 |
+
|
58 |
self.get_status().update_status(100, "Training completed")
|
59 |
|
60 |
except Exception as e:
|
src/main.py
CHANGED
@@ -75,7 +75,7 @@ async def get_task_status(token_data: dict = Depends(verify_token)):
|
|
75 |
"status": status.get_status().value
|
76 |
}
|
77 |
|
78 |
-
@app.
|
79 |
async def stop_task(token_data: dict = Depends(verify_token)):
|
80 |
""" Stop the currently running training (if any). """
|
81 |
try:
|
|
|
75 |
"status": status.get_status().value
|
76 |
}
|
77 |
|
78 |
+
@app.put("/stop_training")
|
79 |
async def stop_task(token_data: dict = Depends(verify_token)):
|
80 |
""" Stop the currently running training (if any). """
|
81 |
try:
|
src/text_classification/text_classification_trainer.py
CHANGED
@@ -44,6 +44,9 @@ class TextClassificationTrainer(AbstractTrainer):
|
|
44 |
|
45 |
self.__train_model(tokenized_dataset, labels, label2id, id2label, parameters)
|
46 |
|
|
|
|
|
|
|
47 |
self.get_status().update_status(100, "Training completed")
|
48 |
|
49 |
except Exception as e:
|
@@ -66,6 +69,18 @@ class TextClassificationTrainer(AbstractTrainer):
|
|
66 |
dataset = load_dataset('csv', data_files=parameters.get_training_csv_file_path(), delimiter=parameters.get_training_csv_limiter())
|
67 |
|
68 |
dataset = dataset["train"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
dataset = dataset.train_test_split(test_size=0.2)
|
70 |
|
71 |
logger.info(dataset)
|
@@ -79,15 +94,6 @@ class TextClassificationTrainer(AbstractTrainer):
|
|
79 |
|
80 |
tokenized_dataset = dataset.map(preprocess_function, batched=True)
|
81 |
|
82 |
-
# Extract the labels
|
83 |
-
labels = tokenized_dataset['train'].unique('target')
|
84 |
-
label2id, id2label = dict(), dict()
|
85 |
-
for i, label in enumerate(labels):
|
86 |
-
label2id[label] = i
|
87 |
-
id2label[i] = label
|
88 |
-
|
89 |
-
logger.info(id2label)
|
90 |
-
|
91 |
# Rename the Target column to labels and remove unnecessary columns
|
92 |
tokenized_dataset = tokenized_dataset.rename_column('target', 'labels')
|
93 |
|
|
|
44 |
|
45 |
self.__train_model(tokenized_dataset, labels, label2id, id2label, parameters)
|
46 |
|
47 |
+
if(self.get_status().is_training_aborted()):
|
48 |
+
return
|
49 |
+
|
50 |
self.get_status().update_status(100, "Training completed")
|
51 |
|
52 |
except Exception as e:
|
|
|
69 |
dataset = load_dataset('csv', data_files=parameters.get_training_csv_file_path(), delimiter=parameters.get_training_csv_limiter())
|
70 |
|
71 |
dataset = dataset["train"]
|
72 |
+
|
73 |
+
# Extract the labels
|
74 |
+
#labels = tokenized_dataset['train'].unique('target')
|
75 |
+
labels = dataset.unique('target')
|
76 |
+
label2id, id2label = dict(), dict()
|
77 |
+
for i, label in enumerate(labels):
|
78 |
+
label2id[label] = i
|
79 |
+
id2label[i] = label
|
80 |
+
|
81 |
+
logger.info(id2label)
|
82 |
+
|
83 |
+
|
84 |
dataset = dataset.train_test_split(test_size=0.2)
|
85 |
|
86 |
logger.info(dataset)
|
|
|
94 |
|
95 |
tokenized_dataset = dataset.map(preprocess_function, batched=True)
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
# Rename the Target column to labels and remove unnecessary columns
|
98 |
tokenized_dataset = tokenized_dataset.rename_column('target', 'labels')
|
99 |
|