Spaces:
Runtime error
Runtime error
Upload 39 files
Browse files- examples/.DS_Store +0 -0
- examples/README.md +101 -0
- examples/dpo/train.sh +13 -0
- examples/dpo/train_config.yaml +41 -0
- examples/dpo/train_dataset.py +190 -0
- examples/dpo/train_dpo.py +465 -0
- examples/edit/train.sh +13 -0
- examples/edit/train_config.yaml +46 -0
- examples/edit/train_dataset.py +184 -0
- examples/edit/train_edit.py +460 -0
- examples/lora/train.sh +13 -0
- examples/lora/train_config.yaml +48 -0
- examples/lora/train_dataset.py +197 -0
- examples/lora/train_lora.py +461 -0
- examples/sft/train.sh +13 -0
- examples/sft/train_config.yaml +46 -0
- examples/sft/train_dataset.py +197 -0
- examples/sft/train_sft.py +433 -0
- longcat_image/__init__.py +0 -0
- longcat_image/dataset/__init__.py +2 -0
- longcat_image/dataset/data_utils.py +34 -0
- longcat_image/dataset/sampler.py +111 -0
- longcat_image/models/__init__.py +1 -0
- longcat_image/models/longcat_image_dit.py +231 -0
- longcat_image/pipelines/__init__.py +2 -0
- longcat_image/pipelines/pipeline_longcat_image.py +576 -0
- longcat_image/pipelines/pipeline_longcat_image_edit.py +512 -0
- longcat_image/pipelines/pipeline_output.py +20 -0
- longcat_image/utils/__init__.py +5 -0
- longcat_image/utils/dist_utils.py +33 -0
- longcat_image/utils/log_buffer.py +41 -0
- longcat_image/utils/model_utils.py +293 -0
- misc/__init__.py +0 -0
- misc/accelerate_config.yaml +21 -0
- misc/prompt_rewrite_api.py +113 -0
- requirements.txt +48 -0
- scripts/inference_edit.py +35 -0
- scripts/inference_t2i.py +42 -0
- setup.py +11 -0
examples/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
examples/README.md
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
## Tainging Example
|
| 3 |
+
|
| 4 |
+
### 1. training example for sft or lora
|
| 5 |
+
|
| 6 |
+
- Data format
|
| 7 |
+
|
| 8 |
+
You need to create a jsonl file with key-values in the table below:
|
| 9 |
+
|
| 10 |
+
| key_word | Required | Description | Example |
|
| 11 |
+
|:---------------:| :------: |:----------------:|:-----------:|
|
| 12 |
+
| `img_path` | Required | image path | `./data_example/images/0.png` |
|
| 13 |
+
| `prompt` | Required | text | `A lovely little girl.` |
|
| 14 |
+
| `width` | Required | image width | ` 1024 ` |
|
| 15 |
+
| `height` | Required | image height | ` 1024 ` |
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
- Tainging Scripts
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
|
| 23 |
+
bash ./examples/sft/train.sh
|
| 24 |
+
|
| 25 |
+
# All training setting in train_config.yaml
|
| 26 |
+
|
| 27 |
+
# --data_csv_root: data csv_filepath
|
| 28 |
+
# --aspect_ratio_type: data bucketing strategy, mar_256ใmar_512ใmar_1024
|
| 29 |
+
# --pretrained_model_name_or_path: root directory of the model
|
| 30 |
+
# --diffusion_pretrain_weight: if a specified diffusion weight path is provided, load the model parameters from the current directory.
|
| 31 |
+
# --work_dir: the save root directory for ckpt and logs
|
| 32 |
+
# --resume_from_checkpoint: If 'resume_from_checkpoint' is set to 'latest', load the most recent step checkpoint. If a specific directory is provided, resume training from that directory.
|
| 33 |
+
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
### 2. training example for dpo
|
| 38 |
+
|
| 39 |
+
- Data format
|
| 40 |
+
|
| 41 |
+
You need to create a txt file with key-values in the table below:
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
| key_word | Required | Description | Example |
|
| 45 |
+
|:---------------:| :------: |:----------------:|:-----------:|
|
| 46 |
+
| `img_path_win` | Required | win image path | `./data_example/images/0.png` |
|
| 47 |
+
| `img_path_lose` | Required | lose image path | `./data_example/images/1.png` |
|
| 48 |
+
| `prompt` | Required | text | `A lovely little girl.` |
|
| 49 |
+
| `width` | Required | image width | ` 1024 ` |
|
| 50 |
+
| `height` | Required | image height | ` 1024 ` |
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
- Tainging Scripts
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
|
| 57 |
+
bash ./examples/dpo/train.sh
|
| 58 |
+
|
| 59 |
+
# All training setting in train_config.yaml
|
| 60 |
+
|
| 61 |
+
# --data_txt_root: data txt_filepath
|
| 62 |
+
# --aspect_ratio_type: data bucketing strategy, mar_256ใmar_512ใmar_1024
|
| 63 |
+
# --pretrained_model_name_or_path: root directory of the model
|
| 64 |
+
# --diffusion_pretrain_weight: if a specified diffusion weight path is provided, load the model parameters from the current directory.
|
| 65 |
+
# --work_dir: the save root directory for ckpt and logs
|
| 66 |
+
# --resume_from_checkpoint: If 'resume_from_checkpoint' is set to 'latest', load the most recent step checkpoint. If a specific directory is provided, resume training from that directory.
|
| 67 |
+
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
### 3. training example for image-edit
|
| 71 |
+
|
| 72 |
+
- Data format
|
| 73 |
+
|
| 74 |
+
You need to create a txt file with key-values in the table below:
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
| key_word | Required | Description | Example |
|
| 78 |
+
|:---------------:| :------: |:----------------:|:-----------:|
|
| 79 |
+
| `img_path` | Required | edited image path | `./data_example/images/0_edited.png` |
|
| 80 |
+
| `ref_img_path` | Required | raw image path | `./data_example/images/0.png` |
|
| 81 |
+
| `prompt` | Required | edit instruction | `change the dog to cat.` |
|
| 82 |
+
| `width` | Required | image width | ` 1024 ` |
|
| 83 |
+
| `height` | Required | image height | ` 1024 ` |
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
- Tainging Scripts
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
|
| 90 |
+
bash ./examples/edit/train.sh
|
| 91 |
+
|
| 92 |
+
# All training setting in train_config.yaml
|
| 93 |
+
|
| 94 |
+
# --data_txt_root: data txt_filepath
|
| 95 |
+
# --aspect_ratio_type: data bucketing strategy, mar_256ใmar_512ใmar_1024
|
| 96 |
+
# --pretrained_model_name_or_path: root directory of the model
|
| 97 |
+
# --diffusion_pretrain_weight: if a specified diffusion weight path is provided, load the model parameters from the current directory.
|
| 98 |
+
# --work_dir: the save root directory for ckpt and logs
|
| 99 |
+
# --resume_from_checkpoint: If 'resume_from_checkpoint' is set to 'latest', load the most recent step checkpoint. If a specific directory is provided, resume training from that directory.
|
| 100 |
+
|
| 101 |
+
```
|
examples/dpo/train.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export TOKENIZERS_PARALLELISM=False
|
| 2 |
+
export NCCL_DEBUG=INFO
|
| 3 |
+
export NCCL_TIMEOUT=12000
|
| 4 |
+
|
| 5 |
+
script_dir=$(cd -- "$(dirname -- "$0")" &> /dev/null && pwd -P)
|
| 6 |
+
project_root=$(dirname "$(dirname "$script_dir")")
|
| 7 |
+
echo "script_dir" ${script_dir}
|
| 8 |
+
|
| 9 |
+
deepspeed_config_file=${project_root}/misc/accelerate_config.yaml
|
| 10 |
+
|
| 11 |
+
accelerate launch --mixed_precision bf16 --num_processes 8 --config_file ${deepspeed_config_file} \
|
| 12 |
+
${script_dir}/train_dpo.py \
|
| 13 |
+
--config ${script_dir}/train_config.yaml
|
examples/dpo/train_config.yaml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1.Data setting
|
| 2 |
+
data_txt_root: '/dataset/example/train_data_info.txt' # data csv_filepath
|
| 3 |
+
resolution: 1024
|
| 4 |
+
aspect_ratio_type: 'mar_1024' # data bucketing strategy, mar_256ใmar_512ใmar_1024
|
| 5 |
+
null_text_ratio: 0.1
|
| 6 |
+
dataloader_num_workers: 8
|
| 7 |
+
train_batch_size: 4
|
| 8 |
+
|
| 9 |
+
# 2. Model setting
|
| 10 |
+
text_tokenizer_max_length: 512 # tokenizer max len
|
| 11 |
+
pretrained_model_name_or_path: "/xxx/weights/Longcat-Image-Dev" # root directory of the model๏ผwith vaeใtransformerใscheduler eta;
|
| 12 |
+
diffusion_pretrain_weight: null # if a specified diffusion weight path is provided, load the model parameters from the current directory.
|
| 13 |
+
use_dynamic_shifting: true # scheduler dynamic shifting
|
| 14 |
+
resume_from_checkpoint: latest
|
| 15 |
+
# - "latest" # Loads most recent step checkpoint
|
| 16 |
+
# - "/path/to/checkpoint" # Resumes from specified directory
|
| 17 |
+
|
| 18 |
+
# 3. Training setting
|
| 19 |
+
use_ema: False
|
| 20 |
+
ema_rate: 0.999
|
| 21 |
+
mixed_precision: 'bf16'
|
| 22 |
+
max_train_steps: 100000
|
| 23 |
+
gradient_accumulation_steps: 1
|
| 24 |
+
gradient_checkpointing: true
|
| 25 |
+
gradient_clip: 1.0
|
| 26 |
+
learning_rate: 5.0e-6
|
| 27 |
+
adam_weight_decay: 1.0e-3
|
| 28 |
+
adam_epsilon: 1.0e-8
|
| 29 |
+
adam_beta1: 0.9
|
| 30 |
+
adam_beta2: 0.999
|
| 31 |
+
lr_num_cycles: 1
|
| 32 |
+
lr_power: 1.0
|
| 33 |
+
lr_scheduler: 'constant'
|
| 34 |
+
lr_warmup_steps: 1000
|
| 35 |
+
beta_dpo: 2000
|
| 36 |
+
|
| 37 |
+
#4. Log setting
|
| 38 |
+
log_interval: 20
|
| 39 |
+
save_model_steps: 1000
|
| 40 |
+
work_dir: 'output/sft_model'
|
| 41 |
+
seed: 43
|
examples/dpo/train_dataset.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import traceback
|
| 5 |
+
import math
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 13 |
+
from transformers import AutoTokenizer
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
from longcat_image.dataset import MULTI_RESOLUTION_MAP
|
| 18 |
+
from longcat_image.utils import encode_prompt
|
| 19 |
+
from longcat_image.dataset import MultiResolutionDistributedSampler
|
| 20 |
+
|
| 21 |
+
Image.MAX_IMAGE_PIXELS = 2000000000
|
| 22 |
+
|
| 23 |
+
MAX_RETRY_NUMS = 100
|
| 24 |
+
|
| 25 |
+
class DpoPairDataSet(torch.utils.data.Dataset):
|
| 26 |
+
def __init__(self,
|
| 27 |
+
cfg: dict,
|
| 28 |
+
txt_root: str,
|
| 29 |
+
tokenizer: AutoTokenizer,
|
| 30 |
+
resolution: tuple = (1024, 1024)):
|
| 31 |
+
super(DpoPairDataSet, self).__init__()
|
| 32 |
+
self.resolution = resolution
|
| 33 |
+
self.text_tokenizer_max_length = cfg.text_tokenizer_max_length
|
| 34 |
+
self.null_text_ratio = cfg.null_text_ratio
|
| 35 |
+
self.aspect_ratio_type = cfg.aspect_ratio_type
|
| 36 |
+
self.aspect_ratio = MULTI_RESOLUTION_MAP[self.aspect_ratio_type]
|
| 37 |
+
self.tokenizer = tokenizer
|
| 38 |
+
|
| 39 |
+
self.total_datas = []
|
| 40 |
+
self.data_resolution_infos = []
|
| 41 |
+
with open(txt_root, 'r') as f:
|
| 42 |
+
lines = f.readlines()
|
| 43 |
+
for line in tqdm(lines):
|
| 44 |
+
data = json.loads(line.strip())
|
| 45 |
+
try:
|
| 46 |
+
height, widht = int(data['height']), int(data['width'])
|
| 47 |
+
self.data_resolution_infos.append((height, widht))
|
| 48 |
+
self.total_datas.append(data)
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f'get error {e}, data {data}.')
|
| 51 |
+
continue
|
| 52 |
+
self.data_nums = len(self.total_datas)
|
| 53 |
+
|
| 54 |
+
def transform_img(self, image, original_size, target_size):
|
| 55 |
+
img_h, img_w = original_size
|
| 56 |
+
target_height, target_width = target_size
|
| 57 |
+
|
| 58 |
+
original_aspect = img_h / img_w # height/width
|
| 59 |
+
crop_aspect = target_height / target_width
|
| 60 |
+
|
| 61 |
+
if original_aspect >= crop_aspect:
|
| 62 |
+
resize_width = target_width
|
| 63 |
+
resize_height = math.ceil(img_h * (target_width/img_w))
|
| 64 |
+
else:
|
| 65 |
+
resize_width = math.ceil(img_w * (target_height/img_h))
|
| 66 |
+
resize_height = target_height
|
| 67 |
+
|
| 68 |
+
image = T.Compose([
|
| 69 |
+
T.Resize((resize_height, resize_width),interpolation=InterpolationMode.BICUBIC), # Image.LANCZOS
|
| 70 |
+
T.CenterCrop((target_height, target_width)),
|
| 71 |
+
T.ToTensor(),
|
| 72 |
+
T.Normalize([.5], [.5]),
|
| 73 |
+
])(image)
|
| 74 |
+
|
| 75 |
+
return image
|
| 76 |
+
|
| 77 |
+
def __getitem__(self, index_tuple):
|
| 78 |
+
index, target_size = index_tuple
|
| 79 |
+
|
| 80 |
+
for _ in range(MAX_RETRY_NUMS):
|
| 81 |
+
try:
|
| 82 |
+
item = self.total_datas[index]
|
| 83 |
+
img_path_win = item["img_path_win"]
|
| 84 |
+
img_path_lose = item["img_path_lose"]
|
| 85 |
+
|
| 86 |
+
prompt = item['prompt']
|
| 87 |
+
|
| 88 |
+
if random.random() < self.null_text_ratio:
|
| 89 |
+
prompt = ''
|
| 90 |
+
|
| 91 |
+
raw_image_win = Image.open(img_path_win).convert('RGB')
|
| 92 |
+
raw_image_lose = Image.open(img_path_lose).convert('RGB')
|
| 93 |
+
assert raw_image_win is not None and raw_image_lose is not None
|
| 94 |
+
img_w, img_h = raw_image_win.size
|
| 95 |
+
|
| 96 |
+
raw_image_win = self.transform_img(raw_image_win, original_size=(
|
| 97 |
+
img_h, img_w), target_size= target_size )
|
| 98 |
+
raw_image_lose = self.transform_img(raw_image_lose, original_size=(
|
| 99 |
+
img_h, img_w), target_size= target_size )
|
| 100 |
+
|
| 101 |
+
input_ids,attention_mask = encode_prompt(prompt, self.tokenizer, self.text_tokenizer_max_length)
|
| 102 |
+
|
| 103 |
+
return {"image_win": raw_image_win, "image_lose": raw_image_lose, "prompt": prompt, 'input_ids': input_ids, 'attention_mask': attention_mask}
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
traceback.print_exc()
|
| 107 |
+
print(f"failed read data {e}!!!")
|
| 108 |
+
index = random.randint(0, self.data_nums-1)
|
| 109 |
+
|
| 110 |
+
def __len__(self):
|
| 111 |
+
return self.data_nums
|
| 112 |
+
|
| 113 |
+
def collate_fn(self, batchs):
|
| 114 |
+
images_win = torch.stack([example["image_win"] for example in batchs])
|
| 115 |
+
images_lose = torch.stack([example["image_lose"] for example in batchs])
|
| 116 |
+
input_ids = torch.stack([example["input_ids"] for example in batchs])
|
| 117 |
+
attention_mask = torch.stack([example["attention_mask"] for example in batchs])
|
| 118 |
+
prompts = [example['prompt'] for example in batchs]
|
| 119 |
+
batch_dict = {
|
| 120 |
+
"images_win": images_win,
|
| 121 |
+
"images_lose": images_lose,
|
| 122 |
+
"input_ids": input_ids,
|
| 123 |
+
"attention_mask": attention_mask,
|
| 124 |
+
"prompts": prompts,
|
| 125 |
+
}
|
| 126 |
+
return batch_dict
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def build_dataloader(cfg: dict,
|
| 130 |
+
csv_root: str,
|
| 131 |
+
tokenizer: AutoTokenizer,
|
| 132 |
+
resolution: tuple = (1024, 1024)):
|
| 133 |
+
dataset = DpoPairDataSet(cfg, csv_root, tokenizer, resolution)
|
| 134 |
+
|
| 135 |
+
sampler = MultiResolutionDistributedSampler(batch_size=cfg.train_batch_size, dataset=dataset,
|
| 136 |
+
data_resolution_infos=dataset.data_resolution_infos,
|
| 137 |
+
bucket_info=dataset.aspect_ratio,
|
| 138 |
+
epoch=0,
|
| 139 |
+
num_replicas=None,
|
| 140 |
+
rank=None
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
train_loader = torch.utils.data.DataLoader(
|
| 144 |
+
dataset,
|
| 145 |
+
collate_fn=dataset.collate_fn,
|
| 146 |
+
batch_size=cfg.train_batch_size,
|
| 147 |
+
num_workers=cfg.dataloader_num_workers,
|
| 148 |
+
sampler=sampler,
|
| 149 |
+
shuffle=None,
|
| 150 |
+
)
|
| 151 |
+
return train_loader
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if __name__ == '__main__':
|
| 155 |
+
import sys
|
| 156 |
+
import argparse
|
| 157 |
+
from torchvision.transforms.functional import to_pil_image
|
| 158 |
+
|
| 159 |
+
txt_root = 'xxxx'
|
| 160 |
+
|
| 161 |
+
cfg = argparse.Namespace(
|
| 162 |
+
txt_root=txt_root,
|
| 163 |
+
text_tokenizer_max_length=256,
|
| 164 |
+
resolution=1024,
|
| 165 |
+
text_encoder_path="xxx",
|
| 166 |
+
center_crop=True,
|
| 167 |
+
dataloader_num_workers=0,
|
| 168 |
+
null_text_ratio=0.1,
|
| 169 |
+
train_batch_size=16,
|
| 170 |
+
seed=0,
|
| 171 |
+
aspect_ratio_type='mar_1024',
|
| 172 |
+
revision=None)
|
| 173 |
+
|
| 174 |
+
from transformers import AutoTokenizer
|
| 175 |
+
tokenizer = AutoTokenizer.from_pretrained(cfg.text_encoder_path, trust_remote_code=True)
|
| 176 |
+
data_loader = build_dataloader(cfg, cfg.csv_root, tokenizer, cfg.resolution)
|
| 177 |
+
|
| 178 |
+
_oroot = f'./debug_data_example_show'
|
| 179 |
+
os.makedirs(_oroot, exist_ok=True)
|
| 180 |
+
|
| 181 |
+
cnt = 0
|
| 182 |
+
for epoch in range(1):
|
| 183 |
+
print(f"Start, epoch {epoch}!!!")
|
| 184 |
+
for i_batch, batch in enumerate(data_loader):
|
| 185 |
+
print(batch['attention_mask'].shape)
|
| 186 |
+
print(batch['images_win'].shape,'-',batch['images_lose'].shape,)
|
| 187 |
+
|
| 188 |
+
if cnt > 100:
|
| 189 |
+
break
|
| 190 |
+
cnt += 1
|
examples/dpo/train_dpo.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import yaml
|
| 8 |
+
import torch
|
| 9 |
+
import math
|
| 10 |
+
import logging
|
| 11 |
+
import transformers
|
| 12 |
+
import diffusers
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from transformers import Qwen2Model, Qwen2TokenizerFast
|
| 17 |
+
from accelerate import Accelerator, InitProcessGroupKwargs
|
| 18 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 19 |
+
from accelerate.logging import get_logger
|
| 20 |
+
from diffusers.models import AutoencoderKL
|
| 21 |
+
from diffusers.optimization import get_scheduler
|
| 22 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 23 |
+
from diffusers.training_utils import EMAModel
|
| 24 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 25 |
+
from transformers import AutoTokenizer, AutoModel
|
| 26 |
+
from train_dataset import build_dataloader
|
| 27 |
+
from longcat_image.models import LongCatImageTransformer2DModel
|
| 28 |
+
from longcat_image.utils import LogBuffer
|
| 29 |
+
from longcat_image.utils import pack_latents, unpack_latents, calculate_shift, prepare_pos_ids
|
| 30 |
+
|
| 31 |
+
warnings.filterwarnings("ignore") # ignore warning
|
| 32 |
+
|
| 33 |
+
current_file_path = Path(__file__).resolve()
|
| 34 |
+
sys.path.insert(0, str(current_file_path.parent.parent))
|
| 35 |
+
|
| 36 |
+
logger = get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
def train(global_step=0):
|
| 39 |
+
|
| 40 |
+
# Train!
|
| 41 |
+
total_batch_size = args.train_batch_size * \
|
| 42 |
+
accelerator.num_processes * args.gradient_accumulation_steps
|
| 43 |
+
|
| 44 |
+
logger.info("***** Running training *****")
|
| 45 |
+
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
|
| 46 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 47 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 48 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 49 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 50 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 51 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 52 |
+
|
| 53 |
+
last_tic = time.time()
|
| 54 |
+
|
| 55 |
+
# Now you train the model
|
| 56 |
+
for epoch in range(first_epoch, args.num_train_epochs + 1):
|
| 57 |
+
data_time_start = time.time()
|
| 58 |
+
data_time_all = 0
|
| 59 |
+
|
| 60 |
+
for step, batch in enumerate(train_dataloader):
|
| 61 |
+
images_win = batch['images_win']
|
| 62 |
+
images_lose = batch['images_lose']
|
| 63 |
+
half_batch_size = images_win.shape[0]
|
| 64 |
+
image = torch.concat([images_win, images_lose], dim=0)
|
| 65 |
+
|
| 66 |
+
data_time_all += time.time() - data_time_start
|
| 67 |
+
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
latents = vae.encode(image.to(weight_dtype).to(accelerator.device)).latent_dist.sample()
|
| 70 |
+
latents = latents.to(dtype=(weight_dtype))
|
| 71 |
+
latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
|
| 72 |
+
|
| 73 |
+
text_input_ids = batch['input_ids'].to(accelerator.device)
|
| 74 |
+
text_attention_mask = batch['attention_mask'].to(accelerator.device)
|
| 75 |
+
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
text_output = text_encoder(
|
| 78 |
+
input_ids=text_input_ids,
|
| 79 |
+
attention_mask=text_attention_mask,
|
| 80 |
+
output_hidden_states=True
|
| 81 |
+
)
|
| 82 |
+
prompt_embeds = text_output.hidden_states[-1].clone().detach()
|
| 83 |
+
|
| 84 |
+
prompt_embeds = prompt_embeds.to(weight_dtype)
|
| 85 |
+
prompt_embeds = torch.concat([prompt_embeds, prompt_embeds], dim=0) # [batch_size, 256, 4096]
|
| 86 |
+
prompt_embeds = prompt_embeds[:,args.prompt_template_encode_start_idx: -args.prompt_template_encode_end_idx ,:]
|
| 87 |
+
|
| 88 |
+
# Sample a random timestep for each image
|
| 89 |
+
grad_norm = None
|
| 90 |
+
with accelerator.accumulate(transformer):
|
| 91 |
+
# Predict the noise residual
|
| 92 |
+
optimizer.zero_grad()
|
| 93 |
+
# logit-normal
|
| 94 |
+
sigmas = torch.sigmoid(torch.randn((half_batch_size,), device=accelerator.device, dtype=latents.dtype))
|
| 95 |
+
|
| 96 |
+
if args.use_dynamic_shifting:
|
| 97 |
+
sigmas = noise_scheduler.time_shift(mu, 1.0, sigmas)
|
| 98 |
+
|
| 99 |
+
sigmas = torch.concat( [sigmas,sigmas], dim=0 )
|
| 100 |
+
timesteps = sigmas * 1000.0
|
| 101 |
+
sigmas = sigmas.view(-1, 1, 1, 1)
|
| 102 |
+
|
| 103 |
+
noise = torch.randn_like(latents)
|
| 104 |
+
noise = noise.chunk(2)[0].repeat(2, 1, 1, 1)
|
| 105 |
+
|
| 106 |
+
noisy_latents = (1 - sigmas) * latents + sigmas * noise
|
| 107 |
+
noisy_latents = noisy_latents.to(weight_dtype)
|
| 108 |
+
|
| 109 |
+
packed_noisy_latents = pack_latents(
|
| 110 |
+
noisy_latents,
|
| 111 |
+
batch_size=latents.shape[0],
|
| 112 |
+
num_channels_latents=latents.shape[1],
|
| 113 |
+
height=latents.shape[2],
|
| 114 |
+
width=latents.shape[3],
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
guidance = None
|
| 118 |
+
img_ids = prepare_pos_ids(modality_id=1,
|
| 119 |
+
type='image',
|
| 120 |
+
start=(prompt_embeds.shape[1], prompt_embeds.shape[1]),
|
| 121 |
+
height=latents.shape[2]//2,
|
| 122 |
+
width=latents.shape[3]//2).to(accelerator.device, dtype=torch.float64)
|
| 123 |
+
|
| 124 |
+
timesteps = (
|
| 125 |
+
torch.tensor(timesteps)
|
| 126 |
+
.expand(noisy_latents.shape[0])
|
| 127 |
+
.to(device=accelerator.device)
|
| 128 |
+
/ 1000
|
| 129 |
+
)
|
| 130 |
+
text_ids = prepare_pos_ids(modality_id=0,
|
| 131 |
+
type='text',
|
| 132 |
+
start=(0, 0),
|
| 133 |
+
num_token=prompt_embeds.shape[1]).to(accelerator.device, torch.float64)
|
| 134 |
+
|
| 135 |
+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
| 136 |
+
model_pred = transformer(packed_noisy_latents, prompt_embeds, timesteps,
|
| 137 |
+
img_ids, text_ids, guidance, return_dict=False)[0]
|
| 138 |
+
|
| 139 |
+
model_pred = unpack_latents(
|
| 140 |
+
model_pred,
|
| 141 |
+
height=latents.shape[2] * 8,
|
| 142 |
+
width=latents.shape[3] * 8,
|
| 143 |
+
vae_scale_factor=16,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
target = noise - latents
|
| 147 |
+
model_losses = torch.mean(
|
| 148 |
+
((model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
| 149 |
+
1,
|
| 150 |
+
)
|
| 151 |
+
model_loss_win, model_loss_lose = model_losses.chunk(2)
|
| 152 |
+
model_diff = model_loss_win - model_loss_lose
|
| 153 |
+
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
ref_pred = ref_model(packed_noisy_latents, prompt_embeds, timesteps,
|
| 156 |
+
img_ids, text_ids, guidance, return_dict=False)[0]
|
| 157 |
+
|
| 158 |
+
ref_pred = unpack_latents(
|
| 159 |
+
ref_pred,
|
| 160 |
+
height=latents.shape[2] * 8,
|
| 161 |
+
width=latents.shape[3] * 8,
|
| 162 |
+
vae_scale_factor=16,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
ref_losses = torch.mean(
|
| 166 |
+
((ref_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
| 167 |
+
1,
|
| 168 |
+
)
|
| 169 |
+
ref_loss_win, ref_loss_lose = ref_losses.chunk(2)
|
| 170 |
+
ref_diff = (ref_loss_win - ref_loss_lose)
|
| 171 |
+
|
| 172 |
+
inside_term = -args.beta_dpo * (model_diff - ref_diff)
|
| 173 |
+
# implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
| 174 |
+
loss = - F.logsigmoid(inside_term).mean()
|
| 175 |
+
|
| 176 |
+
accelerator.backward(loss)
|
| 177 |
+
|
| 178 |
+
if accelerator.sync_gradients:
|
| 179 |
+
grad_norm = transformer.get_global_grad_norm()
|
| 180 |
+
|
| 181 |
+
optimizer.step()
|
| 182 |
+
if not accelerator.optimizer_step_was_skipped:
|
| 183 |
+
lr_scheduler.step()
|
| 184 |
+
|
| 185 |
+
if accelerator.sync_gradients and args.use_ema:
|
| 186 |
+
model_ema.step(transformer.parameters())
|
| 187 |
+
|
| 188 |
+
lr = lr_scheduler.get_last_lr()[0]
|
| 189 |
+
|
| 190 |
+
if accelerator.sync_gradients:
|
| 191 |
+
global_step += 1
|
| 192 |
+
|
| 193 |
+
bsz, ic, ih, iw = image.shape
|
| 194 |
+
logs = {"loss": accelerator.gather(loss).mean().item(), 'aspect_ratio': (ih*1.0 / iw)}
|
| 195 |
+
if grad_norm is not None:
|
| 196 |
+
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item())
|
| 197 |
+
|
| 198 |
+
log_buffer.update(logs)
|
| 199 |
+
if (step + 1) % args.log_interval == 0 or (step + 1) == 1:
|
| 200 |
+
t = (time.time() - last_tic) / args.log_interval
|
| 201 |
+
t_d = data_time_all / args.log_interval
|
| 202 |
+
|
| 203 |
+
log_buffer.average()
|
| 204 |
+
info = f"Step={step+1}, Epoch={epoch}, global_step={global_step}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:(ch:{latents.shape[1]},h:{latents.shape[2]},w:{latents.shape[3]}), "
|
| 205 |
+
info += ', '.join([f"{k}:{v:.4f}" for k,v in log_buffer.output.items()])
|
| 206 |
+
logger.info(info)
|
| 207 |
+
last_tic = time.time()
|
| 208 |
+
log_buffer.clear()
|
| 209 |
+
data_time_all = 0
|
| 210 |
+
logs.update(lr=lr)
|
| 211 |
+
accelerator.log(logs, step=global_step)
|
| 212 |
+
data_time_start = time.time()
|
| 213 |
+
|
| 214 |
+
if global_step != 0 and global_step % args.save_model_steps == 0:
|
| 215 |
+
save_path = os.path.join(
|
| 216 |
+
args.work_dir, f'checkpoints-{global_step}')
|
| 217 |
+
if args.use_ema:
|
| 218 |
+
model_ema.store(transformer.parameters())
|
| 219 |
+
model_ema.copy_to(transformer.parameters())
|
| 220 |
+
|
| 221 |
+
accelerator.save_state(save_path)
|
| 222 |
+
|
| 223 |
+
if args.use_ema:
|
| 224 |
+
model_ema.restore(transformer.parameters())
|
| 225 |
+
logger.info(f"Saved state to {save_path} (global_step: {global_step})")
|
| 226 |
+
accelerator.wait_for_everyone()
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def parse_args():
|
| 230 |
+
parser = argparse.ArgumentParser(description="Process some integers.")
|
| 231 |
+
parser.add_argument("--config", type=str, default='', help="config")
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--report_to",
|
| 234 |
+
type=str,
|
| 235 |
+
default="tensorboard",
|
| 236 |
+
help=(
|
| 237 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 238 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 239 |
+
),
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--allow_tf32",
|
| 243 |
+
action="store_true",
|
| 244 |
+
help=(
|
| 245 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 246 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 247 |
+
),
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 251 |
+
)
|
| 252 |
+
args = parser.parse_args()
|
| 253 |
+
return args
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == '__main__':
|
| 257 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 258 |
+
|
| 259 |
+
args = parse_args()
|
| 260 |
+
|
| 261 |
+
if args.config != '' and os.path.exists(args.config):
|
| 262 |
+
config = yaml.safe_load(open(args.config, 'r'))
|
| 263 |
+
else:
|
| 264 |
+
config = yaml.safe_load(open(f'{cur_dir}/train_config.yaml', 'r'))
|
| 265 |
+
|
| 266 |
+
args_dict = vars(args)
|
| 267 |
+
args_dict.update(config)
|
| 268 |
+
args = argparse.Namespace(**args_dict)
|
| 269 |
+
|
| 270 |
+
os.umask(0o000)
|
| 271 |
+
os.makedirs(args.work_dir, exist_ok=True)
|
| 272 |
+
|
| 273 |
+
log_dir = args.work_dir + f'/logs'
|
| 274 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 275 |
+
accelerator_project_config = ProjectConfiguration(
|
| 276 |
+
project_dir=args.work_dir, logging_dir=log_dir)
|
| 277 |
+
|
| 278 |
+
with open(f'{log_dir}/train.yaml', 'w') as f:
|
| 279 |
+
yaml.dump(args_dict, f)
|
| 280 |
+
|
| 281 |
+
accelerator = Accelerator(
|
| 282 |
+
mixed_precision=args.mixed_precision,
|
| 283 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 284 |
+
log_with=args.report_to,
|
| 285 |
+
project_config= accelerator_project_config,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Make one log on every process with the configuration for debugging.
|
| 289 |
+
logging.basicConfig(
|
| 290 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 291 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 292 |
+
level=logging.INFO,
|
| 293 |
+
)
|
| 294 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 295 |
+
|
| 296 |
+
if accelerator.is_local_main_process:
|
| 297 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 298 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 299 |
+
else:
|
| 300 |
+
transformers.utils.logging.set_verbosity_error()
|
| 301 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 302 |
+
|
| 303 |
+
if args.seed is not None:
|
| 304 |
+
set_seed(args.seed)
|
| 305 |
+
|
| 306 |
+
weight_dtype = torch.float32
|
| 307 |
+
if accelerator.mixed_precision == "fp16":
|
| 308 |
+
weight_dtype = torch.float16
|
| 309 |
+
elif accelerator.mixed_precision == "bf16":
|
| 310 |
+
weight_dtype = torch.bfloat16
|
| 311 |
+
|
| 312 |
+
logger.info(f'using weight_dtype {weight_dtype}!!!')
|
| 313 |
+
|
| 314 |
+
if args.diffusion_pretrain_weight:
|
| 315 |
+
transformer = LongCatImageTransformer2DModel.from_pretrained(args.diffusion_pretrain_weight, ignore_mismatched_sizes=False)
|
| 316 |
+
ref_model = LongCatImageTransformer2DModel.from_pretrained(args.diffusion_pretrain_weight, ignore_mismatched_sizes=False)
|
| 317 |
+
logger.info(f'successful load model weight for BaseModel and RefModel, {args.diffusion_pretrain_weight}!!!')
|
| 318 |
+
else:
|
| 319 |
+
transformer = LongCatImageTransformer2DModel.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "transformer"), ignore_mismatched_sizes=False)
|
| 320 |
+
ref_model = LongCatImageTransformer2DModel.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "transformer"), ignore_mismatched_sizes=False)
|
| 321 |
+
logger.info(f'successful load model weight for BaseModel and RefModel, {args.pretrained_model_name_or_path+"/transformer"}!!!')
|
| 322 |
+
|
| 323 |
+
transformer = transformer.train()
|
| 324 |
+
ref_model.requires_grad_(False)
|
| 325 |
+
ref_model.to(accelerator.device, weight_dtype)
|
| 326 |
+
|
| 327 |
+
total_trainable_params = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
|
| 328 |
+
logger.info(f">>>>>> total_trainable_params: {total_trainable_params}")
|
| 329 |
+
|
| 330 |
+
if args.use_ema:
|
| 331 |
+
model_ema = EMAModel(transformer.parameters(), decay=args.ema_rate)
|
| 332 |
+
else:
|
| 333 |
+
model_ema = None
|
| 334 |
+
|
| 335 |
+
vae_dtype = torch.float32
|
| 336 |
+
vae = AutoencoderKL.from_pretrained(
|
| 337 |
+
args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype).cuda().eval()
|
| 338 |
+
|
| 339 |
+
text_encoder = AutoModel.from_pretrained(
|
| 340 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder" , torch_dtype=weight_dtype, trust_remote_code=True).cuda().eval()
|
| 341 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 342 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer" , torch_dtype=weight_dtype, trust_remote_code=True)
|
| 343 |
+
logger.info("all models loaded successfully")
|
| 344 |
+
|
| 345 |
+
# build models
|
| 346 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 347 |
+
args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 348 |
+
|
| 349 |
+
latent_size = int(args.resolution) // 8
|
| 350 |
+
mu = calculate_shift(
|
| 351 |
+
(latent_size//2)**2,
|
| 352 |
+
noise_scheduler.config.base_image_seq_len,
|
| 353 |
+
noise_scheduler.config.max_image_seq_len,
|
| 354 |
+
noise_scheduler.config.base_shift,
|
| 355 |
+
noise_scheduler.config.max_shift,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
| 359 |
+
def save_model_hook(models, weights, output_dir):
|
| 360 |
+
if accelerator.is_main_process:
|
| 361 |
+
for i, model in enumerate(models):
|
| 362 |
+
model.save_pretrained(os.path.join(output_dir, "transformer"))
|
| 363 |
+
if len(weights) != 0:
|
| 364 |
+
weights.pop()
|
| 365 |
+
|
| 366 |
+
def load_model_hook(models, input_dir):
|
| 367 |
+
while len(models) > 0:
|
| 368 |
+
# pop models so that they are not loaded again
|
| 369 |
+
model = models.pop()
|
| 370 |
+
# load diffusers style into model
|
| 371 |
+
load_model = MeiGenImageTransformer2DModel.from_pretrained(
|
| 372 |
+
input_dir, subfolder="transformer")
|
| 373 |
+
model.register_to_config(**load_model.config)
|
| 374 |
+
model.load_state_dict(load_model.state_dict())
|
| 375 |
+
del load_model
|
| 376 |
+
|
| 377 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 378 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 379 |
+
|
| 380 |
+
if args.gradient_checkpointing:
|
| 381 |
+
transformer.enable_gradient_checkpointing()
|
| 382 |
+
|
| 383 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 384 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 385 |
+
if args.allow_tf32:
|
| 386 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 387 |
+
|
| 388 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 389 |
+
if args.use_8bit_adam:
|
| 390 |
+
try:
|
| 391 |
+
import bitsandbytes as bnb
|
| 392 |
+
except ImportError:
|
| 393 |
+
raise ImportError(
|
| 394 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 395 |
+
)
|
| 396 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 397 |
+
else:
|
| 398 |
+
optimizer_class = torch.optim.AdamW
|
| 399 |
+
|
| 400 |
+
params_to_optimize = transformer.parameters()
|
| 401 |
+
optimizer = optimizer_class(
|
| 402 |
+
params_to_optimize,
|
| 403 |
+
lr=args.learning_rate,
|
| 404 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 405 |
+
weight_decay=args.adam_weight_decay,
|
| 406 |
+
eps=args.adam_epsilon,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 410 |
+
if args.use_ema:
|
| 411 |
+
model_ema.to(accelerator.device, dtype=weight_dtype)
|
| 412 |
+
|
| 413 |
+
train_dataloader = build_dataloader(args, args.data_txt_root, tokenizer, args.resolution,)
|
| 414 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 415 |
+
|
| 416 |
+
# Afterwards we recalculate our number of training epochs
|
| 417 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 418 |
+
|
| 419 |
+
lr_scheduler = get_scheduler(
|
| 420 |
+
args.lr_scheduler,
|
| 421 |
+
optimizer=optimizer,
|
| 422 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 423 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 424 |
+
num_cycles=args.lr_num_cycles,
|
| 425 |
+
power=args.lr_power,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
global_step = 0
|
| 429 |
+
first_epoch = 0
|
| 430 |
+
# Potentially load in the weights and states from a previous save
|
| 431 |
+
if args.resume_from_checkpoint:
|
| 432 |
+
if args.resume_from_checkpoint != "latest":
|
| 433 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 434 |
+
else:
|
| 435 |
+
# Get the most recent checkpoint
|
| 436 |
+
dirs = os.listdir(args.work_dir)
|
| 437 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 438 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 439 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 440 |
+
if path is None:
|
| 441 |
+
logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
|
| 442 |
+
args.resume_from_checkpoint = None
|
| 443 |
+
initial_global_step = 0
|
| 444 |
+
else:
|
| 445 |
+
logger.info(f"Resuming from checkpoint {path}")
|
| 446 |
+
accelerator.load_state(os.path.join(args.work_dir, path))
|
| 447 |
+
global_step = int(path.split("-")[1])
|
| 448 |
+
|
| 449 |
+
initial_global_step = global_step
|
| 450 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 451 |
+
|
| 452 |
+
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
|
| 453 |
+
if accelerator.is_main_process:
|
| 454 |
+
tracker_config = dict(vars(args))
|
| 455 |
+
try:
|
| 456 |
+
accelerator.init_trackers('dpo', tracker_config)
|
| 457 |
+
except Exception as e:
|
| 458 |
+
logger.warning(f'get error in save config, {e}')
|
| 459 |
+
accelerator.init_trackers(f"dpo_{timestamp}")
|
| 460 |
+
|
| 461 |
+
transformer, optimizer, _, _ = accelerator.prepare(
|
| 462 |
+
transformer, optimizer, train_dataloader, lr_scheduler)
|
| 463 |
+
|
| 464 |
+
log_buffer = LogBuffer()
|
| 465 |
+
train(global_step=global_step)
|
examples/edit/train.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export TOKENIZERS_PARALLELISM=False
|
| 2 |
+
export NCCL_DEBUG=INFO
|
| 3 |
+
export NCCL_TIMEOUT=12000
|
| 4 |
+
|
| 5 |
+
script_dir=$(cd -- "$(dirname -- "$0")" &> /dev/null && pwd -P)
|
| 6 |
+
project_root=$(dirname "$(dirname "$script_dir")")
|
| 7 |
+
echo "script_dir" ${script_dir}
|
| 8 |
+
|
| 9 |
+
deepspeed_config_file=${project_root}/misc/accelerate_config.yaml
|
| 10 |
+
|
| 11 |
+
accelerate launch --mixed_precision bf16 --num_processes 8 --config_file ${deepspeed_config_file} \
|
| 12 |
+
${script_dir}/train_edit.py \
|
| 13 |
+
--config ${script_dir}/train_config.yaml
|
examples/edit/train_config.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1.Data setting
|
| 2 |
+
data_txt_root: '/dataset/example/train_data_info.txt' # data csv_filepath
|
| 3 |
+
resolution: 1024
|
| 4 |
+
aspect_ratio_type: 'mar_1024' # data bucketing strategy, mar_256ใmar_512ใmar_1024
|
| 5 |
+
null_text_ratio: 0.1
|
| 6 |
+
dataloader_num_workers: 8
|
| 7 |
+
train_batch_size: 4
|
| 8 |
+
repeats: 1
|
| 9 |
+
|
| 10 |
+
prompt_template_encode_prefix: "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
| 11 |
+
prompt_template_encode_suffix: '<|im_end|>\n<|im_start|>assistant\n'
|
| 12 |
+
prompt_template_encode_start_idx: 67
|
| 13 |
+
prompt_template_encode_end_idx: 5
|
| 14 |
+
|
| 15 |
+
# 2. Model setting
|
| 16 |
+
text_tokenizer_max_length: 512 # tokenizer max len
|
| 17 |
+
pretrained_model_name_or_path: "./weights/LongCat-Image-Edit" # root directory of the model๏ผwith vaeใtransformerใscheduler eta;
|
| 18 |
+
diffusion_pretrain_weight: null # if a specified diffusion weight path is provided, load the model parameters from the current directory.
|
| 19 |
+
use_dynamic_shifting: true # scheduler dynamic shifting
|
| 20 |
+
resume_from_checkpoint: latest
|
| 21 |
+
# - "latest" # Loads most recent step checkpoint
|
| 22 |
+
# - "/path/to/checkpoint" # Resumes from specified directory
|
| 23 |
+
|
| 24 |
+
# 3. Training setting
|
| 25 |
+
use_ema: False
|
| 26 |
+
ema_rate: 0.999
|
| 27 |
+
mixed_precision: 'bf16'
|
| 28 |
+
max_train_steps: 100000
|
| 29 |
+
gradient_accumulation_steps: 1
|
| 30 |
+
gradient_checkpointing: true
|
| 31 |
+
gradient_clip: 1.0
|
| 32 |
+
learning_rate: 1.0e-5
|
| 33 |
+
adam_weight_decay: 1.0e-2
|
| 34 |
+
adam_epsilon: 1.0e-8
|
| 35 |
+
adam_beta1: 0.9
|
| 36 |
+
adam_beta2: 0.999
|
| 37 |
+
lr_num_cycles: 1
|
| 38 |
+
lr_power: 1.0
|
| 39 |
+
lr_scheduler: 'constant'
|
| 40 |
+
lr_warmup_steps: 1000
|
| 41 |
+
|
| 42 |
+
#4. Log setting
|
| 43 |
+
log_interval: 20
|
| 44 |
+
save_model_steps: 1000
|
| 45 |
+
work_dir: 'output/edit_model'
|
| 46 |
+
seed: 43
|
examples/edit/train_dataset.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import traceback
|
| 5 |
+
import math
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 13 |
+
from transformers import AutoTokenizer, AutoProcessor
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
from longcat_image.dataset import MULTI_RESOLUTION_MAP
|
| 18 |
+
from longcat_image.utils import encode_prompt_edit
|
| 19 |
+
from longcat_image.dataset import MultiResolutionDistributedSampler
|
| 20 |
+
|
| 21 |
+
Image.MAX_IMAGE_PIXELS = 2000000000
|
| 22 |
+
|
| 23 |
+
MAX_RETRY_NUMS = 100
|
| 24 |
+
|
| 25 |
+
class Text2ImageLoraDataSet(torch.utils.data.Dataset):
|
| 26 |
+
def __init__(self,
|
| 27 |
+
cfg: dict,
|
| 28 |
+
txt_root: str,
|
| 29 |
+
tokenizer: AutoTokenizer,
|
| 30 |
+
text_processor:AutoProcessor,
|
| 31 |
+
resolution: tuple = (1024, 1024),
|
| 32 |
+
repeats: int = 1 ):
|
| 33 |
+
super(Text2ImageLoraDataSet, self).__init__()
|
| 34 |
+
self.resolution = resolution
|
| 35 |
+
self.text_tokenizer_max_length = cfg.text_tokenizer_max_length
|
| 36 |
+
self.null_text_ratio = cfg.null_text_ratio
|
| 37 |
+
self.aspect_ratio_type = cfg.aspect_ratio_type
|
| 38 |
+
self.aspect_ratio = MULTI_RESOLUTION_MAP[self.aspect_ratio_type]
|
| 39 |
+
self.tokenizer = tokenizer
|
| 40 |
+
self.image_processor_vl = text_processor.image_processor
|
| 41 |
+
|
| 42 |
+
self.prompt_template_encode_prefix = cfg.prompt_template_encode_prefix
|
| 43 |
+
self.prompt_template_encode_suffix = cfg.prompt_template_encode_suffix
|
| 44 |
+
self.prompt_template_encode_start_idx = cfg.prompt_template_encode_start_idx
|
| 45 |
+
self.prompt_template_encode_end_idx = cfg.prompt_template_encode_end_idx
|
| 46 |
+
|
| 47 |
+
self.total_datas = []
|
| 48 |
+
self.data_resolution_infos = []
|
| 49 |
+
with open(txt_root, 'r') as f:
|
| 50 |
+
lines = f.readlines()
|
| 51 |
+
lines *= cfg.repeats
|
| 52 |
+
for line in tqdm(lines):
|
| 53 |
+
data = json.loads(line.strip())
|
| 54 |
+
try:
|
| 55 |
+
height, widht = int(data['height']), int(data['width'])
|
| 56 |
+
self.data_resolution_infos.append((height, widht))
|
| 57 |
+
self.total_datas.append(data)
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f'get error {e}, data {data}.')
|
| 60 |
+
continue
|
| 61 |
+
self.data_nums = len(self.total_datas)
|
| 62 |
+
print(f'get sampler {len(self.total_datas)}, from {txt_root}!!!')
|
| 63 |
+
|
| 64 |
+
def transform_img(self, image, original_size, target_size):
|
| 65 |
+
img_h, img_w = original_size
|
| 66 |
+
target_height, target_width = target_size
|
| 67 |
+
|
| 68 |
+
original_aspect = img_h / img_w # height/width
|
| 69 |
+
crop_aspect = target_height / target_width
|
| 70 |
+
|
| 71 |
+
if original_aspect >= crop_aspect:
|
| 72 |
+
resize_width = target_width
|
| 73 |
+
resize_height = math.ceil(img_h * (target_width/img_w))
|
| 74 |
+
else:
|
| 75 |
+
resize_width = math.ceil(img_w * (target_height/img_h))
|
| 76 |
+
resize_height = target_height
|
| 77 |
+
|
| 78 |
+
image = T.Compose([
|
| 79 |
+
T.Resize((resize_height, resize_width),interpolation=InterpolationMode.BICUBIC),
|
| 80 |
+
T.CenterCrop((target_height, target_width)),
|
| 81 |
+
T.ToTensor(),
|
| 82 |
+
T.Normalize([.5], [.5]),
|
| 83 |
+
])(image)
|
| 84 |
+
|
| 85 |
+
return image
|
| 86 |
+
def transform_img_vl(self, image, original_size, target_size):
|
| 87 |
+
img_h, img_w = original_size
|
| 88 |
+
target_height, target_width = target_size
|
| 89 |
+
|
| 90 |
+
original_aspect = img_h / img_w # height/width
|
| 91 |
+
crop_aspect = target_height / target_width
|
| 92 |
+
|
| 93 |
+
if original_aspect >= crop_aspect:
|
| 94 |
+
resize_width = target_width
|
| 95 |
+
resize_height = math.ceil(img_h * (target_width/img_w))
|
| 96 |
+
else:
|
| 97 |
+
resize_width = math.ceil(img_w * (target_height/img_h))
|
| 98 |
+
resize_height = target_height
|
| 99 |
+
|
| 100 |
+
image = T.Compose([
|
| 101 |
+
T.Resize((resize_height, resize_width),interpolation=InterpolationMode.BICUBIC),
|
| 102 |
+
T.CenterCrop((target_height, target_width)),
|
| 103 |
+
T.Resize((target_height//2, target_width//2)),
|
| 104 |
+
])(image)
|
| 105 |
+
|
| 106 |
+
return image
|
| 107 |
+
|
| 108 |
+
def __getitem__(self, index_tuple):
|
| 109 |
+
index, target_size = index_tuple
|
| 110 |
+
|
| 111 |
+
for _ in range(MAX_RETRY_NUMS):
|
| 112 |
+
try:
|
| 113 |
+
item = self.total_datas[index]
|
| 114 |
+
img_path = item["img_path"]
|
| 115 |
+
ref_img_path = item["ref_img_path"]
|
| 116 |
+
prompt = item['prompt']
|
| 117 |
+
|
| 118 |
+
if random.random() < self.null_text_ratio:
|
| 119 |
+
prompt = ''
|
| 120 |
+
|
| 121 |
+
raw_image = Image.open(img_path).convert('RGB')
|
| 122 |
+
ref_image = Image.open(ref_img_path).convert('RGB')
|
| 123 |
+
assert raw_image is not None
|
| 124 |
+
img_w, img_h = raw_image.size
|
| 125 |
+
|
| 126 |
+
ref_image_vl = self.transform_img_vl(ref_image, original_size=(img_h, img_w), target_size= target_size )
|
| 127 |
+
raw_image = self.transform_img(raw_image, original_size=(img_h, img_w), target_size= target_size )
|
| 128 |
+
ref_image = self.transform_img(ref_image, original_size=(img_h, img_w), target_size= target_size )
|
| 129 |
+
|
| 130 |
+
input_ids, attention_mask, pixel_values, image_grid_thw = encode_prompt_edit(prompt,ref_image_vl, self.tokenizer, self.image_processor_vl,self.text_tokenizer_max_length, self.prompt_template_encode_prefix, self.prompt_template_encode_suffix )
|
| 131 |
+
return {"image": raw_image, "ref_image":ref_image, "prompt": prompt, 'input_ids': input_ids, 'attention_mask': attention_mask, 'pixel_values':pixel_values, 'image_grid_thw':image_grid_thw}
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
traceback.print_exc()
|
| 135 |
+
print(f"failed read data {e}!!!")
|
| 136 |
+
index = random.randint(0, self.data_nums-1)
|
| 137 |
+
|
| 138 |
+
def __len__(self):
|
| 139 |
+
return self.data_nums
|
| 140 |
+
|
| 141 |
+
def collate_fn(self, batchs):
|
| 142 |
+
images = torch.stack([example["image"] for example in batchs])
|
| 143 |
+
ref_images = torch.stack([example["ref_image"] for example in batchs])
|
| 144 |
+
input_ids = torch.stack([example["input_ids"] for example in batchs])
|
| 145 |
+
attention_mask = torch.stack([example["attention_mask"] for example in batchs])
|
| 146 |
+
pixel_values = torch.stack([example["pixel_values"] for example in batchs])
|
| 147 |
+
image_grid_thw = torch.stack([example["image_grid_thw"] for example in batchs])
|
| 148 |
+
prompts = [example['prompt'] for example in batchs]
|
| 149 |
+
batch_dict = {
|
| 150 |
+
"images": images,
|
| 151 |
+
"ref_images": ref_images,
|
| 152 |
+
"input_ids": input_ids,
|
| 153 |
+
"attention_mask": attention_mask,
|
| 154 |
+
"prompts": prompts,
|
| 155 |
+
"pixel_values":pixel_values,
|
| 156 |
+
"image_grid_thw":image_grid_thw
|
| 157 |
+
}
|
| 158 |
+
return batch_dict
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def build_dataloader(cfg: dict,
|
| 162 |
+
csv_root: str,
|
| 163 |
+
tokenizer: AutoTokenizer,
|
| 164 |
+
text_processor: AutoProcessor,
|
| 165 |
+
resolution: tuple = (1024, 1024)):
|
| 166 |
+
dataset = Text2ImageLoraDataSet(cfg, csv_root, tokenizer, text_processor, resolution)
|
| 167 |
+
|
| 168 |
+
sampler = MultiResolutionDistributedSampler(batch_size=cfg.train_batch_size, dataset=dataset,
|
| 169 |
+
data_resolution_infos=dataset.data_resolution_infos,
|
| 170 |
+
bucket_info=dataset.aspect_ratio,
|
| 171 |
+
epoch=0,
|
| 172 |
+
num_replicas=None,
|
| 173 |
+
rank=None
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
train_loader = torch.utils.data.DataLoader(
|
| 177 |
+
dataset,
|
| 178 |
+
collate_fn=dataset.collate_fn,
|
| 179 |
+
batch_size=cfg.train_batch_size,
|
| 180 |
+
num_workers=cfg.dataloader_num_workers,
|
| 181 |
+
sampler=sampler,
|
| 182 |
+
shuffle=None,
|
| 183 |
+
)
|
| 184 |
+
return train_loader
|
examples/edit/train_edit.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import yaml
|
| 8 |
+
import torch
|
| 9 |
+
import math
|
| 10 |
+
import logging
|
| 11 |
+
import transformers
|
| 12 |
+
import diffusers
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from transformers import Qwen2Model, Qwen2TokenizerFast
|
| 15 |
+
from accelerate import Accelerator, InitProcessGroupKwargs
|
| 16 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 17 |
+
from accelerate.logging import get_logger
|
| 18 |
+
from diffusers.models import AutoencoderKL
|
| 19 |
+
from diffusers.optimization import get_scheduler
|
| 20 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 21 |
+
from diffusers.training_utils import EMAModel
|
| 22 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 23 |
+
from transformers import AutoTokenizer, AutoModel, AutoProcessor
|
| 24 |
+
|
| 25 |
+
from train_dataset import build_dataloader
|
| 26 |
+
from longcat_image.models import LongCatImageTransformer2DModel
|
| 27 |
+
from longcat_image.utils import LogBuffer
|
| 28 |
+
from longcat_image.utils import pack_latents, unpack_latents, calculate_shift, prepare_pos_ids
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
warnings.filterwarnings("ignore") # ignore warning
|
| 32 |
+
|
| 33 |
+
current_file_path = Path(__file__).resolve()
|
| 34 |
+
sys.path.insert(0, str(current_file_path.parent.parent))
|
| 35 |
+
|
| 36 |
+
logger = get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
def train(global_step=0):
|
| 39 |
+
|
| 40 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 41 |
+
|
| 42 |
+
# Train!
|
| 43 |
+
total_batch_size = args.train_batch_size * \
|
| 44 |
+
accelerator.num_processes * args.gradient_accumulation_steps
|
| 45 |
+
|
| 46 |
+
logger.info("***** Running training *****")
|
| 47 |
+
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
|
| 48 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 49 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 50 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 51 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 52 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 53 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 54 |
+
|
| 55 |
+
last_tic = time.time()
|
| 56 |
+
|
| 57 |
+
# Now you train the model
|
| 58 |
+
for epoch in range(first_epoch, args.num_train_epochs + 1):
|
| 59 |
+
data_time_start = time.time()
|
| 60 |
+
data_time_all = 0
|
| 61 |
+
|
| 62 |
+
for step, batch in enumerate(train_dataloader):
|
| 63 |
+
image = batch['images']
|
| 64 |
+
ref_image = batch['ref_images']
|
| 65 |
+
|
| 66 |
+
data_time_all += time.time() - data_time_start
|
| 67 |
+
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
latents = vae.encode(image.to(weight_dtype).to(accelerator.device)).latent_dist.sample()
|
| 70 |
+
latents = latents.to(dtype=(weight_dtype))
|
| 71 |
+
latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
|
| 72 |
+
|
| 73 |
+
ref_latents = vae.encode(ref_image.to(weight_dtype).to(accelerator.device)).latent_dist.sample()
|
| 74 |
+
ref_latents = ref_latents.to(dtype=(weight_dtype))
|
| 75 |
+
ref_latents = (ref_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
| 76 |
+
|
| 77 |
+
text_input_ids = batch['input_ids'].to(accelerator.device)
|
| 78 |
+
text_attention_mask = batch['attention_mask'].to(accelerator.device)
|
| 79 |
+
pixel_values = batch['pixel_values'].to(accelerator.device)
|
| 80 |
+
image_grid_thw = batch['image_grid_thw'].to(accelerator.device)
|
| 81 |
+
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
text_output = text_encoder(
|
| 84 |
+
input_ids=text_input_ids,
|
| 85 |
+
attention_mask=text_attention_mask,
|
| 86 |
+
pixel_values=pixel_values,
|
| 87 |
+
image_grid_thw=image_grid_thw,
|
| 88 |
+
output_hidden_states=True
|
| 89 |
+
)
|
| 90 |
+
prompt_embeds = text_output.hidden_states[-1].clone().detach()
|
| 91 |
+
|
| 92 |
+
prompt_embeds = prompt_embeds.to(weight_dtype)
|
| 93 |
+
prompt_embeds = prompt_embeds[:,args.prompt_template_encode_start_idx: -args.prompt_template_encode_end_idx ,:]
|
| 94 |
+
|
| 95 |
+
# Sample a random timestep for each image
|
| 96 |
+
grad_norm = None
|
| 97 |
+
with accelerator.accumulate(transformer):
|
| 98 |
+
# Predict the noise residual
|
| 99 |
+
optimizer.zero_grad()
|
| 100 |
+
# logit-normal
|
| 101 |
+
sigmas = torch.sigmoid(torch.randn((latents.shape[0],), device=accelerator.device, dtype=latents.dtype))
|
| 102 |
+
|
| 103 |
+
if args.use_dynamic_shifting:
|
| 104 |
+
sigmas = noise_scheduler.time_shift(mu, 1.0, sigmas)
|
| 105 |
+
|
| 106 |
+
timesteps = sigmas * 1000.0
|
| 107 |
+
sigmas = sigmas.view(-1, 1, 1, 1)
|
| 108 |
+
|
| 109 |
+
noise = torch.randn_like(latents)
|
| 110 |
+
|
| 111 |
+
noisy_latents = (1 - sigmas) * latents + sigmas * noise
|
| 112 |
+
noisy_latents = noisy_latents.to(weight_dtype)
|
| 113 |
+
|
| 114 |
+
packed_noisy_latents = pack_latents(
|
| 115 |
+
noisy_latents,
|
| 116 |
+
batch_size=latents.shape[0],
|
| 117 |
+
num_channels_latents=latents.shape[1],
|
| 118 |
+
height=latents.shape[2],
|
| 119 |
+
width=latents.shape[3],
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
packed_ref_latents = pack_latents(
|
| 123 |
+
ref_latents,
|
| 124 |
+
batch_size=ref_latents.shape[0],
|
| 125 |
+
num_channels_latents=ref_latents.shape[1],
|
| 126 |
+
height=ref_latents.shape[2],
|
| 127 |
+
width=ref_latents.shape[3],
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
guidance = None
|
| 131 |
+
img_ids = prepare_pos_ids(modality_id=1,
|
| 132 |
+
type='image',
|
| 133 |
+
start=(prompt_embeds.shape[1], prompt_embeds.shape[1]),
|
| 134 |
+
height=latents.shape[2]//2,
|
| 135 |
+
width=latents.shape[3]//2).to(accelerator.device, dtype=torch.float64)
|
| 136 |
+
img_ids_ref = prepare_pos_ids(modality_id=2,
|
| 137 |
+
type='image',
|
| 138 |
+
start=(prompt_embeds.shape[1], prompt_embeds.shape[1]),
|
| 139 |
+
height=ref_latents.shape[2]//2,
|
| 140 |
+
width=ref_latents.shape[3]//2).to(accelerator.device, dtype=torch.float64)
|
| 141 |
+
|
| 142 |
+
timesteps = (
|
| 143 |
+
torch.tensor(timesteps)
|
| 144 |
+
.expand(noisy_latents.shape[0])
|
| 145 |
+
.to(device=accelerator.device)
|
| 146 |
+
/ 1000
|
| 147 |
+
)
|
| 148 |
+
text_ids = prepare_pos_ids(modality_id=0,
|
| 149 |
+
type='text',
|
| 150 |
+
start=(0, 0),
|
| 151 |
+
num_token=prompt_embeds.shape[1]).to(accelerator.device, torch.float64)
|
| 152 |
+
|
| 153 |
+
img_ids = torch.cat([img_ids, img_ids_ref], dim=0)
|
| 154 |
+
latent_model_input = torch.cat([packed_noisy_latents, packed_ref_latents], dim=1)
|
| 155 |
+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
| 156 |
+
model_pred = transformer(latent_model_input, prompt_embeds, timesteps,
|
| 157 |
+
img_ids, text_ids, guidance, return_dict=False)[0]
|
| 158 |
+
model_pred = model_pred[:, :packed_noisy_latents.size(1)]
|
| 159 |
+
|
| 160 |
+
model_pred = unpack_latents(
|
| 161 |
+
model_pred,
|
| 162 |
+
height=latents.shape[2] * 8,
|
| 163 |
+
width=latents.shape[3] * 8,
|
| 164 |
+
vae_scale_factor=16,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
target = noise - latents
|
| 168 |
+
loss = torch.mean(
|
| 169 |
+
((model_pred.float() - target.float()) ** 2).reshape(
|
| 170 |
+
target.shape[0], -1
|
| 171 |
+
),
|
| 172 |
+
1,
|
| 173 |
+
).mean()
|
| 174 |
+
|
| 175 |
+
accelerator.backward(loss)
|
| 176 |
+
|
| 177 |
+
if accelerator.sync_gradients:
|
| 178 |
+
grad_norm = transformer.get_global_grad_norm()
|
| 179 |
+
|
| 180 |
+
optimizer.step()
|
| 181 |
+
if not accelerator.optimizer_step_was_skipped:
|
| 182 |
+
lr_scheduler.step()
|
| 183 |
+
|
| 184 |
+
if accelerator.sync_gradients and args.use_ema:
|
| 185 |
+
model_ema.step(transformer.parameters())
|
| 186 |
+
|
| 187 |
+
lr = lr_scheduler.get_last_lr()[0]
|
| 188 |
+
|
| 189 |
+
if accelerator.sync_gradients:
|
| 190 |
+
bsz, ic, ih, iw = image.shape
|
| 191 |
+
logs = {"loss": accelerator.gather(loss).mean().item(), 'aspect_ratio': (ih*1.0 / iw)}
|
| 192 |
+
if grad_norm is not None:
|
| 193 |
+
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item())
|
| 194 |
+
|
| 195 |
+
log_buffer.update(logs)
|
| 196 |
+
if (step + 1) % args.log_interval == 0 or (step + 1) == 1:
|
| 197 |
+
t = (time.time() - last_tic) / args.log_interval
|
| 198 |
+
t_d = data_time_all / args.log_interval
|
| 199 |
+
|
| 200 |
+
log_buffer.average()
|
| 201 |
+
info = f"Step={step+1}, Epoch={epoch}, global_step={global_step}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:(ch:{latents.shape[1]},h:{latents.shape[2]},w:{latents.shape[3]}), "
|
| 202 |
+
info += ', '.join([f"{k}:{v:.4f}" for k,v in log_buffer.output.items()])
|
| 203 |
+
logger.info(info)
|
| 204 |
+
last_tic = time.time()
|
| 205 |
+
log_buffer.clear()
|
| 206 |
+
data_time_all = 0
|
| 207 |
+
logs.update(lr=lr)
|
| 208 |
+
accelerator.log(logs, step=global_step)
|
| 209 |
+
global_step += 1
|
| 210 |
+
data_time_start = time.time()
|
| 211 |
+
|
| 212 |
+
if global_step != 0 and global_step % args.save_model_steps == 0:
|
| 213 |
+
save_path = os.path.join(args.work_dir, f'checkpoints-{global_step}')
|
| 214 |
+
if args.use_ema:
|
| 215 |
+
model_ema.store(transformer.parameters())
|
| 216 |
+
model_ema.copy_to(transformer.parameters())
|
| 217 |
+
|
| 218 |
+
accelerator.save_state(save_path)
|
| 219 |
+
|
| 220 |
+
if args.use_ema:
|
| 221 |
+
model_ema.restore(transformer.parameters())
|
| 222 |
+
logger.info(f"Saved state to {save_path} (global_step: {global_step})")
|
| 223 |
+
accelerator.wait_for_everyone()
|
| 224 |
+
|
| 225 |
+
if global_step >= args.max_train_steps:
|
| 226 |
+
break
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def parse_args():
|
| 230 |
+
parser = argparse.ArgumentParser(description="Process some integers.")
|
| 231 |
+
parser.add_argument("--config", type=str, default='', help="config")
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--report_to",
|
| 234 |
+
type=str,
|
| 235 |
+
default="tensorboard",
|
| 236 |
+
help=(
|
| 237 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 238 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 239 |
+
),
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--allow_tf32",
|
| 243 |
+
action="store_true",
|
| 244 |
+
help=(
|
| 245 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 246 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 247 |
+
),
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 251 |
+
)
|
| 252 |
+
args = parser.parse_args()
|
| 253 |
+
return args
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == '__main__':
|
| 257 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 258 |
+
|
| 259 |
+
args = parse_args()
|
| 260 |
+
|
| 261 |
+
if args.config != '' and os.path.exists(args.config):
|
| 262 |
+
config = yaml.safe_load(open(args.config, 'r'))
|
| 263 |
+
else:
|
| 264 |
+
config = yaml.safe_load(open(f'{cur_dir}/train_config.yaml', 'r'))
|
| 265 |
+
|
| 266 |
+
args_dict = vars(args)
|
| 267 |
+
args_dict.update(config)
|
| 268 |
+
args = argparse.Namespace(**args_dict)
|
| 269 |
+
|
| 270 |
+
os.umask(0o000)
|
| 271 |
+
os.makedirs(args.work_dir, exist_ok=True)
|
| 272 |
+
|
| 273 |
+
log_dir = args.work_dir + f'/logs'
|
| 274 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 275 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.work_dir, logging_dir=log_dir)
|
| 276 |
+
|
| 277 |
+
with open(f'{log_dir}/train.yaml', 'w') as f:
|
| 278 |
+
yaml.dump(args_dict, f)
|
| 279 |
+
|
| 280 |
+
accelerator = Accelerator(
|
| 281 |
+
mixed_precision=args.mixed_precision,
|
| 282 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 283 |
+
log_with=args.report_to,
|
| 284 |
+
project_config= accelerator_project_config,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Make one log on every process with the configuration for debugging.
|
| 288 |
+
logging.basicConfig(
|
| 289 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 290 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 291 |
+
level=logging.INFO,
|
| 292 |
+
)
|
| 293 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 294 |
+
|
| 295 |
+
if accelerator.is_local_main_process:
|
| 296 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 297 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 298 |
+
else:
|
| 299 |
+
transformers.utils.logging.set_verbosity_error()
|
| 300 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 301 |
+
|
| 302 |
+
if args.seed is not None:
|
| 303 |
+
set_seed(args.seed)
|
| 304 |
+
|
| 305 |
+
weight_dtype = torch.float32
|
| 306 |
+
if accelerator.mixed_precision == "fp16":
|
| 307 |
+
weight_dtype = torch.float16
|
| 308 |
+
elif accelerator.mixed_precision == "bf16":
|
| 309 |
+
weight_dtype = torch.bfloat16
|
| 310 |
+
|
| 311 |
+
logger.info(f'using weight_dtype {weight_dtype}!!!')
|
| 312 |
+
|
| 313 |
+
if args.diffusion_pretrain_weight:
|
| 314 |
+
transformer = LongCatImageTransformer2DModel.from_pretrained(args.diffusion_pretrain_weight, ignore_mismatched_sizes=False)
|
| 315 |
+
logger.info(f'successful load model weight {args.diffusion_pretrain_weight}!!!')
|
| 316 |
+
else:
|
| 317 |
+
transformer = LongCatImageTransformer2DModel.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "transformer"), ignore_mismatched_sizes=False)
|
| 318 |
+
logger.info(f'successful load model weight {args.pretrained_model_name_or_path+"/transformer"}!!!')
|
| 319 |
+
|
| 320 |
+
transformer = transformer.train()
|
| 321 |
+
|
| 322 |
+
total_trainable_params = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
|
| 323 |
+
logger.info(f">>>>>> total_trainable_params: {total_trainable_params}")
|
| 324 |
+
|
| 325 |
+
if args.use_ema:
|
| 326 |
+
model_ema = EMAModel(transformer.parameters(), decay=args.ema_rate)
|
| 327 |
+
else:
|
| 328 |
+
model_ema = None
|
| 329 |
+
|
| 330 |
+
vae_dtype = torch.float32
|
| 331 |
+
vae = AutoencoderKL.from_pretrained(
|
| 332 |
+
args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype).cuda().eval()
|
| 333 |
+
|
| 334 |
+
text_encoder = AutoModel.from_pretrained(
|
| 335 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder" , torch_dtype=weight_dtype, trust_remote_code=True).cuda().eval()
|
| 336 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 337 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer" , torch_dtype=weight_dtype, trust_remote_code=True)
|
| 338 |
+
text_processor = AutoProcessor.from_pretrained(
|
| 339 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer" , torch_dtype=weight_dtype, trust_remote_code=True)
|
| 340 |
+
logger.info("all models loaded successfully")
|
| 341 |
+
|
| 342 |
+
# build models
|
| 343 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 344 |
+
args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 345 |
+
|
| 346 |
+
latent_size = int(args.resolution) // 8
|
| 347 |
+
mu = calculate_shift(
|
| 348 |
+
(latent_size//2)**2,
|
| 349 |
+
noise_scheduler.config.base_image_seq_len,
|
| 350 |
+
noise_scheduler.config.max_image_seq_len,
|
| 351 |
+
noise_scheduler.config.base_shift,
|
| 352 |
+
noise_scheduler.config.max_shift,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
| 356 |
+
def save_model_hook(models, weights, output_dir):
|
| 357 |
+
if accelerator.is_main_process:
|
| 358 |
+
for i, model in enumerate(models):
|
| 359 |
+
model.save_pretrained(os.path.join(output_dir, "transformer"))
|
| 360 |
+
if len(weights) != 0:
|
| 361 |
+
weights.pop()
|
| 362 |
+
|
| 363 |
+
def load_model_hook(models, input_dir):
|
| 364 |
+
while len(models) > 0:
|
| 365 |
+
# pop models so that they are not loaded again
|
| 366 |
+
model = models.pop()
|
| 367 |
+
# load diffusers style into model
|
| 368 |
+
load_model = LongCatImageTransformer2DModel.from_pretrained(
|
| 369 |
+
input_dir, subfolder="transformer")
|
| 370 |
+
model.register_to_config(**load_model.config)
|
| 371 |
+
model.load_state_dict(load_model.state_dict())
|
| 372 |
+
del load_model
|
| 373 |
+
|
| 374 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 375 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 376 |
+
|
| 377 |
+
if args.gradient_checkpointing:
|
| 378 |
+
transformer.enable_gradient_checkpointing()
|
| 379 |
+
|
| 380 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 381 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 382 |
+
if args.allow_tf32:
|
| 383 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 384 |
+
|
| 385 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 386 |
+
if args.use_8bit_adam:
|
| 387 |
+
try:
|
| 388 |
+
import bitsandbytes as bnb
|
| 389 |
+
except ImportError:
|
| 390 |
+
raise ImportError(
|
| 391 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 392 |
+
)
|
| 393 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 394 |
+
else:
|
| 395 |
+
optimizer_class = torch.optim.AdamW
|
| 396 |
+
|
| 397 |
+
params_to_optimize = transformer.parameters()
|
| 398 |
+
optimizer = optimizer_class(
|
| 399 |
+
params_to_optimize,
|
| 400 |
+
lr=args.learning_rate,
|
| 401 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 402 |
+
weight_decay=args.adam_weight_decay,
|
| 403 |
+
eps=args.adam_epsilon,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 407 |
+
if args.use_ema:
|
| 408 |
+
model_ema.to(accelerator.device, dtype=weight_dtype)
|
| 409 |
+
|
| 410 |
+
train_dataloader = build_dataloader(args, args.data_txt_root, tokenizer, text_processor,args.resolution,)
|
| 411 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 412 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 413 |
+
|
| 414 |
+
lr_scheduler = get_scheduler(
|
| 415 |
+
args.lr_scheduler,
|
| 416 |
+
optimizer=optimizer,
|
| 417 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 418 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 419 |
+
num_cycles=args.lr_num_cycles,
|
| 420 |
+
power=args.lr_power,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
global_step = 0
|
| 424 |
+
first_epoch = 0
|
| 425 |
+
# Potentially load in the weights and states from a previous save
|
| 426 |
+
if args.resume_from_checkpoint:
|
| 427 |
+
if args.resume_from_checkpoint != "latest":
|
| 428 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 429 |
+
else:
|
| 430 |
+
# Get the most recent checkpoint
|
| 431 |
+
dirs = os.listdir(args.work_dir)
|
| 432 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 433 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 434 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 435 |
+
if path is None:
|
| 436 |
+
logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
|
| 437 |
+
args.resume_from_checkpoint = None
|
| 438 |
+
initial_global_step = 0
|
| 439 |
+
else:
|
| 440 |
+
logger.info(f"Resuming from checkpoint {path}")
|
| 441 |
+
accelerator.load_state(os.path.join(args.work_dir, path))
|
| 442 |
+
global_step = int(path.split("-")[1])
|
| 443 |
+
|
| 444 |
+
initial_global_step = global_step
|
| 445 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 446 |
+
|
| 447 |
+
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
|
| 448 |
+
if accelerator.is_main_process:
|
| 449 |
+
tracker_config = dict(vars(args))
|
| 450 |
+
try:
|
| 451 |
+
accelerator.init_trackers('sft', tracker_config)
|
| 452 |
+
except Exception as e:
|
| 453 |
+
logger.warning(f'get error in save config, {e}')
|
| 454 |
+
accelerator.init_trackers(f"sft_{timestamp}")
|
| 455 |
+
|
| 456 |
+
transformer, optimizer, _, _ = accelerator.prepare(
|
| 457 |
+
transformer, optimizer, train_dataloader, lr_scheduler)
|
| 458 |
+
|
| 459 |
+
log_buffer = LogBuffer()
|
| 460 |
+
train(global_step=global_step)
|
examples/lora/train.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export TOKENIZERS_PARALLELISM=False
|
| 2 |
+
export NCCL_DEBUG=INFO
|
| 3 |
+
export NCCL_TIMEOUT=12000
|
| 4 |
+
|
| 5 |
+
script_dir=$(cd -- "$(dirname -- "$0")" &> /dev/null && pwd -P)
|
| 6 |
+
project_root=$(dirname "$(dirname "$script_dir")")
|
| 7 |
+
echo "script_dir" ${script_dir}
|
| 8 |
+
|
| 9 |
+
deepspeed_config_file=${project_root}/misc/accelerate_config.yaml
|
| 10 |
+
|
| 11 |
+
accelerate launch --mixed_precision bf16 --num_processes 8 --config_file ${deepspeed_config_file} \
|
| 12 |
+
${script_dir}/train_lora.py \
|
| 13 |
+
--config ${script_dir}/train_config.yaml
|
examples/lora/train_config.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1.Data setting
|
| 2 |
+
|
| 3 |
+
data_txt_root: '/dataset/example/train_data_info.txt' # data csv_filepath
|
| 4 |
+
resolution: 1024
|
| 5 |
+
aspect_ratio_type: 'mar_1024' # data bucketing strategy, mar_256ใmar_512ใmar_1024
|
| 6 |
+
null_text_ratio: 0.1
|
| 7 |
+
dataloader_num_workers: 8
|
| 8 |
+
train_batch_size: 2
|
| 9 |
+
repeats: 100
|
| 10 |
+
|
| 11 |
+
prompt_template_encode_prefix: '<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n'
|
| 12 |
+
prompt_template_encode_suffix: '<|im_end|>\n<|im_start|>assistant\n'
|
| 13 |
+
prompt_template_encode_start_idx: 36
|
| 14 |
+
prompt_template_encode_end_idx: 5
|
| 15 |
+
|
| 16 |
+
# 2. Model setting
|
| 17 |
+
text_tokenizer_max_length: 512 # tokenizer max len
|
| 18 |
+
pretrained_model_name_or_path: "/xxx/weights/Longcat-Image-Dev" # root directory of the model๏ผwith vaeใtransformerใscheduler eta;
|
| 19 |
+
diffusion_pretrain_weight: null # if a specified diffusion weight path is provided, load the model parameters from the current directory.
|
| 20 |
+
use_dynamic_shifting: true # scheduler dynamic shifting
|
| 21 |
+
resume_from_checkpoint: latest
|
| 22 |
+
# - "latest" # Loads most recent step checkpoint
|
| 23 |
+
# - "/path/to/checkpoint" # Resumes from specified directory
|
| 24 |
+
|
| 25 |
+
# 3. Training setting
|
| 26 |
+
lora_rank: 32
|
| 27 |
+
use_ema: False
|
| 28 |
+
ema_rate: 0.999
|
| 29 |
+
mixed_precision: 'bf16'
|
| 30 |
+
max_train_steps: 100000
|
| 31 |
+
gradient_accumulation_steps: 1
|
| 32 |
+
gradient_checkpointing: true
|
| 33 |
+
gradient_clip: 1.0
|
| 34 |
+
learning_rate: 1.0e-4
|
| 35 |
+
adam_weight_decay: 1.0e-3
|
| 36 |
+
adam_epsilon: 1.0e-8
|
| 37 |
+
adam_beta1: 0.9
|
| 38 |
+
adam_beta2: 0.999
|
| 39 |
+
lr_num_cycles: 1
|
| 40 |
+
lr_power: 1.0
|
| 41 |
+
lr_scheduler: 'constant'
|
| 42 |
+
lr_warmup_steps: 1000
|
| 43 |
+
|
| 44 |
+
#4. Log setting
|
| 45 |
+
log_interval: 10
|
| 46 |
+
save_model_steps: 200
|
| 47 |
+
work_dir: 'output/lora_model'
|
| 48 |
+
seed: 43
|
examples/lora/train_dataset.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import traceback
|
| 5 |
+
import math
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 13 |
+
from transformers import AutoTokenizer
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
from longcat_image.dataset import MULTI_RESOLUTION_MAP
|
| 18 |
+
from longcat_image.utils import encode_prompt
|
| 19 |
+
from longcat_image.dataset import MultiResolutionDistributedSampler
|
| 20 |
+
|
| 21 |
+
Image.MAX_IMAGE_PIXELS = 2000000000
|
| 22 |
+
|
| 23 |
+
MAX_RETRY_NUMS = 100
|
| 24 |
+
|
| 25 |
+
class Text2ImageLoraDataSet(torch.utils.data.Dataset):
|
| 26 |
+
def __init__(self,
|
| 27 |
+
cfg: dict,
|
| 28 |
+
txt_root: str,
|
| 29 |
+
tokenizer: AutoTokenizer,
|
| 30 |
+
resolution: tuple = (1024, 1024),
|
| 31 |
+
repeats: int = 1 ):
|
| 32 |
+
super(Text2ImageLoraDataSet, self).__init__()
|
| 33 |
+
self.resolution = resolution
|
| 34 |
+
self.text_tokenizer_max_length = cfg.text_tokenizer_max_length
|
| 35 |
+
self.null_text_ratio = cfg.null_text_ratio
|
| 36 |
+
self.aspect_ratio_type = cfg.aspect_ratio_type
|
| 37 |
+
self.aspect_ratio = MULTI_RESOLUTION_MAP[self.aspect_ratio_type]
|
| 38 |
+
self.tokenizer = tokenizer
|
| 39 |
+
|
| 40 |
+
self.prompt_template_encode_prefix = cfg.prompt_template_encode_prefix
|
| 41 |
+
self.prompt_template_encode_suffix = cfg.prompt_template_encode_suffix
|
| 42 |
+
self.prompt_template_encode_start_idx = cfg.prompt_template_encode_start_idx
|
| 43 |
+
self.prompt_template_encode_end_idx = cfg.prompt_template_encode_end_idx
|
| 44 |
+
|
| 45 |
+
self.total_datas = []
|
| 46 |
+
self.data_resolution_infos = []
|
| 47 |
+
with open(txt_root, 'r') as f:
|
| 48 |
+
lines = f.readlines()
|
| 49 |
+
lines *= cfg.repeats
|
| 50 |
+
for line in tqdm(lines):
|
| 51 |
+
data = json.loads(line.strip())
|
| 52 |
+
try:
|
| 53 |
+
height, widht = int(data['height']), int(data['width'])
|
| 54 |
+
self.data_resolution_infos.append((height, widht))
|
| 55 |
+
self.total_datas.append(data)
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f'get error {e}, data {data}.')
|
| 58 |
+
continue
|
| 59 |
+
self.data_nums = len(self.total_datas)
|
| 60 |
+
print(f'get sampler {len(self.total_datas)}, from {txt_root}!!!')
|
| 61 |
+
|
| 62 |
+
def transform_img(self, image, original_size, target_size):
|
| 63 |
+
img_h, img_w = original_size
|
| 64 |
+
target_height, target_width = target_size
|
| 65 |
+
|
| 66 |
+
original_aspect = img_h / img_w # height/width
|
| 67 |
+
crop_aspect = target_height / target_width
|
| 68 |
+
|
| 69 |
+
if original_aspect >= crop_aspect:
|
| 70 |
+
resize_width = target_width
|
| 71 |
+
resize_height = math.ceil(img_h * (target_width/img_w))
|
| 72 |
+
else:
|
| 73 |
+
resize_width = math.ceil(img_w * (target_height/img_h))
|
| 74 |
+
resize_height = target_height
|
| 75 |
+
|
| 76 |
+
image = T.Compose([
|
| 77 |
+
T.Resize((resize_height, resize_width),interpolation=InterpolationMode.BICUBIC), # Image.LANCZOS
|
| 78 |
+
T.CenterCrop((target_height, target_width)),
|
| 79 |
+
T.ToTensor(),
|
| 80 |
+
T.Normalize([.5], [.5]),
|
| 81 |
+
])(image)
|
| 82 |
+
|
| 83 |
+
return image
|
| 84 |
+
|
| 85 |
+
def __getitem__(self, index_tuple):
|
| 86 |
+
index, target_size = index_tuple
|
| 87 |
+
|
| 88 |
+
for _ in range(MAX_RETRY_NUMS):
|
| 89 |
+
try:
|
| 90 |
+
item = self.total_datas[index]
|
| 91 |
+
img_path = item["img_path"]
|
| 92 |
+
prompt = item['prompt']
|
| 93 |
+
|
| 94 |
+
if random.random() < self.null_text_ratio:
|
| 95 |
+
prompt = ''
|
| 96 |
+
|
| 97 |
+
raw_image = Image.open(img_path).convert('RGB')
|
| 98 |
+
assert raw_image is not None
|
| 99 |
+
img_w, img_h = raw_image.size
|
| 100 |
+
|
| 101 |
+
raw_image = self.transform_img(raw_image, original_size=(img_h, img_w), target_size= target_size )
|
| 102 |
+
input_ids,attention_mask = encode_prompt(prompt, self.tokenizer, self.text_tokenizer_max_length, self.prompt_template_encode_prefix, self.prompt_template_encode_suffix )
|
| 103 |
+
return {"image": raw_image, "prompt": prompt, 'input_ids': input_ids, 'attention_mask': attention_mask}
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
traceback.print_exc()
|
| 107 |
+
print(f"failed read data {e}!!!")
|
| 108 |
+
index = random.randint(0, self.data_nums-1)
|
| 109 |
+
|
| 110 |
+
def __len__(self):
|
| 111 |
+
return self.data_nums
|
| 112 |
+
|
| 113 |
+
def collate_fn(self, batchs):
|
| 114 |
+
images = torch.stack([example["image"] for example in batchs])
|
| 115 |
+
input_ids = torch.stack([example["input_ids"] for example in batchs])
|
| 116 |
+
attention_mask = torch.stack([example["attention_mask"] for example in batchs])
|
| 117 |
+
prompts = [example['prompt'] for example in batchs]
|
| 118 |
+
batch_dict = {
|
| 119 |
+
"images": images,
|
| 120 |
+
"input_ids": input_ids,
|
| 121 |
+
"attention_mask": attention_mask,
|
| 122 |
+
"prompts": prompts,
|
| 123 |
+
}
|
| 124 |
+
return batch_dict
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def build_dataloader(cfg: dict,
|
| 128 |
+
csv_root: str,
|
| 129 |
+
tokenizer: AutoTokenizer,
|
| 130 |
+
resolution: tuple = (1024, 1024)):
|
| 131 |
+
dataset = Text2ImageLoraDataSet(cfg, csv_root, tokenizer, resolution)
|
| 132 |
+
|
| 133 |
+
sampler = MultiResolutionDistributedSampler(batch_size=cfg.train_batch_size, dataset=dataset,
|
| 134 |
+
data_resolution_infos=dataset.data_resolution_infos,
|
| 135 |
+
bucket_info=dataset.aspect_ratio,
|
| 136 |
+
epoch=0,
|
| 137 |
+
num_replicas=None,
|
| 138 |
+
rank=None
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
train_loader = torch.utils.data.DataLoader(
|
| 142 |
+
dataset,
|
| 143 |
+
collate_fn=dataset.collate_fn,
|
| 144 |
+
batch_size=cfg.train_batch_size,
|
| 145 |
+
num_workers=cfg.dataloader_num_workers,
|
| 146 |
+
sampler=sampler,
|
| 147 |
+
shuffle=None,
|
| 148 |
+
)
|
| 149 |
+
return train_loader
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if __name__ == '__main__':
|
| 153 |
+
import sys
|
| 154 |
+
import argparse
|
| 155 |
+
from torchvision.transforms.functional import to_pil_image
|
| 156 |
+
|
| 157 |
+
txt_root = 'xxx'
|
| 158 |
+
cfg = argparse.Namespace(
|
| 159 |
+
txt_root=txt_root,
|
| 160 |
+
text_tokenizer_max_length=512,
|
| 161 |
+
resolution=1024,
|
| 162 |
+
text_encoder_path="xxx",
|
| 163 |
+
center_crop=True,
|
| 164 |
+
dataloader_num_workers=0,
|
| 165 |
+
null_text_ratio=0.1,
|
| 166 |
+
train_batch_size=16,
|
| 167 |
+
seed=0,
|
| 168 |
+
aspect_ratio_type='mar_1024',
|
| 169 |
+
revision=None)
|
| 170 |
+
|
| 171 |
+
from transformers import AutoTokenizer
|
| 172 |
+
tokenizer = AutoTokenizer.from_pretrained(cfg.text_encoder_path, trust_remote_code=True)
|
| 173 |
+
data_loader = build_dataloader(cfg, cfg.csv_root, tokenizer, cfg.resolution)
|
| 174 |
+
|
| 175 |
+
_oroot = f'./debug_data_example_show'
|
| 176 |
+
os.makedirs(_oroot, exist_ok=True)
|
| 177 |
+
|
| 178 |
+
cnt = 0
|
| 179 |
+
for epoch in range(1):
|
| 180 |
+
print(f"Start, epoch {epoch}!!!")
|
| 181 |
+
for i_batch, batch in enumerate(data_loader):
|
| 182 |
+
print(batch['attention_mask'].shape)
|
| 183 |
+
print(batch['images'].shape)
|
| 184 |
+
|
| 185 |
+
batch_prompts = batch['prompts']
|
| 186 |
+
for idx, per_img in enumerate(batch['images']):
|
| 187 |
+
re_transforms = T.Compose([
|
| 188 |
+
T.Normalize(mean=[-0.5/0.5], std=[1.0/0.5])
|
| 189 |
+
])
|
| 190 |
+
prompt = batch_prompts[idx]
|
| 191 |
+
img = to_pil_image(re_transforms(per_img))
|
| 192 |
+
prompt = prompt[:min(30, len(prompt))]
|
| 193 |
+
oname = _oroot + f'/{str(i_batch)}_{str(idx)}_{prompt}.png'
|
| 194 |
+
img.save(oname)
|
| 195 |
+
if cnt > 100:
|
| 196 |
+
break
|
| 197 |
+
cnt += 1
|
examples/lora/train_lora.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import yaml
|
| 8 |
+
import torch
|
| 9 |
+
import math
|
| 10 |
+
import logging
|
| 11 |
+
import transformers
|
| 12 |
+
import diffusers
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from transformers import Qwen2Model, Qwen2TokenizerFast
|
| 15 |
+
from accelerate import Accelerator, InitProcessGroupKwargs
|
| 16 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 17 |
+
from accelerate.logging import get_logger
|
| 18 |
+
from diffusers.models import AutoencoderKL
|
| 19 |
+
from diffusers.optimization import get_scheduler
|
| 20 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 21 |
+
from diffusers.training_utils import EMAModel
|
| 22 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 23 |
+
from transformers import AutoTokenizer, AutoModel
|
| 24 |
+
|
| 25 |
+
from train_dataset import build_dataloader
|
| 26 |
+
from longcat_image.models import LongCatImageTransformer2DModel
|
| 27 |
+
from longcat_image.utils import LogBuffer
|
| 28 |
+
from longcat_image.utils import pack_latents, unpack_latents, calculate_shift, prepare_pos_ids
|
| 29 |
+
from peft import LoraConfig, set_peft_model_state_dict
|
| 30 |
+
from peft import LoraConfig, get_peft_model
|
| 31 |
+
from peft.utils import get_peft_model_state_dict
|
| 32 |
+
|
| 33 |
+
warnings.filterwarnings("ignore") # ignore warning
|
| 34 |
+
|
| 35 |
+
current_file_path = Path(__file__).resolve()
|
| 36 |
+
sys.path.insert(0, str(current_file_path.parent.parent))
|
| 37 |
+
|
| 38 |
+
logger = get_logger(__name__)
|
| 39 |
+
|
| 40 |
+
def train(global_step=0):
|
| 41 |
+
|
| 42 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 43 |
+
|
| 44 |
+
# Train!
|
| 45 |
+
total_batch_size = args.train_batch_size * \
|
| 46 |
+
accelerator.num_processes * args.gradient_accumulation_steps
|
| 47 |
+
|
| 48 |
+
logger.info("***** Running training *****")
|
| 49 |
+
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
|
| 50 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 51 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 52 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 53 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 54 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 55 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 56 |
+
|
| 57 |
+
last_tic = time.time()
|
| 58 |
+
|
| 59 |
+
# Now you train the model
|
| 60 |
+
for epoch in range(first_epoch, args.num_train_epochs + 1):
|
| 61 |
+
data_time_start = time.time()
|
| 62 |
+
data_time_all = 0
|
| 63 |
+
|
| 64 |
+
for step, batch in enumerate(train_dataloader):
|
| 65 |
+
image = batch['images']
|
| 66 |
+
|
| 67 |
+
data_time_all += time.time() - data_time_start
|
| 68 |
+
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
latents = vae.encode(image.to(weight_dtype).to(accelerator.device)).latent_dist.sample()
|
| 71 |
+
latents = latents.to(dtype=(weight_dtype))
|
| 72 |
+
latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
|
| 73 |
+
|
| 74 |
+
text_input_ids = batch['input_ids'].to(accelerator.device)
|
| 75 |
+
text_attention_mask = batch['attention_mask'].to(accelerator.device)
|
| 76 |
+
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
text_output = text_encoder(
|
| 79 |
+
input_ids=text_input_ids,
|
| 80 |
+
attention_mask=text_attention_mask,
|
| 81 |
+
output_hidden_states=True
|
| 82 |
+
)
|
| 83 |
+
prompt_embeds = text_output.hidden_states[-1].clone().detach()
|
| 84 |
+
|
| 85 |
+
prompt_embeds = prompt_embeds.to(weight_dtype)
|
| 86 |
+
prompt_embeds = prompt_embeds[:,args.prompt_template_encode_start_idx: -args.prompt_template_encode_end_idx ,:]
|
| 87 |
+
|
| 88 |
+
# Sample a random timestep for each image
|
| 89 |
+
grad_norm = None
|
| 90 |
+
with accelerator.accumulate(transformer):
|
| 91 |
+
# Predict the noise residual
|
| 92 |
+
optimizer.zero_grad()
|
| 93 |
+
# logit-normal
|
| 94 |
+
sigmas = torch.sigmoid(torch.randn((latents.shape[0],), device=accelerator.device, dtype=latents.dtype))
|
| 95 |
+
|
| 96 |
+
if args.use_dynamic_shifting:
|
| 97 |
+
sigmas = noise_scheduler.time_shift(mu, 1.0, sigmas)
|
| 98 |
+
|
| 99 |
+
timesteps = sigmas * 1000.0
|
| 100 |
+
sigmas = sigmas.view(-1, 1, 1, 1)
|
| 101 |
+
|
| 102 |
+
noise = torch.randn_like(latents)
|
| 103 |
+
|
| 104 |
+
noisy_latents = (1 - sigmas) * latents + sigmas * noise
|
| 105 |
+
noisy_latents = noisy_latents.to(weight_dtype)
|
| 106 |
+
|
| 107 |
+
packed_noisy_latents = pack_latents(
|
| 108 |
+
noisy_latents,
|
| 109 |
+
batch_size=latents.shape[0],
|
| 110 |
+
num_channels_latents=latents.shape[1],
|
| 111 |
+
height=latents.shape[2],
|
| 112 |
+
width=latents.shape[3],
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
guidance = None
|
| 116 |
+
img_ids = prepare_pos_ids(modality_id=1,
|
| 117 |
+
type='image',
|
| 118 |
+
start=(prompt_embeds.shape[1], prompt_embeds.shape[1]),
|
| 119 |
+
height=latents.shape[2]//2,
|
| 120 |
+
width=latents.shape[3]//2).to(accelerator.device, dtype=torch.float64)
|
| 121 |
+
|
| 122 |
+
timesteps = (
|
| 123 |
+
torch.tensor(timesteps)
|
| 124 |
+
.expand(noisy_latents.shape[0])
|
| 125 |
+
.to(device=accelerator.device)
|
| 126 |
+
/ 1000
|
| 127 |
+
)
|
| 128 |
+
text_ids = prepare_pos_ids(modality_id=0,
|
| 129 |
+
type='text',
|
| 130 |
+
start=(0, 0),
|
| 131 |
+
num_token=prompt_embeds.shape[1]).to(accelerator.device, torch.float64)
|
| 132 |
+
|
| 133 |
+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
| 134 |
+
model_pred = transformer(packed_noisy_latents, prompt_embeds, timesteps,
|
| 135 |
+
img_ids, text_ids, guidance, return_dict=False)[0]
|
| 136 |
+
|
| 137 |
+
model_pred = unpack_latents(
|
| 138 |
+
model_pred,
|
| 139 |
+
height=latents.shape[2] * 8,
|
| 140 |
+
width=latents.shape[3] * 8,
|
| 141 |
+
vae_scale_factor=16,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
target = noise - latents
|
| 145 |
+
loss = torch.mean(
|
| 146 |
+
((model_pred.float() - target.float()) ** 2).reshape(
|
| 147 |
+
target.shape[0], -1
|
| 148 |
+
),
|
| 149 |
+
1,
|
| 150 |
+
).mean()
|
| 151 |
+
|
| 152 |
+
accelerator.backward(loss)
|
| 153 |
+
|
| 154 |
+
if accelerator.sync_gradients:
|
| 155 |
+
grad_norm = transformer.get_global_grad_norm()
|
| 156 |
+
|
| 157 |
+
optimizer.step()
|
| 158 |
+
if not accelerator.optimizer_step_was_skipped:
|
| 159 |
+
lr_scheduler.step()
|
| 160 |
+
|
| 161 |
+
if accelerator.sync_gradients and args.use_ema:
|
| 162 |
+
model_ema.step(transformer.parameters())
|
| 163 |
+
|
| 164 |
+
lr = lr_scheduler.get_last_lr()[0]
|
| 165 |
+
|
| 166 |
+
if accelerator.sync_gradients:
|
| 167 |
+
bsz, ic, ih, iw = image.shape
|
| 168 |
+
logs = {"loss": accelerator.gather(loss).mean().item(), 'aspect_ratio': (ih*1.0 / iw)}
|
| 169 |
+
if grad_norm is not None:
|
| 170 |
+
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item())
|
| 171 |
+
|
| 172 |
+
log_buffer.update(logs)
|
| 173 |
+
if (step + 1) % args.log_interval == 0 or (step + 1) == 1:
|
| 174 |
+
t = (time.time() - last_tic) / args.log_interval
|
| 175 |
+
t_d = data_time_all / args.log_interval
|
| 176 |
+
|
| 177 |
+
log_buffer.average()
|
| 178 |
+
info = f"Step={step+1}, Epoch={epoch}, global_step={global_step}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:(ch:{latents.shape[1]},h:{latents.shape[2]},w:{latents.shape[3]}), "
|
| 179 |
+
info += ', '.join([f"{k}:{v:.4f}" for k,v in log_buffer.output.items()])
|
| 180 |
+
logger.info(info)
|
| 181 |
+
last_tic = time.time()
|
| 182 |
+
log_buffer.clear()
|
| 183 |
+
data_time_all = 0
|
| 184 |
+
logs.update(lr=lr)
|
| 185 |
+
accelerator.log(logs, step=global_step)
|
| 186 |
+
global_step += 1
|
| 187 |
+
data_time_start = time.time()
|
| 188 |
+
|
| 189 |
+
if global_step != 0 and global_step % args.save_model_steps == 0:
|
| 190 |
+
accelerator.wait_for_everyone()
|
| 191 |
+
os.umask(0o000)
|
| 192 |
+
cur_lora_ckpt_save_dir = f"{args.work_dir}/checkpoints-{global_step}"
|
| 193 |
+
os.makedirs(cur_lora_ckpt_save_dir, exist_ok=True)
|
| 194 |
+
if accelerator.is_main_process:
|
| 195 |
+
if hasattr(transformer, 'module'):
|
| 196 |
+
transformer.module.save_pretrained(cur_lora_ckpt_save_dir)
|
| 197 |
+
else:
|
| 198 |
+
transformer.save_pretrained(cur_lora_ckpt_save_dir)
|
| 199 |
+
logger.info(f'Saved lora checkpoint of epoch {epoch} to {cur_lora_ckpt_save_dir}.')
|
| 200 |
+
accelerator.save_state(cur_lora_ckpt_save_dir)
|
| 201 |
+
accelerator.wait_for_everyone()
|
| 202 |
+
|
| 203 |
+
if global_step >= args.max_train_steps:
|
| 204 |
+
break
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def parse_args():
|
| 208 |
+
parser = argparse.ArgumentParser(description="Process some integers.")
|
| 209 |
+
parser.add_argument("--config", type=str, default='', help="config")
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--report_to",
|
| 212 |
+
type=str,
|
| 213 |
+
default="tensorboard",
|
| 214 |
+
help=(
|
| 215 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 216 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 217 |
+
),
|
| 218 |
+
)
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
"--allow_tf32",
|
| 221 |
+
action="store_true",
|
| 222 |
+
help=(
|
| 223 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 224 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 225 |
+
),
|
| 226 |
+
)
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 229 |
+
)
|
| 230 |
+
args = parser.parse_args()
|
| 231 |
+
return args
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
if __name__ == '__main__':
|
| 235 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 236 |
+
|
| 237 |
+
args = parse_args()
|
| 238 |
+
|
| 239 |
+
if args.config != '' and os.path.exists(args.config):
|
| 240 |
+
config = yaml.safe_load(open(args.config, 'r'))
|
| 241 |
+
else:
|
| 242 |
+
config = yaml.safe_load(open(f'{cur_dir}/train_config.yaml', 'r'))
|
| 243 |
+
|
| 244 |
+
args_dict = vars(args)
|
| 245 |
+
args_dict.update(config)
|
| 246 |
+
args = argparse.Namespace(**args_dict)
|
| 247 |
+
|
| 248 |
+
os.umask(0o000)
|
| 249 |
+
os.makedirs(args.work_dir, exist_ok=True)
|
| 250 |
+
|
| 251 |
+
log_dir = args.work_dir + f'/logs'
|
| 252 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 253 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.work_dir, logging_dir=log_dir)
|
| 254 |
+
|
| 255 |
+
with open(f'{log_dir}/train.yaml', 'w') as f:
|
| 256 |
+
yaml.dump(args_dict, f)
|
| 257 |
+
|
| 258 |
+
accelerator = Accelerator(
|
| 259 |
+
mixed_precision=args.mixed_precision,
|
| 260 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 261 |
+
log_with=args.report_to,
|
| 262 |
+
project_config= accelerator_project_config,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Make one log on every process with the configuration for debugging.
|
| 266 |
+
logging.basicConfig(
|
| 267 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 268 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 269 |
+
level=logging.INFO,
|
| 270 |
+
)
|
| 271 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 272 |
+
|
| 273 |
+
if accelerator.is_local_main_process:
|
| 274 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 275 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 276 |
+
else:
|
| 277 |
+
transformers.utils.logging.set_verbosity_error()
|
| 278 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 279 |
+
|
| 280 |
+
if args.seed is not None:
|
| 281 |
+
set_seed(args.seed)
|
| 282 |
+
|
| 283 |
+
weight_dtype = torch.float32
|
| 284 |
+
if accelerator.mixed_precision == "fp16":
|
| 285 |
+
weight_dtype = torch.float16
|
| 286 |
+
elif accelerator.mixed_precision == "bf16":
|
| 287 |
+
weight_dtype = torch.bfloat16
|
| 288 |
+
|
| 289 |
+
logger.info(f'using weight_dtype {weight_dtype}!!!')
|
| 290 |
+
|
| 291 |
+
if args.diffusion_pretrain_weight:
|
| 292 |
+
transformer = LongCatImageTransformer2DModel.from_pretrained(args.diffusion_pretrain_weight, ignore_mismatched_sizes=False)
|
| 293 |
+
logger.info(f'successful load model weight {args.diffusion_pretrain_weight}!!!')
|
| 294 |
+
else:
|
| 295 |
+
transformer = LongCatImageTransformer2DModel.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "transformer"), ignore_mismatched_sizes=False)
|
| 296 |
+
logger.info(f'successful load model weight {args.pretrained_model_name_or_path+"/transformer"}!!!')
|
| 297 |
+
|
| 298 |
+
# transformer = transformer.train()
|
| 299 |
+
|
| 300 |
+
target_modules = [
|
| 301 |
+
"attn.to_k",
|
| 302 |
+
"attn.to_q",
|
| 303 |
+
"attn.to_v",
|
| 304 |
+
"attn.to_out.0",
|
| 305 |
+
"attn.add_k_proj",
|
| 306 |
+
"attn.add_q_proj",
|
| 307 |
+
"attn.add_v_proj",
|
| 308 |
+
"attn.to_add_out",
|
| 309 |
+
"ff.net.0.proj",
|
| 310 |
+
"ff.net.2",
|
| 311 |
+
"ff_context.net.0.proj",
|
| 312 |
+
"ff_context.net.2",
|
| 313 |
+
]
|
| 314 |
+
|
| 315 |
+
lora_config = LoraConfig(
|
| 316 |
+
r=args.lora_rank,
|
| 317 |
+
init_lora_weights="gaussian",
|
| 318 |
+
target_modules= target_modules ,
|
| 319 |
+
use_dora=False,
|
| 320 |
+
use_rslora=False
|
| 321 |
+
)
|
| 322 |
+
transformer = get_peft_model(transformer, lora_config)
|
| 323 |
+
transformer.print_trainable_parameters()
|
| 324 |
+
|
| 325 |
+
total_trainable_params = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
|
| 326 |
+
logger.info(f">>>>>> total_trainable_params: {total_trainable_params}")
|
| 327 |
+
|
| 328 |
+
if args.use_ema:
|
| 329 |
+
model_ema = EMAModel(transformer.parameters(), decay=args.ema_rate)
|
| 330 |
+
else:
|
| 331 |
+
model_ema = None
|
| 332 |
+
|
| 333 |
+
vae_dtype = torch.float32
|
| 334 |
+
vae = AutoencoderKL.from_pretrained(
|
| 335 |
+
args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype).cuda().eval()
|
| 336 |
+
|
| 337 |
+
text_encoder = AutoModel.from_pretrained(
|
| 338 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder" , torch_dtype=weight_dtype, trust_remote_code=True).cuda().eval()
|
| 339 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 340 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer" , torch_dtype=weight_dtype, trust_remote_code=True)
|
| 341 |
+
logger.info("all models loaded successfully")
|
| 342 |
+
|
| 343 |
+
# build models
|
| 344 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 345 |
+
args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 346 |
+
|
| 347 |
+
latent_size = int(args.resolution) // 8
|
| 348 |
+
mu = calculate_shift(
|
| 349 |
+
(latent_size//2)**2,
|
| 350 |
+
noise_scheduler.config.base_image_seq_len,
|
| 351 |
+
noise_scheduler.config.max_image_seq_len,
|
| 352 |
+
noise_scheduler.config.base_shift,
|
| 353 |
+
noise_scheduler.config.max_shift,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
| 357 |
+
def save_model_hook(models, weights, output_dir):
|
| 358 |
+
if accelerator.is_main_process:
|
| 359 |
+
for i, model in enumerate(models):
|
| 360 |
+
model.save_pretrained(os.path.join(output_dir, "transformer"))
|
| 361 |
+
if len(weights) != 0:
|
| 362 |
+
weights.pop()
|
| 363 |
+
|
| 364 |
+
def load_model_hook(models, input_dir):
|
| 365 |
+
while len(models) > 0:
|
| 366 |
+
# pop models so that they are not loaded again
|
| 367 |
+
model = models.pop()
|
| 368 |
+
# load diffusers style into model
|
| 369 |
+
load_model = LongCatImageTransformer2DModel.from_pretrained(
|
| 370 |
+
input_dir, subfolder="transformer")
|
| 371 |
+
model.register_to_config(**load_model.config)
|
| 372 |
+
model.load_state_dict(load_model.state_dict())
|
| 373 |
+
del load_model
|
| 374 |
+
|
| 375 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 376 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 377 |
+
|
| 378 |
+
if args.gradient_checkpointing:
|
| 379 |
+
transformer.enable_gradient_checkpointing()
|
| 380 |
+
|
| 381 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 382 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 383 |
+
if args.allow_tf32:
|
| 384 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 385 |
+
|
| 386 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 387 |
+
if args.use_8bit_adam:
|
| 388 |
+
try:
|
| 389 |
+
import bitsandbytes as bnb
|
| 390 |
+
except ImportError:
|
| 391 |
+
raise ImportError(
|
| 392 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 393 |
+
)
|
| 394 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 395 |
+
else:
|
| 396 |
+
optimizer_class = torch.optim.AdamW
|
| 397 |
+
|
| 398 |
+
params_to_optimize = transformer.parameters()
|
| 399 |
+
optimizer = optimizer_class(
|
| 400 |
+
params_to_optimize,
|
| 401 |
+
lr=args.learning_rate,
|
| 402 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 403 |
+
weight_decay=args.adam_weight_decay,
|
| 404 |
+
eps=args.adam_epsilon,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 408 |
+
if args.use_ema:
|
| 409 |
+
model_ema.to(accelerator.device, dtype=weight_dtype)
|
| 410 |
+
|
| 411 |
+
train_dataloader = build_dataloader(args, args.data_txt_root, tokenizer, args.resolution,)
|
| 412 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 413 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 414 |
+
|
| 415 |
+
lr_scheduler = get_scheduler(
|
| 416 |
+
args.lr_scheduler,
|
| 417 |
+
optimizer=optimizer,
|
| 418 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 419 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 420 |
+
num_cycles=args.lr_num_cycles,
|
| 421 |
+
power=args.lr_power,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
global_step = 0
|
| 425 |
+
first_epoch = 0
|
| 426 |
+
# Potentially load in the weights and states from a previous save
|
| 427 |
+
if args.resume_from_checkpoint:
|
| 428 |
+
if args.resume_from_checkpoint != "latest":
|
| 429 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 430 |
+
else:
|
| 431 |
+
# Get the most recent checkpoint
|
| 432 |
+
dirs = os.listdir(args.work_dir)
|
| 433 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 434 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 435 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 436 |
+
if path is None:
|
| 437 |
+
logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
|
| 438 |
+
args.resume_from_checkpoint = None
|
| 439 |
+
initial_global_step = 0
|
| 440 |
+
else:
|
| 441 |
+
logger.info(f"Resuming from checkpoint {path}")
|
| 442 |
+
accelerator.load_state(os.path.join(args.work_dir, path))
|
| 443 |
+
global_step = int(path.split("-")[1])
|
| 444 |
+
|
| 445 |
+
initial_global_step = global_step
|
| 446 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 447 |
+
|
| 448 |
+
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
|
| 449 |
+
if accelerator.is_main_process:
|
| 450 |
+
tracker_config = dict(vars(args))
|
| 451 |
+
try:
|
| 452 |
+
accelerator.init_trackers('sft', tracker_config)
|
| 453 |
+
except Exception as e:
|
| 454 |
+
logger.warning(f'get error in save config, {e}')
|
| 455 |
+
accelerator.init_trackers(f"sft_{timestamp}")
|
| 456 |
+
|
| 457 |
+
transformer, optimizer, _, _ = accelerator.prepare(
|
| 458 |
+
transformer, optimizer, train_dataloader, lr_scheduler)
|
| 459 |
+
|
| 460 |
+
log_buffer = LogBuffer()
|
| 461 |
+
train(global_step=global_step)
|
examples/sft/train.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export TOKENIZERS_PARALLELISM=False
|
| 2 |
+
export NCCL_DEBUG=INFO
|
| 3 |
+
export NCCL_TIMEOUT=12000
|
| 4 |
+
|
| 5 |
+
script_dir=$(cd -- "$(dirname -- "$0")" &> /dev/null && pwd -P)
|
| 6 |
+
project_root=$(dirname "$(dirname "$script_dir")")
|
| 7 |
+
echo "script_dir" ${script_dir}
|
| 8 |
+
|
| 9 |
+
deepspeed_config_file=${project_root}/misc/accelerate_config.yaml
|
| 10 |
+
|
| 11 |
+
accelerate launch --mixed_precision bf16 --num_processes 8 --config_file ${deepspeed_config_file} \
|
| 12 |
+
${script_dir}/train_sft.py \
|
| 13 |
+
--config ${script_dir}/train_config.yaml
|
examples/sft/train_config.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1.Data setting
|
| 2 |
+
data_txt_root: '/dataset/example/train_data_info.txt' # data csv_filepath
|
| 3 |
+
resolution: 1024
|
| 4 |
+
aspect_ratio_type: 'mar_1024' # data bucketing strategy, mar_256ใmar_512ใmar_1024
|
| 5 |
+
null_text_ratio: 0.1
|
| 6 |
+
dataloader_num_workers: 8
|
| 7 |
+
train_batch_size: 4
|
| 8 |
+
repeats: 1
|
| 9 |
+
|
| 10 |
+
prompt_template_encode_prefix: '<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n'
|
| 11 |
+
prompt_template_encode_suffix: '<|im_end|>\n<|im_start|>assistant\n'
|
| 12 |
+
prompt_template_encode_start_idx: 36
|
| 13 |
+
prompt_template_encode_end_idx: 5
|
| 14 |
+
|
| 15 |
+
# 2. Model setting
|
| 16 |
+
text_tokenizer_max_length: 512 # tokenizer max len
|
| 17 |
+
pretrained_model_name_or_path: "/xxx/weights/Longcat-Image-Dev" # root directory of the model๏ผwith vaeใtransformerใscheduler eta;
|
| 18 |
+
diffusion_pretrain_weight: null # if a specified diffusion weight path is provided, load the model parameters from the current directory.
|
| 19 |
+
use_dynamic_shifting: true # scheduler dynamic shifting
|
| 20 |
+
resume_from_checkpoint: latest
|
| 21 |
+
# - "latest" # Loads most recent step checkpoint
|
| 22 |
+
# - "/path/to/checkpoint" # Resumes from specified directory
|
| 23 |
+
|
| 24 |
+
# 3. Training setting
|
| 25 |
+
use_ema: False
|
| 26 |
+
ema_rate: 0.999
|
| 27 |
+
mixed_precision: 'bf16'
|
| 28 |
+
max_train_steps: 100000
|
| 29 |
+
gradient_accumulation_steps: 1
|
| 30 |
+
gradient_checkpointing: true
|
| 31 |
+
gradient_clip: 1.0
|
| 32 |
+
learning_rate: 1.0e-5
|
| 33 |
+
adam_weight_decay: 1.0e-2
|
| 34 |
+
adam_epsilon: 1.0e-8
|
| 35 |
+
adam_beta1: 0.9
|
| 36 |
+
adam_beta2: 0.999
|
| 37 |
+
lr_num_cycles: 1
|
| 38 |
+
lr_power: 1.0
|
| 39 |
+
lr_scheduler: 'constant'
|
| 40 |
+
lr_warmup_steps: 1000
|
| 41 |
+
|
| 42 |
+
#4. Log setting
|
| 43 |
+
log_interval: 20
|
| 44 |
+
save_model_steps: 1000
|
| 45 |
+
work_dir: 'output/sft_model'
|
| 46 |
+
seed: 43
|
examples/sft/train_dataset.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import traceback
|
| 5 |
+
import math
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 13 |
+
from transformers import AutoTokenizer
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
from longcat_image.dataset import MULTI_RESOLUTION_MAP
|
| 18 |
+
from longcat_image.utils import encode_prompt
|
| 19 |
+
from longcat_image.dataset import MultiResolutionDistributedSampler
|
| 20 |
+
|
| 21 |
+
Image.MAX_IMAGE_PIXELS = 2000000000
|
| 22 |
+
|
| 23 |
+
MAX_RETRY_NUMS = 100
|
| 24 |
+
|
| 25 |
+
class Text2ImageLoraDataSet(torch.utils.data.Dataset):
|
| 26 |
+
def __init__(self,
|
| 27 |
+
cfg: dict,
|
| 28 |
+
txt_root: str,
|
| 29 |
+
tokenizer: AutoTokenizer,
|
| 30 |
+
resolution: tuple = (1024, 1024),
|
| 31 |
+
repeats: int = 1 ):
|
| 32 |
+
super(Text2ImageLoraDataSet, self).__init__()
|
| 33 |
+
self.resolution = resolution
|
| 34 |
+
self.text_tokenizer_max_length = cfg.text_tokenizer_max_length
|
| 35 |
+
self.null_text_ratio = cfg.null_text_ratio
|
| 36 |
+
self.aspect_ratio_type = cfg.aspect_ratio_type
|
| 37 |
+
self.aspect_ratio = MULTI_RESOLUTION_MAP[self.aspect_ratio_type]
|
| 38 |
+
self.tokenizer = tokenizer
|
| 39 |
+
|
| 40 |
+
self.prompt_template_encode_prefix = cfg.prompt_template_encode_prefix
|
| 41 |
+
self.prompt_template_encode_suffix = cfg.prompt_template_encode_suffix
|
| 42 |
+
self.prompt_template_encode_start_idx = cfg.prompt_template_encode_start_idx
|
| 43 |
+
self.prompt_template_encode_end_idx = cfg.prompt_template_encode_end_idx
|
| 44 |
+
|
| 45 |
+
self.total_datas = []
|
| 46 |
+
self.data_resolution_infos = []
|
| 47 |
+
with open(txt_root, 'r') as f:
|
| 48 |
+
lines = f.readlines()
|
| 49 |
+
lines *= cfg.repeats
|
| 50 |
+
for line in tqdm(lines):
|
| 51 |
+
data = json.loads(line.strip())
|
| 52 |
+
try:
|
| 53 |
+
height, widht = int(data['height']), int(data['width'])
|
| 54 |
+
self.data_resolution_infos.append((height, widht))
|
| 55 |
+
self.total_datas.append(data)
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f'get error {e}, data {data}.')
|
| 58 |
+
continue
|
| 59 |
+
self.data_nums = len(self.total_datas)
|
| 60 |
+
print(f'get sampler {len(self.total_datas)}, from {txt_root}!!!')
|
| 61 |
+
|
| 62 |
+
def transform_img(self, image, original_size, target_size):
|
| 63 |
+
img_h, img_w = original_size
|
| 64 |
+
target_height, target_width = target_size
|
| 65 |
+
|
| 66 |
+
original_aspect = img_h / img_w # height/width
|
| 67 |
+
crop_aspect = target_height / target_width
|
| 68 |
+
|
| 69 |
+
if original_aspect >= crop_aspect:
|
| 70 |
+
resize_width = target_width
|
| 71 |
+
resize_height = math.ceil(img_h * (target_width/img_w))
|
| 72 |
+
else:
|
| 73 |
+
resize_width = math.ceil(img_w * (target_height/img_h))
|
| 74 |
+
resize_height = target_height
|
| 75 |
+
|
| 76 |
+
image = T.Compose([
|
| 77 |
+
T.Resize((resize_height, resize_width),interpolation=InterpolationMode.BICUBIC), # Image.LANCZOS
|
| 78 |
+
T.CenterCrop((target_height, target_width)),
|
| 79 |
+
T.ToTensor(),
|
| 80 |
+
T.Normalize([.5], [.5]),
|
| 81 |
+
])(image)
|
| 82 |
+
|
| 83 |
+
return image
|
| 84 |
+
|
| 85 |
+
def __getitem__(self, index_tuple):
|
| 86 |
+
index, target_size = index_tuple
|
| 87 |
+
|
| 88 |
+
for _ in range(MAX_RETRY_NUMS):
|
| 89 |
+
try:
|
| 90 |
+
item = self.total_datas[index]
|
| 91 |
+
img_path = item["img_path"]
|
| 92 |
+
prompt = item['prompt']
|
| 93 |
+
|
| 94 |
+
if random.random() < self.null_text_ratio:
|
| 95 |
+
prompt = ''
|
| 96 |
+
|
| 97 |
+
raw_image = Image.open(img_path).convert('RGB')
|
| 98 |
+
assert raw_image is not None
|
| 99 |
+
img_w, img_h = raw_image.size
|
| 100 |
+
|
| 101 |
+
raw_image = self.transform_img(raw_image, original_size=(img_h, img_w), target_size= target_size )
|
| 102 |
+
input_ids,attention_mask = encode_prompt(prompt, self.tokenizer, self.text_tokenizer_max_length, self.prompt_template_encode_prefix, self.prompt_template_encode_suffix )
|
| 103 |
+
return {"image": raw_image, "prompt": prompt, 'input_ids': input_ids, 'attention_mask': attention_mask}
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
traceback.print_exc()
|
| 107 |
+
print(f"failed read data {e}!!!")
|
| 108 |
+
index = random.randint(0, self.data_nums-1)
|
| 109 |
+
|
| 110 |
+
def __len__(self):
|
| 111 |
+
return self.data_nums
|
| 112 |
+
|
| 113 |
+
def collate_fn(self, batchs):
|
| 114 |
+
images = torch.stack([example["image"] for example in batchs])
|
| 115 |
+
input_ids = torch.stack([example["input_ids"] for example in batchs])
|
| 116 |
+
attention_mask = torch.stack([example["attention_mask"] for example in batchs])
|
| 117 |
+
prompts = [example['prompt'] for example in batchs]
|
| 118 |
+
batch_dict = {
|
| 119 |
+
"images": images,
|
| 120 |
+
"input_ids": input_ids,
|
| 121 |
+
"attention_mask": attention_mask,
|
| 122 |
+
"prompts": prompts,
|
| 123 |
+
}
|
| 124 |
+
return batch_dict
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def build_dataloader(cfg: dict,
|
| 128 |
+
csv_root: str,
|
| 129 |
+
tokenizer: AutoTokenizer,
|
| 130 |
+
resolution: tuple = (1024, 1024)):
|
| 131 |
+
dataset = Text2ImageLoraDataSet(cfg, csv_root, tokenizer, resolution)
|
| 132 |
+
|
| 133 |
+
sampler = MultiResolutionDistributedSampler(batch_size=cfg.train_batch_size, dataset=dataset,
|
| 134 |
+
data_resolution_infos=dataset.data_resolution_infos,
|
| 135 |
+
bucket_info=dataset.aspect_ratio,
|
| 136 |
+
epoch=0,
|
| 137 |
+
num_replicas=None,
|
| 138 |
+
rank=None
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
train_loader = torch.utils.data.DataLoader(
|
| 142 |
+
dataset,
|
| 143 |
+
collate_fn=dataset.collate_fn,
|
| 144 |
+
batch_size=cfg.train_batch_size,
|
| 145 |
+
num_workers=cfg.dataloader_num_workers,
|
| 146 |
+
sampler=sampler,
|
| 147 |
+
shuffle=None,
|
| 148 |
+
)
|
| 149 |
+
return train_loader
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if __name__ == '__main__':
|
| 153 |
+
import sys
|
| 154 |
+
import argparse
|
| 155 |
+
from torchvision.transforms.functional import to_pil_image
|
| 156 |
+
|
| 157 |
+
txt_root = 'xxx'
|
| 158 |
+
cfg = argparse.Namespace(
|
| 159 |
+
txt_root=txt_root,
|
| 160 |
+
text_tokenizer_max_length=512,
|
| 161 |
+
resolution=1024,
|
| 162 |
+
text_encoder_path="xxx",
|
| 163 |
+
center_crop=True,
|
| 164 |
+
dataloader_num_workers=0,
|
| 165 |
+
null_text_ratio=0.1,
|
| 166 |
+
train_batch_size=16,
|
| 167 |
+
seed=0,
|
| 168 |
+
aspect_ratio_type='mar_1024',
|
| 169 |
+
revision=None)
|
| 170 |
+
|
| 171 |
+
from transformers import AutoTokenizer
|
| 172 |
+
tokenizer = AutoTokenizer.from_pretrained(cfg.text_encoder_path, trust_remote_code=True)
|
| 173 |
+
data_loader = build_dataloader(cfg, cfg.csv_root, tokenizer, cfg.resolution)
|
| 174 |
+
|
| 175 |
+
_oroot = f'./debug_data_example_show'
|
| 176 |
+
os.makedirs(_oroot, exist_ok=True)
|
| 177 |
+
|
| 178 |
+
cnt = 0
|
| 179 |
+
for epoch in range(1):
|
| 180 |
+
print(f"Start, epoch {epoch}!!!")
|
| 181 |
+
for i_batch, batch in enumerate(data_loader):
|
| 182 |
+
print(batch['attention_mask'].shape)
|
| 183 |
+
print(batch['images'].shape)
|
| 184 |
+
|
| 185 |
+
batch_prompts = batch['prompts']
|
| 186 |
+
for idx, per_img in enumerate(batch['images']):
|
| 187 |
+
re_transforms = T.Compose([
|
| 188 |
+
T.Normalize(mean=[-0.5/0.5], std=[1.0/0.5])
|
| 189 |
+
])
|
| 190 |
+
prompt = batch_prompts[idx]
|
| 191 |
+
img = to_pil_image(re_transforms(per_img))
|
| 192 |
+
prompt = prompt[:min(30, len(prompt))]
|
| 193 |
+
oname = _oroot + f'/{str(i_batch)}_{str(idx)}_{prompt}.png'
|
| 194 |
+
img.save(oname)
|
| 195 |
+
if cnt > 100:
|
| 196 |
+
break
|
| 197 |
+
cnt += 1
|
examples/sft/train_sft.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import yaml
|
| 8 |
+
import torch
|
| 9 |
+
import math
|
| 10 |
+
import logging
|
| 11 |
+
import transformers
|
| 12 |
+
import diffusers
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from transformers import Qwen2Model, Qwen2TokenizerFast
|
| 15 |
+
from accelerate import Accelerator, InitProcessGroupKwargs
|
| 16 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 17 |
+
from accelerate.logging import get_logger
|
| 18 |
+
from diffusers.models import AutoencoderKL
|
| 19 |
+
from diffusers.optimization import get_scheduler
|
| 20 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 21 |
+
from diffusers.training_utils import EMAModel
|
| 22 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 23 |
+
from transformers import AutoTokenizer, AutoModel
|
| 24 |
+
|
| 25 |
+
from train_dataset import build_dataloader
|
| 26 |
+
from longcat_image.models import LongCatImageTransformer2DModel
|
| 27 |
+
from longcat_image.utils import LogBuffer
|
| 28 |
+
from longcat_image.utils import pack_latents, unpack_latents, calculate_shift, prepare_pos_ids
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
warnings.filterwarnings("ignore") # ignore warning
|
| 32 |
+
|
| 33 |
+
current_file_path = Path(__file__).resolve()
|
| 34 |
+
sys.path.insert(0, str(current_file_path.parent.parent))
|
| 35 |
+
|
| 36 |
+
logger = get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
def train(global_step=0):
|
| 39 |
+
|
| 40 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 41 |
+
|
| 42 |
+
# Train!
|
| 43 |
+
total_batch_size = args.train_batch_size * \
|
| 44 |
+
accelerator.num_processes * args.gradient_accumulation_steps
|
| 45 |
+
|
| 46 |
+
logger.info("***** Running training *****")
|
| 47 |
+
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
|
| 48 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 49 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 50 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 51 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 52 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 53 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 54 |
+
|
| 55 |
+
last_tic = time.time()
|
| 56 |
+
|
| 57 |
+
# Now you train the model
|
| 58 |
+
for epoch in range(first_epoch, args.num_train_epochs + 1):
|
| 59 |
+
data_time_start = time.time()
|
| 60 |
+
data_time_all = 0
|
| 61 |
+
|
| 62 |
+
for step, batch in enumerate(train_dataloader):
|
| 63 |
+
image = batch['images']
|
| 64 |
+
|
| 65 |
+
data_time_all += time.time() - data_time_start
|
| 66 |
+
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
latents = vae.encode(image.to(weight_dtype).to(accelerator.device)).latent_dist.sample()
|
| 69 |
+
latents = latents.to(dtype=(weight_dtype))
|
| 70 |
+
latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
|
| 71 |
+
|
| 72 |
+
text_input_ids = batch['input_ids'].to(accelerator.device)
|
| 73 |
+
text_attention_mask = batch['attention_mask'].to(accelerator.device)
|
| 74 |
+
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
text_output = text_encoder(
|
| 77 |
+
input_ids=text_input_ids,
|
| 78 |
+
attention_mask=text_attention_mask,
|
| 79 |
+
output_hidden_states=True
|
| 80 |
+
)
|
| 81 |
+
prompt_embeds = text_output.hidden_states[-1].clone().detach()
|
| 82 |
+
|
| 83 |
+
prompt_embeds = prompt_embeds.to(weight_dtype)
|
| 84 |
+
prompt_embeds = prompt_embeds[:,args.prompt_template_encode_start_idx: -args.prompt_template_encode_end_idx ,:]
|
| 85 |
+
|
| 86 |
+
# Sample a random timestep for each image
|
| 87 |
+
grad_norm = None
|
| 88 |
+
with accelerator.accumulate(transformer):
|
| 89 |
+
# Predict the noise residual
|
| 90 |
+
optimizer.zero_grad()
|
| 91 |
+
# logit-normal
|
| 92 |
+
sigmas = torch.sigmoid(torch.randn((latents.shape[0],), device=accelerator.device, dtype=latents.dtype))
|
| 93 |
+
|
| 94 |
+
if args.use_dynamic_shifting:
|
| 95 |
+
sigmas = noise_scheduler.time_shift(mu, 1.0, sigmas)
|
| 96 |
+
|
| 97 |
+
timesteps = sigmas * 1000.0
|
| 98 |
+
sigmas = sigmas.view(-1, 1, 1, 1)
|
| 99 |
+
|
| 100 |
+
noise = torch.randn_like(latents)
|
| 101 |
+
|
| 102 |
+
noisy_latents = (1 - sigmas) * latents + sigmas * noise
|
| 103 |
+
noisy_latents = noisy_latents.to(weight_dtype)
|
| 104 |
+
|
| 105 |
+
packed_noisy_latents = pack_latents(
|
| 106 |
+
noisy_latents,
|
| 107 |
+
batch_size=latents.shape[0],
|
| 108 |
+
num_channels_latents=latents.shape[1],
|
| 109 |
+
height=latents.shape[2],
|
| 110 |
+
width=latents.shape[3],
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
guidance = None
|
| 114 |
+
img_ids = prepare_pos_ids(modality_id=1,
|
| 115 |
+
type='image',
|
| 116 |
+
start=(prompt_embeds.shape[1], prompt_embeds.shape[1]),
|
| 117 |
+
height=latents.shape[2]//2,
|
| 118 |
+
width=latents.shape[3]//2).to(accelerator.device, dtype=torch.float64)
|
| 119 |
+
|
| 120 |
+
timesteps = (
|
| 121 |
+
torch.tensor(timesteps)
|
| 122 |
+
.expand(noisy_latents.shape[0])
|
| 123 |
+
.to(device=accelerator.device)
|
| 124 |
+
/ 1000
|
| 125 |
+
)
|
| 126 |
+
text_ids = prepare_pos_ids(modality_id=0,
|
| 127 |
+
type='text',
|
| 128 |
+
start=(0, 0),
|
| 129 |
+
num_token=prompt_embeds.shape[1]).to(accelerator.device, torch.float64)
|
| 130 |
+
|
| 131 |
+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
| 132 |
+
model_pred = transformer(packed_noisy_latents, prompt_embeds, timesteps,
|
| 133 |
+
img_ids, text_ids, guidance, return_dict=False)[0]
|
| 134 |
+
|
| 135 |
+
model_pred = unpack_latents(
|
| 136 |
+
model_pred,
|
| 137 |
+
height=latents.shape[2] * 8,
|
| 138 |
+
width=latents.shape[3] * 8,
|
| 139 |
+
vae_scale_factor=16,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
target = noise - latents
|
| 143 |
+
loss = torch.mean(
|
| 144 |
+
((model_pred.float() - target.float()) ** 2).reshape(
|
| 145 |
+
target.shape[0], -1
|
| 146 |
+
),
|
| 147 |
+
1,
|
| 148 |
+
).mean()
|
| 149 |
+
|
| 150 |
+
accelerator.backward(loss)
|
| 151 |
+
|
| 152 |
+
if accelerator.sync_gradients:
|
| 153 |
+
grad_norm = transformer.get_global_grad_norm()
|
| 154 |
+
|
| 155 |
+
optimizer.step()
|
| 156 |
+
if not accelerator.optimizer_step_was_skipped:
|
| 157 |
+
lr_scheduler.step()
|
| 158 |
+
|
| 159 |
+
if accelerator.sync_gradients and args.use_ema:
|
| 160 |
+
model_ema.step(transformer.parameters())
|
| 161 |
+
|
| 162 |
+
lr = lr_scheduler.get_last_lr()[0]
|
| 163 |
+
|
| 164 |
+
if accelerator.sync_gradients:
|
| 165 |
+
bsz, ic, ih, iw = image.shape
|
| 166 |
+
logs = {"loss": accelerator.gather(loss).mean().item(), 'aspect_ratio': (ih*1.0 / iw)}
|
| 167 |
+
if grad_norm is not None:
|
| 168 |
+
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item())
|
| 169 |
+
|
| 170 |
+
log_buffer.update(logs)
|
| 171 |
+
if (step + 1) % args.log_interval == 0 or (step + 1) == 1:
|
| 172 |
+
t = (time.time() - last_tic) / args.log_interval
|
| 173 |
+
t_d = data_time_all / args.log_interval
|
| 174 |
+
|
| 175 |
+
log_buffer.average()
|
| 176 |
+
info = f"Step={step+1}, Epoch={epoch}, global_step={global_step}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:(ch:{latents.shape[1]},h:{latents.shape[2]},w:{latents.shape[3]}), "
|
| 177 |
+
info += ', '.join([f"{k}:{v:.4f}" for k,v in log_buffer.output.items()])
|
| 178 |
+
logger.info(info)
|
| 179 |
+
last_tic = time.time()
|
| 180 |
+
log_buffer.clear()
|
| 181 |
+
data_time_all = 0
|
| 182 |
+
logs.update(lr=lr)
|
| 183 |
+
accelerator.log(logs, step=global_step)
|
| 184 |
+
global_step += 1
|
| 185 |
+
data_time_start = time.time()
|
| 186 |
+
|
| 187 |
+
if global_step != 0 and global_step % args.save_model_steps == 0:
|
| 188 |
+
save_path = os.path.join(args.work_dir, f'checkpoints-{global_step}')
|
| 189 |
+
if args.use_ema:
|
| 190 |
+
model_ema.store(transformer.parameters())
|
| 191 |
+
model_ema.copy_to(transformer.parameters())
|
| 192 |
+
|
| 193 |
+
accelerator.save_state(save_path)
|
| 194 |
+
|
| 195 |
+
if args.use_ema:
|
| 196 |
+
model_ema.restore(transformer.parameters())
|
| 197 |
+
logger.info(f"Saved state to {save_path} (global_step: {global_step})")
|
| 198 |
+
accelerator.wait_for_everyone()
|
| 199 |
+
|
| 200 |
+
if global_step >= args.max_train_steps:
|
| 201 |
+
break
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def parse_args():
|
| 205 |
+
parser = argparse.ArgumentParser(description="Process some integers.")
|
| 206 |
+
parser.add_argument("--config", type=str, default='', help="config")
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"--report_to",
|
| 209 |
+
type=str,
|
| 210 |
+
default="tensorboard",
|
| 211 |
+
help=(
|
| 212 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 213 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 214 |
+
),
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--allow_tf32",
|
| 218 |
+
action="store_true",
|
| 219 |
+
help=(
|
| 220 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 221 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 222 |
+
),
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 226 |
+
)
|
| 227 |
+
args = parser.parse_args()
|
| 228 |
+
return args
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
if __name__ == '__main__':
|
| 232 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 233 |
+
|
| 234 |
+
args = parse_args()
|
| 235 |
+
|
| 236 |
+
if args.config != '' and os.path.exists(args.config):
|
| 237 |
+
config = yaml.safe_load(open(args.config, 'r'))
|
| 238 |
+
else:
|
| 239 |
+
config = yaml.safe_load(open(f'{cur_dir}/train_config.yaml', 'r'))
|
| 240 |
+
|
| 241 |
+
args_dict = vars(args)
|
| 242 |
+
args_dict.update(config)
|
| 243 |
+
args = argparse.Namespace(**args_dict)
|
| 244 |
+
|
| 245 |
+
os.umask(0o000)
|
| 246 |
+
os.makedirs(args.work_dir, exist_ok=True)
|
| 247 |
+
|
| 248 |
+
log_dir = args.work_dir + f'/logs'
|
| 249 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 250 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.work_dir, logging_dir=log_dir)
|
| 251 |
+
|
| 252 |
+
with open(f'{log_dir}/train.yaml', 'w') as f:
|
| 253 |
+
yaml.dump(args_dict, f)
|
| 254 |
+
|
| 255 |
+
accelerator = Accelerator(
|
| 256 |
+
mixed_precision=args.mixed_precision,
|
| 257 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 258 |
+
log_with=args.report_to,
|
| 259 |
+
project_config= accelerator_project_config,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# Make one log on every process with the configuration for debugging.
|
| 263 |
+
logging.basicConfig(
|
| 264 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 265 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 266 |
+
level=logging.INFO,
|
| 267 |
+
)
|
| 268 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 269 |
+
|
| 270 |
+
if accelerator.is_local_main_process:
|
| 271 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 272 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 273 |
+
else:
|
| 274 |
+
transformers.utils.logging.set_verbosity_error()
|
| 275 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 276 |
+
|
| 277 |
+
if args.seed is not None:
|
| 278 |
+
set_seed(args.seed)
|
| 279 |
+
|
| 280 |
+
weight_dtype = torch.float32
|
| 281 |
+
if accelerator.mixed_precision == "fp16":
|
| 282 |
+
weight_dtype = torch.float16
|
| 283 |
+
elif accelerator.mixed_precision == "bf16":
|
| 284 |
+
weight_dtype = torch.bfloat16
|
| 285 |
+
|
| 286 |
+
logger.info(f'using weight_dtype {weight_dtype}!!!')
|
| 287 |
+
|
| 288 |
+
if args.diffusion_pretrain_weight:
|
| 289 |
+
transformer = LongCatImageTransformer2DModel.from_pretrained(args.diffusion_pretrain_weight, ignore_mismatched_sizes=False)
|
| 290 |
+
logger.info(f'successful load model weight {args.diffusion_pretrain_weight}!!!')
|
| 291 |
+
else:
|
| 292 |
+
transformer = LongCatImageTransformer2DModel.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "transformer"), ignore_mismatched_sizes=False)
|
| 293 |
+
logger.info(f'successful load model weight {args.pretrained_model_name_or_path+"/transformer"}!!!')
|
| 294 |
+
|
| 295 |
+
transformer = transformer.train()
|
| 296 |
+
|
| 297 |
+
total_trainable_params = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
|
| 298 |
+
logger.info(f">>>>>> total_trainable_params: {total_trainable_params}")
|
| 299 |
+
|
| 300 |
+
if args.use_ema:
|
| 301 |
+
model_ema = EMAModel(transformer.parameters(), decay=args.ema_rate)
|
| 302 |
+
else:
|
| 303 |
+
model_ema = None
|
| 304 |
+
|
| 305 |
+
vae_dtype = torch.float32
|
| 306 |
+
vae = AutoencoderKL.from_pretrained(
|
| 307 |
+
args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype).cuda().eval()
|
| 308 |
+
|
| 309 |
+
text_encoder = AutoModel.from_pretrained(
|
| 310 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder" , torch_dtype=weight_dtype, trust_remote_code=True).cuda().eval()
|
| 311 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 312 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer" , torch_dtype=weight_dtype, trust_remote_code=True)
|
| 313 |
+
logger.info("all models loaded successfully")
|
| 314 |
+
|
| 315 |
+
# build models
|
| 316 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 317 |
+
args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 318 |
+
|
| 319 |
+
latent_size = int(args.resolution) // 8
|
| 320 |
+
mu = calculate_shift(
|
| 321 |
+
(latent_size//2)**2,
|
| 322 |
+
noise_scheduler.config.base_image_seq_len,
|
| 323 |
+
noise_scheduler.config.max_image_seq_len,
|
| 324 |
+
noise_scheduler.config.base_shift,
|
| 325 |
+
noise_scheduler.config.max_shift,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
| 329 |
+
def save_model_hook(models, weights, output_dir):
|
| 330 |
+
if accelerator.is_main_process:
|
| 331 |
+
for i, model in enumerate(models):
|
| 332 |
+
model.save_pretrained(os.path.join(output_dir, "transformer"))
|
| 333 |
+
if len(weights) != 0:
|
| 334 |
+
weights.pop()
|
| 335 |
+
|
| 336 |
+
def load_model_hook(models, input_dir):
|
| 337 |
+
while len(models) > 0:
|
| 338 |
+
# pop models so that they are not loaded again
|
| 339 |
+
model = models.pop()
|
| 340 |
+
# load diffusers style into model
|
| 341 |
+
load_model = LongCatImageTransformer2DModel.from_pretrained(
|
| 342 |
+
input_dir, subfolder="transformer")
|
| 343 |
+
model.register_to_config(**load_model.config)
|
| 344 |
+
model.load_state_dict(load_model.state_dict())
|
| 345 |
+
del load_model
|
| 346 |
+
|
| 347 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 348 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 349 |
+
|
| 350 |
+
if args.gradient_checkpointing:
|
| 351 |
+
transformer.enable_gradient_checkpointing()
|
| 352 |
+
|
| 353 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 354 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 355 |
+
if args.allow_tf32:
|
| 356 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 357 |
+
|
| 358 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 359 |
+
if args.use_8bit_adam:
|
| 360 |
+
try:
|
| 361 |
+
import bitsandbytes as bnb
|
| 362 |
+
except ImportError:
|
| 363 |
+
raise ImportError(
|
| 364 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 365 |
+
)
|
| 366 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 367 |
+
else:
|
| 368 |
+
optimizer_class = torch.optim.AdamW
|
| 369 |
+
|
| 370 |
+
params_to_optimize = transformer.parameters()
|
| 371 |
+
optimizer = optimizer_class(
|
| 372 |
+
params_to_optimize,
|
| 373 |
+
lr=args.learning_rate,
|
| 374 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 375 |
+
weight_decay=args.adam_weight_decay,
|
| 376 |
+
eps=args.adam_epsilon,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 380 |
+
if args.use_ema:
|
| 381 |
+
model_ema.to(accelerator.device, dtype=weight_dtype)
|
| 382 |
+
|
| 383 |
+
train_dataloader = build_dataloader(args, args.data_txt_root, tokenizer, args.resolution,)
|
| 384 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 385 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 386 |
+
|
| 387 |
+
lr_scheduler = get_scheduler(
|
| 388 |
+
args.lr_scheduler,
|
| 389 |
+
optimizer=optimizer,
|
| 390 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 391 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 392 |
+
num_cycles=args.lr_num_cycles,
|
| 393 |
+
power=args.lr_power,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
global_step = 0
|
| 397 |
+
first_epoch = 0
|
| 398 |
+
# Potentially load in the weights and states from a previous save
|
| 399 |
+
if args.resume_from_checkpoint:
|
| 400 |
+
if args.resume_from_checkpoint != "latest":
|
| 401 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 402 |
+
else:
|
| 403 |
+
# Get the most recent checkpoint
|
| 404 |
+
dirs = os.listdir(args.work_dir)
|
| 405 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 406 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 407 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 408 |
+
if path is None:
|
| 409 |
+
logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
|
| 410 |
+
args.resume_from_checkpoint = None
|
| 411 |
+
initial_global_step = 0
|
| 412 |
+
else:
|
| 413 |
+
logger.info(f"Resuming from checkpoint {path}")
|
| 414 |
+
accelerator.load_state(os.path.join(args.work_dir, path))
|
| 415 |
+
global_step = int(path.split("-")[1])
|
| 416 |
+
|
| 417 |
+
initial_global_step = global_step
|
| 418 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 419 |
+
|
| 420 |
+
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
|
| 421 |
+
if accelerator.is_main_process:
|
| 422 |
+
tracker_config = dict(vars(args))
|
| 423 |
+
try:
|
| 424 |
+
accelerator.init_trackers('sft', tracker_config)
|
| 425 |
+
except Exception as e:
|
| 426 |
+
logger.warning(f'get error in save config, {e}')
|
| 427 |
+
accelerator.init_trackers(f"sft_{timestamp}")
|
| 428 |
+
|
| 429 |
+
transformer, optimizer, _, _ = accelerator.prepare(
|
| 430 |
+
transformer, optimizer, train_dataloader, lr_scheduler)
|
| 431 |
+
|
| 432 |
+
log_buffer = LogBuffer()
|
| 433 |
+
train(global_step=global_step)
|
longcat_image/__init__.py
ADDED
|
File without changes
|
longcat_image/dataset/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .data_utils import MULTI_RESOLUTION_MAP
|
| 2 |
+
from .sampler import MultiResolutionDistributedSampler
|
longcat_image/dataset/data_utils.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MULTI_ASPECT_RATIO_1024 = {
|
| 2 |
+
'0.5': [704., 1408.], '0.52': [704., 1344.],
|
| 3 |
+
'0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
|
| 4 |
+
'0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
|
| 5 |
+
'1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
|
| 6 |
+
'1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
|
| 7 |
+
'1.75': [1344., 768.], '2.0': [1408., 704.]
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
MULTI_ASPECT_RATIO_512 = {
|
| 12 |
+
'0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
| 13 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
| 14 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
| 15 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
| 16 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
| 17 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0]
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
MULTI_ASPECT_RATIO_256 = {
|
| 22 |
+
'0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
|
| 23 |
+
'0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
|
| 24 |
+
'0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
|
| 25 |
+
'1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
|
| 26 |
+
'1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
|
| 27 |
+
'1.75': [336.0, 192.0], '2.0': [352.0, 176.0]
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
MULTI_RESOLUTION_MAP = {
|
| 31 |
+
'mar_256': MULTI_ASPECT_RATIO_256,
|
| 32 |
+
'mar_512': MULTI_ASPECT_RATIO_512,
|
| 33 |
+
'mar_1024': MULTI_ASPECT_RATIO_1024,
|
| 34 |
+
}
|
longcat_image/dataset/sampler.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
import numpy as np
|
| 6 |
+
import copy
|
| 7 |
+
from torch.utils.data import IterableDataset
|
| 8 |
+
|
| 9 |
+
from longcat_image.utils.dist_utils import get_world_size, get_rank, get_local_rank
|
| 10 |
+
|
| 11 |
+
class MultiResolutionDistributedSampler(torch.utils.data.Sampler):
|
| 12 |
+
def __init__(self,
|
| 13 |
+
batch_size: int,
|
| 14 |
+
dataset: IterableDataset,
|
| 15 |
+
data_resolution_infos: List,
|
| 16 |
+
bucket_info: dict,
|
| 17 |
+
num_replicas: int = None,
|
| 18 |
+
rank: int = None,
|
| 19 |
+
seed: int = 888,
|
| 20 |
+
epoch: int = 0,
|
| 21 |
+
shuffle: bool = True):
|
| 22 |
+
|
| 23 |
+
if not dist.is_available():
|
| 24 |
+
num_replicas = 1
|
| 25 |
+
rank = 0
|
| 26 |
+
else:
|
| 27 |
+
num_replicas = get_world_size()
|
| 28 |
+
rank = get_rank()
|
| 29 |
+
|
| 30 |
+
self.len_items = len(dataset)
|
| 31 |
+
bucket_info = {float(b): bucket_info[b] for b in bucket_info.keys()}
|
| 32 |
+
self.aspect_ratios = np.array(sorted(list(bucket_info.keys())))
|
| 33 |
+
self.resolutions = np.array([bucket_info[aspect] for aspect in self.aspect_ratios])
|
| 34 |
+
|
| 35 |
+
self.batch_size = batch_size
|
| 36 |
+
self.num_replicas = num_replicas
|
| 37 |
+
self.rank = rank
|
| 38 |
+
self.epoch = epoch
|
| 39 |
+
self.shuffle = shuffle
|
| 40 |
+
self.seed = seed
|
| 41 |
+
self.cur_rank_index = []
|
| 42 |
+
self.rng = np.random.RandomState(seed+self.epoch)
|
| 43 |
+
self.global_batch_size = batch_size*num_replicas
|
| 44 |
+
self.data_resolution_infos = np.array(data_resolution_infos, dtype=np.float32)
|
| 45 |
+
print(f'num_replicas {num_replicas}, cur rank {rank}!!!')
|
| 46 |
+
|
| 47 |
+
self.split_to_buckets()
|
| 48 |
+
self.num_samples = len(dataset)//num_replicas
|
| 49 |
+
|
| 50 |
+
def split_to_buckets(self):
|
| 51 |
+
self.buckets = {}
|
| 52 |
+
self._buckets_bak = {}
|
| 53 |
+
data_aspect_ratio = self.data_resolution_infos[:,0]*1.0/self.data_resolution_infos[:, 1]
|
| 54 |
+
bucket_id = np.abs(data_aspect_ratio[:, None] - self.aspect_ratios).argmin(axis=1)
|
| 55 |
+
for i in range(len(self.aspect_ratios)):
|
| 56 |
+
self.buckets[i] = np.where(bucket_id == i)[0]
|
| 57 |
+
self._buckets_bak[i] = np.where(bucket_id == i)[0]
|
| 58 |
+
for k, v in self.buckets.items():
|
| 59 |
+
print(f'bucket {k}, resolutions {self.resolutions[k]}, sampler nums {len(v)}!!!')
|
| 60 |
+
|
| 61 |
+
def get_batch_index(self):
|
| 62 |
+
success_flag = False
|
| 63 |
+
while not success_flag:
|
| 64 |
+
bucket_ids = list(self.buckets.keys())
|
| 65 |
+
bucket_probs = [len(self.buckets[bucket_id]) for bucket_id in bucket_ids]
|
| 66 |
+
bucket_probs = np.array(bucket_probs, dtype=np.float32)
|
| 67 |
+
bucket_probs = bucket_probs / bucket_probs.sum()
|
| 68 |
+
bucket_ids = np.array(bucket_ids, dtype=np.int64)
|
| 69 |
+
chosen_id = int(self.rng.choice(bucket_ids, 1, p=bucket_probs)[0])
|
| 70 |
+
if len(self.buckets[chosen_id]) < self.global_batch_size:
|
| 71 |
+
del self.buckets[chosen_id]
|
| 72 |
+
continue
|
| 73 |
+
batch_data = self.buckets[chosen_id][:self.global_batch_size]
|
| 74 |
+
batch_data = (batch_data, self.resolutions[chosen_id])
|
| 75 |
+
self.buckets[chosen_id] = self.buckets[chosen_id][self.global_batch_size:]
|
| 76 |
+
if len(self.buckets[chosen_id]) == 0:
|
| 77 |
+
del self.buckets[chosen_id]
|
| 78 |
+
success_flag = True
|
| 79 |
+
assert bool(self.buckets), 'There is not enough data in the current epoch.'
|
| 80 |
+
return batch_data
|
| 81 |
+
|
| 82 |
+
def shuffle_bucker_index(self):
|
| 83 |
+
self.rng = np.random.RandomState(self.seed+self.epoch)
|
| 84 |
+
self.buckets = copy.deepcopy(self._buckets_bak)
|
| 85 |
+
for bucket_id in self.buckets.keys():
|
| 86 |
+
self.rng.shuffle(self.buckets[bucket_id])
|
| 87 |
+
|
| 88 |
+
def __iter__(self):
|
| 89 |
+
return self
|
| 90 |
+
|
| 91 |
+
def __next__(self):
|
| 92 |
+
try:
|
| 93 |
+
if len(self.cur_rank_index) == 0:
|
| 94 |
+
global_batch_index, target_resolutions = self.get_batch_index()
|
| 95 |
+
self.cur_rank_index = list(map(
|
| 96 |
+
int, global_batch_index[self.batch_size*self.rank:self.batch_size*(self.rank+1)]))
|
| 97 |
+
self.resolution = list(map(int, target_resolutions))
|
| 98 |
+
data_index = self.cur_rank_index.pop(0)
|
| 99 |
+
return (data_index, self.resolution)
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
self.epoch += 1
|
| 103 |
+
self.shuffle_bucker_index()
|
| 104 |
+
print(f'get error {e}.')
|
| 105 |
+
raise StopIteration
|
| 106 |
+
|
| 107 |
+
def __len__(self):
|
| 108 |
+
return self.num_samples
|
| 109 |
+
|
| 110 |
+
def set_epoch(self, epoch):
|
| 111 |
+
self.epoch = epoch
|
longcat_image/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .longcat_image_dit import LongCatImageTransformer2DModel
|
longcat_image/models/longcat_image_dit.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from diffusers.models.transformers.transformer_flux import \
|
| 8 |
+
FluxTransformerBlock, FluxSingleTransformerBlock, \
|
| 9 |
+
AdaLayerNormContinuous, Transformer2DModelOutput
|
| 10 |
+
from diffusers.models.embeddings import Timesteps, TimestepEmbedding,FluxPosEmbed
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
from accelerate.logging import get_logger
|
| 14 |
+
from diffusers.loaders import PeftAdapterMixin
|
| 15 |
+
|
| 16 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TimestepEmbeddings(nn.Module):
|
| 20 |
+
def __init__(self, embedding_dim):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 24 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 25 |
+
|
| 26 |
+
def forward(self, timestep, hidden_dtype):
|
| 27 |
+
timesteps_proj = self.time_proj(timestep)
|
| 28 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
| 29 |
+
|
| 30 |
+
return timesteps_emb
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LongCatImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin ):
|
| 34 |
+
"""
|
| 35 |
+
The Transformer model introduced in Flux.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
_supports_gradient_checkpointing = True
|
| 39 |
+
|
| 40 |
+
@register_to_config
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
patch_size: int = 1,
|
| 44 |
+
in_channels: int = 64,
|
| 45 |
+
num_layers: int = 19,
|
| 46 |
+
num_single_layers: int = 38,
|
| 47 |
+
attention_head_dim: int = 128,
|
| 48 |
+
num_attention_heads: int = 24,
|
| 49 |
+
joint_attention_dim: int = 3584,
|
| 50 |
+
pooled_projection_dim: int = 3584,
|
| 51 |
+
axes_dims_rope: List[int] = [16, 56, 56],
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.out_channels = in_channels
|
| 55 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 56 |
+
self.pooled_projection_dim = pooled_projection_dim
|
| 57 |
+
|
| 58 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
| 59 |
+
|
| 60 |
+
self.time_embed = TimestepEmbeddings(embedding_dim=self.inner_dim)
|
| 61 |
+
|
| 62 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 63 |
+
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
| 64 |
+
|
| 65 |
+
self.transformer_blocks = nn.ModuleList(
|
| 66 |
+
[
|
| 67 |
+
FluxTransformerBlock(
|
| 68 |
+
dim=self.inner_dim,
|
| 69 |
+
num_attention_heads=num_attention_heads,
|
| 70 |
+
attention_head_dim=attention_head_dim,
|
| 71 |
+
)
|
| 72 |
+
for i in range(num_layers)
|
| 73 |
+
]
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 77 |
+
[
|
| 78 |
+
FluxSingleTransformerBlock(
|
| 79 |
+
dim=self.inner_dim,
|
| 80 |
+
num_attention_heads=num_attention_heads,
|
| 81 |
+
attention_head_dim=attention_head_dim,
|
| 82 |
+
)
|
| 83 |
+
for i in range(num_single_layers)
|
| 84 |
+
]
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.norm_out = AdaLayerNormContinuous(
|
| 88 |
+
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 89 |
+
self.proj_out = nn.Linear(
|
| 90 |
+
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 91 |
+
|
| 92 |
+
self.gradient_checkpointing = False
|
| 93 |
+
|
| 94 |
+
self.initialize_weights()
|
| 95 |
+
|
| 96 |
+
self.use_checkpoint = [True] * num_layers
|
| 97 |
+
self.use_single_checkpoint = [True] * num_single_layers
|
| 98 |
+
|
| 99 |
+
def forward(
|
| 100 |
+
self,
|
| 101 |
+
hidden_states: torch.Tensor,
|
| 102 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 103 |
+
timestep: torch.LongTensor = None,
|
| 104 |
+
img_ids: torch.Tensor = None,
|
| 105 |
+
txt_ids: torch.Tensor = None,
|
| 106 |
+
guidance: torch.Tensor = None,
|
| 107 |
+
return_dict: bool = True,
|
| 108 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
| 109 |
+
"""
|
| 110 |
+
The forward method.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
| 114 |
+
Input `hidden_states`.
|
| 115 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
| 116 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 117 |
+
timestep ( `torch.LongTensor`):
|
| 118 |
+
Used to indicate denoising step.
|
| 119 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
| 120 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
| 121 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 122 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 123 |
+
tuple.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 127 |
+
`tuple` where the first element is the sample tensor.
|
| 128 |
+
"""
|
| 129 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 130 |
+
|
| 131 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 132 |
+
if guidance is not None:
|
| 133 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 134 |
+
else:
|
| 135 |
+
guidance = None
|
| 136 |
+
|
| 137 |
+
temb = self.time_embed( timestep, hidden_states.dtype )
|
| 138 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 139 |
+
|
| 140 |
+
if txt_ids.ndim == 3:
|
| 141 |
+
logger.warning(
|
| 142 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
| 143 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 144 |
+
)
|
| 145 |
+
txt_ids = txt_ids[0]
|
| 146 |
+
if img_ids.ndim == 3:
|
| 147 |
+
logger.warning(
|
| 148 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
| 149 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 150 |
+
)
|
| 151 |
+
img_ids = img_ids[0]
|
| 152 |
+
|
| 153 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
| 154 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 155 |
+
|
| 156 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 157 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]:
|
| 158 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 159 |
+
block,
|
| 160 |
+
hidden_states,
|
| 161 |
+
encoder_hidden_states,
|
| 162 |
+
temb,
|
| 163 |
+
image_rotary_emb,
|
| 164 |
+
)
|
| 165 |
+
else:
|
| 166 |
+
encoder_hidden_states, hidden_states = block(
|
| 167 |
+
hidden_states=hidden_states,
|
| 168 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 169 |
+
temb=temb,
|
| 170 |
+
image_rotary_emb=image_rotary_emb,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 174 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_single_checkpoint[index_block]:
|
| 175 |
+
encoder_hidden_states,hidden_states = self._gradient_checkpointing_func(
|
| 176 |
+
block,
|
| 177 |
+
hidden_states,
|
| 178 |
+
encoder_hidden_states,
|
| 179 |
+
temb,
|
| 180 |
+
image_rotary_emb,
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
encoder_hidden_states, hidden_states = block(
|
| 184 |
+
hidden_states=hidden_states,
|
| 185 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 186 |
+
temb=temb,
|
| 187 |
+
image_rotary_emb=image_rotary_emb,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 191 |
+
output = self.proj_out(hidden_states)
|
| 192 |
+
|
| 193 |
+
if not return_dict:
|
| 194 |
+
return (output,)
|
| 195 |
+
|
| 196 |
+
return Transformer2DModelOutput(sample=output)
|
| 197 |
+
|
| 198 |
+
def initialize_weights(self):
|
| 199 |
+
# Initialize transformer layers:
|
| 200 |
+
def _basic_init(module):
|
| 201 |
+
if isinstance(module, nn.Linear):
|
| 202 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 203 |
+
if module.bias is not None:
|
| 204 |
+
nn.init.constant_(module.bias, 0)
|
| 205 |
+
self.apply(_basic_init)
|
| 206 |
+
|
| 207 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 208 |
+
w = self.x_embedder.weight.data
|
| 209 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 210 |
+
nn.init.constant_(self.x_embedder.bias, 0)
|
| 211 |
+
|
| 212 |
+
# Initialize caption embedding MLP:
|
| 213 |
+
nn.init.normal_(self.context_embedder.weight, std=0.02)
|
| 214 |
+
|
| 215 |
+
# Zero-out adaLN modulation layers in blocks:
|
| 216 |
+
for block in self.transformer_blocks:
|
| 217 |
+
nn.init.constant_(block.norm1.linear.weight, 0)
|
| 218 |
+
nn.init.constant_(block.norm1.linear.bias, 0)
|
| 219 |
+
nn.init.constant_(block.norm1_context.linear.weight, 0)
|
| 220 |
+
nn.init.constant_(block.norm1_context.linear.bias, 0)
|
| 221 |
+
|
| 222 |
+
for block in self.single_transformer_blocks:
|
| 223 |
+
nn.init.constant_(block.norm.linear.weight, 0)
|
| 224 |
+
nn.init.constant_(block.norm.linear.bias, 0)
|
| 225 |
+
|
| 226 |
+
# Zero-out output layers:
|
| 227 |
+
nn.init.constant_(self.norm_out.linear.weight, 0)
|
| 228 |
+
nn.init.constant_(self.norm_out.linear.bias, 0)
|
| 229 |
+
nn.init.constant_(self.proj_out.weight, 0)
|
| 230 |
+
nn.init.constant_(self.proj_out.bias, 0)
|
| 231 |
+
|
longcat_image/pipelines/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pipeline_longcat_image import LongCatImagePipeline
|
| 2 |
+
from .pipeline_longcat_image_edit import LongCatImageEditPipeline
|
longcat_image/pipelines/pipeline_longcat_image.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 16 |
+
import json
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import re
|
| 20 |
+
|
| 21 |
+
from transformers import (
|
| 22 |
+
CLIPImageProcessor,
|
| 23 |
+
CLIPVisionModelWithProjection,
|
| 24 |
+
)
|
| 25 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 26 |
+
from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
| 27 |
+
from diffusers.models import AutoencoderKL
|
| 28 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 29 |
+
from diffusers.utils import (
|
| 30 |
+
USE_PEFT_BACKEND,
|
| 31 |
+
is_torch_xla_available,
|
| 32 |
+
logging,
|
| 33 |
+
)
|
| 34 |
+
from transformers import AutoTokenizer, AutoModel, AutoProcessor
|
| 35 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 36 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 37 |
+
|
| 38 |
+
from longcat_image.utils.model_utils import split_quotation, prepare_pos_ids, calculate_shift, retrieve_timesteps, optimized_scale
|
| 39 |
+
from longcat_image.models.longcat_image_dit import LongCatImageTransformer2DModel
|
| 40 |
+
from longcat_image.pipelines.pipeline_output import LongCatImagePipelineOutput
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if is_torch_xla_available():
|
| 44 |
+
import torch_xla.core.xla_model as xm
|
| 45 |
+
XLA_AVAILABLE = True
|
| 46 |
+
else:
|
| 47 |
+
XLA_AVAILABLE = False
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
SYSTEM_PROMPT_EN = """
|
| 54 |
+
You are a prompt engineering expert for text-to-image models. Since text-to-image models have limited capabilities in understanding user prompts, you need to identify the core theme and intent of the user's input and improve the model's understanding accuracy and generation quality through optimization and rewriting. The rewrite must strictly retain all information from the user's original prompt without deleting or distorting any details.
|
| 55 |
+
Specific requirements are as follows:
|
| 56 |
+
1. The rewrite must not affect any information expressed in the user's original prompt; the rewritten prompt should use coherent natural language, avoid low-information redundant descriptions, and keep the rewritten prompt length as concise as possible.
|
| 57 |
+
2. Ensure consistency between input and output languages: Chinese input yields Chinese output, and English input yields English output. The rewritten token count should not exceed 512.
|
| 58 |
+
3. The rewritten description should further refine subject characteristics and aesthetic techniques appearing in the original prompt, such as lighting and textures.
|
| 59 |
+
4. If the original prompt does not specify an image style, ensure the rewritten prompt uses a **realistic photography style**. If the user specifies a style, retain the user's style.
|
| 60 |
+
5. When the original prompt requires reasoning to clarify user intent, use logical reasoning based on world knowledge to convert vague abstract descriptions into specific tangible objects (e.g., convert "the tallest animal" to "a giraffe").
|
| 61 |
+
6. When the original prompt requires text generation, please use double quotes to enclose the text part (e.g., `"50% OFF"`).
|
| 62 |
+
7. When the original prompt requires generating text-heavy scenes like webpages, logos, UIs, or posters, and no specific text content is specified, you need to infer appropriate text content and enclose it in double quotes. For example, if the user inputs: "A tourism flyer with a grassland theme," it should be rewritten as: "A tourism flyer with the image title 'Grassland'."
|
| 63 |
+
8. When negative words exist in the original prompt, ensure the rewritten prompt does not contain negative words. For example, "a lakeside without boats" should be rewritten such that the word "boat" does not appear at all.
|
| 64 |
+
9. Except for text content explicitly requested by the user, **adding any extra text content is prohibited**.
|
| 65 |
+
Here are examples of rewrites for different types of prompts:
|
| 66 |
+
# Examples (Few-Shot Learning)
|
| 67 |
+
1. User Input: An animal with nine lives.
|
| 68 |
+
Rewrite Output: A cat bathed in soft sunlight, its fur soft and glossy. The background is a comfortable home environment with light from the window filtering through curtains, creating a warm light and shadow effect. The shot uses a medium distance perspective to highlight the cat's leisurely and stretched posture. Light cleverly hits the cat's face, emphasizing its spirited eyes and delicate whiskers, adding depth and affinity to the image.
|
| 69 |
+
2. User Input: Create an anime-style tourism flyer with a grassland theme.
|
| 70 |
+
Rewrite Output: In the lower right of the center, a short-haired girl sits sideways on a gray, irregularly shaped rock. She wears a white short-sleeved dress and brown flat shoes, holding a bunch of small white flowers in her left hand, smiling with her legs hanging naturally. The girl has dark brown shoulder-length hair with bangs covering her forehead, brown eyes, and a slightly open mouth. The rock surface has textures of varying depths. To the girl's left and front is lush grass, with long, yellow-green blades, some glowing golden in the sunlight. The grass extends into the distance, forming rolling green hills that fade in color as they recede. The sky occupies the upper half of the picture, pale blue dotted with a few fluffy white clouds. In the upper left corner, there is a line of text in italic, dark green font reading "Explore Nature's Peace". Colors are dominated by green, blue, and yellow, fluid lines, and distinct light and shadow contrast, creating a quiet and comfortable atmosphere.
|
| 71 |
+
3. User Input: A Christmas sale poster with a red background, promoting a Buy 1 Get 1 Free milk tea offer.
|
| 72 |
+
Rewrite Output: The poster features an overall red tone, embellished with white snowflake patterns on the top and left side. The upper right features a bunch of holly leaves with red berries and a pine cone. In the upper center, golden 3D text reads "Christmas Heartwarming Feedback" centered, along with red bold text "Buy 1 Get 1". Below, two transparent cups filled with bubble tea are placed side by side; the tea is light brown with dark brown pearls scattered at the bottom and middle. Below the cups, white snow piles up, decorated with pine branches, red berries, and pine cones. A blurry Christmas tree is faintly visible in the lower right corner. The image has high clarity, accurate text content, a unified design style, a prominent Christmas theme, and a reasonable layout, providing strong visual appeal.
|
| 73 |
+
4. User Input: A woman indoors shot in natural light, smiling with arms crossed, showing a relaxed and confident posture.
|
| 74 |
+
Rewrite Output: The image features a young Asian woman with long dark brown hair naturally falling over her shoulders, with some strands illuminated by light, showing a soft sheen. Her features are delicate, with long eyebrows, bright and spirited dark brown eyes looking directly at the camera, revealing peace and confidence. She has a high nose bridge, full lips with nude lipstick, and corners of the mouth slightly raised in a faint smile. Her skin is fair, with cheeks and collarbones illuminated by warm light, showing a healthy ruddiness. She wears a black spaghetti strap tank top revealing graceful collarbone lines, and a thin gold necklace with small beads and metal bars glinting in the light. Her outer layer is a beige knitted cardigan, soft in texture with visible knitting patterns on the sleeves. Her arms are crossed over her chest, hands covered by the cardigan sleeves, in a relaxed posture. The background is a pure dark brown without extra decoration, making the figure the absolute focus. The figure is located in the center of the frame. Light enters from the upper right, creating bright spots on her left cheek, neck, and collarbone, while the right side is slightly shadowed, creating a three-dimensional and soft tone. Image details are clear, showcasing skin texture, hair, and clothing materials well. Colors are dominated by warm tones, with the combination of beige and dark brown creating a warm and comfortable atmosphere. The overall style is natural, elegant, and artistic.
|
| 75 |
+
5. User Input: Create a series of images showing the growth process of an apple from seed to fruit. The series should include four stages: 1. Sowing, 2. Seedling growth, 3. Plant maturity, 4. Fruit harvesting.
|
| 76 |
+
Rewrite Output: A 4-panel exquisite illustration depicting the growth process of an apple, capturing each stage precisely and clearly. 1. "Sowing": A close-up shot of a hand gently placing a small apple seed into fertile dark soil, with visible soil texture and the seed's smooth surface. The background is a soft-focus garden dotted with green leaves and sunlight filtering through. 2. "Seedling Growth": A young apple sapling breaks through the soil, stretching tender green leaves toward the sky. The scene is set in a vibrant garden illuminated by warm golden light, highlighting the seedling's delicate structure. 3. "Plant Maturity": A mature apple tree, lush with branches and leaves, covered in tender green foliage and developing small apples. The background is a vibrant orchard under a clear blue sky, with dappled sunlight creating a peaceful atmosphere. 4. "Fruit Harvesting": A hand reaches into the tree to pick a ripe red apple, its smooth skin glistening in the sun. The scene shows the abundance of the orchard, with baskets of apples in the background, giving a sense of fulfillment. Each illustration uses a realistic style, focusing on details and harmonious colors to showcase the natural beauty and development of the apple's life cycle.
|
| 77 |
+
6. User Input: If 1 represents red, 2 represents green, 3 represents purple, and 4 represents yellow, please generate a four-color rainbow based on this rule. The color order from top to bottom is 3142.
|
| 78 |
+
Rewrite Output: The image consists of four horizontally arranged colored stripes, ordered from top to bottom as purple, red, yellow, and green. A white number is centered on each stripe. The top purple stripe features the number "3", the red stripe below it has the number "1", the yellow stripe further down has the number "4", and the bottom green stripe has the number "2". All numbers use a sans-serif font in pure white, forming a sharp contrast with the background colors to ensure good readability. The stripes have high color saturation and a slight texture. The overall layout is simple and clear, with distinct visual effects and no extra decorative elements, emphasizing the numerical information. The image is high definition, with accurate colors and a consistent style, offering strong visual appeal.
|
| 79 |
+
7. User Input: A stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", natural light, background is a Chinese garden.
|
| 80 |
+
Rewrite Output: An ancient stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", the surface covered with traces of time, the writing clear and deep. Natural light falls from above, softly illuminating every detail of the stone tablet and enhancing its sense of history. The background is an elegant Chinese garden featuring lush bamboo forests, winding paths, and quiet pools, creating a serene and distant atmosphere. The overall picture uses a realistic style with rich details and natural light and shadow effects, highlighting the cultural heritage of the stone tablet and the classical beauty of the garden.
|
| 81 |
+
# Output Format
|
| 82 |
+
Please directly output the rewritten and optimized Prompt content. Do not include any explanatory language or JSON formatting, and do not add opening or closing quotes yourself."""
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
SYSTEM_PROMPT_ZH = """
|
| 86 |
+
ไฝ ๆฏไธๅๆ็ๅพๆจกๅ็prompt engineeringไธๅฎถใ็ฑไบๆ็ๅพๆจกๅๅฏน็จๆทprompt็็่งฃ่ฝๅๆ้๏ผไฝ ้่ฆ่ฏๅซ็จๆท่พๅ
ฅ็ๆ ธๅฟไธป้ขๅๆๅพ๏ผๅนถ้่ฟไผๅๆนๅๆๅๆจกๅ็็่งฃๅ็กฎๆงๅ็ๆ่ดจ้ใๆนๅๅฟ
้กปไธฅๆ ผไฟ็็จๆทๅๅงprompt็ๆๆไฟกๆฏ๏ผไธๅพๅ ๅๆๆฒ่งฃไปปไฝ็ป่ใ
|
| 87 |
+
ๅ
ทไฝ่ฆๆฑๅฆไธ๏ผ
|
| 88 |
+
1. ๆนๅไธ่ฝๅฝฑๅ็จๆทๅๅงprompt้่กจ่พพ็ไปปไฝไฟกๆฏ๏ผๆนๅๅ็promptๅบ่ฏฅไฝฟ็จ่ฟ่ดฏ็่ช็ถ่ฏญ่จ่กจ่พพ,ไธ่ฆๅบ็ฐไฝไฟกๆฏ้็ๅไฝๆ่ฟฐ๏ผๅฐฝๅฏ่ฝไฟๆๆนๅๅprompt้ฟๅบฆ็ฒพ็ฎใ
|
| 89 |
+
2. ่ฏท็กฎไฟ่พๅ
ฅๅ่พๅบ็่ฏญ่จ็ฑปๅไธ่ด๏ผไธญๆ่พๅ
ฅไธญๆ่พๅบ๏ผ่ฑๆ่พๅ
ฅ่ฑๆ่พๅบ๏ผๆนๅๅ็tokenๆฐ้ไธ่ฆ่ถ
่ฟ512ไธช;
|
| 90 |
+
3. ๆนๅๅ็ๆ่ฟฐๅบๅฝ่ฟไธๆญฅๅฎๅๅๅงpromptไธญๅบ็ฐ็ไธปไฝ็นๅพใ็พๅญฆๆๅทง๏ผๅฆๆๅ
ใ็บน็็ญ๏ผ
|
| 91 |
+
4. ๅฆๆๅๅงpromptๆฒกๆๆๅฎๅพ็้ฃๆ ผๆถ๏ผ็กฎไฟๆนๅๅ็promptไฝฟ็จ็ๅฎๆๅฝฑ้ฃๆ ผ๏ผๅฆๆ็จๆทๆๅฎไบๅพ็้ฃๆ ผ๏ผๅไฟ็็จๆท้ฃๆ ผ๏ผ
|
| 92 |
+
5. ๅฝๅๅงprompt้่ฆๆจ็ๆ่ฝๆ็กฎ็จๆทๆๅพๆถ๏ผๆ นๆฎไธ็็ฅ่ฏ่ฟ่ก้ๅฝ้ป่พๆจ็๏ผๅฐๆจก็ณๆฝ่ฑกๆ่ฟฐ่ฝฌๅไธบๅ
ทไฝๆๅไบ็ฉ๏ผไพ๏ผๅฐ"ๆ้ซ็ๅจ็ฉ"่ฝฌๅไธบ"ไธๅคด้ฟ้ข้นฟ"๏ผใ
|
| 93 |
+
6. ๅฝๅๅงprompt้่ฆ็ๆๆๅญๆถ๏ผ่ฏทไฝฟ็จๅๅผๅทๅๅฎๆๅญ้จๅ๏ผไพ๏ผ`"้ๆถ5ๆ"`๏ผใ
|
| 94 |
+
7. ๅฝๅๅงprompt้่ฆ็ๆ็ฝ้กตใlogoใuiใๆตทๆฅ็ญๆๅญๅบๆฏๆถ๏ผไธๆฒกๆๆๅฎๅ
ทไฝ็ๆๅญๅ
ๅฎนๆถ๏ผ้่ฆๆจๆญๅบๅ้็ๆๅญๅ
ๅฎน๏ผๅนถไฝฟ็จๅๅผๅทๅๅฎ๏ผๅฆ็จๆท่พๅ
ฅ๏ผไธไธชๆ
ๆธธๅฎฃไผ ๅ๏ผไปฅ่ๅไธบไธป้ขใๅบ่ฏฅๆนๅๆ๏ผไธไธชๆ
ๆธธๅฎฃไผ ๅ๏ผๅพ็ๆ ้ขไธบโ่ๅโใ
|
| 95 |
+
8. ๅฝๅๅงpromptไธญๅญๅจๅฆๅฎ่ฏๆถ๏ผ้่ฆ็กฎไฟๆนๅๅ็promptไธๅญๅจๅฆๅฎ่ฏ๏ผๅฆๆฒกๆ่น็ๆน่พน๏ผๆนๅๅ็promptไธ่ฝๅบ็ฐ่น่ฟไธช่ฏๆฑใ
|
| 96 |
+
9. ้ค้็จๆทๆๅฎ็ๆๅ็logo๏ผๅฆๅไธ่ฆๅขๅ ้ขๅค็ๅ็logo.
|
| 97 |
+
10. ้คไบ็จๆทๆ็กฎ่ฆๆฑไนฆๅ็ๆๅญๅ
ๅฎนๅค๏ผ**็ฆๆญขๅขๅ ไปปไฝ้ขๅค็ๆๅญๅ
ๅฎน**ใ
|
| 98 |
+
ไปฅไธๆฏ้ๅฏนไธๅ็ฑปๅpromptๆนๅ็็คบไพ๏ผ
|
| 99 |
+
|
| 100 |
+
# Examples (Few-Shot Learning)
|
| 101 |
+
1. ็จๆท่พๅ
ฅ: ไนๆกๅฝ็ๅจ็ฉใ
|
| 102 |
+
ๆนๅ่พๅบ: ไธๅช็ซ๏ผ่ขซๆๅ็้ณๅ
็ฌผ็ฝฉ็๏ผๆฏๅๆ่ฝฏ่ๅฏๆๅ
ๆณฝใ่ๆฏๆฏไธไธช่้็ๅฎถๅฑ
็ฏๅข๏ผ็ชๅค็ๅ
็บฟ้่ฟ็ชๅธ๏ผๅฝขๆๆธฉ้ฆจ็ๅ
ๅฝฑๆๆใ้ๅคด้็จไธญ่ท็ฆป่ง่ง๏ผ็ชๅบ็ซๆ ้ฒ่ๅฑ็ๅงฟๆใๅ
็บฟๅทงๅฆๅฐๆๅจ็ซ็่ธ้จ๏ผๅผบ่ฐๅฎ็ตๅจ็็ผ็ๅ็ฒพ่ด็่ก้กป๏ผๅขๅ ็ป้ข็ๅฑๆฌกๆไธไบฒๅๅใ
|
| 103 |
+
2. ็จๆท่พๅ
ฅ: ๅถไฝไธไธชๅจ็ป้ฃๆ ผ็ๆ
ๆธธๅฎฃไผ ๅ๏ผไปฅ่ๅไธบไธป้ขใ
|
| 104 |
+
ๆนๅ่พๅบ: ็ป้ขไธญๅคฎๅๅณไธ่ง๏ผไธไธช็ญๅๅฅณๅญฉไพง่บซๅๅจ็ฐ่ฒ็ไธ่งๅๅฝข็ถๅฒฉ็ณไธ๏ผๅฅน็ฉฟ็็ฝ่ฒ็ญ่ข่ฟ่กฃ่ฃๅๆฃ่ฒๅนณๅบ้๏ผๅทฆๆๆฟ็ไธๆ็ฝ่ฒๅฐ่ฑ๏ผ้ขๅธฆๅพฎ็ฌ๏ผๅ่
ฟ่ช็ถๅไธใๅฅณๅญฉ็ๅคดๅไธบๆทฑๆฃ่ฒ๏ผ้ฝ่ฉ็ญๅ๏ผๅๆตท่ฆ็้ขๅคด๏ผ็ผ็ๅๆฃ่ฒ๏ผๅดๅทดๅพฎๅผ ใๅฒฉ็ณ่กจ้ขๆๆทฑๆต
ไธไธ็็บน็ใๅฅณๅญฉ็ๅทฆไพงๅๅๆนๆฏ่็็่ๅฐ๏ผ่ๅถ็ป้ฟ๏ผๅ้ป็ปฟ่ฒ๏ผ้จๅ่ๅถๅจ้ณๅ
ไธๆณ็้่ฒ็ๅ
่๏ผไปฟไฝ่ขซ้ณๅ
็
งไบฎใ่ๅฐๅ่ฟๅคๅปถไผธ๏ผๅฝขๆ่ฟ็ปต่ตทไผ็็ปฟ่ฒๅฑฑไธ๏ผๅฑฑไธ็้ข่ฒ็ฑ่ฟๅ่ฟ้ๆธๅๆต
ใๅคฉ็ฉบๅ ๆฎไบ็ป้ข็ไธๅ้จๅ๏ผๅๆทก่่ฒ๏ผ็น็ผ็ๅ ๆต็ฝ่ฒ่ฌๆพ็ไบๅฝฉใ็ป้ข็ๅทฆไธ่งๆไธ่กๆๅญ๏ผๆๅญๅ
ๅฎนๆฏๆไฝใๆทฑ็ปฟ่ฒ็โExplore Nature's Peaceโใ่ฒๅฝฉไปฅ็ปฟ่ฒใ่่ฒๅ้ป่ฒไธบไธป๏ผ็บฟๆกๆต็
๏ผๅ
ๅฝฑๆๆๅฏนๆฏๆๆพ๏ผ่ฅ้ ๅบไธ็งๅฎ้ใ่้็ๆฐๅดใ
|
| 105 |
+
3. ็จๆท่พๅ
ฅ: ไธๅผ ไปฅ็บข่ฒไธบ่ๆฏ็ๅฃ่ฏ่ไฟ้ๆตทๆฅ๏ผไธป่ฆๅฎฃไผ ๅฅถ่ถไนฐไธ้ไธ็ไผๆ ๆดปๅจใ
|
| 106 |
+
ๆนๅ่พๅบ: ๆตทๆฅๆดไฝๅ็ฐ็บข่ฒ่ฐ๏ผไธๆนๅๅทฆไพง็น็ผ็็ฝ่ฒ้ช่ฑๅพๆก๏ผๅณไธๆนๆไธๆๅฌ้ๅถๅ็บข่ฒๆตๆ๏ผไปฅๅไธไธชๆพๆใๆตทๆฅไธญๅคฎๅไธไฝ็ฝฎ๏ผ้่ฒ็ซไฝๅญๆ ทโๅฃ่ฏ่ ๆๅฟๅ้ฆโๅฑ
ไธญๆๅ๏ผๅ็บข่ฒ็ฒไฝๅญโไนฐ1้1โใๆตทๆฅไธๆน๏ผไธคไธช่ฃ
ๆปก็็ ๅฅถ่ถ็้ๆๆฏๅญๅนถๆๆๆพ๏ผๆฏไธญๅฅถ่ถๅๆต
ๆฃ่ฒ๏ผๅบ้จๅไธญ้ดๆฃๅธ็ๆทฑๆฃ่ฒ็็ ใๆฏๅญไธๆน๏ผๅ ็งฏ็็ฝ่ฒ้ช่ฑ๏ผ้ช่ฑไธ่ฃ
้ฅฐ็ๆพๆใ็บข่ฒๆตๆๅๆพๆใๅณไธ่ง้็บฆๅฏ่งไธๆฃตๆจก็ณ็ๅฃ่ฏๆ ใๅพ็ๆธ
ๆฐๅบฆ้ซ๏ผๆๅญๅ
ๅฎนๅ็กฎ๏ผๆดไฝ่ฎพ่ฎก้ฃๆ ผ็ปไธ๏ผๅฃ่ฏไธป้ข็ชๅบ๏ผๆ็ๅธๅฑๅ็๏ผๅ
ทๆ่พๅผบ็่ง่งๅธๅผๅใ
|
| 107 |
+
4. ็จๆท่พๅ
ฅ: ไธไฝๅฅณๆงๅจๅฎคๅ
ไปฅ่ช็ถๅ
็บฟๆๆ๏ผๅฅน้ขๅธฆๅพฎ็ฌ๏ผๅ่ไบคๅ๏ผๅฑ็ฐๅบ่ฝปๆพ่ชไฟก็ๅงฟๆใ
|
| 108 |
+
ๆนๅ่พๅบ: ็ป้ขไธญๆฏไธไฝๅนด่ฝป็ไบๆดฒๅฅณๆง๏ผๅฅนๆฅๆๆทฑๆฃ่ฒ็้ฟๅ๏ผๅไธ่ช็ถๅฐๅ่ฝๅจๅ่ฉ๏ผ้จๅๅไธ่ขซๅ
็บฟ็
งไบฎ๏ผๅ็ฐๅบๆๅ็ๅ
ๆณฝใๅฅน็ไบๅฎ็ฒพ่ด๏ผ็ๆฏไฟฎ้ฟ๏ผ็ผ็ๆไบฎๆ็ฅ๏ผ็ณๅญๅๆทฑๆฃ่ฒ๏ผ็ผ็ฅ็ด่ง้ๅคด๏ผๆต้ฒๅบๅนณๅไธ่ชไฟกใ้ผปๆขๆบๆ๏ผๅดๅไธฐๆปก๏ผๆถๆ่ฃธ่ฒ็ณปๅ่๏ผๅด่งๅพฎๅพฎไธๆฌ๏ผๅฑ็ฐๅบๆต
ๆต
็ๅพฎ็ฌใๅฅน็่ค่ฒ็ฝ็๏ผ่ธ้ขๅ้้ชจๅค่ขซๆ่ฒ่ฐ็ๅ
็บฟ็
งไบฎ๏ผๅ็ฐๅบๅฅๅบท็็บขๆถฆๆใๅฅน็ฉฟ็ไธไปถ้ป่ฒ็็ปๅๅธฆ่ๅฟ๏ผ่ฉๅธฆ็บค็ป๏ผ้ฒๅบไผ็พ็้้ชจ็บฟๆกใ่้ขไธไฝฉๆด็ไธๆก้่ฒ็็ป้กน้พ๏ผ้กน้พ็ฑๅฐ็ ๅญๅๅ ไธช็ป้ฟ็้ๅฑๆก็ปๆ๏ผๅจๅ
็บฟไธ้ช็็ๅ
ๆณฝใๅฅน็ๅคๆญๆฏไธไปถ็ฑณ้ป่ฒ็้็ปๅผ่กซ๏ผๆ่ดจๆ่ฝฏ๏ผ่ขๅญ้จๅๆๆๆพ็้็ป็บน็ใๅฅนๅ่ไบคๅๅจ่ธๅ๏ผๅๆ่ขซๅผ่กซ็่ขๅญ่ฆ็๏ผๅงฟๆๆพๆพใ่ๆฏๆฏ็บฏ็ฒน็ๆทฑๆฃ่ฒ๏ผๆฒกๆๅคไฝ็่ฃ
้ฅฐ๏ผไฝฟๅพไบบ็ฉๆไธบ็ป้ข็็ปๅฏน็ฆ็นใไบบ็ฉไฝไบ็ป้ขไธญๅคฎใๅ
็บฟไป็ป้ข็ๅณไธๆนๅฐๅ
ฅ๏ผๅจไบบ็ฉ็ๅทฆไพง่ธ้ขใ่้ขๅ้้ชจๅคๅฝขๆๆไบฎ็ๅ
ๆ๏ผๅณไพงๅ็ฅๆพ้ดๅฝฑ๏ผ่ฅ้ ๅบ็ซไฝๆๅๆๅ็ๅฝฑ่ฐใๅพๅ็ป่ๆธ
ๆฐ๏ผไบบ็ฉ็็ฎ่ค็บน็ใๅไธไปฅๅ่กฃ็ฉๆ่ดจ้ฝๅพๅฐไบๅพๅฅฝ็ๅฑ็ฐใ่ฒๅฝฉไปฅๆ่ฒ่ฐไธบไธป๏ผ็ฑณ้ป่ฒๅๆทฑๆฃ่ฒ็ๆญ้
่ฅ้ ๅบๆธฉ้ฆจ่้็ๆฐๅดใๆดไฝๅ็ฐๅบไธ็ง่ช็ถใไผ้
ไธๅฏๆไบฒๅๅ็่บๆฏ้ฃๆ ผใ
|
| 109 |
+
5. ็จๆท่พๅ
ฅ๏ผๅไฝไธ็ณปๅๅพ็๏ผๅฑ็ฐ่นๆไป็งๅญๅฐ็ปๆ็็้ฟ่ฟ็จใ่ฏฅ็ณปๅๅพ็ๅบๅ
ๅซไปฅไธๅไธช้ถๆฎต๏ผ1. ๆญ็ง๏ผ2. ๅนผ่็้ฟ๏ผ3. ๆค็ฉๆ็๏ผ4. ๆๅฎ้ๆใ
|
| 110 |
+
ๆนๅ่พๅบ๏ผไธไธช4ๅฎซๆ ผ็็ฒพ็พๆๅพ๏ผๆ็ป่นๆ็็้ฟ่ฟ็จ๏ผ็ฒพ็กฎๆธ
ๆฐๅฐๆๆๆฏไธช้ถๆฎตใ1.โๆญ็งโ๏ผ็นๅ้ๅคด๏ผไธๅชๆ่ฝป่ฝปๅฐๅฐไธ้ขๅฐๅฐ็่นๆ็งๅญๆพๅ
ฅ่ฅๆฒ็ๆทฑ่ฒๅๅฃคไธญ๏ผๅๅฃค็็บน็ๅ็งๅญๅ
ๆป็่กจ้ขๆธ
ๆฐๅฏ่งใ่ๆฏๆฏ่ฑๅญ็ๆ็ฆ็ป้ข๏ผ็น็ผ็็ปฟ่ฒ็ๆ ๅถๅ้่ฟๆ ๅถๆดไธ็้ณๅ
ใ2.โๅนผ่็้ฟโ๏ผไธๆฃตๅนผๅฐ็่นๆๆ ่็ ดๅ่ๅบ๏ผๅซฉ็ปฟ็ๅถๅญๅๅคฉ็ฉบ่ๅฑใๅบๆฏ่ฎพๅฎๅจไธไธช็ๆบๅๅ็่ฑๅญไธญ๏ผๆธฉๆ็้ๅ
็
งไบฎไบๅฎใๅนผ่็็บค็ป็ปๆใ3.โๆค็ฉ็ๆ็โ๏ผไธๆฃตๆ็็่นๆๆ ๏ผๆ็นๅถ่๏ผๆๆปกไบๅซฉ็ปฟ็ๅถๅญๅๆญฃๅจ่ๅ็ๅฐ่นๆใ่ๆฏๆฏไธ็็ๆบๅๅ็ๆๅญ๏ผๆน่็ๅคฉ็ฉบไธ๏ผๆ้ฉณ็้ณๅ
่ฅ้ ๅบๅฎ้็ฅฅๅ็ๆฐๅดใ4.โ้ๆๆๅฎโ๏ผไธๅชๆไผธๅๆ ไธ๏ผๆไธไธไธชๆ็็็บข่นๆ๏ผ่นๆๅ
ๆป็ๆ็ฎๅจ้ณๅ
ไธ้ช้ชๅๅ
ใ็ป้ขๅฑ็ฐไบๆๅญ็ไธฐๆถๆฏ่ฑก๏ผ่ๆฏไธญๆๆพ็ไธ็ฏฎ็ฏฎ็่นๆ๏ผ็ปไบบไธ็งๅๆปกๆปก่ถณ็ๆ่งใๆฏๅน
ๆๅพ้ฝ้็จๅๅฎ้ฃๆ ผ๏ผๆณจ้็ป่๏ผ่ฒๅฝฉๅ่ฐ๏ผๅฑ็ฐไบ่นๆ็ๅฝๅจๆ็่ช็ถไน็พๅๅๅฑ่ฟ็จใ
|
| 111 |
+
6. ็จๆท่พๅ
ฅ๏ผ ๅฆๆ1ไปฃ่กจ็บข่ฒ๏ผ2ไปฃ่กจ็ปฟ่ฒ๏ผ3ไปฃ่กจ็ดซ่ฒ๏ผ4ไปฃ่กจ้ป่ฒ๏ผ่ฏทๆ็
งๆญค่งๅ็ๆๅ่ฒๅฝฉ่นใๅฎ็้ข่ฒ้กบๅบไปไธๅฐไธๆฏ3142
|
| 112 |
+
ๆนๅ่พๅบ๏ผๅพ็็ฑๅไธชๆฐดๅนณๆๅ็ๅฝฉ่ฒๆก็บน็ปๆ๏ผไปไธๅฐไธไพๆฌกไธบ็ดซ่ฒใ็บข่ฒใ้ป่ฒๅ็ปฟ่ฒใๆฏไธชๆก็บนไธ้ฝๅฑ
ไธญๆพ็ฝฎไธไธช็ฝ่ฒๆฐๅญใๆไธๆน็็ดซ่ฒๆก็บนไธๆฏๆฐๅญโ3โ๏ผๅ
ถไธๆน็บข่ฒๆก็บนไธๆฏๆฐๅญโ1โ๏ผๅไธๆน้ป่ฒๆก็บนไธๆฏๆฐๅญโ4โ๏ผๆไธๆน็็ปฟ่ฒๆก็บนไธๆฏๆฐๅญโ2โใๆๆๆฐๅญๅ้็จๆ ่กฌ็บฟๅญไฝ๏ผ้ข่ฒไธบ็บฏ็ฝ่ฒ๏ผไธ่ๆฏ่ฒๅฝขๆ้ฒๆๅฏนๆฏ๏ผ็กฎไฟไบ่ฏๅฅฝ็ๅฏ่ฏปๆงใๆก็บน็้ข่ฒ้ฅฑๅๅบฆ้ซ๏ผไธๅธฆๆ่ฝปๅพฎ็็บน็ๆ๏ผๆดไฝๆ็็ฎๆดๆไบ๏ผ่ง่งๆๆๆธ
ๆฐ๏ผๆฒกๆๅคไฝ็่ฃ
้ฅฐๅ
็ด ๏ผๅผบ่ฐไบๆฐๅญไฟกๆฏๆฌ่บซใๅพ็ๆดไฝๆธ
ๆฐๅบฆ้ซ๏ผ่ฒๅฝฉๅ็กฎ๏ผ้ฃๆ ผไธ่ด๏ผๅ
ทๆ่พๅผบ็่ง่งๅธๅผๅใ
|
| 113 |
+
7. ็จๆท่พๅ
ฅ๏ผ็ณ็ขไธๅป็โๅ
ณๅ
ณ้้ธ ๏ผๅจๆฒณไนๆดฒโ๏ผ่ช็ถๅ
็
ง๏ผ่ๆฏๆฏไธญๅผๅญๆ
|
| 114 |
+
ๆนๅ่พๅบ๏ผไธๅๅค่็๏ฟฝ๏ฟฝ็ขไธๅป็โๅ
ณๅ
ณ้้ธ ๏ผๅจๆฒณไนๆดฒโ๏ผ็ณ็ข่กจ้ขๅธๆปกๅฒๆ็็่ฟน๏ผๅญ่ฟนๆธ
ๆฐ่ๆทฑๅปใ่ช็ถๅ
็บฟไปไธๆนๆดไธ๏ผๆๅๅฐ็
งไบฎ็ณ็ข็ๆฏไธไธช็ป่๏ผๅขๅผบไบๅ
ถๅๅฒๆใ่ๆฏๆฏไธๅบงๅ
ธ้
็ไธญๅผๅญๆ๏ผๅญๆไธญๆ็ฟ ็ปฟ็็ซนๆใ่ฟ่็ๅฐๅพๅ้่ฐง็ๆฐดๆฑ ๏ผ่ฅ้ ๅบไธ็งๅฎ้่ๆ ่ฟ็ๆฐๅดใๆดไฝ็ป้ข้็จๅๅฎ้ฃๆ ผ๏ผ็ป่ไธฐๅฏ๏ผๅ
ๅฝฑๆๆ่ช็ถ๏ผ็ชๅบไบ็ณ็ข็ๆๅๅบ่ดๅๅญๆ็ๅคๅ
ธ็พใ
|
| 115 |
+
# ่พๅบๆ ผๅผ
|
| 116 |
+
่ฏท็ดๆฅ่พๅบๆนๅไผๅๅ็ Prompt ๅ
ๅฎน๏ผไธ่ฆๅ
ๅซไปปไฝ่งฃ้ๆง่ฏญ่จๆ JSON ๆ ผๅผ๏ผไธ่ฆ่ช่กๆทปๅ ๅผๅคดๆ็ปๅฐพ็ๅผๅทใ
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def get_prompt_language(prompt):
|
| 120 |
+
pattern = re.compile(r'[\u4e00-\u9fff]')
|
| 121 |
+
if bool(pattern.search(prompt)):
|
| 122 |
+
return 'zh'
|
| 123 |
+
return 'en'
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class LongCatImagePipeline(
|
| 127 |
+
DiffusionPipeline,
|
| 128 |
+
FluxLoraLoaderMixin,
|
| 129 |
+
FromSingleFileMixin,
|
| 130 |
+
TextualInversionLoaderMixin,
|
| 131 |
+
):
|
| 132 |
+
r"""
|
| 133 |
+
The pipeline for text-to-image generation.
|
| 134 |
+
|
| 135 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
|
| 139 |
+
_optional_components = ["image_encoder", "feature_extractor", "text_processor"]
|
| 140 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 141 |
+
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 145 |
+
vae: AutoencoderKL,
|
| 146 |
+
text_encoder: AutoModel,
|
| 147 |
+
tokenizer: AutoTokenizer,
|
| 148 |
+
text_processor: AutoProcessor,
|
| 149 |
+
transformer: LongCatImageTransformer2DModel,
|
| 150 |
+
image_encoder: CLIPVisionModelWithProjection,
|
| 151 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 152 |
+
|
| 153 |
+
):
|
| 154 |
+
super().__init__()
|
| 155 |
+
|
| 156 |
+
self.register_modules(
|
| 157 |
+
vae=vae,
|
| 158 |
+
text_encoder=text_encoder,
|
| 159 |
+
tokenizer=tokenizer,
|
| 160 |
+
transformer=transformer,
|
| 161 |
+
scheduler=scheduler,
|
| 162 |
+
image_encoder=image_encoder,
|
| 163 |
+
feature_extractor=feature_extractor,
|
| 164 |
+
text_processor=text_processor,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 168 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 169 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 170 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 171 |
+
|
| 172 |
+
self.prompt_template_encode_prefix = '<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n'
|
| 173 |
+
self.prompt_template_encode_suffix = '<|im_end|>\n<|im_start|>assistant\n'
|
| 174 |
+
self.prompt_template_encode_start_idx = 36
|
| 175 |
+
self.prompt_template_encode_end_idx = 5
|
| 176 |
+
self.default_sample_size = 128
|
| 177 |
+
self.max_tokenizer_len = 512
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@torch.inference_mode()
|
| 181 |
+
def rewire_prompt(self, prompt, device):
|
| 182 |
+
language = get_prompt_language(prompt)
|
| 183 |
+
if language == 'zh':
|
| 184 |
+
question = SYSTEM_PROMPT_ZH + f"\n็จๆท่พๅ
ฅไธบ๏ผ{prompt}\nๆนๅๅ็promptไธบ๏ผ"
|
| 185 |
+
else:
|
| 186 |
+
question = SYSTEM_PROMPT_EN + f"\nUser Input: {prompt}\nRewritten prompt:"
|
| 187 |
+
|
| 188 |
+
messages = [
|
| 189 |
+
{
|
| 190 |
+
"role": "user",
|
| 191 |
+
"content": [
|
| 192 |
+
{"type": "text", "text": question},
|
| 193 |
+
],
|
| 194 |
+
}
|
| 195 |
+
]
|
| 196 |
+
# Preparation for inference
|
| 197 |
+
text = self.text_processor.apply_chat_template(
|
| 198 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 199 |
+
)
|
| 200 |
+
inputs = self.text_processor(
|
| 201 |
+
text=[text],padding=True,return_tensors="pt",)
|
| 202 |
+
inputs = inputs.to(device)
|
| 203 |
+
|
| 204 |
+
generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.max_tokenizer_len)
|
| 205 |
+
generated_ids_trimmed = [
|
| 206 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 207 |
+
]
|
| 208 |
+
output_text = self.text_processor.batch_decode(
|
| 209 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 210 |
+
)[0]
|
| 211 |
+
rewrite_prompt= output_text
|
| 212 |
+
return rewrite_prompt
|
| 213 |
+
|
| 214 |
+
@torch.inference_mode()
|
| 215 |
+
def encode_prompt(self,
|
| 216 |
+
prompts,
|
| 217 |
+
device,
|
| 218 |
+
dtype):
|
| 219 |
+
|
| 220 |
+
prompts = [prompt.strip('"') if prompt.startswith('"') and prompt.endswith('"') else prompt for prompt in prompts]
|
| 221 |
+
all_tokens = []
|
| 222 |
+
for clean_prompt_sub, matched in split_quotation(prompts[0]):
|
| 223 |
+
if matched:
|
| 224 |
+
for sub_word in clean_prompt_sub:
|
| 225 |
+
tokens = self.tokenizer(sub_word, add_special_tokens=False)['input_ids']
|
| 226 |
+
all_tokens.extend(tokens)
|
| 227 |
+
else:
|
| 228 |
+
tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)['input_ids']
|
| 229 |
+
all_tokens.extend(tokens)
|
| 230 |
+
|
| 231 |
+
all_tokens = all_tokens[:self.max_tokenizer_len]
|
| 232 |
+
text_tokens_and_mask = self.tokenizer.pad(
|
| 233 |
+
{'input_ids': [all_tokens]},
|
| 234 |
+
max_length=self.max_tokenizer_len,
|
| 235 |
+
padding='max_length',
|
| 236 |
+
return_attention_mask=True,
|
| 237 |
+
return_tensors='pt')
|
| 238 |
+
|
| 239 |
+
prefix_tokens = self.tokenizer(self.prompt_template_encode_prefix, add_special_tokens=False)['input_ids']
|
| 240 |
+
suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)['input_ids']
|
| 241 |
+
prefix_tokens_mask = torch.tensor( [1]*len(prefix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype )
|
| 242 |
+
suffix_tokens_mask = torch.tensor( [1]*len(suffix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype )
|
| 243 |
+
|
| 244 |
+
prefix_tokens = torch.tensor(prefix_tokens,dtype=text_tokens_and_mask.input_ids.dtype)
|
| 245 |
+
suffix_tokens = torch.tensor(suffix_tokens,dtype=text_tokens_and_mask.input_ids.dtype)
|
| 246 |
+
|
| 247 |
+
input_ids = torch.cat( (prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1 )
|
| 248 |
+
attention_mask = torch.cat( (prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1 )
|
| 249 |
+
|
| 250 |
+
input_ids = input_ids.unsqueeze(0).to(self.device)
|
| 251 |
+
attention_mask = attention_mask.unsqueeze(0).to(self.device)
|
| 252 |
+
|
| 253 |
+
text_output = self.text_encoder(
|
| 254 |
+
input_ids=input_ids,
|
| 255 |
+
attention_mask=attention_mask,
|
| 256 |
+
output_hidden_states=True
|
| 257 |
+
)
|
| 258 |
+
# [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size]
|
| 259 |
+
# clone to have a contiguous tensor
|
| 260 |
+
prompt_embeds = text_output.hidden_states[-1].detach()
|
| 261 |
+
prompt_embeds = prompt_embeds[:,self.prompt_template_encode_start_idx: -self.prompt_template_encode_end_idx ,:]
|
| 262 |
+
|
| 263 |
+
text_ids = prepare_pos_ids(modality_id=0,
|
| 264 |
+
type='text',
|
| 265 |
+
start=(0, 0),
|
| 266 |
+
num_token=prompt_embeds.shape[1]).to(device, dtype=dtype)
|
| 267 |
+
|
| 268 |
+
return prompt_embeds, text_ids
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 272 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 273 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 274 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 275 |
+
|
| 276 |
+
return latents
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 280 |
+
batch_size, num_patches, channels = latents.shape
|
| 281 |
+
|
| 282 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 283 |
+
# latent height and width to be divisible by 2.
|
| 284 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 285 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 286 |
+
|
| 287 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 288 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 289 |
+
|
| 290 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
| 291 |
+
|
| 292 |
+
return latents
|
| 293 |
+
|
| 294 |
+
def enable_vae_slicing(self):
|
| 295 |
+
r"""
|
| 296 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 297 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 298 |
+
"""
|
| 299 |
+
self.vae.enable_slicing()
|
| 300 |
+
|
| 301 |
+
def disable_vae_slicing(self):
|
| 302 |
+
r"""
|
| 303 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 304 |
+
computing decoding in one step.
|
| 305 |
+
"""
|
| 306 |
+
self.vae.disable_slicing()
|
| 307 |
+
|
| 308 |
+
def enable_vae_tiling(self):
|
| 309 |
+
r"""
|
| 310 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 311 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 312 |
+
processing larger images.
|
| 313 |
+
"""
|
| 314 |
+
self.vae.enable_tiling()
|
| 315 |
+
|
| 316 |
+
def disable_vae_tiling(self):
|
| 317 |
+
r"""
|
| 318 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 319 |
+
computing decoding in one step.
|
| 320 |
+
"""
|
| 321 |
+
self.vae.disable_tiling()
|
| 322 |
+
|
| 323 |
+
@property
|
| 324 |
+
def guidance_scale(self):
|
| 325 |
+
return self._guidance_scale
|
| 326 |
+
|
| 327 |
+
@property
|
| 328 |
+
def do_classifier_free_guidance(self):
|
| 329 |
+
return self._guidance_scale > 1
|
| 330 |
+
|
| 331 |
+
def prepare_latents(
|
| 332 |
+
self,
|
| 333 |
+
batch_size,
|
| 334 |
+
num_channels_latents,
|
| 335 |
+
height,
|
| 336 |
+
width,
|
| 337 |
+
dtype,
|
| 338 |
+
device,
|
| 339 |
+
generator,
|
| 340 |
+
latents=None,
|
| 341 |
+
):
|
| 342 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 343 |
+
# latent height and width to be divisible by 2.
|
| 344 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 345 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 346 |
+
|
| 347 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 348 |
+
latent_image_ids = prepare_pos_ids(modality_id=1,
|
| 349 |
+
type='image',
|
| 350 |
+
start=(self.max_tokenizer_len,
|
| 351 |
+
self.max_tokenizer_len),
|
| 352 |
+
height=height//2,
|
| 353 |
+
width=width//2).to(device, dtype=torch.float64)
|
| 354 |
+
|
| 355 |
+
if latents is not None:
|
| 356 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
| 357 |
+
|
| 358 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 359 |
+
raise ValueError(
|
| 360 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 361 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
latents = randn_tensor(shape, generator=generator,device=device)
|
| 365 |
+
latents = latents.to(dtype=dtype)
|
| 366 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 367 |
+
|
| 368 |
+
return latents, latent_image_ids
|
| 369 |
+
|
| 370 |
+
@property
|
| 371 |
+
def guidance_scale(self):
|
| 372 |
+
return self._guidance_scale
|
| 373 |
+
|
| 374 |
+
@property
|
| 375 |
+
def joint_attention_kwargs(self):
|
| 376 |
+
return self._joint_attention_kwargs
|
| 377 |
+
|
| 378 |
+
@property
|
| 379 |
+
def num_timesteps(self):
|
| 380 |
+
return self._num_timesteps
|
| 381 |
+
|
| 382 |
+
@property
|
| 383 |
+
def current_timestep(self):
|
| 384 |
+
return self._current_timestep
|
| 385 |
+
|
| 386 |
+
@property
|
| 387 |
+
def interrupt(self):
|
| 388 |
+
return self._interrupt
|
| 389 |
+
|
| 390 |
+
@torch.no_grad()
|
| 391 |
+
def __call__(
|
| 392 |
+
self,
|
| 393 |
+
prompt: Union[str, List[str]] = None,
|
| 394 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 395 |
+
height: Optional[int] = None,
|
| 396 |
+
width: Optional[int] = None,
|
| 397 |
+
num_inference_steps: int = 50,
|
| 398 |
+
sigmas: Optional[List[float]] = None,
|
| 399 |
+
guidance_scale: float = 4.5,
|
| 400 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 401 |
+
generator: Optional[Union[torch.Generator,
|
| 402 |
+
List[torch.Generator]]] = None,
|
| 403 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 404 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 405 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 406 |
+
output_type: Optional[str] = "pil",
|
| 407 |
+
return_dict: bool = True,
|
| 408 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 409 |
+
enable_cfg_renorm: Optional[bool] = True,
|
| 410 |
+
cfg_renorm_min: Optional[float] = 0.0,
|
| 411 |
+
enable_prompt_rewrite: Optional[bool] = True,
|
| 412 |
+
):
|
| 413 |
+
r"""
|
| 414 |
+
Function invoked when calling the pipeline for generation.
|
| 415 |
+
|
| 416 |
+
Args:
|
| 417 |
+
enable_cfg_renorm: Whether to enable cfg_renorm. Enabling cfg_renorm will improve image quality,
|
| 418 |
+
but it may lead to a decrease in the stability of some image outputs..
|
| 419 |
+
cfg_renorm_min: The minimum value of the cfg_renorm_scale range (0-1).
|
| 420 |
+
cfg_renorm_min = 1.0, renorm has no effect, while cfg_renorm_min=0.0, the renorm range is larger.
|
| 421 |
+
enable_prompt_rewrite: whether to enable prompt rewrite.
|
| 422 |
+
Examples:
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
[`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if `return_dict`
|
| 426 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 427 |
+
images.
|
| 428 |
+
"""
|
| 429 |
+
|
| 430 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 431 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 432 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 433 |
+
logger.warning(
|
| 434 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 435 |
+
)
|
| 436 |
+
pixel_step= self.vae_scale_factor * 2
|
| 437 |
+
height = int( height/pixel_step )*pixel_step
|
| 438 |
+
width = int( width/pixel_step )*pixel_step
|
| 439 |
+
|
| 440 |
+
self._guidance_scale = guidance_scale
|
| 441 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 442 |
+
self._current_timestep = None
|
| 443 |
+
self._interrupt = False
|
| 444 |
+
|
| 445 |
+
# 2. Define call parameters
|
| 446 |
+
if prompt is not None and isinstance(prompt, str):
|
| 447 |
+
batch_size = 1
|
| 448 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 449 |
+
batch_size = len(prompt)
|
| 450 |
+
else:
|
| 451 |
+
batch_size = prompt_embeds.shape[0]
|
| 452 |
+
|
| 453 |
+
device = self._execution_device
|
| 454 |
+
if enable_prompt_rewrite:
|
| 455 |
+
prompt = self.rewire_prompt(prompt, device )
|
| 456 |
+
|
| 457 |
+
negative_prompt = '' if negative_prompt is None else negative_prompt
|
| 458 |
+
negative_prompt = [negative_prompt]*num_images_per_prompt
|
| 459 |
+
prompt = [prompt]*num_images_per_prompt
|
| 460 |
+
|
| 461 |
+
prompt_embeds, text_ids = self.encode_prompt(prompt, device, dtype=torch.float64)
|
| 462 |
+
negative_prompt_embeds, negative_text_ids = self.encode_prompt(negative_prompt, device, dtype=torch.float64)
|
| 463 |
+
|
| 464 |
+
# 4. Prepare latent variables
|
| 465 |
+
num_channels_latents = 16
|
| 466 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 467 |
+
batch_size * num_images_per_prompt,
|
| 468 |
+
num_channels_latents,
|
| 469 |
+
height,
|
| 470 |
+
width,
|
| 471 |
+
prompt_embeds.dtype,
|
| 472 |
+
device,
|
| 473 |
+
generator,
|
| 474 |
+
latents,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# 5. Prepare timesteps
|
| 478 |
+
sigmas = np.linspace(1.0, 1.0 / num_inference_steps,num_inference_steps) if sigmas is None else sigmas
|
| 479 |
+
image_seq_len = latents.shape[1]
|
| 480 |
+
mu = calculate_shift(
|
| 481 |
+
image_seq_len,
|
| 482 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 483 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 484 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 485 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 486 |
+
)
|
| 487 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 488 |
+
self.scheduler,
|
| 489 |
+
num_inference_steps,
|
| 490 |
+
device,
|
| 491 |
+
sigmas=sigmas,
|
| 492 |
+
mu=mu,
|
| 493 |
+
)
|
| 494 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 495 |
+
self._num_timesteps = len(timesteps)
|
| 496 |
+
|
| 497 |
+
# handle guidance
|
| 498 |
+
guidance = None
|
| 499 |
+
|
| 500 |
+
if self.joint_attention_kwargs is None:
|
| 501 |
+
self._joint_attention_kwargs = {}
|
| 502 |
+
|
| 503 |
+
if self.do_classifier_free_guidance:
|
| 504 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0).to(device)
|
| 505 |
+
else:
|
| 506 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 507 |
+
|
| 508 |
+
# 6. Denoising loop
|
| 509 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 510 |
+
for i, t in enumerate(timesteps):
|
| 511 |
+
if self.interrupt:
|
| 512 |
+
continue
|
| 513 |
+
|
| 514 |
+
self._current_timestep = t
|
| 515 |
+
|
| 516 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 517 |
+
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
| 518 |
+
|
| 519 |
+
noise_pred = self.transformer(
|
| 520 |
+
hidden_states=latent_model_input,
|
| 521 |
+
timestep=timestep / 1000,
|
| 522 |
+
guidance=guidance,
|
| 523 |
+
encoder_hidden_states=prompt_embeds,
|
| 524 |
+
txt_ids=text_ids,
|
| 525 |
+
img_ids=latent_image_ids,
|
| 526 |
+
return_dict=False,
|
| 527 |
+
)[0]
|
| 528 |
+
|
| 529 |
+
if self.do_classifier_free_guidance:
|
| 530 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0)
|
| 531 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * \
|
| 532 |
+
(noise_pred_text - noise_pred_uncond)
|
| 533 |
+
|
| 534 |
+
if enable_cfg_renorm:
|
| 535 |
+
cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True)
|
| 536 |
+
noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
| 537 |
+
scale = (cond_norm / (noise_norm + 1e-8)).clamp(min= cfg_renorm_min , max=1.0)
|
| 538 |
+
noise_pred = noise_pred * scale
|
| 539 |
+
|
| 540 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 541 |
+
latents_dtype = latents.dtype
|
| 542 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 543 |
+
|
| 544 |
+
if latents.dtype != latents_dtype:
|
| 545 |
+
if torch.backends.mps.is_available():
|
| 546 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 547 |
+
latents = latents.to(latents_dtype)
|
| 548 |
+
|
| 549 |
+
# call the callback, if provided
|
| 550 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 551 |
+
progress_bar.update()
|
| 552 |
+
|
| 553 |
+
if XLA_AVAILABLE:
|
| 554 |
+
xm.mark_step()
|
| 555 |
+
|
| 556 |
+
self._current_timestep = None
|
| 557 |
+
|
| 558 |
+
if output_type == "latent":
|
| 559 |
+
image = latents
|
| 560 |
+
else:
|
| 561 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 562 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 563 |
+
|
| 564 |
+
if latents.dtype != self.vae.dtype:
|
| 565 |
+
latents = latents.to(dtype=self.vae.dtype)
|
| 566 |
+
|
| 567 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 568 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 569 |
+
|
| 570 |
+
# Offload all models
|
| 571 |
+
# self.maybe_free_model_hooks()
|
| 572 |
+
|
| 573 |
+
if not return_dict:
|
| 574 |
+
return (image,)
|
| 575 |
+
|
| 576 |
+
return LongCatImagePipelineOutput(images=image)
|
longcat_image/pipelines/pipeline_longcat_image_edit.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 16 |
+
import json
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import math
|
| 20 |
+
from transformers import (
|
| 21 |
+
CLIPImageProcessor,
|
| 22 |
+
CLIPVisionModelWithProjection,
|
| 23 |
+
)
|
| 24 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 25 |
+
from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
| 26 |
+
from diffusers.models import AutoencoderKL
|
| 27 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 28 |
+
from diffusers.utils import (
|
| 29 |
+
USE_PEFT_BACKEND,
|
| 30 |
+
is_torch_xla_available,
|
| 31 |
+
logging,
|
| 32 |
+
)
|
| 33 |
+
from transformers import AutoTokenizer, AutoModel, AutoProcessor
|
| 34 |
+
|
| 35 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 36 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 37 |
+
|
| 38 |
+
from longcat_image.utils.model_utils import split_quotation, prepare_pos_ids, calculate_shift, retrieve_timesteps, optimized_scale
|
| 39 |
+
from longcat_image.models.longcat_image_dit import LongCatImageTransformer2DModel
|
| 40 |
+
from longcat_image.pipelines.pipeline_output import LongCatImagePipelineOutput
|
| 41 |
+
|
| 42 |
+
if is_torch_xla_available():
|
| 43 |
+
import torch_xla.core.xla_model as xm
|
| 44 |
+
XLA_AVAILABLE = True
|
| 45 |
+
else:
|
| 46 |
+
XLA_AVAILABLE = False
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def calculate_dimensions(target_area, ratio):
|
| 52 |
+
width = math.sqrt(target_area * ratio)
|
| 53 |
+
height = width / ratio
|
| 54 |
+
|
| 55 |
+
width = width if width % 16 == 0 else (width // 16 + 1) * 16
|
| 56 |
+
height = height if height % 16 == 0 else (height // 16 + 1) * 16
|
| 57 |
+
|
| 58 |
+
width = int(width)
|
| 59 |
+
height = int(height)
|
| 60 |
+
|
| 61 |
+
return width, height
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LongCatImageEditPipeline(
|
| 65 |
+
DiffusionPipeline,
|
| 66 |
+
FluxLoraLoaderMixin,
|
| 67 |
+
FromSingleFileMixin,
|
| 68 |
+
TextualInversionLoaderMixin,
|
| 69 |
+
):
|
| 70 |
+
r"""
|
| 71 |
+
The pipeline for text-to-image generation.
|
| 72 |
+
|
| 73 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
|
| 77 |
+
_optional_components = ["image_encoder", "feature_extractor", "text_processor"]
|
| 78 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 83 |
+
vae: AutoencoderKL,
|
| 84 |
+
text_encoder: AutoModel,
|
| 85 |
+
tokenizer: AutoTokenizer,
|
| 86 |
+
text_processor: AutoProcessor,
|
| 87 |
+
transformer=LongCatImageTransformer2DModel,
|
| 88 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 89 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 90 |
+
):
|
| 91 |
+
super().__init__()
|
| 92 |
+
|
| 93 |
+
self.register_modules(
|
| 94 |
+
vae=vae,
|
| 95 |
+
text_encoder=text_encoder,
|
| 96 |
+
tokenizer=tokenizer,
|
| 97 |
+
transformer=transformer,
|
| 98 |
+
scheduler=scheduler,
|
| 99 |
+
image_encoder=image_encoder,
|
| 100 |
+
feature_extractor=feature_extractor,
|
| 101 |
+
text_processor=text_processor,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 105 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 106 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 107 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 108 |
+
self.image_processor_vl = text_processor.image_processor
|
| 109 |
+
|
| 110 |
+
self.image_token = "<|image_pad|>"
|
| 111 |
+
self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
| 112 |
+
self.prompt_template_encode_suffix = '<|im_end|>\n<|im_start|>assistant\n'
|
| 113 |
+
self.prompt_template_encode_start_idx = 67
|
| 114 |
+
self.prompt_template_encode_end_idx = 5
|
| 115 |
+
self.default_sample_size = 128
|
| 116 |
+
self.max_tokenizer_len = 512
|
| 117 |
+
self.latent_channels = 16
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@torch.inference_mode()
|
| 121 |
+
def encode_prompt(self,
|
| 122 |
+
image,
|
| 123 |
+
prompts,
|
| 124 |
+
device,
|
| 125 |
+
dtype):
|
| 126 |
+
raw_vl_input = self.image_processor_vl(images=image,return_tensors="pt")
|
| 127 |
+
pixel_values = raw_vl_input['pixel_values']
|
| 128 |
+
image_grid_thw = raw_vl_input['image_grid_thw']
|
| 129 |
+
|
| 130 |
+
prompts = [prompt.strip('"') if prompt.startswith('"') and prompt.endswith('"') else prompt for prompt in prompts]
|
| 131 |
+
all_tokens = []
|
| 132 |
+
|
| 133 |
+
for clean_prompt_sub, matched in split_quotation(prompts[0]):
|
| 134 |
+
if matched:
|
| 135 |
+
for sub_word in clean_prompt_sub:
|
| 136 |
+
tokens = self.tokenizer(sub_word, add_special_tokens=False)['input_ids']
|
| 137 |
+
all_tokens.extend(tokens)
|
| 138 |
+
else:
|
| 139 |
+
tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)['input_ids']
|
| 140 |
+
all_tokens.extend(tokens)
|
| 141 |
+
|
| 142 |
+
all_tokens = all_tokens[:self.max_tokenizer_len]
|
| 143 |
+
text_tokens_and_mask = self.tokenizer.pad(
|
| 144 |
+
{'input_ids': [all_tokens]},
|
| 145 |
+
max_length=self.max_tokenizer_len,
|
| 146 |
+
padding='max_length',
|
| 147 |
+
return_attention_mask=True,
|
| 148 |
+
return_tensors='pt')
|
| 149 |
+
text = self.prompt_template_encode_prefix
|
| 150 |
+
|
| 151 |
+
merge_length = self.image_processor_vl.merge_size**2
|
| 152 |
+
while self.image_token in text:
|
| 153 |
+
num_image_tokens = image_grid_thw.prod() // merge_length
|
| 154 |
+
text = text.replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
| 155 |
+
text = text.replace("<|placeholder|>", self.image_token)
|
| 156 |
+
|
| 157 |
+
prefix_tokens = self.tokenizer(text, add_special_tokens=False)['input_ids']
|
| 158 |
+
suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)['input_ids']
|
| 159 |
+
prefix_tokens_mask = torch.tensor([1]*len(prefix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype)
|
| 160 |
+
suffix_tokens_mask = torch.tensor([1]*len(suffix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype)
|
| 161 |
+
|
| 162 |
+
prefix_tokens = torch.tensor(prefix_tokens,dtype=text_tokens_and_mask.input_ids.dtype)
|
| 163 |
+
suffix_tokens = torch.tensor(suffix_tokens,dtype=text_tokens_and_mask.input_ids.dtype)
|
| 164 |
+
|
| 165 |
+
input_ids = torch.cat((prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1 )
|
| 166 |
+
attention_mask = torch.cat((prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1)
|
| 167 |
+
|
| 168 |
+
pixel_values = pixel_values.to(self.device)
|
| 169 |
+
image_grid_thw = image_grid_thw.to(self.device)
|
| 170 |
+
|
| 171 |
+
input_ids = input_ids.unsqueeze(0).to(self.device)
|
| 172 |
+
attention_mask = attention_mask.unsqueeze(0).to(self.device)
|
| 173 |
+
|
| 174 |
+
text_output = self.text_encoder(
|
| 175 |
+
input_ids=input_ids,
|
| 176 |
+
attention_mask=attention_mask,
|
| 177 |
+
pixel_values=pixel_values,
|
| 178 |
+
image_grid_thw =image_grid_thw,
|
| 179 |
+
output_hidden_states=True
|
| 180 |
+
)
|
| 181 |
+
# [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size]
|
| 182 |
+
# clone to have a contiguous tensor
|
| 183 |
+
prompt_embeds = text_output.hidden_states[-1].detach()
|
| 184 |
+
prompt_embeds = prompt_embeds[:,self.prompt_template_encode_start_idx: -self.prompt_template_encode_end_idx ,:]
|
| 185 |
+
|
| 186 |
+
text_ids = prepare_pos_ids(modality_id=0,
|
| 187 |
+
type='text',
|
| 188 |
+
start=(0, 0),
|
| 189 |
+
num_token=prompt_embeds.shape[1]).to(device, dtype=dtype)
|
| 190 |
+
|
| 191 |
+
return prompt_embeds, text_ids
|
| 192 |
+
|
| 193 |
+
@staticmethod
|
| 194 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 195 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 196 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 197 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 198 |
+
|
| 199 |
+
return latents
|
| 200 |
+
|
| 201 |
+
@staticmethod
|
| 202 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 203 |
+
batch_size, num_patches, channels = latents.shape
|
| 204 |
+
|
| 205 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 206 |
+
# latent height and width to be divisible by 2.
|
| 207 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 208 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 209 |
+
|
| 210 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 211 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 212 |
+
|
| 213 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
| 214 |
+
|
| 215 |
+
return latents
|
| 216 |
+
|
| 217 |
+
def enable_vae_slicing(self):
|
| 218 |
+
r"""
|
| 219 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 220 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 221 |
+
"""
|
| 222 |
+
self.vae.enable_slicing()
|
| 223 |
+
|
| 224 |
+
def disable_vae_slicing(self):
|
| 225 |
+
r"""
|
| 226 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 227 |
+
computing decoding in one step.
|
| 228 |
+
"""
|
| 229 |
+
self.vae.disable_slicing()
|
| 230 |
+
|
| 231 |
+
def enable_vae_tiling(self):
|
| 232 |
+
r"""
|
| 233 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 234 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 235 |
+
processing larger images.
|
| 236 |
+
"""
|
| 237 |
+
self.vae.enable_tiling()
|
| 238 |
+
|
| 239 |
+
def disable_vae_tiling(self):
|
| 240 |
+
r"""
|
| 241 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 242 |
+
computing decoding in one step.
|
| 243 |
+
"""
|
| 244 |
+
self.vae.disable_tiling()
|
| 245 |
+
|
| 246 |
+
@property
|
| 247 |
+
def guidance_scale(self):
|
| 248 |
+
return self._guidance_scale
|
| 249 |
+
|
| 250 |
+
@property
|
| 251 |
+
def do_classifier_free_guidance(self):
|
| 252 |
+
return self._guidance_scale > 1
|
| 253 |
+
|
| 254 |
+
def prepare_latents(
|
| 255 |
+
self,
|
| 256 |
+
image,
|
| 257 |
+
batch_size,
|
| 258 |
+
num_channels_latents,
|
| 259 |
+
height,
|
| 260 |
+
width,
|
| 261 |
+
dtype,
|
| 262 |
+
prompt_embeds_length,
|
| 263 |
+
device,
|
| 264 |
+
generator,
|
| 265 |
+
latents=None,
|
| 266 |
+
):
|
| 267 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 268 |
+
# latent height and width to be divisible by 2.
|
| 269 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 270 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 271 |
+
|
| 272 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 273 |
+
|
| 274 |
+
image_latents, image_latents_ids = None, None
|
| 275 |
+
|
| 276 |
+
if image is not None:
|
| 277 |
+
image = image.to(device=self.device, dtype=dtype)
|
| 278 |
+
image_latents = self.vae.encode(image).latent_dist
|
| 279 |
+
image_latents = image_latents.mode()
|
| 280 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 281 |
+
image_latents = image_latents.to(device=self.device, dtype=dtype)
|
| 282 |
+
image_latents = self._pack_latents(
|
| 283 |
+
image_latents, batch_size, num_channels_latents, height, width
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
image_latents_ids = prepare_pos_ids(modality_id=2,
|
| 287 |
+
type='image',
|
| 288 |
+
start=(prompt_embeds_length,
|
| 289 |
+
prompt_embeds_length),
|
| 290 |
+
height=height//2,
|
| 291 |
+
width=width//2).to(device, dtype=torch.float64)
|
| 292 |
+
|
| 293 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 294 |
+
raise ValueError(
|
| 295 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 296 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
if latents is None:
|
| 300 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 301 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 302 |
+
else:
|
| 303 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 304 |
+
|
| 305 |
+
latents_ids = prepare_pos_ids(modality_id=1,
|
| 306 |
+
type='image',
|
| 307 |
+
start=(prompt_embeds_length,
|
| 308 |
+
prompt_embeds_length),
|
| 309 |
+
height=height//2,
|
| 310 |
+
width=width//2).to(device, dtype=torch.float64)
|
| 311 |
+
|
| 312 |
+
return latents, image_latents, latents_ids, image_latents_ids
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@property
|
| 316 |
+
def guidance_scale(self):
|
| 317 |
+
return self._guidance_scale
|
| 318 |
+
|
| 319 |
+
@property
|
| 320 |
+
def joint_attention_kwargs(self):
|
| 321 |
+
return self._joint_attention_kwargs
|
| 322 |
+
|
| 323 |
+
@property
|
| 324 |
+
def num_timesteps(self):
|
| 325 |
+
return self._num_timesteps
|
| 326 |
+
|
| 327 |
+
@property
|
| 328 |
+
def current_timestep(self):
|
| 329 |
+
return self._current_timestep
|
| 330 |
+
|
| 331 |
+
@property
|
| 332 |
+
def interrupt(self):
|
| 333 |
+
return self._interrupt
|
| 334 |
+
|
| 335 |
+
@torch.no_grad()
|
| 336 |
+
def __call__(
|
| 337 |
+
self,
|
| 338 |
+
image: Optional[PipelineImageInput] = None,
|
| 339 |
+
prompt: Union[str, List[str]] = None,
|
| 340 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 341 |
+
num_inference_steps: int = 50,
|
| 342 |
+
sigmas: Optional[List[float]] = None,
|
| 343 |
+
guidance_scale: float = 3.5,
|
| 344 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 345 |
+
generator: Optional[Union[torch.Generator,List[torch.Generator]]] = None,
|
| 346 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 347 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 348 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 349 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 350 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 351 |
+
output_type: Optional[str] = "pil",
|
| 352 |
+
return_dict: bool = True,
|
| 353 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 354 |
+
):
|
| 355 |
+
|
| 356 |
+
image_size = image[0].size if isinstance(image, list) else image.size
|
| 357 |
+
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0]*1.0/image_size[1])
|
| 358 |
+
|
| 359 |
+
self._guidance_scale = guidance_scale
|
| 360 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 361 |
+
self._current_timestep = None
|
| 362 |
+
self._interrupt = False
|
| 363 |
+
|
| 364 |
+
# 2. Define call parameters
|
| 365 |
+
if prompt is not None and isinstance(prompt, str):
|
| 366 |
+
batch_size = 1
|
| 367 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 368 |
+
batch_size = len(prompt)
|
| 369 |
+
else:
|
| 370 |
+
batch_size = prompt_embeds.shape[0]
|
| 371 |
+
|
| 372 |
+
device = self._execution_device
|
| 373 |
+
|
| 374 |
+
image = self.image_processor.resize(image, calculated_height, calculated_width)
|
| 375 |
+
prompt_image = self.image_processor.resize(image, calculated_height//2, calculated_width//2)
|
| 376 |
+
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
|
| 377 |
+
|
| 378 |
+
negative_prompt = '' if negative_prompt is None else negative_prompt
|
| 379 |
+
negative_prompt = [negative_prompt]*num_images_per_prompt
|
| 380 |
+
prompt = [prompt]*num_images_per_prompt
|
| 381 |
+
|
| 382 |
+
prompt_embeds, text_ids = self.encode_prompt(
|
| 383 |
+
image=prompt_image,
|
| 384 |
+
prompts=prompt,
|
| 385 |
+
device=device,
|
| 386 |
+
dtype=torch.float64
|
| 387 |
+
)
|
| 388 |
+
negative_prompt_embeds, negative_text_ids = self.encode_prompt(
|
| 389 |
+
image=prompt_image,
|
| 390 |
+
prompts=negative_prompt,
|
| 391 |
+
device = device,
|
| 392 |
+
dtype=torch.float64
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# 4. Prepare latent variables
|
| 396 |
+
num_channels_latents = 16
|
| 397 |
+
latents, image_latents, latents_ids, image_latents_ids = self.prepare_latents(
|
| 398 |
+
image,
|
| 399 |
+
batch_size * num_images_per_prompt,
|
| 400 |
+
num_channels_latents,
|
| 401 |
+
calculated_height,
|
| 402 |
+
calculated_width,
|
| 403 |
+
prompt_embeds.dtype,
|
| 404 |
+
prompt_embeds.shape[1],
|
| 405 |
+
device,
|
| 406 |
+
generator,
|
| 407 |
+
latents,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# 5. Prepare timesteps
|
| 411 |
+
sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 412 |
+
image_seq_len = latents.shape[1]
|
| 413 |
+
mu = calculate_shift(
|
| 414 |
+
image_seq_len,
|
| 415 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 416 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 417 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 418 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 419 |
+
)
|
| 420 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 421 |
+
self.scheduler,
|
| 422 |
+
num_inference_steps,
|
| 423 |
+
device,
|
| 424 |
+
sigmas=sigmas,
|
| 425 |
+
mu=mu,
|
| 426 |
+
)
|
| 427 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 428 |
+
self._num_timesteps = len(timesteps)
|
| 429 |
+
|
| 430 |
+
# handle guidance
|
| 431 |
+
guidance = None
|
| 432 |
+
|
| 433 |
+
if self.joint_attention_kwargs is None:
|
| 434 |
+
self._joint_attention_kwargs = {}
|
| 435 |
+
|
| 436 |
+
if self.do_classifier_free_guidance:
|
| 437 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0).to(device)
|
| 438 |
+
else:
|
| 439 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 440 |
+
|
| 441 |
+
if image is not None:
|
| 442 |
+
latent_image_ids = torch.cat([latents_ids, image_latents_ids], dim=0)
|
| 443 |
+
else:
|
| 444 |
+
latent_image_ids = latents_ids
|
| 445 |
+
|
| 446 |
+
# 6. Denoising loop
|
| 447 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 448 |
+
for i, t in enumerate(timesteps):
|
| 449 |
+
if self.interrupt:
|
| 450 |
+
continue
|
| 451 |
+
|
| 452 |
+
self._current_timestep = t
|
| 453 |
+
|
| 454 |
+
latent_model_input = latents
|
| 455 |
+
if image_latents is not None:
|
| 456 |
+
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
| 457 |
+
|
| 458 |
+
latent_model_input = torch.cat([latent_model_input] * 2) if self.do_classifier_free_guidance else latent_model_input
|
| 459 |
+
|
| 460 |
+
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
| 461 |
+
|
| 462 |
+
noise_pred = self.transformer(
|
| 463 |
+
hidden_states=latent_model_input,
|
| 464 |
+
timestep=timestep / 1000,
|
| 465 |
+
guidance=guidance,
|
| 466 |
+
encoder_hidden_states=prompt_embeds,
|
| 467 |
+
txt_ids=text_ids,
|
| 468 |
+
img_ids=latent_image_ids,
|
| 469 |
+
return_dict=False,
|
| 470 |
+
)[0]
|
| 471 |
+
noise_pred = noise_pred[:, :image_seq_len]
|
| 472 |
+
if self.do_classifier_free_guidance:
|
| 473 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0)
|
| 474 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 475 |
+
|
| 476 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 477 |
+
latents_dtype = latents.dtype
|
| 478 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 479 |
+
|
| 480 |
+
if latents.dtype != latents_dtype:
|
| 481 |
+
if torch.backends.mps.is_available():
|
| 482 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 483 |
+
latents = latents.to(latents_dtype)
|
| 484 |
+
|
| 485 |
+
# call the callback, if provided
|
| 486 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 487 |
+
progress_bar.update()
|
| 488 |
+
|
| 489 |
+
if XLA_AVAILABLE:
|
| 490 |
+
xm.mark_step()
|
| 491 |
+
|
| 492 |
+
self._current_timestep = None
|
| 493 |
+
|
| 494 |
+
if output_type == "latent":
|
| 495 |
+
image = latents
|
| 496 |
+
else:
|
| 497 |
+
latents = self._unpack_latents(latents, calculated_height, calculated_width, self.vae_scale_factor)
|
| 498 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 499 |
+
|
| 500 |
+
if latents.dtype != self.vae.dtype:
|
| 501 |
+
latents = latents.to(dtype=self.vae.dtype)
|
| 502 |
+
|
| 503 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 504 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 505 |
+
|
| 506 |
+
# Offload all models
|
| 507 |
+
# self.maybe_free_model_hooks()
|
| 508 |
+
|
| 509 |
+
if not return_dict:
|
| 510 |
+
return (image,)
|
| 511 |
+
|
| 512 |
+
return LongCatImagePipelineOutput(images=image)
|
longcat_image/pipelines/pipeline_output.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL.Image
|
| 6 |
+
from diffusers.utils import BaseOutput
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class LongCatImagePipelineOutput(BaseOutput):
|
| 11 |
+
"""
|
| 12 |
+
Output class for Stable Diffusion pipelines.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 16 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
| 17 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
longcat_image/utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dist_utils import get_world_size, get_rank, get_local_rank
|
| 2 |
+
from .log_buffer import LogBuffer
|
| 3 |
+
from .model_utils import split_quotation, prepare_pos_ids, calculate_shift, \
|
| 4 |
+
retrieve_timesteps, optimized_scale,pack_latents, unpack_latents, \
|
| 5 |
+
calculate_shift, prepare_pos_ids, encode_prompt, encode_prompt_edit
|
longcat_image/utils/dist_utils.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
def is_distributed():
|
| 7 |
+
return get_world_size() > 1
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_world_size():
|
| 11 |
+
if not dist.is_available():
|
| 12 |
+
return 1
|
| 13 |
+
return dist.get_world_size() if dist.is_initialized() else 1
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_rank():
|
| 17 |
+
if not dist.is_available():
|
| 18 |
+
return 0
|
| 19 |
+
return dist.get_rank() if dist.is_initialized() else 0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_local_rank():
|
| 23 |
+
if not dist.is_available():
|
| 24 |
+
return 0
|
| 25 |
+
return int(os.getenv('LOCAL_RANK', 0)) if dist.is_initialized() else 0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def is_master():
|
| 29 |
+
return get_rank() == 0
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def is_local_master():
|
| 33 |
+
return get_local_rank() == 0
|
longcat_image/utils/log_buffer.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LogBuffer:
|
| 8 |
+
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.val_history = OrderedDict()
|
| 11 |
+
self.n_history = OrderedDict()
|
| 12 |
+
self.output = OrderedDict()
|
| 13 |
+
self.ready = False
|
| 14 |
+
|
| 15 |
+
def clear(self) -> None:
|
| 16 |
+
self.val_history.clear()
|
| 17 |
+
self.n_history.clear()
|
| 18 |
+
self.clear_output()
|
| 19 |
+
|
| 20 |
+
def clear_output(self) -> None:
|
| 21 |
+
self.output.clear()
|
| 22 |
+
self.ready = False
|
| 23 |
+
|
| 24 |
+
def update(self, vars: dict, count: int = 1) -> None:
|
| 25 |
+
assert isinstance(vars, dict)
|
| 26 |
+
for key, var in vars.items():
|
| 27 |
+
if key not in self.val_history:
|
| 28 |
+
self.val_history[key] = []
|
| 29 |
+
self.n_history[key] = []
|
| 30 |
+
self.val_history[key].append(var)
|
| 31 |
+
self.n_history[key].append(count)
|
| 32 |
+
|
| 33 |
+
def average(self, n: int = 0) -> None:
|
| 34 |
+
"""Average latest n values or all values."""
|
| 35 |
+
assert n >= 0
|
| 36 |
+
for key in self.val_history:
|
| 37 |
+
values = np.array(self.val_history[key][-n:])
|
| 38 |
+
nums = np.array(self.n_history[key][-n:])
|
| 39 |
+
avg = np.sum(values * nums) / np.sum(nums)
|
| 40 |
+
self.output[key] = avg
|
| 41 |
+
self.ready = True
|
longcat_image/utils/model_utils.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import math
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import inspect
|
| 10 |
+
from transformers import AutoTokenizer, AutoProcessor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 14 |
+
latents = latents.view(
|
| 15 |
+
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
| 16 |
+
)
|
| 17 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 18 |
+
latents = latents.reshape(
|
| 19 |
+
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
return latents
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def unpack_latents(latents, height, width, vae_scale_factor):
|
| 26 |
+
batch_size, num_patches, channels = latents.shape
|
| 27 |
+
|
| 28 |
+
height = height // vae_scale_factor
|
| 29 |
+
width = width // vae_scale_factor
|
| 30 |
+
|
| 31 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
| 32 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 33 |
+
|
| 34 |
+
latents = latents.reshape(batch_size, channels //
|
| 35 |
+
(2 * 2), height * 2, width * 2)
|
| 36 |
+
|
| 37 |
+
return latents
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def split_quotation(prompt, quote_pairs=None):
|
| 41 |
+
"""
|
| 42 |
+
Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote pairs.
|
| 43 |
+
|
| 44 |
+
Examples::
|
| 45 |
+
>>> prompt_en = "Please write 'Hello' on the blackboard for me."
|
| 46 |
+
>>> print(split_quotation(prompt_en))
|
| 47 |
+
>>> # output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)]
|
| 48 |
+
"""
|
| 49 |
+
word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+")
|
| 50 |
+
matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt)
|
| 51 |
+
mapping_word_internal_quote = []
|
| 52 |
+
|
| 53 |
+
for i, word_src in enumerate(set(matches_word_internal_quote_pattern)):
|
| 54 |
+
word_tgt = 'longcat_$##$_longcat' * (i+1)
|
| 55 |
+
prompt = prompt.replace(word_src, word_tgt)
|
| 56 |
+
mapping_word_internal_quote.append([word_src, word_tgt])
|
| 57 |
+
|
| 58 |
+
if quote_pairs is None:
|
| 59 |
+
quote_pairs = [("'", "'"), ('"', '"'), ('โ', 'โ'), ('โ', 'โ')]
|
| 60 |
+
quotes = ["'", '"', 'โ', 'โ', 'โ', 'โ']
|
| 61 |
+
for q1 in quotes:
|
| 62 |
+
for q2 in quotes:
|
| 63 |
+
if (q1, q2) not in quote_pairs:
|
| 64 |
+
quote_pairs.append((q1, q2))
|
| 65 |
+
|
| 66 |
+
pattern = '|'.join([re.escape(q1) + r'[^' + re.escape(q1+q2) +
|
| 67 |
+
r']*?' + re.escape(q2) for q1, q2 in quote_pairs])
|
| 68 |
+
|
| 69 |
+
parts = re.split(f'({pattern})', prompt)
|
| 70 |
+
|
| 71 |
+
result = []
|
| 72 |
+
for part in parts:
|
| 73 |
+
for word_src, word_tgt in mapping_word_internal_quote:
|
| 74 |
+
part = part.replace(word_tgt, word_src)
|
| 75 |
+
if re.match(pattern, part):
|
| 76 |
+
if len(part):
|
| 77 |
+
result.append((part, True))
|
| 78 |
+
else:
|
| 79 |
+
if len(part):
|
| 80 |
+
result.append((part, False))
|
| 81 |
+
return result
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def encode_prompt(prompt: str, tokenizer: AutoTokenizer, text_tokenizer_max_length: int, prompt_template_encode_prefix: str, prompt_template_encode_suffix: str):
|
| 85 |
+
|
| 86 |
+
all_tokens = []
|
| 87 |
+
for clean_prompt_sub, matched in split_quotation(prompt):
|
| 88 |
+
if matched:
|
| 89 |
+
for sub_word in clean_prompt_sub:
|
| 90 |
+
tokens = tokenizer(sub_word, add_special_tokens=False)['input_ids']
|
| 91 |
+
all_tokens.extend(tokens)
|
| 92 |
+
else:
|
| 93 |
+
tokens = tokenizer(clean_prompt_sub, add_special_tokens=False)['input_ids']
|
| 94 |
+
all_tokens.extend(tokens)
|
| 95 |
+
|
| 96 |
+
all_tokens = all_tokens[:text_tokenizer_max_length]
|
| 97 |
+
text_tokens_and_mask = tokenizer.pad(
|
| 98 |
+
{'input_ids': [all_tokens]},
|
| 99 |
+
max_length=text_tokenizer_max_length,
|
| 100 |
+
padding='max_length',
|
| 101 |
+
return_attention_mask=True,
|
| 102 |
+
return_tensors='pt')
|
| 103 |
+
|
| 104 |
+
prefix_tokens = tokenizer(prompt_template_encode_prefix, add_special_tokens=False)['input_ids']
|
| 105 |
+
suffix_tokens = tokenizer(prompt_template_encode_suffix, add_special_tokens=False)['input_ids']
|
| 106 |
+
prefix_tokens_mask = torch.tensor( [1]*len(prefix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype )
|
| 107 |
+
suffix_tokens_mask = torch.tensor( [1]*len(suffix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype )
|
| 108 |
+
|
| 109 |
+
prefix_tokens = torch.tensor(prefix_tokens,dtype=text_tokens_and_mask.input_ids.dtype)
|
| 110 |
+
suffix_tokens = torch.tensor(suffix_tokens,dtype=text_tokens_and_mask.input_ids.dtype)
|
| 111 |
+
|
| 112 |
+
input_ids = torch.cat( (prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1 )
|
| 113 |
+
attention_mask = torch.cat( (prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1 )
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
input_ids = text_tokens_and_mask['input_ids'].squeeze(0)
|
| 117 |
+
attention_mask = text_tokens_and_mask['attention_mask'].squeeze(0)
|
| 118 |
+
|
| 119 |
+
return input_ids, attention_mask
|
| 120 |
+
|
| 121 |
+
def encode_prompt_edit(prompt: str, img: Image.Image,tokenizer: AutoTokenizer, image_processor_vl: AutoProcessor, text_tokenizer_max_length: int, prompt_template_encode_prefix: str, prompt_template_encode_suffix: str):
|
| 122 |
+
raw_vl_input = image_processor_vl(images=img,return_tensors="pt")
|
| 123 |
+
pixel_values = raw_vl_input['pixel_values'].squeeze(0)
|
| 124 |
+
image_grid_thw = raw_vl_input['image_grid_thw'].squeeze(0)
|
| 125 |
+
|
| 126 |
+
all_tokens = []
|
| 127 |
+
for clean_prompt_sub, matched in split_quotation(prompt):
|
| 128 |
+
if matched:
|
| 129 |
+
for sub_word in clean_prompt_sub:
|
| 130 |
+
tokens = tokenizer(sub_word, add_special_tokens=False)['input_ids']
|
| 131 |
+
all_tokens.extend(tokens)
|
| 132 |
+
else:
|
| 133 |
+
tokens = tokenizer(clean_prompt_sub, add_special_tokens=False)['input_ids']
|
| 134 |
+
all_tokens.extend(tokens)
|
| 135 |
+
|
| 136 |
+
all_tokens = all_tokens[:text_tokenizer_max_length]
|
| 137 |
+
text_tokens_and_mask = tokenizer.pad(
|
| 138 |
+
{'input_ids': [all_tokens]},
|
| 139 |
+
max_length=text_tokenizer_max_length,
|
| 140 |
+
padding='max_length',
|
| 141 |
+
return_attention_mask=True,
|
| 142 |
+
return_tensors='pt')
|
| 143 |
+
|
| 144 |
+
text = prompt_template_encode_prefix
|
| 145 |
+
merge_length = image_processor_vl.merge_size**2
|
| 146 |
+
while "<|image_pad|>" in text:
|
| 147 |
+
num_image_tokens = image_grid_thw.prod() // merge_length
|
| 148 |
+
text = text.replace( "<|image_pad|>", "<|placeholder|>" * num_image_tokens, 1)
|
| 149 |
+
text = text.replace("<|placeholder|>", "<|image_pad|>")
|
| 150 |
+
|
| 151 |
+
prefix_tokens = tokenizer(text, add_special_tokens=False)['input_ids']
|
| 152 |
+
suffix_tokens = tokenizer(prompt_template_encode_suffix, add_special_tokens=False)['input_ids']
|
| 153 |
+
prefix_tokens_mask = torch.tensor( [1]*len(prefix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype )
|
| 154 |
+
suffix_tokens_mask = torch.tensor( [1]*len(suffix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype )
|
| 155 |
+
|
| 156 |
+
prefix_tokens = torch.tensor(prefix_tokens,dtype=text_tokens_and_mask.input_ids.dtype)
|
| 157 |
+
suffix_tokens = torch.tensor(suffix_tokens,dtype=text_tokens_and_mask.input_ids.dtype)
|
| 158 |
+
|
| 159 |
+
input_ids = torch.cat( (prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1 )
|
| 160 |
+
attention_mask = torch.cat( (prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1 )
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# input_ids = text_tokens_and_mask['input_ids'].squeeze(0)
|
| 164 |
+
# attention_mask = text_tokens_and_mask['attention_mask'].squeeze(0)
|
| 165 |
+
|
| 166 |
+
return input_ids, attention_mask, pixel_values, image_grid_thw
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def prepare_pos_ids(
|
| 170 |
+
modality_id=0,
|
| 171 |
+
type='text',
|
| 172 |
+
start=(0, 0),
|
| 173 |
+
num_token=None,
|
| 174 |
+
height=None,
|
| 175 |
+
width=None):
|
| 176 |
+
if type == 'text':
|
| 177 |
+
assert num_token
|
| 178 |
+
if height or width:
|
| 179 |
+
print(
|
| 180 |
+
'Warning: The parameters of height and width will be ignored in "text" type.')
|
| 181 |
+
pos_ids = torch.zeros(num_token, 3)
|
| 182 |
+
pos_ids[..., 0] = modality_id
|
| 183 |
+
pos_ids[..., 1] = torch.arange(num_token) + start[0]
|
| 184 |
+
pos_ids[..., 2] = torch.arange(num_token) + start[1]
|
| 185 |
+
elif type == 'image':
|
| 186 |
+
assert height and width
|
| 187 |
+
if num_token:
|
| 188 |
+
print('Warning: The parameter of num_token will be ignored in "image" type.')
|
| 189 |
+
pos_ids = torch.zeros(height, width, 3)
|
| 190 |
+
pos_ids[..., 0] = modality_id
|
| 191 |
+
pos_ids[..., 1] = (
|
| 192 |
+
pos_ids[..., 1] + torch.arange(height)[:, None] + start[0]
|
| 193 |
+
)
|
| 194 |
+
pos_ids[..., 2] = (
|
| 195 |
+
pos_ids[..., 2] + torch.arange(width)[None, :] + start[1]
|
| 196 |
+
)
|
| 197 |
+
pos_ids = pos_ids.reshape(height*width, 3)
|
| 198 |
+
else:
|
| 199 |
+
raise KeyError(f'Unknow type {type}, only support "text" or "image".')
|
| 200 |
+
# pos_ids = pos_ids[None, :].repeat(batch_size, 1, 1)
|
| 201 |
+
return pos_ids
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def calculate_shift(
|
| 205 |
+
image_seq_len,
|
| 206 |
+
base_seq_len: int = 256,
|
| 207 |
+
max_seq_len: int = 4096,
|
| 208 |
+
base_shift: float = 0.5,
|
| 209 |
+
max_shift: float = 1.15,
|
| 210 |
+
):
|
| 211 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 212 |
+
b = base_shift - m * base_seq_len
|
| 213 |
+
mu = image_seq_len * m + b
|
| 214 |
+
return mu
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 218 |
+
def retrieve_timesteps(
|
| 219 |
+
scheduler,
|
| 220 |
+
num_inference_steps: Optional[int] = None,
|
| 221 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 222 |
+
timesteps: Optional[List[int]] = None,
|
| 223 |
+
sigmas: Optional[List[float]] = None,
|
| 224 |
+
**kwargs,
|
| 225 |
+
):
|
| 226 |
+
r"""
|
| 227 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 228 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
scheduler (`SchedulerMixin`):
|
| 232 |
+
The scheduler to get timesteps from.
|
| 233 |
+
num_inference_steps (`int`):
|
| 234 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 235 |
+
must be `None`.
|
| 236 |
+
device (`str` or `torch.device`, *optional*):
|
| 237 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 238 |
+
timesteps (`List[int]`, *optional*):
|
| 239 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 240 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 241 |
+
sigmas (`List[float]`, *optional*):
|
| 242 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 243 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 247 |
+
second element is the number of inference steps.
|
| 248 |
+
"""
|
| 249 |
+
if timesteps is not None and sigmas is not None:
|
| 250 |
+
raise ValueError(
|
| 251 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 252 |
+
if timesteps is not None:
|
| 253 |
+
accepts_timesteps = "timesteps" in set(
|
| 254 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 255 |
+
if not accepts_timesteps:
|
| 256 |
+
raise ValueError(
|
| 257 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 258 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 259 |
+
)
|
| 260 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 261 |
+
timesteps = scheduler.timesteps
|
| 262 |
+
num_inference_steps = len(timesteps)
|
| 263 |
+
elif sigmas is not None:
|
| 264 |
+
accept_sigmas = "sigmas" in set(inspect.signature(
|
| 265 |
+
scheduler.set_timesteps).parameters.keys())
|
| 266 |
+
if not accept_sigmas:
|
| 267 |
+
raise ValueError(
|
| 268 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 269 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 270 |
+
)
|
| 271 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 272 |
+
timesteps = scheduler.timesteps
|
| 273 |
+
num_inference_steps = len(timesteps)
|
| 274 |
+
else:
|
| 275 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 276 |
+
timesteps = scheduler.timesteps
|
| 277 |
+
return timesteps, num_inference_steps
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# @torch.cuda.amp.autocast(dtype=torch.float32)
|
| 281 |
+
@torch.amp.autocast('cuda', dtype=torch.float32)
|
| 282 |
+
def optimized_scale(positive_flat, negative_flat):
|
| 283 |
+
|
| 284 |
+
# Calculate dot production
|
| 285 |
+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
| 286 |
+
|
| 287 |
+
# Squared norm of uncondition
|
| 288 |
+
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
| 289 |
+
|
| 290 |
+
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
| 291 |
+
st_star = dot_product / squared_norm
|
| 292 |
+
|
| 293 |
+
return st_star
|
misc/__init__.py
ADDED
|
File without changes
|
misc/accelerate_config.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
deepspeed_config:
|
| 3 |
+
deepspeed_multinode_launcher: standard
|
| 4 |
+
gradient_accumulation_steps: 1
|
| 5 |
+
gradient_clipping: 1.0
|
| 6 |
+
zero3_init_flag: true
|
| 7 |
+
zero_stage: 2
|
| 8 |
+
distributed_type: DEEPSPEED
|
| 9 |
+
downcast_bf16: 'no'
|
| 10 |
+
dynamo_backend: 'NO'
|
| 11 |
+
fsdp_config: {}
|
| 12 |
+
machine_rank: 0
|
| 13 |
+
main_process_ip: null
|
| 14 |
+
main_process_port: null
|
| 15 |
+
main_training_function: main
|
| 16 |
+
mixed_precision: bf16
|
| 17 |
+
num_machines: 1
|
| 18 |
+
num_processes: 8
|
| 19 |
+
rdzv_backend: static
|
| 20 |
+
same_network: true
|
| 21 |
+
use_cpu: false
|
misc/prompt_rewrite_api.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from openai import OpenAI
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
SYSTEM_PROMPT_EN = """
|
| 6 |
+
You are a prompt engineering expert for text-to-image models. Since text-to-image models have limited capabilities in understanding user prompts, you need to identify the core theme and intent of the user's input and improve the model's understanding accuracy and generation quality through optimization and rewriting. The rewrite must strictly retain all information from the user's original prompt without deleting or distorting any details.
|
| 7 |
+
Specific requirements are as follows:
|
| 8 |
+
1. The rewrite must not affect any information expressed in the user's original prompt; the rewritten prompt should use coherent natural language, avoid low-information redundant descriptions, and keep the rewritten prompt length as concise as possible.
|
| 9 |
+
2. Ensure consistency between input and output languages: Chinese input yields Chinese output, and English input yields English output. The rewritten token count should not exceed 512.
|
| 10 |
+
3. The rewritten description should further refine subject characteristics and aesthetic techniques appearing in the original prompt, such as lighting and textures.
|
| 11 |
+
4. If the original prompt does not specify an image style, ensure the rewritten prompt uses a **realistic photography style**. If the user specifies a style, retain the user's style.
|
| 12 |
+
5. When the original prompt requires reasoning to clarify user intent, use logical reasoning based on world knowledge to convert vague abstract descriptions into specific tangible objects (e.g., convert "the tallest animal" to "a giraffe").
|
| 13 |
+
6. When the original prompt requires text generation, please use double quotes to enclose the text part (e.g., `"50% OFF"`).
|
| 14 |
+
7. When the original prompt requires generating text-heavy scenes like webpages, logos, UIs, or posters, and no specific text content is specified, you need to infer appropriate text content and enclose it in double quotes. For example, if the user inputs: "A tourism flyer with a grassland theme," it should be rewritten as: "A tourism flyer with the image title 'Grassland'."
|
| 15 |
+
8. When negative words exist in the original prompt, ensure the rewritten prompt does not contain negative words. For example, "a lakeside without boats" should be rewritten such that the word "boat" does not appear at all.
|
| 16 |
+
9. Except for text content explicitly requested by the user, **adding any extra text content is prohibited**.
|
| 17 |
+
Here are examples of rewrites for different types of prompts:
|
| 18 |
+
# Examples (Few-Shot Learning)
|
| 19 |
+
1. User Input: An animal with nine lives.
|
| 20 |
+
Rewrite Output: A cat bathed in soft sunlight, its fur soft and glossy. The background is a comfortable home environment with light from the window filtering through curtains, creating a warm light and shadow effect. The shot uses a medium distance perspective to highlight the cat's leisurely and stretched posture. Light cleverly hits the cat's face, emphasizing its spirited eyes and delicate whiskers, adding depth and affinity to the image.
|
| 21 |
+
2. User Input: Create an anime-style tourism flyer with a grassland theme.
|
| 22 |
+
Rewrite Output: In the lower right of the center, a short-haired girl sits sideways on a gray, irregularly shaped rock. She wears a white short-sleeved dress and brown flat shoes, holding a bunch of small white flowers in her left hand, smiling with her legs hanging naturally. The girl has dark brown shoulder-length hair with bangs covering her forehead, brown eyes, and a slightly open mouth. The rock surface has textures of varying depths. To the girl's left and front is lush grass, with long, yellow-green blades, some glowing golden in the sunlight. The grass extends into the distance, forming rolling green hills that fade in color as they recede. The sky occupies the upper half of the picture, pale blue dotted with a few fluffy white clouds. In the upper left corner, there is a line of text in italic, dark green font reading "Explore Nature's Peace". Colors are dominated by green, blue, and yellow, fluid lines, and distinct light and shadow contrast, creating a quiet and comfortable atmosphere.
|
| 23 |
+
3. User Input: A Christmas sale poster with a red background, promoting a Buy 1 Get 1 Free milk tea offer.
|
| 24 |
+
Rewrite Output: The poster features an overall red tone, embellished with white snowflake patterns on the top and left side. The upper right features a bunch of holly leaves with red berries and a pine cone. In the upper center, golden 3D text reads "Christmas Heartwarming Feedback" centered, along with red bold text "Buy 1 Get 1". Below, two transparent cups filled with bubble tea are placed side by side; the tea is light brown with dark brown pearls scattered at the bottom and middle. Below the cups, white snow piles up, decorated with pine branches, red berries, and pine cones. A blurry Christmas tree is faintly visible in the lower right corner. The image has high clarity, accurate text content, a unified design style, a prominent Christmas theme, and a reasonable layout, providing strong visual appeal.
|
| 25 |
+
4. User Input: A woman indoors shot in natural light, smiling with arms crossed, showing a relaxed and confident posture.
|
| 26 |
+
Rewrite Output: The image features a young Asian woman with long dark brown hair naturally falling over her shoulders, with some strands illuminated by light, showing a soft sheen. Her features are delicate, with long eyebrows, bright and spirited dark brown eyes looking directly at the camera, revealing peace and confidence. She has a high nose bridge, full lips with nude lipstick, and corners of the mouth slightly raised in a faint smile. Her skin is fair, with cheeks and collarbones illuminated by warm light, showing a healthy ruddiness. She wears a black spaghetti strap tank top revealing graceful collarbone lines, and a thin gold necklace with small beads and metal bars glinting in the light. Her outer layer is a beige knitted cardigan, soft in texture with visible knitting patterns on the sleeves. Her arms are crossed over her chest, hands covered by the cardigan sleeves, in a relaxed posture. The background is a pure dark brown without extra decoration, making the figure the absolute focus. The figure is located in the center of the frame. Light enters from the upper right, creating bright spots on her left cheek, neck, and collarbone, while the right side is slightly shadowed, creating a three-dimensional and soft tone. Image details are clear, showcasing skin texture, hair, and clothing materials well. Colors are dominated by warm tones, with the combination of beige and dark brown creating a warm and comfortable atmosphere. The overall style is natural, elegant, and artistic.
|
| 27 |
+
5. User Input: Create a series of images showing the growth process of an apple from seed to fruit. The series should include four stages: 1. Sowing, 2. Seedling growth, 3. Plant maturity, 4. Fruit harvesting.
|
| 28 |
+
Rewrite Output: A 4-panel exquisite illustration depicting the growth process of an apple, capturing each stage precisely and clearly. 1. "Sowing": A close-up shot of a hand gently placing a small apple seed into fertile dark soil, with visible soil texture and the seed's smooth surface. The background is a soft-focus garden dotted with green leaves and sunlight filtering through. 2. "Seedling Growth": A young apple sapling breaks through the soil, stretching tender green leaves toward the sky. The scene is set in a vibrant garden illuminated by warm golden light, highlighting the seedling's delicate structure. 3. "Plant Maturity": A mature apple tree, lush with branches and leaves, covered in tender green foliage and developing small apples. The background is a vibrant orchard under a clear blue sky, with dappled sunlight creating a peaceful atmosphere. 4. "Fruit Harvesting": A hand reaches into the tree to pick a ripe red apple, its smooth skin glistening in the sun. The scene shows the abundance of the orchard, with baskets of apples in the background, giving a sense of fulfillment. Each illustration uses a realistic style, focusing on details and harmonious colors to showcase the natural beauty and development of the apple's life cycle.
|
| 29 |
+
6. User Input: If 1 represents red, 2 represents green, 3 represents purple, and 4 represents yellow, please generate a four-color rainbow based on this rule. The color order from top to bottom is 3142.
|
| 30 |
+
Rewrite Output: The image consists of four horizontally arranged colored stripes, ordered from top to bottom as purple, red, yellow, and green. A white number is centered on each stripe. The top purple stripe features the number "3", the red stripe below it has the number "1", the yellow stripe further down has the number "4", and the bottom green stripe has the number "2". All numbers use a sans-serif font in pure white, forming a sharp contrast with the background colors to ensure good readability. The stripes have high color saturation and a slight texture. The overall layout is simple and clear, with distinct visual effects and no extra decorative elements, emphasizing the numerical information. The image is high definition, with accurate colors and a consistent style, offering strong visual appeal.
|
| 31 |
+
7. User Input: A stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", natural light, background is a Chinese garden.
|
| 32 |
+
Rewrite Output: An ancient stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", the surface covered with traces of time, the writing clear and deep. Natural light falls from above, softly illuminating every detail of the stone tablet and enhancing its sense of history. The background is an elegant Chinese garden featuring lush bamboo forests, winding paths, and quiet pools, creating a serene and distant atmosphere. The overall picture uses a realistic style with rich details and natural light and shadow effects, highlighting the cultural heritage of the stone tablet and the classical beauty of the garden.
|
| 33 |
+
# Output Format
|
| 34 |
+
Please directly output the rewritten and optimized Prompt content. Do not include any explanatory language or JSON formatting, and do not add opening or closing quotes yourself."""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
SYSTEM_PROMPT_ZH = """
|
| 38 |
+
ไฝ ๆฏไธๅๆ็ๅพๆจกๅ็prompt engineeringไธๅฎถใ็ฑไบๆ็ๅพๆจกๅๅฏน็จๆทprompt็็่งฃ่ฝๅๆ้๏ผไฝ ้่ฆ่ฏๅซ็จๆท่พๅ
ฅ็ๆ ธๅฟไธป้ขๅๆๅพ๏ผๅนถ้่ฟไผๅๆนๅๆๅๆจกๅ็็่งฃๅ็กฎๆงๅ็ๆ่ดจ้ใๆนๅๅฟ
้กปไธฅๆ ผไฟ็็จๆทๅๅงprompt็ๆๆไฟกๆฏ๏ผไธๅพๅ ๅๆๆฒ่งฃไปปไฝ็ป่ใ
|
| 39 |
+
ๅ
ทไฝ่ฆๆฑๅฆไธ๏ผ
|
| 40 |
+
1. ๆนๅไธ่ฝๅฝฑๅ็จๆทๅๅงprompt้่กจ่พพ็ไปปไฝไฟกๆฏ๏ผๆนๅๅ็promptๅบ่ฏฅไฝฟ็จ่ฟ่ดฏ็่ช็ถ่ฏญ่จ่กจ่พพ,ไธ่ฆๅบ็ฐไฝไฟกๆฏ้็ๅไฝๆ่ฟฐ๏ผๅฐฝๅฏ่ฝไฟๆๆนๅๅprompt้ฟๅบฆ็ฒพ็ฎใ
|
| 41 |
+
2. ่ฏท็กฎไฟ่พๅ
ฅๅ่พๅบ็่ฏญ่จ็ฑปๅไธ่ด๏ผไธญๆ่พๅ
ฅไธญๆ่พๅบ๏ผ่ฑๆ่พๅ
ฅ่ฑๆ่พๅบ๏ผๆนๅๅ็tokenๆฐ้ไธ่ฆ่ถ
่ฟ512ไธช;
|
| 42 |
+
3. ๆนๅๅ็ๆ่ฟฐๅบๅฝ่ฟไธๆญฅๅฎๅๅๅงpromptไธญๅบ็ฐ็ไธปไฝ็นๅพใ็พๅญฆๆๅทง๏ผๅฆๆๅ
ใ็บน็็ญ๏ผ
|
| 43 |
+
4. ๅฆๆๅๅงpromptๆฒกๆๆๅฎๅพ็้ฃๆ ผๆถ๏ผ็กฎไฟๆนๅๅ็promptไฝฟ็จ็ๅฎๆๅฝฑ้ฃๆ ผ๏ผๅฆๆ็จๆทๆๅฎไบๅพ็้ฃๆ ผ๏ผๅไฟ็็จๆท้ฃๆ ผ๏ผ
|
| 44 |
+
5. ๅฝๅๅงprompt้่ฆๆจ็ๆ่ฝๆ็กฎ็จๆทๆๅพๆถ๏ผๆ นๆฎไธ็็ฅ่ฏ่ฟ่ก้ๅฝ้ป่พๆจ็๏ผๅฐๆจก็ณๆฝ่ฑกๆ่ฟฐ่ฝฌๅไธบๅ
ทไฝๆๅไบ็ฉ๏ผไพ๏ผๅฐ"ๆ้ซ็ๅจ็ฉ"่ฝฌๅไธบ"ไธๅคด้ฟ้ข้นฟ"๏ผใ
|
| 45 |
+
6. ๅฝๅๅงprompt้่ฆ็ๆๆๅญๆถ๏ผ่ฏทไฝฟ็จๅๅผๅทๅๅฎๆๅญ้จๅ๏ผไพ๏ผ`"้ๆถ5ๆ"`๏ผใ
|
| 46 |
+
7. ๅฝๅๅงprompt้่ฆ็ๆ็ฝ้กตใlogoใuiใๆตทๆฅ็ญๆๅญๅบๆฏๆถ๏ผไธๆฒกๆๆๅฎๅ
ทไฝ็ๆๅญๅ
ๅฎนๆถ๏ผ้่ฆๆจๆญๅบๅ้็ๆๅญๅ
ๅฎน๏ผๅนถไฝฟ็จๅๅผๅทๅๅฎ๏ผๅฆ็จๆท่พๅ
ฅ๏ผไธไธชๆ
ๆธธๅฎฃไผ ๅ๏ผไปฅ่ๅไธบไธป้ขใๅบ่ฏฅๆนๅๆ๏ผไธไธชๆ
ๆธธๅฎฃไผ ๅ๏ผๅพ็ๆ ้ขไธบโ่ๅโใ
|
| 47 |
+
8. ๅฝๅๅงpromptไธญๅญๅจๅฆๅฎ่ฏๆถ๏ผ้่ฆ็กฎไฟๆนๅๅ็promptไธๅญๅจๅฆๅฎ่ฏ๏ผๅฆๆฒกๆ่น็ๆน่พน๏ผๆนๅๅ็promptไธ่ฝๅบ็ฐ่น่ฟไธช่ฏๆฑใ
|
| 48 |
+
9. ้ค้็จๆทๆๅฎ็ๆๅ็logo๏ผๅฆๅไธ่ฆๅขๅ ้ขๅค็ๅ็logo.
|
| 49 |
+
10. ้คไบ็จๆทๆ็กฎ่ฆๆฑไนฆๅ็ๆๅญๅ
ๅฎนๅค๏ผ**็ฆๆญขๅขๅ ไปปไฝ้ขๅค็ๆๅญๅ
ๅฎน**ใ
|
| 50 |
+
ไปฅไธๆฏ้ๅฏนไธๅ็ฑปๅpromptๆนๅ็็คบไพ๏ผ
|
| 51 |
+
|
| 52 |
+
# Examples (Few-Shot Learning)
|
| 53 |
+
1. ็จๆท่พๅ
ฅ: ไนๆกๅฝ็ๅจ็ฉใ
|
| 54 |
+
ๆนๅ่พๅบ: ไธๅช็ซ๏ผ่ขซๆๅ็้ณๅ
็ฌผ็ฝฉ็๏ผๆฏๅๆ่ฝฏ่ๅฏๆๅ
ๆณฝใ่ๆฏๆฏไธไธช่้็ๅฎถๅฑ
็ฏๅข๏ผ็ชๅค็ๅ
็บฟ้่ฟ็ชๅธ๏ผๅฝขๆๆธฉ้ฆจ็ๅ
ๅฝฑๆๆใ้ๅคด้็จไธญ่ท็ฆป่ง่ง๏ผ็ชๅบ็ซๆ ้ฒ่ๅฑ็ๅงฟๆใๅ
็บฟๅทงๅฆๅฐๆๅจ็ซ็่ธ้จ๏ผๅผบ่ฐๅฎ็ตๅจ็็ผ็ๅ็ฒพ่ด็่ก้กป๏ผๅขๅ ็ป้ข็ๅฑๆฌกๆไธไบฒๅๅใ
|
| 55 |
+
2. ็จๆท่พๅ
ฅ: ๅถไฝไธไธชๅจ็ป้ฃๆ ผ็ๆ
ๆธธๅฎฃไผ ๅ๏ผไปฅ่ๅไธบไธป้ขใ
|
| 56 |
+
ๆนๅ่พๅบ: ็ป้ขไธญๅคฎๅๅณไธ่ง๏ผไธไธช็ญๅๅฅณๅญฉไพง่บซๅๅจ็ฐ่ฒ็ไธ่งๅๅฝข็ถๅฒฉ็ณไธ๏ผๅฅน็ฉฟ็็ฝ่ฒ็ญ่ข่ฟ่กฃ่ฃๅๆฃ่ฒๅนณๅบ้๏ผๅทฆๆๆฟ็ไธๆ็ฝ่ฒๅฐ่ฑ๏ผ้ขๅธฆๅพฎ็ฌ๏ผๅ่
ฟ่ช็ถๅไธใๅฅณๅญฉ็ๅคดๅไธบๆทฑๆฃ่ฒ๏ผ้ฝ่ฉ็ญๅ๏ผๅๆตท่ฆ็้ขๅคด๏ผ็ผ็ๅๆฃ่ฒ๏ผๅดๅทดๅพฎๅผ ใๅฒฉ็ณ่กจ้ขๆๆทฑๆต
ไธไธ็็บน็ใๅฅณๅญฉ็ๅทฆไพงๅๅๆนๆฏ่็็่ๅฐ๏ผ่ๅถ็ป้ฟ๏ผๅ้ป็ปฟ่ฒ๏ผ้จๅ่ๅถๅจ้ณๅ
ไธๆณ็้่ฒ็ๅ
่๏ผไปฟไฝ่ขซ้ณๅ
็
งไบฎใ่ๅฐๅ่ฟๅคๅปถไผธ๏ผๅฝขๆ่ฟ็ปต่ตทไผ็็ปฟ่ฒๅฑฑไธ๏ผๅฑฑไธ็้ข่ฒ็ฑ่ฟๅ่ฟ้ๆธๅๆต
ใๅคฉ็ฉบๅ ๆฎไบ็ป้ข็ไธๅ้จๅ๏ผๅๆทก่่ฒ๏ผ็น็ผ็ๅ ๆต็ฝ่ฒ่ฌๆพ็ไบๅฝฉใ็ป้ข็ๅทฆไธ่งๆไธ่กๆๅญ๏ผๆๅญๅ
ๅฎนๆฏๆไฝใๆทฑ็ปฟ่ฒ็โExplore Nature's Peaceโใ่ฒๅฝฉไปฅ็ปฟ่ฒใ่่ฒๅ้ป่ฒไธบไธป๏ผ็บฟๆกๆต็
๏ผๅ
ๅฝฑๆๆๅฏนๆฏๆๆพ๏ผ่ฅ้ ๅบไธ็งๅฎ้ใ่้็ๆฐๅดใ
|
| 57 |
+
3. ็จๆท่พๅ
ฅ: ไธๅผ ไปฅ็บข่ฒไธบ่ๆฏ็ๅฃ่ฏ่ไฟ้ๆตทๆฅ๏ผไธป่ฆๅฎฃไผ ๅฅถ่ถไนฐไธ้ไธ็ไผๆ ๆดปๅจใ
|
| 58 |
+
ๆนๅ่พๅบ: ๆตทๆฅๆดไฝๅ็ฐ็บข่ฒ่ฐ๏ผไธๆนๅๅทฆไพง็น็ผ็็ฝ่ฒ้ช่ฑๅพๆก๏ผๅณไธๆนๆไธๆๅฌ้ๅถๅ็บข่ฒๆตๆ๏ผไปฅๅไธไธชๆพๆใๆตทๆฅไธญๅคฎๅไธไฝ็ฝฎ๏ผ้่ฒ็ซไฝๅญๆ ทโๅฃ่ฏ่ ๆๅฟๅ้ฆโๅฑ
ไธญๆๅ๏ผๅ็บข่ฒ็ฒไฝๅญโไนฐ1้1โใๆตทๆฅไธๆน๏ผไธคไธช่ฃ
ๆปก็็ ๅฅถ่ถ็้ๆๆฏๅญๅนถๆๆๆพ๏ผๆฏไธญๅฅถ่ถๅๆต
ๆฃ่ฒ๏ผๅบ้จๅไธญ้ดๆฃๅธ็ๆทฑๆฃ่ฒ็็ ใๆฏๅญไธๆน๏ผๅ ็งฏ็็ฝ่ฒ้ช่ฑ๏ผ้ช่ฑไธ่ฃ
้ฅฐ็ๆพๆใ็บข่ฒๆตๆๅๆพๆใๅณไธ่ง้็บฆๅฏ่งไธๆฃตๆจก็ณ็ๅฃ่ฏๆ ใๅพ็ๆธ
ๆฐๅบฆ้ซ๏ผๆๅญๅ
ๅฎนๅ็กฎ๏ผๆดไฝ่ฎพ่ฎก้ฃๆ ผ็ปไธ๏ผๅฃ่ฏไธป้ข็ชๅบ๏ผๆ็ๅธๅฑๅ็๏ผๅ
ทๆ่พๅผบ็่ง่งๅธๅผๅใ
|
| 59 |
+
4. ็จๆท่พๅ
ฅ: ไธไฝๅฅณๆงๅจๅฎคๅ
ไปฅ่ช็ถๅ
็บฟๆๆ๏ผๅฅน้ขๅธฆๅพฎ็ฌ๏ผๅ่ไบคๅ๏ผๅฑ็ฐๅบ่ฝปๆพ่ชไฟก็ๅงฟๆใ
|
| 60 |
+
ๆนๅ่พๅบ: ็ป้ขไธญๆฏไธไฝๅนด่ฝป็ไบๆดฒๅฅณๆง๏ผๅฅนๆฅๆๆทฑๆฃ่ฒ็้ฟๅ๏ผๅไธ่ช็ถๅฐๅ่ฝๅจๅ่ฉ๏ผ้จๅๅไธ่ขซๅ
็บฟ็
งไบฎ๏ผๅ็ฐๅบๆๅ็ๅ
ๆณฝใๅฅน็ไบๅฎ็ฒพ่ด๏ผ็ๆฏไฟฎ้ฟ๏ผ็ผ็ๆไบฎๆ็ฅ๏ผ็ณๅญๅๆทฑๆฃ่ฒ๏ผ็ผ็ฅ็ด่ง้ๅคด๏ผๆต้ฒๅบๅนณๅไธ่ชไฟกใ้ผปๆขๆบๆ๏ผๅดๅไธฐๆปก๏ผๆถๆ่ฃธ่ฒ็ณปๅ่๏ผๅด่งๅพฎๅพฎไธๆฌ๏ผๅฑ็ฐๅบๆต
ๆต
็ๅพฎ็ฌใๅฅน็่ค่ฒ็ฝ็๏ผ่ธ้ขๅ้้ชจๅค่ขซๆ่ฒ่ฐ็ๅ
็บฟ็
งไบฎ๏ผๅ็ฐๅบๅฅๅบท็็บขๆถฆๆใๅฅน็ฉฟ็ไธไปถ้ป่ฒ็็ปๅๅธฆ่ๅฟ๏ผ่ฉๅธฆ็บค็ป๏ผ้ฒๅบไผ็พ็้้ชจ็บฟๆกใ่้ขไธไฝฉๆด็ไธๆก้่ฒ็็ป้กน้พ๏ผ้กน้พ็ฑๅฐ็ ๅญ๏ฟฝ๏ฟฝๅ ไธช็ป้ฟ็้ๅฑๆก็ปๆ๏ผๅจๅ
็บฟไธ้ช็็ๅ
ๆณฝใๅฅน็ๅคๆญๆฏไธไปถ็ฑณ้ป่ฒ็้็ปๅผ่กซ๏ผๆ่ดจๆ่ฝฏ๏ผ่ขๅญ้จๅๆๆๆพ็้็ป็บน็ใๅฅนๅ่ไบคๅๅจ่ธๅ๏ผๅๆ่ขซๅผ่กซ็่ขๅญ่ฆ็๏ผๅงฟๆๆพๆพใ่ๆฏๆฏ็บฏ็ฒน็ๆทฑๆฃ่ฒ๏ผๆฒกๆๅคไฝ็่ฃ
้ฅฐ๏ผไฝฟๅพไบบ็ฉๆไธบ็ป้ข็็ปๅฏน็ฆ็นใไบบ็ฉไฝไบ็ป้ขไธญๅคฎใๅ
็บฟไป็ป้ข็ๅณไธๆนๅฐๅ
ฅ๏ผๅจไบบ็ฉ็ๅทฆไพง่ธ้ขใ่้ขๅ้้ชจๅคๅฝขๆๆไบฎ็ๅ
ๆ๏ผๅณไพงๅ็ฅๆพ้ดๅฝฑ๏ผ่ฅ้ ๅบ็ซไฝๆๅๆๅ็ๅฝฑ่ฐใๅพๅ็ป่ๆธ
ๆฐ๏ผไบบ็ฉ็็ฎ่ค็บน็ใๅไธไปฅๅ่กฃ็ฉๆ่ดจ้ฝๅพๅฐไบๅพๅฅฝ็ๅฑ็ฐใ่ฒๅฝฉไปฅๆ่ฒ่ฐไธบไธป๏ผ็ฑณ้ป่ฒๅๆทฑๆฃ่ฒ็ๆญ้
่ฅ้ ๅบๆธฉ้ฆจ่้็ๆฐๅดใๆดไฝๅ็ฐๅบไธ็ง่ช็ถใไผ้
ไธๅฏๆไบฒๅๅ็่บๆฏ้ฃๆ ผใ
|
| 61 |
+
5. ็จๆท่พๅ
ฅ๏ผๅไฝไธ็ณปๅๅพ็๏ผๅฑ็ฐ่นๆไป็งๅญๅฐ็ปๆ็็้ฟ่ฟ็จใ่ฏฅ็ณปๅๅพ็ๅบๅ
ๅซไปฅไธๅไธช้ถๆฎต๏ผ1. ๆญ็ง๏ผ2. ๅนผ่็้ฟ๏ผ3. ๆค็ฉๆ็๏ผ4. ๆๅฎ้ๆใ
|
| 62 |
+
ๆนๅ่พๅบ๏ผไธไธช4ๅฎซๆ ผ็็ฒพ็พๆๅพ๏ผๆ็ป่นๆ็็้ฟ่ฟ็จ๏ผ็ฒพ็กฎๆธ
ๆฐๅฐๆๆๆฏไธช้ถๆฎตใ1.โๆญ็งโ๏ผ็นๅ้ๅคด๏ผไธๅชๆ่ฝป่ฝปๅฐๅฐไธ้ขๅฐๅฐ็่นๆ็งๅญๆพๅ
ฅ่ฅๆฒ็ๆทฑ่ฒๅๅฃคไธญ๏ผๅๅฃค็็บน็ๅ็งๅญๅ
ๆป็่กจ้ขๆธ
ๆฐๅฏ่งใ่ๆฏๆฏ่ฑๅญ็ๆ็ฆ็ป้ข๏ผ็น็ผ็็ปฟ่ฒ็ๆ ๅถๅ้่ฟๆ ๅถๆดไธ็้ณๅ
ใ2.โๅนผ่็้ฟโ๏ผไธๆฃตๅนผๅฐ็่นๆๆ ่็ ดๅ่ๅบ๏ผๅซฉ็ปฟ็ๅถๅญๅๅคฉ็ฉบ่ๅฑใๅบๆฏ่ฎพๅฎๅจไธไธช็ๆบๅๅ็่ฑๅญไธญ๏ผๆธฉๆ็้ๅ
็
งไบฎไบๅฎใๅนผ่็็บค็ป็ปๆใ3.โๆค็ฉ็ๆ็โ๏ผไธๆฃตๆ็็่นๆๆ ๏ผๆ็นๅถ่๏ผๆๆปกไบๅซฉ็ปฟ็ๅถๅญๅๆญฃๅจ่ๅ็ๅฐ่นๆใ่ๆฏๆฏไธ็็ๆบๅๅ็ๆๅญ๏ผๆน่็ๅคฉ็ฉบไธ๏ผๆ้ฉณ็้ณๅ
่ฅ้ ๅบๅฎ้็ฅฅๅ็ๆฐๅดใ4.โ้ๆๆๅฎโ๏ผไธๅชๆไผธๅๆ ไธ๏ผๆไธไธไธชๆ็็็บข่นๆ๏ผ่นๆๅ
ๆป็ๆ็ฎๅจ้ณๅ
ไธ้ช้ชๅๅ
ใ็ป้ขๅฑ็ฐไบๆๅญ็ไธฐๆถๆฏ่ฑก๏ผ่ๆฏไธญๆๆพ็ไธ็ฏฎ็ฏฎ็่นๆ๏ผ็ปไบบไธ็งๅๆปกๆปก่ถณ็ๆ่งใๆฏๅน
ๆๅพ้ฝ้็จๅๅฎ้ฃๆ ผ๏ผๆณจ้็ป่๏ผ่ฒๅฝฉๅ่ฐ๏ผๅฑ็ฐไบ่นๆ็ๅฝๅจๆ็่ช็ถไน็พๅๅๅฑ่ฟ็จใ
|
| 63 |
+
6. ็จๆท่พๅ
ฅ๏ผ ๅฆๆ1ไปฃ่กจ็บข่ฒ๏ผ2ไปฃ่กจ็ปฟ่ฒ๏ผ3ไปฃ่กจ็ดซ่ฒ๏ผ4ไปฃ่กจ้ป่ฒ๏ผ่ฏทๆ็
งๆญค่งๅ็ๆๅ่ฒๅฝฉ่นใๅฎ็้ข่ฒ้กบๅบไปไธๅฐไธๆฏ3142
|
| 64 |
+
ๆนๅ่พๅบ๏ผๅพ็็ฑๅไธชๆฐดๅนณๆๅ็ๅฝฉ่ฒๆก็บน็ปๆ๏ผไปไธๅฐไธไพๆฌกไธบ็ดซ่ฒใ็บข่ฒใ้ป่ฒๅ็ปฟ่ฒใๆฏไธชๆก็บนไธ้ฝๅฑ
ไธญๆพ็ฝฎไธไธช็ฝ่ฒๆฐๅญใๆไธๆน็็ดซ่ฒๆก็บนไธๆฏๆฐๅญโ3โ๏ผๅ
ถไธๆน็บข่ฒๆก็บนไธๆฏๆฐๅญโ1โ๏ผๅไธๆน้ป่ฒๆก็บนไธๆฏๆฐๅญโ4โ๏ผๆไธๆน็็ปฟ่ฒๆก็บนไธๆฏๆฐๅญโ2โใๆๆๆฐๅญๅ้็จๆ ่กฌ็บฟๅญไฝ๏ผ้ข่ฒไธบ็บฏ็ฝ่ฒ๏ผไธ่ๆฏ่ฒๅฝขๆ้ฒๆๅฏนๆฏ๏ผ็กฎไฟไบ่ฏๅฅฝ็ๅฏ่ฏปๆงใๆก็บน็้ข่ฒ้ฅฑๅๅบฆ้ซ๏ผไธๅธฆๆ่ฝปๅพฎ็็บน็ๆ๏ผๆดไฝๆ็็ฎๆดๆไบ๏ผ่ง่งๆๆๆธ
ๆฐ๏ผๆฒกๆๅคไฝ็่ฃ
้ฅฐๅ
็ด ๏ผๅผบ่ฐไบๆฐๅญไฟกๆฏๆฌ่บซใๅพ็ๆดไฝๆธ
ๆฐๅบฆ้ซ๏ผ่ฒๅฝฉๅ็กฎ๏ผ้ฃๆ ผไธ่ด๏ผๅ
ทๆ่พๅผบ็่ง่งๅธๅผๅใ
|
| 65 |
+
7. ็จๆท่พๅ
ฅ๏ผ็ณ็ขไธๅป็โๅ
ณๅ
ณ้้ธ ๏ผๅจๆฒณไนๆดฒโ๏ผ่ช็ถๅ
็
ง๏ผ่ๆฏๆฏไธญๅผๅญๆ
|
| 66 |
+
ๆนๅ่พๅบ๏ผไธๅๅค่็็ณ็ขไธๅป็โๅ
ณๅ
ณ้้ธ ๏ผๅจๆฒณไนๆดฒโ๏ผ็ณ็ข่กจ้ขๅธๆปกๅฒๆ็็่ฟน๏ผๅญ่ฟนๆธ
ๆฐ่ๆทฑๅปใ่ช็ถๅ
็บฟไปไธๆนๆดไธ๏ผๆๅๅฐ็
งไบฎ็ณ็ข็ๆฏไธไธช็ป่๏ผๅขๅผบไบๅ
ถๅๅฒๆใ่ๆฏๆฏไธๅบงๅ
ธ้
็ไธญๅผๅญๆ๏ผๅญๆไธญๆ็ฟ ็ปฟ็็ซนๆใ่ฟ่็ๅฐๅพๅ้่ฐง็ๆฐดๆฑ ๏ผ่ฅ้ ๅบไธ็งๅฎ้่ๆ ่ฟ็ๆฐๅดใๆดไฝ็ป้ข้็จๅๅฎ้ฃๆ ผ๏ผ็ป่ไธฐๅฏ๏ผๅ
ๅฝฑๆๆ่ช็ถ๏ผ็ชๅบไบ็ณ็ข็ๆๅๅบ่ดๅๅญๆ็ๅคๅ
ธ็พใ
|
| 67 |
+
# ่พๅบๆ ผๅผ
|
| 68 |
+
่ฏท็ดๆฅ่พๅบๆนๅไผๅๅ็ Prompt ๅ
ๅฎน๏ผไธ่ฆๅ
ๅซไปปไฝ่งฃ้ๆง่ฏญ่จๆ JSON ๆ ผๅผ๏ผไธ่ฆ่ช่กๆทปๅ ๅผๅคดๆ็ปๅฐพ็ๅผๅทใ
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def contains_chinese(text):
|
| 73 |
+
pattern = re.compile(r'[\u4e00-\u9fff]')
|
| 74 |
+
if bool(pattern.search(text)):
|
| 75 |
+
return 'zh'
|
| 76 |
+
return 'en'
|
| 77 |
+
|
| 78 |
+
def prompt_rewrite_deepseek( prompt ):
|
| 79 |
+
client = OpenAI(api_key=os.environ.get('DEEPSEEK_API_KEY'),
|
| 80 |
+
base_url="https://api.deepseek.com" )
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
language = contains_chinese(prompt)
|
| 84 |
+
if language == 'zh':
|
| 85 |
+
question = SYSTEM_PROMPT_ZH + f"\n็จๆท่พๅ
ฅไธบ๏ผ{prompt}\nๆนๅๅ็promptไธบ๏ผ"
|
| 86 |
+
else:
|
| 87 |
+
question = SYSTEM_PROMPT_EN + f"\nUser Input: {prompt}\nRewritten prompt:"
|
| 88 |
+
|
| 89 |
+
response = client.chat.completions.create(
|
| 90 |
+
model="deepseek-chat",
|
| 91 |
+
messages=[
|
| 92 |
+
{"role": "system", "content": "You are a helpful assistant"},
|
| 93 |
+
{"role": "user", "content": question }
|
| 94 |
+
],
|
| 95 |
+
temperature=0.7,
|
| 96 |
+
max_tokens=512,
|
| 97 |
+
stream=False
|
| 98 |
+
)
|
| 99 |
+
rewrite_prompt = response.choices[0].message.content
|
| 100 |
+
except Exception as e:
|
| 101 |
+
rewrite_prompt = prompt
|
| 102 |
+
print(f"ๅ็้่ฏฏ: {e}")
|
| 103 |
+
return rewrite_prompt
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
# ่พๅ
ฅๆ ผๅผ
|
| 107 |
+
import os
|
| 108 |
+
os.environ['DEEPSEEK_API_KEY'] = 'sk-xxxxxxxxxxxxxxxxxxxxxxxxx'
|
| 109 |
+
prompt = 'ไธไธชๅนด่ฝป็ไบๆดฒ็พๅฅณ'
|
| 110 |
+
rewrite_prompt = prompt_rewrite_deepseek( prompt )
|
| 111 |
+
|
| 112 |
+
print(rewrite_prompt)
|
| 113 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.11.0
|
| 2 |
+
certifi==2025.10.5
|
| 3 |
+
charset-normalizer==3.4.4
|
| 4 |
+
deepspeed==0.18.2
|
| 5 |
+
diffusers==0.35.2
|
| 6 |
+
filelock==3.20.0
|
| 7 |
+
fsspec==2025.10.0
|
| 8 |
+
hf-xet==1.2.0
|
| 9 |
+
huggingface-hub==0.36.0
|
| 10 |
+
idna==3.11
|
| 11 |
+
importlib_metadata==8.7.0
|
| 12 |
+
Jinja2==3.1.6
|
| 13 |
+
MarkupSafe==3.0.3
|
| 14 |
+
mpmath==1.3.0
|
| 15 |
+
networkx==3.4.2
|
| 16 |
+
numpy==2.2.6
|
| 17 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 18 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 19 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 20 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 21 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 22 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 23 |
+
nvidia-curand-cu12==10.3.5.147
|
| 24 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 25 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 26 |
+
nvidia-cusparselt-cu12==0.6.2
|
| 27 |
+
nvidia-nccl-cu12==2.21.5
|
| 28 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 29 |
+
nvidia-nvtx-cu12==12.4.127
|
| 30 |
+
packaging==25.0
|
| 31 |
+
pillow==12.0.0
|
| 32 |
+
psutil==7.1.3
|
| 33 |
+
PyYAML==6.0.3
|
| 34 |
+
regex==2025.11.3
|
| 35 |
+
requests==2.32.5
|
| 36 |
+
safetensors==0.6.2
|
| 37 |
+
sympy==1.13.1
|
| 38 |
+
tokenizers==0.22.1
|
| 39 |
+
torch==2.6.0
|
| 40 |
+
torchvision==0.21.0
|
| 41 |
+
tqdm==4.67.1
|
| 42 |
+
transformers==4.57.1
|
| 43 |
+
triton==3.2.0
|
| 44 |
+
typing_extensions==4.15.0
|
| 45 |
+
urllib3==2.5.0
|
| 46 |
+
zipp==3.23.0
|
| 47 |
+
openai==2.8.1
|
| 48 |
+
peft==0.18.0
|
scripts/inference_edit.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from transformers import AutoProcessor
|
| 4 |
+
from longcat_image.models import LongCatImageTransformer2DModel
|
| 5 |
+
from longcat_image.pipelines import LongCatImageEditPipeline
|
| 6 |
+
|
| 7 |
+
if __name__ == '__main__':
|
| 8 |
+
|
| 9 |
+
device = torch.device('cuda')
|
| 10 |
+
checkpoint_dir = './weights/LongCat-Image-Edit'
|
| 11 |
+
text_processor = AutoProcessor.from_pretrained( checkpoint_dir, subfolder = 'tokenizer' )
|
| 12 |
+
transformer = LongCatImageTransformer2DModel.from_pretrained( checkpoint_dir , subfolder = 'transformer',
|
| 13 |
+
torch_dtype=torch.bfloat16, use_safetensors=True).to(device)
|
| 14 |
+
|
| 15 |
+
pipe = LongCatImageEditPipeline.from_pretrained(
|
| 16 |
+
checkpoint_dir,
|
| 17 |
+
transformer=transformer,
|
| 18 |
+
text_processor=text_processor,
|
| 19 |
+
)
|
| 20 |
+
pipe.to(device, torch.bfloat16)
|
| 21 |
+
|
| 22 |
+
generator = torch.Generator("cpu").manual_seed(43)
|
| 23 |
+
img = Image.open('assets/test.png').convert('RGB')
|
| 24 |
+
prompt = 'ๅฐ็ซๅๆ็'
|
| 25 |
+
image = pipe(
|
| 26 |
+
img,
|
| 27 |
+
prompt,
|
| 28 |
+
negative_prompt='',
|
| 29 |
+
guidance_scale=4.5,
|
| 30 |
+
num_inference_steps=50,
|
| 31 |
+
num_images_per_prompt=1,
|
| 32 |
+
generator=generator
|
| 33 |
+
).images[0]
|
| 34 |
+
|
| 35 |
+
image.save('./edit_example.png')
|
scripts/inference_t2i.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from transformers import AutoProcessor
|
| 4 |
+
from longcat_image.models import LongCatImageTransformer2DModel
|
| 5 |
+
from longcat_image.pipelines import LongCatImagePipeline
|
| 6 |
+
from misc.prompt_rewrite_api import prompt_rewrite_deepseek
|
| 7 |
+
|
| 8 |
+
if __name__ == '__main__':
|
| 9 |
+
|
| 10 |
+
device = torch.device('cuda')
|
| 11 |
+
checkpoint_dir = './weights/LongCat-Image'
|
| 12 |
+
|
| 13 |
+
text_processor = AutoProcessor.from_pretrained( checkpoint_dir, subfolder = 'tokenizer' )
|
| 14 |
+
transformer = LongCatImageTransformer2DModel.from_pretrained( checkpoint_dir , subfolder = 'transformer',
|
| 15 |
+
torch_dtype=torch.bfloat16, use_safetensors=True).to(device)
|
| 16 |
+
|
| 17 |
+
pipe = LongCatImagePipeline.from_pretrained(
|
| 18 |
+
checkpoint_dir,
|
| 19 |
+
transformer=transformer,
|
| 20 |
+
text_processor=text_processor
|
| 21 |
+
)
|
| 22 |
+
pipe.to(device, torch.bfloat16)
|
| 23 |
+
|
| 24 |
+
prompt = 'ไธไธชๅนด่ฝป็ไบ่ฃๅฅณๆง๏ผ่บซ็ฉฟ้ป่ฒ้็ป่กซ๏ผๆญ้
็ฝ่ฒ้กน้พใๅฅน็ๅๆๆพๅจ่็ไธ๏ผ่กจๆ
ๆฌ้ใ่ๆฏๆฏไธๅ ต็ฒ็ณ็็ ๅข๏ผๅๅ็้ณๅ
ๆธฉๆๅฐๆดๅจๅฅน่บซไธ๏ผ่ฅ้ ๅบไธ็งๅฎ้่ๆธฉ้ฆจ็ๆฐๅดใ้ๅคด้็จไธญ่ท็ฆป่ง่ง๏ผ็ชๅบๅฅน็็ฅๆๅๆ้ฅฐ็็ป่ใๅ
็บฟๆๅๅฐๆๅจๅฅน็่ธไธ๏ผๅผบ่ฐๅฅน็ไบๅฎๅ้ฅฐๅ็่ดจๆ๏ผๅขๅ ็ป้ข็ๅฑๆฌกๆไธไบฒๅๅใๆดไธช็ป้ขๆๅพ็ฎๆด๏ผ็ ๅข็็บน็ไธ้ณๅ
็ๅ
ๅฝฑๆๆ็ธๅพ็ๅฝฐ๏ผ็ชๆพๅบไบบ็ฉ็ไผ้
ไธไปๅฎนใ'
|
| 25 |
+
|
| 26 |
+
enable_prompt_rewrite_api = False
|
| 27 |
+
if enable_prompt_rewrite_api:
|
| 28 |
+
prompt = prompt_rewrite_deepseek( prompt )
|
| 29 |
+
|
| 30 |
+
image = pipe(
|
| 31 |
+
prompt,
|
| 32 |
+
negative_prompt='',
|
| 33 |
+
height=768,
|
| 34 |
+
width=1344,
|
| 35 |
+
guidance_scale=4.5,
|
| 36 |
+
num_inference_steps=50,
|
| 37 |
+
num_images_per_prompt=1,
|
| 38 |
+
generator= torch.Generator("cpu").manual_seed(43),
|
| 39 |
+
enable_cfg_renorm=True,
|
| 40 |
+
enable_prompt_rewrite=True if not enable_prompt_rewrite_api else False
|
| 41 |
+
).images[0]
|
| 42 |
+
image.save('./t2i_example.png')
|
setup.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="longcat_image",
|
| 5 |
+
version="1.0",
|
| 6 |
+
author="LongCat-Image",
|
| 7 |
+
description="LongCat-Image training and inference codes.",
|
| 8 |
+
packages=find_packages(),
|
| 9 |
+
install_requires=[],
|
| 10 |
+
dependency_links=[],
|
| 11 |
+
)
|