Timber-identification-CNN / S3_intermediateDataset.py
Yapp99's picture
Included project files
687ef3d
import os
import torch
from time import time
from torch import Tensor
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from S2_TimberDataset import build_dataloader
from typing import Callable
import pandas as pd
import numpy as np
# 3264 x 2448
def write_random_lowercase(n):
min_lc = ord(b'a')
len_lc = 26
ba = bytearray(os.urandom(n))
for i, b in enumerate(ba):
ba[i] = min_lc + b % len_lc # convert 0..255 to 97..122
return ba.decode("utf-8")
INTERMEDIATE_DIR = "data/intermediate"
class IntermediateDataset(Dataset):
def __init__(self,
name,
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) -> None:
super().__init__()
self.name = name
self.device = device
files = os.listdir(f"{INTERMEDIATE_DIR}/{self.name}")
self.tensors = np.array([f for f in files if os.path.splitext(f)[-1] == ".pt"])
self.labels = np.array([f for f in files if os.path.splitext(f)[-1] == ".txt"])
def __len__(self) -> int:
return len(self.tensors)
def __getitem__(self, idx: list[int]|Tensor):
tensor = self.tensors[idx]
labels = self.labels[idx]
images = torch.load(f"{INTERMEDIATE_DIR}/{self.name}/{tensor}")
with open(f"{INTERMEDIATE_DIR}/{self.name}/{labels}", 'r') as f:
labels = f.readline().split("-")
labels = Tensor(list(map(int, labels)))
images = images.to(self.device)
labels = labels.to(device=self.device, dtype=torch.int64)
return images, labels
@staticmethod
def prepare_intermediate_dataset(pred: Callable, name: str, dataset: DataLoader, iterations = 1) -> None:
with torch.no_grad():
for _ in range(iterations):
for images, labels in tqdm(dataset):
out = pred(images)
labels = np.char.mod('%d', labels.cpu().numpy())
labels = '-'.join(labels)
file_name = f"{INTERMEDIATE_DIR}/{name}/{int(time())}_{write_random_lowercase(10)}"
torch.save(out,f"{file_name}.pt")
with open(f"{file_name}.txt", 'w') as f:
f.write(labels)
def build_intermediate_dataset_if_not_exists(pred_:Callable, name:str, dataset:DataLoader) -> None:
try: os.mkdir(INTERMEDIATE_DIR)
except: pass
try: os.mkdir(f"{INTERMEDIATE_DIR}/{name}")
except: pass
if os.listdir(f"{INTERMEDIATE_DIR}/{name}") == []:
IntermediateDataset.prepare_intermediate_dataset(pred_, name, dataset)
def intermediate_dataset(name:str) -> DataLoader:
return DataLoader(IntermediateDataset(name=name),batch_size=1)
if __name__ == '__main__':
train, val, test = build_dataloader(train_ratio= 0.01)
build_intermediate_dataset_if_not_exists(lambda x:x, "testing", train)
train_loader = DataLoader(IntermediateDataset("testing"),batch_size=1)
(i1,i2,i3), val = next(iter(train_loader))
"a"