File size: 2,436 Bytes
b06fb83
 
 
d705756
b06fb83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d705756
b06fb83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

# Preprocess functions
preprocess_code = """
def preprocess(text):
    import re
    import string
    import spacy

    try:

        # Checking if it the string
        text = str(text)

        # remove html
        text = re.sub(r"<.*?>","", text)
        
        # Remove URL
        url_pattern = r"https?://\\S+|www\\.\\S+|\\S+\\.\\S{2,"
        text = re.sub(url_pattern,"", text)

        # Remove Punctuation
        translator = str.maketrans("","", string.punctuation)
        text.translate(translator)

        # Lower case
        text.lower().strip()

        # Remove Unicodes - only applicable for english language. Because other language letters represented as unicodes.
        unicode_pattern = str.maketrans("","","\\xa0")
        text.translate(unicode_pattern)

        # Remove Escape sequences (\\n, \\t, \\r)
        text = re.sub(r"\\[nt\\r]"," ",text)

        # Remove Stop words using spacy

        spacy.prefer_gpu() # using GPU if available. may reduce the run time.
        nlp = spacy.load("en_core_web_sm")
        doc = nlp(text)
        text = " ".join([token.text for token in doc if not token.is_stop])

        # Remove irrelevant white spaces
        text = re.sub(r"\\s+"," ",text)
    except:
        print(f"error occured")
    
    return text
"""


postprocess_code = """
def post_process(output):
    import torch
    classes = ['ACCOUNTANT', 'ADVOCATE', 'AGRICULTURE', 'APPAREL', 'ARTS', 'AUTOMOBILE', 'AVIATION', 'BANKING', 'BPO', 'BUSINESS-DEVELOPMENT', 'CHEF', 'CONSTRUCTION', 'CONSULTANT', 'DESIGNER', 'DIGITAL-MEDIA', 'ENGINEERING', 'FINANCE', 'FITNESS', 'HEALTHCARE', 'HR', 'INFORMATION-TECHNOLOGY', 'PUBLIC-RELATIONS', 'SALES', 'TEACHER']
    try:
        logits = output.logits
        sigmoid = torch.nn.Sigmoid()
        probs = sigmoid(logits.squeeze().cpu())
        temp = probs.sort()
        return classes[temp[-1][-1].item()]
    except:
        print("Some Error occured")
"""



from transformers import PretrainedConfig, AutoModel

class CustomConfig(PretrainedConfig):

    def __init__(self, preprocess_function = None, postprocess_function = None, **kwargs):
        super().__init__(**kwargs)
        self.preprocess_function = preprocess_function
        self.postprocess_function = postprocess_function

    
config = CustomConfig(preprocess_function= preprocess_code, postprocess_function=postprocess_code)



config.save_pretrained("config with functions")