import os import json import argparse import logging from pathlib import Path import torch logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) log = logging.getLogger(__name__) def get_args(): parser = argparse.ArgumentParser(description='ESPER') parser.add_argument( '--init-model', type=str, default='gpt2', help='language model used for policy.') parser.add_argument( '--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path') parser.add_argument( '--checkpoint', type=str, default='./data/esper_demo/ckpt', help='checkpoint file path') parser.add_argument( '--prefix_length', type=int, default=10, help='prefix length for the visual mapper') parser.add_argument( '--clipcap_num_layers', type=int, default=1, help='num_layers for the visual mapper') parser.add_argument( '--use_transformer_mapper', action='store_true', default=False, help='use transformer mapper instead of mlp') parser.add_argument( '--use_label_prefix', action='store_true', default=False, help='label as prefixes') parser.add_argument( '--clip_model_type', type=str, default='ViT-B/32', help='clip backbone type') parser.add_argument( '--infer_no_repeat_size', type=int, default=2, help="no repeat ngram size for inference") parser.add_argument( '--response-length', type=int, default=20, help='number of tokens to generate for each prompt.') parser.add_argument( '--port', type=int, default=None, help="port for the demo server") args = parser.parse_args() args.cuda = torch.cuda.is_available() if args.use_label_prefix: log.info(f'using label prefix') if args.checkpoint is not None: args.checkpoint = str(Path(args.checkpoint).resolve()) return args