Spaces:
Sleeping
Sleeping
| """ Manages all pregenerated and user-generated data""" | |
| import hashlib | |
| import os | |
| from abc import ABC, abstractmethod | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| from huggingface_hub import HfApi | |
| from PIL import Image | |
| from constants import NUMBERS_LIST, DataPath, SessionFiles | |
| DATASET_REPO_ID = "RomanShp/MNIST-ResNet-Demo-Data" | |
| class ImageData(ABC): | |
| """Interface for saving image data""" | |
| api = HfApi() | |
| def __init__(self) -> None: | |
| pass | |
| def path(self) -> Path: | |
| return self._path | |
| def path(self, value: Path) -> None: | |
| self._path = value | |
| def save_data(self, img: Image.Image, img_class: str) -> None: | |
| pass | |
| def _save_png_image(self, img: Image.Image, path: Path) -> Path: | |
| if not path.exists(): | |
| path.mkdir(parents=True) | |
| file = path / f"{self._datetime_hash()}.png" | |
| img.save(file) | |
| return file | |
| def _preprocess_data_img(self, img) -> Image.Image: | |
| sketch_image = Image.fromarray(img, mode="L") | |
| sketch_image = sketch_image.resize((224, 224), resample=Image.LANCZOS) | |
| return sketch_image | |
| def _check_submitted_data(self, img, img_cat: str) -> None: | |
| if img is None: | |
| raise ValueError("Image is empty") | |
| if img_cat == "": | |
| raise ValueError("Category is empty") | |
| def _datetime_hash(self) -> str: | |
| """Returns current datetime as MD5 hash string""" | |
| return hashlib.md5(str(datetime.now()).encode()).hexdigest() | |
| def _upload_to_hub(self, file: Path): | |
| repo_path = str(file.relative_to("..")).replace("\\", "/") | |
| self.api.upload_file( | |
| path_or_fileobj=file, | |
| path_in_repo=repo_path, | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| token=os.environ.get("DS_WRITE"), | |
| ) | |
| class UserFinetuneData(ImageData): | |
| """ | |
| path defined by session | |
| has statistics generated and updated | |
| saves images | |
| """ | |
| def __init__(self) -> None: | |
| self.statistics = dict(zip(NUMBERS_LIST, len(NUMBERS_LIST) * [0])) | |
| def path(self) -> Path: | |
| return super().path | |
| def path(self, value: Path) -> None: | |
| self._path = value / SessionFiles.USER_DATA_FOLDER.value | |
| # TODO: Move statistics update out of non related code | |
| self.statistics = self._get_statistics() | |
| def _update_statistics(self, img_class: str) -> None: | |
| self.statistics[img_class] += 1 | |
| def _get_statistics(self) -> Dict[str, int]: | |
| result = {} | |
| for cat_dir in NUMBERS_LIST: | |
| img_dir = os.path.join(self.path, cat_dir) | |
| if os.path.isdir(img_dir): | |
| file_count = len(os.listdir(img_dir)) | |
| result[cat_dir] = file_count | |
| else: | |
| result[cat_dir] = 0 | |
| return result | |
| def save_data(self, img: Image.Image, img_class: str) -> None: | |
| self._check_submitted_data(img, img_class) | |
| processed_img = self._preprocess_data_img(img) | |
| img_path = self._path / img_class | |
| saved_file = self._save_png_image(processed_img, img_path) | |
| self._update_statistics(img_class) | |
| self._upload_to_hub(saved_file) | |
| class UserPredictData(ImageData): | |
| """ | |
| path defined by session | |
| no statistics | |
| saves images | |
| """ | |
| def __init__(self) -> None: | |
| pass | |
| def path(self) -> Path: | |
| return super().path | |
| def path(self, value: Path) -> None: | |
| self._path = value / "data_predict" | |
| def save_data(self, img: Image.Image, img_class: str) -> None: | |
| self._check_submitted_data(img, img_class) | |
| processed_img = self._preprocess_data_img(img) | |
| img_path = self._path / img_class | |
| saved_file = self._save_png_image(processed_img, img_path) | |
| self._upload_to_hub(saved_file) | |
| class PregeneratedFinetuneData(ImageData): | |
| """ | |
| Predefined data path | |
| no statistics | |
| doesnt save images | |
| """ | |
| def __init__(self) -> None: | |
| pass | |
| def path(self) -> Path: | |
| return DataPath.PREGEN_DATA_PATH.value | |
| def path(self, _: Path) -> None: | |
| raise AttributeError("Path is fixed for pregenerated data set") | |
| def save_data(self, *_: Any) -> None: | |
| raise AttributeError("Saving data for pregenerated set is not supported") | |