Spaces:
Sleeping
Sleeping
import os | |
import urllib.request | |
import zipfile | |
from torch.utils.data import Dataset | |
import pandas as pd | |
from pathlib import Path | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from typing import TypedDict | |
import torch | |
class PokemonSample(TypedDict): | |
text: torch.Tensor # Text already tokenized | |
image: torch.Tensor | |
description: str # Text before tokenization | |
pokemon_name: str | |
idx: int | |
attention_mask: torch.Tensor | |
def reporthook(block_num, block_size, total_size): | |
if block_num % 16384 == 0: | |
print(f"Downloading... {block_num * block_size / (1024 * 1024):.2f} MB") | |
def download_dataset_if_not_exists(): | |
dataset_dir = "dataset" | |
pokedex_main_dir = os.path.join(dataset_dir, "pokedex-main") | |
zip_url = "https://github.com/cristobalmitchell/pokedex/archive/refs/heads/main.zip" | |
zip_path = "pokedex_main.zip" | |
if os.path.exists(pokedex_main_dir): | |
print(f"{pokedex_main_dir} already exists. Skipping download.") | |
return | |
os.makedirs(dataset_dir, exist_ok=True) | |
print("Downloading dataset...") | |
urllib.request.urlretrieve(zip_url, zip_path, reporthook) | |
print("Download complete.") | |
print("Extracting dataset...") | |
with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
zip_ref.extractall(dataset_dir) | |
print("Extraction complete.") | |
os.remove(zip_path) | |
class PokemonDataset(Dataset): | |
def __init__( | |
self, | |
tokenizer, | |
csv_path="dataset/pokedex-main/data/pokemon.csv", | |
image_dir="dataset/pokedex-main/images/small_images", | |
max_length=128, | |
augmentation_transforms=None, | |
): | |
self.df = pd.read_csv(csv_path, encoding="utf-16 LE", delimiter="\t") | |
self.image_dir = Path(image_dir) | |
print(f"Dataset caricato: {len(self.df)} Pokemon con descrizioni e immagini") | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
if augmentation_transforms is not None: | |
self.final_transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Resize((256, 256), antialias=True), | |
augmentation_transforms, | |
transforms.Normalize( | |
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] | |
), # Normalizza a [-1, 1] | |
] | |
) | |
else: | |
self.final_transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Resize((256, 256), antialias=True), | |
transforms.Normalize( | |
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] | |
), # Normalizza a [-1, 1] | |
] | |
) | |
def __len__(self): | |
return len(self.df) | |
def __getitem__(self, idx: int) -> PokemonSample: | |
# Ottieni la riga corrispondente | |
row = self.df.iloc[idx] | |
# === PREPROCESSING DEL TESTO === | |
description = str(row["description"]) | |
# Tokenizza il testo | |
encoded = self.tokenizer( | |
description, | |
max_length=self.max_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt", | |
) | |
# Estrai token_ids e attention_mask | |
text_ids = encoded["input_ids"].squeeze(0) # Rimuovi la dimensione batch | |
attention_mask = encoded["attention_mask"].squeeze(0) | |
# === CARICAMENTO E PREPROCESSING DELL'IMMAGINE === | |
# Costruisce il percorso dell'immagine | |
image_filename = f"{row['national_number']:03d}.png" | |
image_path = self.image_dir / image_filename | |
# Carica l'immagine | |
image_rgba = Image.open(image_path).convert("RGBA") | |
# Gestisce la trasparenza: ricombina l'immagine con uno sfondo bianco | |
background = Image.new("RGB", image_rgba.size, (255, 255, 255)) | |
background.paste(image_rgba, mask=image_rgba.split()[-1]) | |
# Applica le trasformazioni finali (ToTensor, Resize, Normalize) | |
image_tensor = self.final_transform(background) | |
# Costruisce il risultato (matches pokemon_dataset.py structure) | |
sample = { | |
"text": text_ids, | |
"image": image_tensor, | |
"description": description, | |
"pokemon_name": row["english_name"], | |
"idx": idx, | |
"attention_mask": attention_mask, | |
} | |
return sample | |
download_dataset_if_not_exists() | |
print("Dataset ready!") | |