File size: 1,911 Bytes
2ec0615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from transformers import GPT2LMHeadModel
from transformers import PreTrainedTokenizerFast
import os


class LoadModel:
    """
    Example usage:

    # if loading model and tokenizer from Huggingface
    model_repo = "misnaej/the-jam-machine"
    model, tokenizer = LoadModel(
        model_repo, from_huggingface=True
    ).load_model_and_tokenizer()

    # if loading model and tokenizer from a local folder
    model_path = "models/model_2048_wholedataset"
    model, tokenizer = LoadModel(
        model_path, from_huggingface=False
    ).load_model_and_tokenizer()

    """

    def __init__(self, path, from_huggingface=True, device="cpu", revision=None):
        # path is either a relative path on a local/remote machine or a model repo on HuggingFace
        if not from_huggingface:
            if not os.path.exists(path):
                print(path)
                raise Exception("Model path does not exist")
        self.from_huggingface = from_huggingface
        self.path = path
        self.device = device
        self.revision = revision

    def load_model_and_tokenizer(self):
        model = self.load_model()
        tokenizer = self.load_tokenizer()

        return model, tokenizer

    def load_model(self):
        if self.revision is None:
            model = GPT2LMHeadModel.from_pretrained(self.path).to(self.device)
        else:
            model = GPT2LMHeadModel.from_pretrained(
                self.path, revision=self.revision
            ).to(self.device)

        return model

    def load_tokenizer(self):
        if self.from_huggingface:
            pass
        else:
            if not os.path.exists(f"{self.path}/tokenizer.json"):
                raise Exception(
                    f"There is no 'tokenizer.json'file in the defined {self.path}"
                )
        tokenizer = PreTrainedTokenizerFast.from_pretrained(self.path)
        return tokenizer