Spaces:
Sleeping
Sleeping
""" Manages all pregenerated and user-generated data""" | |
import hashlib | |
import os | |
from abc import ABC, abstractmethod | |
from datetime import datetime | |
from pathlib import Path | |
from typing import Any, Dict | |
from huggingface_hub import HfApi | |
from PIL import Image | |
from constants import NUMBERS_LIST, DataPath, SessionFiles | |
DATASET_REPO_ID = "RomanShp/MNIST-ResNet-Demo-Data" | |
class ImageData(ABC): | |
"""Interface for saving image data""" | |
api = HfApi() | |
def __init__(self) -> None: | |
pass | |
def path(self) -> Path: | |
return self._path | |
def path(self, value: Path) -> None: | |
self._path = value | |
def save_data(self, img: Image.Image, img_class: str) -> None: | |
pass | |
def _save_png_image(self, img: Image.Image, path: Path) -> Path: | |
if not path.exists(): | |
path.mkdir(parents=True) | |
file = path / f"{self._datetime_hash()}.png" | |
img.save(file) | |
return file | |
def _preprocess_data_img(self, img) -> Image.Image: | |
sketch_image = Image.fromarray(img, mode="L") | |
sketch_image = sketch_image.resize((224, 224), resample=Image.LANCZOS) | |
return sketch_image | |
def _check_submitted_data(self, img, img_cat: str) -> None: | |
if img is None: | |
raise ValueError("Image is empty") | |
if img_cat == "": | |
raise ValueError("Category is empty") | |
def _datetime_hash(self) -> str: | |
"""Returns current datetime as MD5 hash string""" | |
return hashlib.md5(str(datetime.now()).encode()).hexdigest() | |
def _upload_to_hub(self, file: Path): | |
repo_path = str(file.relative_to("..")).replace("\\", "/") | |
self.api.upload_file( | |
path_or_fileobj=file, | |
path_in_repo=repo_path, | |
repo_id=DATASET_REPO_ID, | |
repo_type="dataset", | |
token=os.environ.get("DS_WRITE"), | |
) | |
class UserFinetuneData(ImageData): | |
""" | |
path defined by session | |
has statistics generated and updated | |
saves images | |
""" | |
def __init__(self) -> None: | |
self.statistics = dict(zip(NUMBERS_LIST, len(NUMBERS_LIST) * [0])) | |
def path(self) -> Path: | |
return super().path | |
def path(self, value: Path) -> None: | |
self._path = value / SessionFiles.USER_DATA_FOLDER.value | |
# TODO: Move statistics update out of non related code | |
self.statistics = self._get_statistics() | |
def _update_statistics(self, img_class: str) -> None: | |
self.statistics[img_class] += 1 | |
def _get_statistics(self) -> Dict[str, int]: | |
result = {} | |
for cat_dir in NUMBERS_LIST: | |
img_dir = os.path.join(self.path, cat_dir) | |
if os.path.isdir(img_dir): | |
file_count = len(os.listdir(img_dir)) | |
result[cat_dir] = file_count | |
else: | |
result[cat_dir] = 0 | |
return result | |
def save_data(self, img: Image.Image, img_class: str) -> None: | |
self._check_submitted_data(img, img_class) | |
processed_img = self._preprocess_data_img(img) | |
img_path = self._path / img_class | |
saved_file = self._save_png_image(processed_img, img_path) | |
self._update_statistics(img_class) | |
self._upload_to_hub(saved_file) | |
class UserPredictData(ImageData): | |
""" | |
path defined by session | |
no statistics | |
saves images | |
""" | |
def __init__(self) -> None: | |
pass | |
def path(self) -> Path: | |
return super().path | |
def path(self, value: Path) -> None: | |
self._path = value / "data_predict" | |
def save_data(self, img: Image.Image, img_class: str) -> None: | |
self._check_submitted_data(img, img_class) | |
processed_img = self._preprocess_data_img(img) | |
img_path = self._path / img_class | |
saved_file = self._save_png_image(processed_img, img_path) | |
self._upload_to_hub(saved_file) | |
class PregeneratedFinetuneData(ImageData): | |
""" | |
Predefined data path | |
no statistics | |
doesnt save images | |
""" | |
def __init__(self) -> None: | |
pass | |
def path(self) -> Path: | |
return DataPath.PREGEN_DATA_PATH.value | |
def path(self, _: Path) -> None: | |
raise AttributeError("Path is fixed for pregenerated data set") | |
def save_data(self, *_: Any) -> None: | |
raise AttributeError("Saving data for pregenerated set is not supported") | |