File size: 1,911 Bytes
2ec0615 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from transformers import GPT2LMHeadModel
from transformers import PreTrainedTokenizerFast
import os
class LoadModel:
"""
Example usage:
# if loading model and tokenizer from Huggingface
model_repo = "misnaej/the-jam-machine"
model, tokenizer = LoadModel(
model_repo, from_huggingface=True
).load_model_and_tokenizer()
# if loading model and tokenizer from a local folder
model_path = "models/model_2048_wholedataset"
model, tokenizer = LoadModel(
model_path, from_huggingface=False
).load_model_and_tokenizer()
"""
def __init__(self, path, from_huggingface=True, device="cpu", revision=None):
# path is either a relative path on a local/remote machine or a model repo on HuggingFace
if not from_huggingface:
if not os.path.exists(path):
print(path)
raise Exception("Model path does not exist")
self.from_huggingface = from_huggingface
self.path = path
self.device = device
self.revision = revision
def load_model_and_tokenizer(self):
model = self.load_model()
tokenizer = self.load_tokenizer()
return model, tokenizer
def load_model(self):
if self.revision is None:
model = GPT2LMHeadModel.from_pretrained(self.path).to(self.device)
else:
model = GPT2LMHeadModel.from_pretrained(
self.path, revision=self.revision
).to(self.device)
return model
def load_tokenizer(self):
if self.from_huggingface:
pass
else:
if not os.path.exists(f"{self.path}/tokenizer.json"):
raise Exception(
f"There is no 'tokenizer.json'file in the defined {self.path}"
)
tokenizer = PreTrainedTokenizerFast.from_pretrained(self.path)
return tokenizer
|