Joshua Lochner
Add cache system for loading classifiers/vectorizers
4678c9b
raw
history blame
4.21 kB
import pickle
import os
from shared import CustomTokens
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
default='google/t5-v1_1-small', # t5-small
metadata={
'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
# config_name: Optional[str] = field( # TODO remove?
# default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
# )
tokenizer_name: Optional[str] = field(
default=None, metadata={'help': 'Pretrained tokenizer name or path if not the same as model_name'}
)
cache_dir: Optional[str] = field(
default=None,
metadata={
'help': 'Where to store the pretrained models downloaded from huggingface.co'},
)
use_fast_tokenizer: bool = field( # TODO remove?
default=True,
metadata={
'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'},
)
model_revision: str = field( # TODO remove?
default='main',
metadata={
'help': 'The specific model version to use (can be a branch name, tag name or commit id).'},
)
use_auth_token: bool = field(
default=False,
metadata={
'help': 'Will use the token generated when running `transformers-cli login` (necessary to use this script '
'with private models).'
},
)
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
'help': "Whether to automatically resize the position embeddings if `max_source_length` exceeds the model's position embeddings."
},
)
def get_model(model_args, use_cache=True):
name = model_args.model_name_or_path
cached_path = f'models/{name}'
# Model created after tokenizer:
if use_cache and os.path.exists(os.path.join(cached_path, 'pytorch_model.bin')):
name = cached_path
config = AutoConfig.from_pretrained(
name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForSeq2SeqLM.from_pretrained(
name,
from_tf='.ckpt' in name,
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
return model
def get_tokenizer(model_args, use_cache=True):
name = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
cached_path = f'models/{name}'
if use_cache and os.path.exists(os.path.join(cached_path, 'tokenizer.json')):
name = cached_path
tokenizer = AutoTokenizer.from_pretrained(
name,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
CustomTokens.add_custom_tokens(tokenizer)
return tokenizer
CLASSIFIER_CACHE = {}
def get_classifier_vectorizer(classifier_args, use_cache=True):
classifier_path = os.path.join(classifier_args.classifier_dir, classifier_args.classifier_file)
if use_cache and classifier_path in CLASSIFIER_CACHE:
classifier = CLASSIFIER_CACHE[classifier_path]
else:
with open(classifier_path, 'rb') as fp:
classifier = CLASSIFIER_CACHE[classifier_path] = pickle.load(fp)
vectorizer_path = os.path.join(classifier_args.classifier_dir, classifier_args.vectorizer_file)
if use_cache and vectorizer_path in CLASSIFIER_CACHE:
vectorizer = CLASSIFIER_CACHE[vectorizer_path]
else:
with open(vectorizer_path, 'rb') as fp:
vectorizer = CLASSIFIER_CACHE[vectorizer_path] = pickle.load(fp)
return classifier, vectorizer