| | import torch |
| | import os |
| | import functools |
| | import base64 |
| | import numpy as np |
| | import gradio as gr |
| |
|
| | from typing import Any, Callable, Dict |
| |
|
| |
|
| | def load_state_dict(ckpt_path, location="cpu"): |
| | _, extension = os.path.splitext(ckpt_path) |
| | if extension.lower() == ".safetensors": |
| | import safetensors.torch |
| |
|
| | state_dict = safetensors.torch.load_file(ckpt_path, device=location) |
| | else: |
| | state_dict = get_state_dict( |
| | torch.load(ckpt_path, map_location=torch.device(location)) |
| | ) |
| | state_dict = get_state_dict(state_dict) |
| | print(f"Loaded state_dict from [{ckpt_path}]") |
| | return state_dict |
| |
|
| |
|
| | def get_state_dict(d): |
| | return d.get("state_dict", d) |
| |
|
| |
|
| | def ndarray_lru_cache(max_size: int = 128, typed: bool = False): |
| | """ |
| | Decorator to enable caching for functions with numpy array arguments. |
| | Numpy arrays are mutable, and thus not directly usable as hash keys. |
| | |
| | The idea here is to wrap the incoming arguments with type `np.ndarray` |
| | as `HashableNpArray` so that `lru_cache` can correctly handles `np.ndarray` |
| | arguments. |
| | |
| | `HashableNpArray` functions exactly the same way as `np.ndarray` except |
| | having `__hash__` and `__eq__` overriden. |
| | """ |
| |
|
| | def decorator(func: Callable): |
| | """The actual decorator that accept function as input.""" |
| |
|
| | class HashableNpArray(np.ndarray): |
| | def __new__(cls, input_array): |
| | |
| | |
| | obj = np.asarray(input_array).view(cls) |
| | return obj |
| |
|
| | def __eq__(self, other) -> bool: |
| | return np.array_equal(self, other) |
| |
|
| | def __hash__(self): |
| | |
| | return hash(self.tobytes()) |
| |
|
| | @functools.lru_cache(maxsize=max_size, typed=typed) |
| | def cached_func(*args, **kwargs): |
| | """This function only accepts `HashableNpArray` as input params.""" |
| | return func(*args, **kwargs) |
| |
|
| | |
| | @functools.wraps(func) |
| | def decorated_func(*args, **kwargs): |
| | """The decorated function that delegates the original function.""" |
| |
|
| | def convert_item(item: Any): |
| | return HashableNpArray(item) if isinstance(item, np.ndarray) else item |
| |
|
| | args = [convert_item(arg) for arg in args] |
| | kwargs = {k: convert_item(arg) for k, arg in kwargs.items()} |
| | return cached_func(*args, **kwargs) |
| |
|
| | return decorated_func |
| |
|
| | return decorator |
| |
|
| |
|
| | |
| | svgsupport = False |
| | try: |
| | import io |
| | from svglib.svglib import svg2rlg |
| | from reportlab.graphics import renderPM |
| |
|
| | svgsupport = True |
| | except ImportError: |
| | pass |
| |
|
| |
|
| | def svg_preprocess(inputs: Dict, preprocess: Callable): |
| | if not inputs: |
| | return None |
| |
|
| | if inputs["image"].startswith("data:image/svg+xml;base64,") and svgsupport: |
| | svg_data = base64.b64decode( |
| | inputs["image"].replace("data:image/svg+xml;base64,", "") |
| | ) |
| | drawing = svg2rlg(io.BytesIO(svg_data)) |
| | png_data = renderPM.drawToString(drawing, fmt="PNG") |
| | encoded_string = base64.b64encode(png_data) |
| | base64_str = str(encoded_string, "utf-8") |
| | base64_str = "data:image/png;base64," + base64_str |
| | inputs["image"] = base64_str |
| | return preprocess(inputs) |
| |
|
| |
|