| import os |
| import sys |
| from argparse import ArgumentParser |
| from getpass import getpass |
| from typing import List, Union |
|
|
| from requests.exceptions import HTTPError |
|
|
| from transformers.commands import BaseTransformersCLICommand |
| from transformers.hf_api import HfApi, HfFolder |
|
|
|
|
| UPLOAD_MAX_FILES = 15 |
|
|
|
|
| class UserCommands(BaseTransformersCLICommand): |
| @staticmethod |
| def register_subcommand(parser: ArgumentParser): |
| login_parser = parser.add_parser("login", help="Log in using the same credentials as on huggingface.co") |
| login_parser.set_defaults(func=lambda args: LoginCommand(args)) |
| whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.") |
| whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) |
| logout_parser = parser.add_parser("logout", help="Log out") |
| logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) |
| |
| s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.") |
| s3_subparsers = s3_parser.add_subparsers(help="s3 related commands") |
| ls_parser = s3_subparsers.add_parser("ls") |
| ls_parser.set_defaults(func=lambda args: ListObjsCommand(args)) |
| rm_parser = s3_subparsers.add_parser("rm") |
| rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.") |
| rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args)) |
| |
| upload_parser = parser.add_parser("upload") |
| upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.") |
| upload_parser.add_argument( |
| "--filename", type=str, default=None, help="Optional: override individual object filename on S3." |
| ) |
| upload_parser.set_defaults(func=lambda args: UploadCommand(args)) |
|
|
|
|
| class ANSI: |
| """ |
| Helper for en.wikipedia.org/wiki/ANSI_escape_code |
| """ |
|
|
| _bold = "\u001b[1m" |
| _reset = "\u001b[0m" |
|
|
| @classmethod |
| def bold(cls, s): |
| return "{}{}{}".format(cls._bold, s, cls._reset) |
|
|
|
|
| class BaseUserCommand: |
| def __init__(self, args): |
| self.args = args |
| self._api = HfApi() |
|
|
|
|
| class LoginCommand(BaseUserCommand): |
| def run(self): |
| print( |
| """ |
| _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| |
| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| |
| _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| |
| _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| |
| _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| |
| |
| """ |
| ) |
| username = input("Username: ") |
| password = getpass() |
| try: |
| token = self._api.login(username, password) |
| except HTTPError as e: |
| |
| print(e) |
| exit(1) |
| HfFolder.save_token(token) |
| print("Login successful") |
| print("Your token:", token, "\n") |
| print("Your token has been saved to", HfFolder.path_token) |
|
|
|
|
| class WhoamiCommand(BaseUserCommand): |
| def run(self): |
| token = HfFolder.get_token() |
| if token is None: |
| print("Not logged in") |
| exit() |
| try: |
| user = self._api.whoami(token) |
| print(user) |
| except HTTPError as e: |
| print(e) |
|
|
|
|
| class LogoutCommand(BaseUserCommand): |
| def run(self): |
| token = HfFolder.get_token() |
| if token is None: |
| print("Not logged in") |
| exit() |
| HfFolder.delete_token() |
| self._api.logout(token) |
| print("Successfully logged out.") |
|
|
|
|
| class ListObjsCommand(BaseUserCommand): |
| def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str: |
| """ |
| Inspired by: |
| stackoverflow.com/a/8356620/593036 |
| stackoverflow.com/questions/9535954/printing-lists-as-tabular-data |
| """ |
| col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] |
| row_format = ("{{:{}}} " * len(headers)).format(*col_widths) |
| lines = [] |
| lines.append(row_format.format(*headers)) |
| lines.append(row_format.format(*["-" * w for w in col_widths])) |
| for row in rows: |
| lines.append(row_format.format(*row)) |
| return "\n".join(lines) |
|
|
| def run(self): |
| token = HfFolder.get_token() |
| if token is None: |
| print("Not logged in") |
| exit(1) |
| try: |
| objs = self._api.list_objs(token) |
| except HTTPError as e: |
| print(e) |
| exit(1) |
| if len(objs) == 0: |
| print("No shared file yet") |
| exit() |
| rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs] |
| print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])) |
|
|
|
|
| class DeleteObjCommand(BaseUserCommand): |
| def run(self): |
| token = HfFolder.get_token() |
| if token is None: |
| print("Not logged in") |
| exit(1) |
| try: |
| self._api.delete_obj(token, filename=self.args.filename) |
| except HTTPError as e: |
| print(e) |
| exit(1) |
| print("Done") |
|
|
|
|
| class UploadCommand(BaseUserCommand): |
| def walk_dir(self, rel_path): |
| """ |
| Recursively list all files in a folder. |
| """ |
| entries: List[os.DirEntry] = list(os.scandir(rel_path)) |
| files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] |
| for f in entries: |
| if f.is_dir(): |
| files += self.walk_dir(f.path) |
| return files |
|
|
| def run(self): |
| token = HfFolder.get_token() |
| if token is None: |
| print("Not logged in") |
| exit(1) |
| local_path = os.path.abspath(self.args.path) |
| if os.path.isdir(local_path): |
| if self.args.filename is not None: |
| raise ValueError("Cannot specify a filename override when uploading a folder.") |
| rel_path = os.path.basename(local_path) |
| files = self.walk_dir(rel_path) |
| elif os.path.isfile(local_path): |
| filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path) |
| files = [(local_path, filename)] |
| else: |
| raise ValueError("Not a valid file or directory: {}".format(local_path)) |
|
|
| if sys.platform == "win32": |
| files = [(filepath, filename.replace(os.sep, "/")) for filepath, filename in files] |
|
|
| if len(files) > UPLOAD_MAX_FILES: |
| print( |
| "About to upload {} files to S3. This is probably wrong. Please filter files before uploading.".format( |
| ANSI.bold(len(files)) |
| ) |
| ) |
| exit(1) |
|
|
| for filepath, filename in files: |
| print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename))) |
|
|
| choice = input("Proceed? [Y/n] ").lower() |
| if not (choice == "" or choice == "y" or choice == "yes"): |
| print("Abort") |
| exit() |
| print(ANSI.bold("Uploading... This might take a while if files are large")) |
| for filepath, filename in files: |
| access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath) |
| print("Your file now lives at:") |
| print(access_url) |
|
|