| import torch | |
| import os | |
| import time | |
| from pt_variety_identifier.src.utils import setup_logger, create_output_dir | |
| from pt_variety_identifier.src.bert.data import Data | |
| from tqdm import tqdm | |
| from pt_variety_identifier.src.tunning import Tunning | |
| from pt_variety_identifier.src.bert.trainer import Trainer | |
| from pt_variety_identifier.src.bert.tester import Tester | |
| from pt_variety_identifier.src.bert.results import Results | |
| from pt_variety_identifier.src.bert.model import EnsembleIdentfier, LanguageIdentfier | |
| import torch.multiprocessing as mp | |
| from threading import Thread | |
| import logging | |
| import numpy as np | |
| class Run: | |
| def __init__(self, dataset_name, tokenizer_name, model_name, batch_size, test_set_list) -> None: | |
| self.CURRENT_PATH = os.path.dirname(os.path.abspath(__file__)) | |
| self.CURRENT_TIME = int(time.time()) | |
| self.num_gpus = torch.cuda.device_count() | |
| self.sem = mp.Semaphore(self.num_gpus) | |
| self.gpus_free = [i for i in range(self.num_gpus)] | |
| self.test_set_list = test_set_list | |
| create_output_dir(self.CURRENT_PATH, self.CURRENT_TIME) | |
| setup_logger(self.CURRENT_PATH, self.CURRENT_TIME) | |
| self.data = Data( | |
| dataset_name, tokenizer_name=tokenizer_name, batch_size=batch_size, test_set_list=test_set_list) | |
| self._DOMAINS = ['literature', 'legal', 'politics', 'web', 'social_media', 'journalistic'] | |
| self.model_name = model_name | |
| tqdm.pandas() | |
| def tune_with_gpu(self): | |
| threads = [] | |
| for pos_prob in tqdm(range(np.arange(0.0, 1.0, 0.1))): | |
| for ner_prob in tqdm(range(np.arange(0.0, 1.0, 0.2))): | |
| pos_prob = round(pos_prob, 2) | |
| ner_prob = round(ner_prob, 2) | |
| self.sem.acquire() | |
| gpu_in_use = self.gpus_free.pop() | |
| tuner = Tunning(self.data, self._DOMAINS, | |
| Results, Trainer, Tester, 5_000, | |
| self.CURRENT_PATH, self.CURRENT_TIME, | |
| params={ | |
| 'epochs': 30, | |
| 'early_stoping': 5, | |
| 'model_name': self.model_name, | |
| 'device': f"cuda:{gpu_in_use}", | |
| 'sem': self.sem, | |
| 'gpus_free': self.gpus_free, | |
| }) | |
| thread = Thread(target=tuner.run, args=( | |
| pos_prob, pos_prob, ner_prob, ner_prob), daemon=True | |
| ) | |
| threads.append(thread) | |
| for t in threads: | |
| t.join() | |
| def tune_with_cpu(self): | |
| tuner = Tunning(self.data, self._DOMAINS, | |
| Results, Trainer, Tester, 5_000, | |
| self.CURRENT_PATH, self.CURRENT_TIME, | |
| params={ | |
| 'epochs': 30, | |
| 'early_stoping': 5, | |
| 'model_name': self.model_name, | |
| 'device': 'cpu', | |
| }) | |
| tuner.run() | |
| def tune(self): | |
| if torch.cuda.is_available(): | |
| return self.tune_with_gpu() | |
| return self.tune_with_cpu() | |
| def _train_domain(self, domain, gpu): | |
| logging.info(f"Training {domain} domain") | |
| data = self.data.load_domain(domain, balance=True, pos_prob=None, ner_prob=None) | |
| validation_dataset_dict = self.data.load_validation_set() | |
| """ | |
| logging.info(f"Removing non training domains from validation set") | |
| validation_dataset_dict = { | |
| domain: validation_dataset_dict[domain] | |
| } | |
| """ | |
| trainer = Trainer(data, params={ | |
| 'epochs': 30, | |
| 'early_stoping': 5, | |
| 'model_name': self.model_name, | |
| 'device': gpu, | |
| 'CURRENT_PATH': self.CURRENT_PATH, | |
| 'CURRENT_TIME': self.CURRENT_TIME, | |
| 'training_domain': domain, | |
| },validation_dataset_dict=validation_dataset_dict) | |
| best_results = trainer.train() | |
| logging.info(f"Best results for {domain} domain: {best_results}") | |
| logging.info(f"Freeing cuda:{gpu[-1]}") | |
| self.gpus_free.append(gpu[-1]) | |
| return self.sem.release() | |
| def train(self): | |
| threads = [] | |
| for domain in ['all']: | |
| self.sem.acquire() | |
| gpu_in_use = self.gpus_free.pop() | |
| thread = Thread(target=self._train_domain, args=(domain, f"cuda:{gpu_in_use}"), daemon=True) | |
| threads.append(thread) | |
| thread.start() | |
| for t in threads: | |
| t.join() | |
| def test(self): | |
| model = LanguageIdentfier(self.model_name) | |
| logging.info(f"Loading model from {os.path.join(self.CURRENT_PATH, 'out', str(self.CURRENT_TIME), 'models', 'all.pt')}") | |
| model.load_state_dict(torch.load(os.path.join(self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models", "all.pt"))) | |
| model.eval() | |
| model.to('cuda') | |
| data = self.data.load_test_set(filter_label_2=True) | |
| tester = Tester(data, model, None) | |
| results = tester.validate() | |
| logging.info(f"Results for all: {results}") | |
| def test_ensemble(self): | |
| data = self.data.load_test_set(filter_label_2=True) | |
| ensemble = EnsembleIdentfier(os.path.join(self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models"), self.model_name) | |
| tester = Tester(data, ensemble, None) | |
| results = tester.test() | |
| logging.info(f"Results for ensemble: {results}") | |