Joshua Lochner
Add source code
5fbdd3c
raw history blame
No virus
3.76 kB
import pickle
import os
from shared import CustomTokens
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
default='google/t5-v1_1-small', # t5-small
metadata={
'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
# config_name: Optional[str] = field( # TODO remove?
# default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
# )
tokenizer_name: Optional[str] = field(
default=None, metadata={'help': 'Pretrained tokenizer name or path if not the same as model_name'}
)
cache_dir: Optional[str] = field(
default=None,
metadata={
'help': 'Where to store the pretrained models downloaded from huggingface.co'},
)
use_fast_tokenizer: bool = field( # TODO remove?
default=True,
metadata={
'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'},
)
model_revision: str = field( # TODO remove?
default='main',
metadata={
'help': 'The specific model version to use (can be a branch name, tag name or commit id).'},
)
use_auth_token: bool = field(
default=False,
metadata={
'help': 'Will use the token generated when running `transformers-cli login` (necessary to use this script '
'with private models).'
},
)
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
'help': "Whether to automatically resize the position embeddings if `max_source_length` exceeds the model's position embeddings."
},
)
def get_model(model_args, use_cache=True):
name = model_args.model_name_or_path
cached_path = f'models/{name}'
# Model created after tokenizer:
if use_cache and os.path.exists(os.path.join(cached_path, 'pytorch_model.bin')):
name = cached_path
config = AutoConfig.from_pretrained(
name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForSeq2SeqLM.from_pretrained(
name,
from_tf='.ckpt' in name,
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
return model
def get_tokenizer(model_args, use_cache=True):
name = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
cached_path = f'models/{name}'
if use_cache and os.path.exists(os.path.join(cached_path, 'tokenizer.json')):
name = cached_path
tokenizer = AutoTokenizer.from_pretrained(
name,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
CustomTokens.add_custom_tokens(tokenizer)
return tokenizer
def get_classifier_vectorizer(classifier_args):
with open(os.path.join(classifier_args.classifier_dir, classifier_args.classifier_file), 'rb') as fp:
classifier = pickle.load(fp)
with open(os.path.join(classifier_args.classifier_dir, classifier_args.vectorizer_file), 'rb') as fp:
vectorizer = pickle.load(fp)
return classifier, vectorizer