LCA-PORVID's picture
Upload 34 files
ebdb5af verified
import torch
import os
import time
from pt_variety_identifier.src.utils import setup_logger, create_output_dir
from pt_variety_identifier.src.bert.data import Data
from tqdm import tqdm
from pt_variety_identifier.src.tunning import Tunning
from pt_variety_identifier.src.bert.trainer import Trainer
from pt_variety_identifier.src.bert.tester import Tester
from pt_variety_identifier.src.bert.results import Results
from pt_variety_identifier.src.bert.model import EnsembleIdentfier, LanguageIdentfier
import torch.multiprocessing as mp
from threading import Thread
import logging
import numpy as np
class Run:
def __init__(self, dataset_name, tokenizer_name, model_name, batch_size, test_set_list) -> None:
self.CURRENT_PATH = os.path.dirname(os.path.abspath(__file__))
self.CURRENT_TIME = int(time.time())
self.num_gpus = torch.cuda.device_count()
self.sem = mp.Semaphore(self.num_gpus)
self.gpus_free = [i for i in range(self.num_gpus)]
self.test_set_list = test_set_list
create_output_dir(self.CURRENT_PATH, self.CURRENT_TIME)
setup_logger(self.CURRENT_PATH, self.CURRENT_TIME)
self.data = Data(
dataset_name, tokenizer_name=tokenizer_name, batch_size=batch_size, test_set_list=test_set_list)
self._DOMAINS = ['literature', 'legal', 'politics', 'web', 'social_media', 'journalistic']
self.model_name = model_name
tqdm.pandas()
def tune_with_gpu(self):
threads = []
for pos_prob in tqdm(range(np.arange(0.0, 1.0, 0.1))):
for ner_prob in tqdm(range(np.arange(0.0, 1.0, 0.2))):
pos_prob = round(pos_prob, 2)
ner_prob = round(ner_prob, 2)
self.sem.acquire()
gpu_in_use = self.gpus_free.pop()
tuner = Tunning(self.data, self._DOMAINS,
Results, Trainer, Tester, 5_000,
self.CURRENT_PATH, self.CURRENT_TIME,
params={
'epochs': 30,
'early_stoping': 5,
'model_name': self.model_name,
'device': f"cuda:{gpu_in_use}",
'sem': self.sem,
'gpus_free': self.gpus_free,
})
thread = Thread(target=tuner.run, args=(
pos_prob, pos_prob, ner_prob, ner_prob), daemon=True
)
threads.append(thread)
for t in threads:
t.join()
def tune_with_cpu(self):
tuner = Tunning(self.data, self._DOMAINS,
Results, Trainer, Tester, 5_000,
self.CURRENT_PATH, self.CURRENT_TIME,
params={
'epochs': 30,
'early_stoping': 5,
'model_name': self.model_name,
'device': 'cpu',
})
tuner.run()
def tune(self):
if torch.cuda.is_available():
return self.tune_with_gpu()
return self.tune_with_cpu()
def _train_domain(self, domain, gpu):
logging.info(f"Training {domain} domain")
data = self.data.load_domain(domain, balance=True, pos_prob=None, ner_prob=None)
validation_dataset_dict = self.data.load_validation_set()
"""
logging.info(f"Removing non training domains from validation set")
validation_dataset_dict = {
domain: validation_dataset_dict[domain]
}
"""
trainer = Trainer(data, params={
'epochs': 30,
'early_stoping': 5,
'model_name': self.model_name,
'device': gpu,
'CURRENT_PATH': self.CURRENT_PATH,
'CURRENT_TIME': self.CURRENT_TIME,
'training_domain': domain,
},validation_dataset_dict=validation_dataset_dict)
best_results = trainer.train()
logging.info(f"Best results for {domain} domain: {best_results}")
logging.info(f"Freeing cuda:{gpu[-1]}")
self.gpus_free.append(gpu[-1])
return self.sem.release()
def train(self):
threads = []
for domain in ['all']:
self.sem.acquire()
gpu_in_use = self.gpus_free.pop()
thread = Thread(target=self._train_domain, args=(domain, f"cuda:{gpu_in_use}"), daemon=True)
threads.append(thread)
thread.start()
for t in threads:
t.join()
def test(self):
model = LanguageIdentfier(self.model_name)
logging.info(f"Loading model from {os.path.join(self.CURRENT_PATH, 'out', str(self.CURRENT_TIME), 'models', 'all.pt')}")
model.load_state_dict(torch.load(os.path.join(self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models", "all.pt")))
model.eval()
model.to('cuda')
data = self.data.load_test_set(filter_label_2=True)
tester = Tester(data, model, None)
results = tester.validate()
logging.info(f"Results for all: {results}")
def test_ensemble(self):
data = self.data.load_test_set(filter_label_2=True)
ensemble = EnsembleIdentfier(os.path.join(self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models"), self.model_name)
tester = Tester(data, ensemble, None)
results = tester.test()
logging.info(f"Results for ensemble: {results}")