Val-2's picture
First commit
66347a3
import os
import urllib.request
import zipfile
from torch.utils.data import Dataset
import pandas as pd
from pathlib import Path
import torchvision.transforms as transforms
from PIL import Image
from typing import TypedDict
import torch
class PokemonSample(TypedDict):
text: torch.Tensor # Text already tokenized
image: torch.Tensor
description: str # Text before tokenization
pokemon_name: str
idx: int
attention_mask: torch.Tensor
def reporthook(block_num, block_size, total_size):
if block_num % 16384 == 0:
print(f"Downloading... {block_num * block_size / (1024 * 1024):.2f} MB")
def download_dataset_if_not_exists():
dataset_dir = "dataset"
pokedex_main_dir = os.path.join(dataset_dir, "pokedex-main")
zip_url = "https://github.com/cristobalmitchell/pokedex/archive/refs/heads/main.zip"
zip_path = "pokedex_main.zip"
if os.path.exists(pokedex_main_dir):
print(f"{pokedex_main_dir} already exists. Skipping download.")
return
os.makedirs(dataset_dir, exist_ok=True)
print("Downloading dataset...")
urllib.request.urlretrieve(zip_url, zip_path, reporthook)
print("Download complete.")
print("Extracting dataset...")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(dataset_dir)
print("Extraction complete.")
os.remove(zip_path)
class PokemonDataset(Dataset):
def __init__(
self,
tokenizer,
csv_path="dataset/pokedex-main/data/pokemon.csv",
image_dir="dataset/pokedex-main/images/small_images",
max_length=128,
augmentation_transforms=None,
):
self.df = pd.read_csv(csv_path, encoding="utf-16 LE", delimiter="\t")
self.image_dir = Path(image_dir)
print(f"Dataset caricato: {len(self.df)} Pokemon con descrizioni e immagini")
self.tokenizer = tokenizer
self.max_length = max_length
if augmentation_transforms is not None:
self.final_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((256, 256), antialias=True),
augmentation_transforms,
transforms.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
), # Normalizza a [-1, 1]
]
)
else:
self.final_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((256, 256), antialias=True),
transforms.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
), # Normalizza a [-1, 1]
]
)
def __len__(self):
return len(self.df)
def __getitem__(self, idx: int) -> PokemonSample:
# Ottieni la riga corrispondente
row = self.df.iloc[idx]
# === PREPROCESSING DEL TESTO ===
description = str(row["description"])
# Tokenizza il testo
encoded = self.tokenizer(
description,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
# Estrai token_ids e attention_mask
text_ids = encoded["input_ids"].squeeze(0) # Rimuovi la dimensione batch
attention_mask = encoded["attention_mask"].squeeze(0)
# === CARICAMENTO E PREPROCESSING DELL'IMMAGINE ===
# Costruisce il percorso dell'immagine
image_filename = f"{row['national_number']:03d}.png"
image_path = self.image_dir / image_filename
# Carica l'immagine
image_rgba = Image.open(image_path).convert("RGBA")
# Gestisce la trasparenza: ricombina l'immagine con uno sfondo bianco
background = Image.new("RGB", image_rgba.size, (255, 255, 255))
background.paste(image_rgba, mask=image_rgba.split()[-1])
# Applica le trasformazioni finali (ToTensor, Resize, Normalize)
image_tensor = self.final_transform(background)
# Costruisce il risultato (matches pokemon_dataset.py structure)
sample = {
"text": text_ids,
"image": image_tensor,
"description": description,
"pokemon_name": row["english_name"],
"idx": idx,
"attention_mask": attention_mask,
}
return sample
download_dataset_if_not_exists()
print("Dataset ready!")