# Install

In [None]:
!git clone https://github.com/kopyl/PixArt-alpha.git

In [None]:
%cd PixArt-alpha

In [None]:
!pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu117
!pip install -r requirements.txt
!pip install wandb

## Download model

In [None]:
!python tools/download.py --model_names "PixArt-XL-2-512x512.pth"

## Make dataset out of Hugginggface dataset

In [None]:
import os
from tqdm.notebook import tqdm
from datasets import load_dataset
import json

In [None]:
dataset = load_dataset("lambdalabs/pokemon-blip-captions")

In [None]:
root_dir = "/workspace/pixart-pokemon"
images_dir = "images"
captions_dir = "captions"

images_dir_absolute = os.path.join(root_dir, images_dir)
captions_dir_absolute = os.path.join(root_dir, captions_dir)

if not os.path.exists(root_dir):
 os.makedirs(os.path.join(root_dir, images_dir))

if not os.path.exists(os.path.join(root_dir, images_dir)):
 os.makedirs(os.path.join(root_dir, images_dir))
if not os.path.exists(os.path.join(root_dir, captions_dir)):
 os.makedirs(os.path.join(root_dir, captions_dir))

image_format = "png"
json_name = "partition/data_info.json"
if not os.path.exists(os.path.join(root_dir, "partition")):
 os.makedirs(os.path.join(root_dir, "partition"))

absolute_json_name = os.path.join(root_dir, json_name)
data_info = []

order = 0
for item in tqdm(dataset["train"]): 
 image = item["image"]
 image.save(f"{images_dir_absolute}/{order}.{image_format}")
 with open(f"{captions_dir_absolute}/{order}.txt", "w") as text_file:
 text_file.write(item["text"])
 
 width, height = 512, 512
 ratio = 1
 data_info.append({
 "height": height,
 "width": width,
 "ratio": ratio,
 "path": f"images/{order}.{image_format}",
 "prompt": item["text"],
 })
 
 order += 1

with open(absolute_json_name, "w") as json_file:
 json.dump(data_info, json_file)

## Extract features

In [None]:
!python /workspace/PixArt-alpha/tools/extract_features.py \
 --img_size 512 \
 --json_path "/workspace/pixart-pokemon/partition/data_info.json" \
 --t5_save_root "/workspace/pixart-pokemon/caption_feature_wmask" \
 --vae_save_root "/workspace/pixart-pokemon/img_vae_features" \
 --pretrained_models_dir "/workspace/PixArt-alpha/output/pretrained_models" \
 --dataset_root "/workspace/pixart-pokemon"

In [None]:
!wandb login REPLACE_THIS_WITH_YOUR_AUTH_TOKEN_OF_WANDB

## Train model

In [None]:
!python -m torch.distributed.launch \
 train_scripts/train.py \
 /workspace/PixArt-alpha/notebooks/PixArt_xl2_img512_internal_for_pokemon_sample_training.py \
 --work-dir output/trained_model \
 --report_to="wandb" \
 --loss_report_name="train_loss"