realfake / realfake /utils.py
devforfu
Movie stills binary classifier
c1f3687
from __future__ import annotations
import argparse
import datetime
import json
import os
from operator import itemgetter
from pathlib import Path
from typing import Callable
import requests
import pynvml
import PIL.Image
import torch
from pydantic import BaseSettings, BaseModel
class Args(BaseSettings):
@classmethod
def from_args(cls):
parser = argparse.ArgumentParser()
for field in cls.__fields__.values():
if issubclass(field.type_, BaseModel):
prefix = field.type_.__name__.lower()
for subfield in field.type_.__fields__.values():
short = "".join([x[0] for x in subfield.name.split("_")])
parser.add_argument(f"--{prefix}.{subfield.name}", default=subfield.default, required=subfield.required)
else:
short = "".join([x[0] for x in field.name.split("_")])
parser.add_argument(f"-{short}", f"--{field.name}", default=field.default, required=field.required)
args = vars(parser.parse_known_args()[0])
to_delete = set()
for field in cls.__fields__.values():
if issubclass(field.type_, BaseModel):
prefix = field.type_.__name__.lower()
sub_args = {}
for k, v in args.items():
if k.startswith(prefix):
to_delete.add(k)
sub_args[k.replace(f"{prefix}.", "")] = v
args[field.name] = sub_args
args = {k: v for k, v in args.items() if k not in to_delete}
return cls(**args)
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
env_prefix = "ARG_"
def inject_args(func: Callable) -> Callable:
"""Decorates a function to inject the arguments."""
injected = None
for type_ in func.__annotations__.values():
if issubclass(type_, Args):
injected = type_.from_args()
break
if injected is None:
raise ValueError(f"Function {func.__name__} is not annotated with an Args subclass.")
def wrapper(*args, **kwargs):
return func(injected, *args, **kwargs)
return wrapper
def get_free_gpu() -> int:
pynvml.nvmlInit()
total = torch.cuda.device_count()
gpus = []
for i in range(total):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpus.append((i, info.free))
gpus = sorted(gpus, key=itemgetter(1), reverse=True)
return gpus[0][0]
def get_user_name() -> str:
return Path(os.environ["HOME"]).stem
def get_storage_dir() -> Path:
return Path(f"/fsx/{get_user_name()}")
def get_checkpoints_dir(*, timestamp: bool) -> Path:
base_dir = get_storage_dir()/"checkpoints"
return Path(f"{base_dir}/{now()}") if timestamp else base_dir
def now() -> str:
return datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S")
def read_jsonl(path: Path) -> list:
return [json.loads(x) for x in Path(path).read_text().split("\n") if x]
def write_jsonl(path: Path, data: list):
with Path(path).open("w") as f:
for x in data:
f.write(json.dumps(x) + "\n")
def get_image(url: str, filename: Path | None = None):
if filename is None: filename = Path(f"{url.split('/')[-1]}.jpg")
filename = Path(filename)
if filename.exists(): return filename
PIL.Image.open(requests.get(url, stream=True).raw).save(filename)
return filename
def find_latest_checkpoint(dirname: Path) -> Path:
dirname = Path(dirname)
checkpoints = list(dirname.glob("*.ckpt"))
if not checkpoints:
return None
latest = max(checkpoints, key=lambda path: path.stat().st_mtime)
return latest
def list_files(dirname: Path, exts: list[str] | None = None) -> list:
files = Path(dirname).iterdir()
if not exts:
return list(files)
return [fn for fn in files for ext in exts if fn.match(f"*.{ext}")]