File size: 3,992 Bytes
ea847ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c9258d
ea847ad
 
 
 
 
7b1ae8d
 
c1f3687
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import annotations
import argparse
import datetime
import json
import os
from operator import itemgetter
from pathlib import Path
from typing import Callable

import requests
import pynvml
import PIL.Image
import torch
from pydantic import BaseSettings, BaseModel


class Args(BaseSettings):

    @classmethod
    def from_args(cls):
        parser = argparse.ArgumentParser()
        for field in cls.__fields__.values():
            if issubclass(field.type_, BaseModel):
                prefix = field.type_.__name__.lower()
                for subfield in field.type_.__fields__.values():
                    short = "".join([x[0] for x in subfield.name.split("_")])
                    parser.add_argument(f"--{prefix}.{subfield.name}", default=subfield.default, required=subfield.required)
            else:
                short = "".join([x[0] for x in field.name.split("_")])
                parser.add_argument(f"-{short}", f"--{field.name}", default=field.default, required=field.required)
        args = vars(parser.parse_known_args()[0])
        to_delete = set()
        for field in cls.__fields__.values():
            if issubclass(field.type_, BaseModel):
                prefix = field.type_.__name__.lower()
                sub_args = {}
                for k, v in args.items():
                    if k.startswith(prefix):
                        to_delete.add(k)
                        sub_args[k.replace(f"{prefix}.", "")] = v
                args[field.name] = sub_args
        args = {k: v for k, v in args.items() if k not in to_delete}
        return cls(**args)

    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"
        env_prefix = "ARG_"


def inject_args(func: Callable) -> Callable:
    """Decorates a function to inject the arguments."""
    
    injected = None
    for type_ in func.__annotations__.values():
        if issubclass(type_, Args):
            injected = type_.from_args()
            break

    if injected is None:
        raise ValueError(f"Function {func.__name__} is not annotated with an Args subclass.")

    def wrapper(*args, **kwargs):
        return func(injected, *args, **kwargs)

    return wrapper


def get_free_gpu() -> int:
    pynvml.nvmlInit()
    total = torch.cuda.device_count()
    gpus = []
    for i in range(total):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        gpus.append((i, info.free))
    gpus = sorted(gpus, key=itemgetter(1), reverse=True)
    return gpus[0][0]


def get_user_name() -> str:
    return Path(os.environ["HOME"]).stem


def get_storage_dir() -> Path:
    return Path(f"/fsx/{get_user_name()}")


def get_checkpoints_dir(*, timestamp: bool) -> Path:
    base_dir = get_storage_dir()/"checkpoints"
    return Path(f"{base_dir}/{now()}") if timestamp else base_dir


def now() -> str:
    return datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S")


def read_jsonl(path: Path) -> list:
    return [json.loads(x) for x in Path(path).read_text().split("\n") if x]


def write_jsonl(path: Path, data: list):
    with Path(path).open("w") as f:
        for x in data:
            f.write(json.dumps(x) + "\n")


def get_image(url: str, filename: Path | None = None):
    if filename is None: filename = Path(f"{url.split('/')[-1]}.jpg")
    filename = Path(filename)
    if filename.exists(): return filename
    PIL.Image.open(requests.get(url, stream=True).raw).save(filename)
    return filename


def find_latest_checkpoint(dirname: Path) -> Path:
    dirname = Path(dirname)
    checkpoints = list(dirname.glob("*.ckpt"))
    if not checkpoints:
        return None
    latest = max(checkpoints, key=lambda path: path.stat().st_mtime)
    return latest


def list_files(dirname: Path, exts: list[str] | None = None) -> list:
    files = Path(dirname).iterdir()
    if not exts:
        return list(files)
    return [fn for fn in files for ext in exts if fn.match(f"*.{ext}")]