Timber-identification-CNN / S2_TimberDataset.py
Yapp99's picture
Small changes
475c7dc
import torch
from torch import Tensor
from torchvision import transforms
import cv2
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import pandas as pd
# 3264 x 2448
DATA_DIR = "data/image/train"
labels = os.listdir(DATA_DIR)
label2id = {label:id for id, label in enumerate(labels)}
def compile_image_df(data_dir:str, split_at = 0.9)-> pd.DataFrame:
dirs = os.listdir(data_dir)
columns=['Image_ID','Species']
train = pd.DataFrame(columns=columns)
val = pd.DataFrame(columns=columns)
for dir in dirs:
imgs = [(f"{data_dir}/{dir}/{img}", dir) for img in list(os.listdir(f"{data_dir}/{dir}"))]
length = len(imgs)
train_count = int(length * split_at)
train = pd.concat([train, pd.DataFrame(imgs[:train_count],columns=columns)])
val = pd.concat([val, pd.DataFrame(imgs[train_count:],columns=columns)])
return train, val
class TimberDataset(Dataset):
def __init__(self,
dataframe: pd.DataFrame,
is_train=False,
transform=None,
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) -> None:
super().__init__()
self.dataframe = dataframe
self.is_train = is_train
self.transform = transform
self.device = device
def __len__(self) -> int:
return len(self.dataframe)
def __getitem__(self, idx: list[int]|Tensor):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.dataframe.iloc[idx,0])
image = cv2.imread(img_name)
image = Image.fromarray(image)
label = self.dataframe.iloc[idx,1]
label = label2id[label]
label = torch.tensor(int(label))
if self.transform:
image = self.transform(image)
return image.to(self.device), label.to(self.device)
def build_dataloader(
train_ratio = 0.9,
img_size = (640,640),
batch_size = 12,
) -> tuple[DataLoader,DataLoader]:
train_df, val_df = compile_image_df(DATA_DIR, split_at=train_ratio)
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
])
train_loader = DataLoader(TimberDataset(train_df, is_train=True,transform=transform),
shuffle=True,
batch_size=batch_size)
val_loader = DataLoader(TimberDataset(val_df, is_train=True,transform=transform),
batch_size=batch_size)
return train_loader,val_loader