File size: 2,572 Bytes
42d6a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
"""
Created on Sat Oct 29 14:56:09 2022

@author: Santiago Moreno
"""
import os 
import argparse
from functions import json_to_txt, training_model, characterize_data, upsampling_data, str2bool, usage_cuda,copy_data
default_path = os.path.dirname(os.path.abspath(__file__))
os.chdir(default_path)

        

if __name__ == '__main__':
    parser = argparse.ArgumentParser(add_help=True, usage='Train a new model with given data (GPU optional)')
    parser.add_argument('-f','--fast', type=str2bool, nargs='?',const=True, default=False, help='Training fast option (Only for functioning test)', choices=(True, False), required=False)
    parser.add_argument('-m','--model', type=str, nargs='?', help='New model name', required=True)
    parser.add_argument('-s','--standard', type=str2bool, nargs='?',const=True, default=False, help='Standard CONLL input or not', choices=(True, False), required=False)
    parser.add_argument('-id','--input_dir', type=str, nargs='?', help='Absolute path input directory', required=True)
    parser.add_argument('-u','--up_sample_flag', type=str2bool, nargs='?',const=True, default=False , help='Boolean value to upsampling the data = True or not upsampling = False', required=False, choices=(True, False))
    parser.add_argument('-cu','--cuda', type=str2bool, nargs='?', const=True, default=False, help='Boolean value for using cuda to Train the model (True). By defaul False.', choices=(True, False), required=False)

    args = parser.parse_args()
    

    if args.fast: epochs = 1
    else: epochs = 20
    
    if args.standard:
        copy_data(args.input_dir)
        not_error=True
    else:
        Error = json_to_txt(args.input_dir)
        if type(Error)==int:
            print('Error processing the input documents, code error {}'.format(Error))
            not_error=False
        else:
            not_error=True

    if not_error:
        if args.up_sample_flag:
            entities_dict=characterize_data()
            entities = list(entities_dict.keys())
            entities_to_upsample = [entities[i] for i,value in enumerate(entities_dict.values()) if value < 200]
            upsampling_data(entities_to_upsample, 0.8,  entities)
            
        if args.cuda: cuda_info = usage_cuda(True)
        else: cuda_info = usage_cuda(False)
        
        print(cuda_info)
        
        Error = training_model(args.model,epochs)
        if type(Error)==int:
            print('Error training the model, code error {}'.format(Error))
        else: 
            print('Training complete')