MNIST-ResNet-Demo / src /data_manager.py
shpr
Add env var for uploading to hub
7aaff91
""" Manages all pregenerated and user-generated data"""
import hashlib
import os
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Any, Dict
from huggingface_hub import HfApi
from PIL import Image
from constants import NUMBERS_LIST, DataPath, SessionFiles
DATASET_REPO_ID = "RomanShp/MNIST-ResNet-Demo-Data"
class ImageData(ABC):
"""Interface for saving image data"""
api = HfApi()
@abstractmethod
def __init__(self) -> None:
pass
@property
@abstractmethod
def path(self) -> Path:
return self._path
@path.setter
@abstractmethod
def path(self, value: Path) -> None:
self._path = value
@abstractmethod
def save_data(self, img: Image.Image, img_class: str) -> None:
pass
def _save_png_image(self, img: Image.Image, path: Path) -> Path:
if not path.exists():
path.mkdir(parents=True)
file = path / f"{self._datetime_hash()}.png"
img.save(file)
return file
def _preprocess_data_img(self, img) -> Image.Image:
sketch_image = Image.fromarray(img, mode="L")
sketch_image = sketch_image.resize((224, 224), resample=Image.LANCZOS)
return sketch_image
def _check_submitted_data(self, img, img_cat: str) -> None:
if img is None:
raise ValueError("Image is empty")
if img_cat == "":
raise ValueError("Category is empty")
def _datetime_hash(self) -> str:
"""Returns current datetime as MD5 hash string"""
return hashlib.md5(str(datetime.now()).encode()).hexdigest()
def _upload_to_hub(self, file: Path):
repo_path = str(file.relative_to("..")).replace("\\", "/")
self.api.upload_file(
path_or_fileobj=file,
path_in_repo=repo_path,
repo_id=DATASET_REPO_ID,
repo_type="dataset",
token=os.environ.get("DS_WRITE"),
)
class UserFinetuneData(ImageData):
"""
path defined by session
has statistics generated and updated
saves images
"""
def __init__(self) -> None:
self.statistics = dict(zip(NUMBERS_LIST, len(NUMBERS_LIST) * [0]))
@property
def path(self) -> Path:
return super().path
@path.setter
def path(self, value: Path) -> None:
self._path = value / SessionFiles.USER_DATA_FOLDER.value
# TODO: Move statistics update out of non related code
self.statistics = self._get_statistics()
def _update_statistics(self, img_class: str) -> None:
self.statistics[img_class] += 1
def _get_statistics(self) -> Dict[str, int]:
result = {}
for cat_dir in NUMBERS_LIST:
img_dir = os.path.join(self.path, cat_dir)
if os.path.isdir(img_dir):
file_count = len(os.listdir(img_dir))
result[cat_dir] = file_count
else:
result[cat_dir] = 0
return result
def save_data(self, img: Image.Image, img_class: str) -> None:
self._check_submitted_data(img, img_class)
processed_img = self._preprocess_data_img(img)
img_path = self._path / img_class
saved_file = self._save_png_image(processed_img, img_path)
self._update_statistics(img_class)
self._upload_to_hub(saved_file)
class UserPredictData(ImageData):
"""
path defined by session
no statistics
saves images
"""
def __init__(self) -> None:
pass
@property
def path(self) -> Path:
return super().path
@path.setter
def path(self, value: Path) -> None:
self._path = value / "data_predict"
def save_data(self, img: Image.Image, img_class: str) -> None:
self._check_submitted_data(img, img_class)
processed_img = self._preprocess_data_img(img)
img_path = self._path / img_class
saved_file = self._save_png_image(processed_img, img_path)
self._upload_to_hub(saved_file)
class PregeneratedFinetuneData(ImageData):
"""
Predefined data path
no statistics
doesnt save images
"""
def __init__(self) -> None:
pass
@property
def path(self) -> Path:
return DataPath.PREGEN_DATA_PATH.value
@path.setter
def path(self, _: Path) -> None:
raise AttributeError("Path is fixed for pregenerated data set")
def save_data(self, *_: Any) -> None:
raise AttributeError("Saving data for pregenerated set is not supported")