Plat commited on
Commit
4b08319
ยท
1 Parent(s): 5a8f4fc
Files changed (8) hide show
  1. .gitignore +14 -0
  2. README.md +3 -3
  3. app.py +206 -122
  4. model/class_encoder.py +131 -0
  5. model/config.py +96 -0
  6. model/denoiser.py +833 -0
  7. model/pipeline.py +412 -0
  8. requirements.txt +2 -4
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ /models
13
+ /output
14
+ /notebooks
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: JiT AnimeFace Demo
3
- emoji: ๐Ÿ–ผ
4
- colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
  title: JiT AnimeFace Demo
3
+ emoji: ๐Ÿš€
4
+ colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 6.1.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -1,154 +1,238 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
 
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
 
 
 
 
 
 
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
 
39
- generator = torch.Generator().manual_seed(seed)
 
40
 
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
 
51
- return image, seed
 
52
 
 
53
 
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
 
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
 
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
 
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
- result = gr.Image(label="Result", show_label=False)
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
 
126
  )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
  minimum=1,
131
- maximum=50,
 
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
 
 
 
 
 
 
 
 
 
 
 
134
  )
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import json
4
+ import yaml
5
+ import os
6
 
 
 
7
  import torch
8
 
9
+ import gradio as gr
 
10
 
11
+ from huggingface_hub import hf_hub_download
 
 
 
12
 
13
+ from model.pipeline import JiTModel, JiTConfig
14
+ from model.config import ClassContextConfig
15
 
 
 
16
 
17
+ MODEL_REPO = os.environ.get("MODEL_REPO", "p1atdev/JiT-AnimeFace-experiment")
18
+ MODEL_PATH = os.environ.get(
19
+ "MODEL_PATH", "jit-b256-p16-cls/12-jit-animeface_00043e_033368s.safetensors"
20
+ )
21
+ LABEL2ID_PATH = os.environ.get("LABEL2ID_PATH", "jit-b256-p16-cls/label2id.json")
22
+ CONFIG_PATH = os.environ.get("CONFIG_PATH", "jit-b256-p16-cls/config.yml")
23
 
24
+ DEVICE = (
25
+ torch.device("cuda")
26
+ if torch.cuda.is_available()
27
+ else torch.device("mps")
28
+ if torch.backends.mps.is_available()
29
+ else torch.device("cpu")
30
+ )
31
+ MAX_TOKEN_LENGTH = 32
 
 
 
 
 
 
32
 
33
+ model_map: dict[str, JiTModel] = {} # {model_path: model}
34
+ label2id_map: dict[str, dict] = {} # {label2id_path: label2id}
35
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def get_file_path(repo: str, path: str) -> str:
38
+ """Hugging Face Hub ใ‹ใ‚‰ใƒ•ใ‚กใ‚คใƒซใ‚’ๅ–ๅพ—"""
39
 
40
+ return hf_hub_download(repo, path)
41
 
 
 
 
 
 
42
 
43
+ def load_label2id(label2id_path: str) -> dict:
44
+ """label2id.json ใ‚’่ชญใฟ่พผใ‚€"""
45
+ with open(label2id_path, "r") as f:
46
+ return json.load(f)
 
 
47
 
 
 
 
48
 
49
+ def load_config(config_path: str) -> JiTConfig:
50
+ """่จญๅฎšใƒ•ใ‚กใ‚คใƒซใ‚’่ชญใฟ่พผใ‚€"""
51
+ with open(config_path, "r") as f:
52
+ if config_path.endswith(".json"):
53
+ config_dict = json.load(f)
54
+ elif config_path.endswith((".yaml", ".yml")):
55
+ config_dict = yaml.safe_load(f)
56
+ else:
57
+ raise ValueError("Unsupported config file format. Use .json or .yaml/.yml")
58
 
59
+ return JiTConfig.model_validate(config_dict)
60
 
 
61
 
62
+ def load_model(
63
+ model_path: str,
64
+ label2id_path: str,
65
+ config_path: str,
66
+ device: torch.device,
67
+ ) -> tuple[JiTModel, dict]:
68
+ """ใƒขใƒ‡ใƒซใ‚’่ชญใฟ่พผใ‚€"""
69
 
70
+ if model_path in model_map: # use cache
71
+ model = model_map[model_path]
72
+ label2id = label2id_map[label2id_path]
73
+ return model, label2id
 
 
 
74
 
75
+ config = load_config(get_file_path(MODEL_REPO, config_path))
76
+ if isinstance(config.context_encoder, ClassContextConfig):
77
+ config.context_encoder.label2id_map_path = get_file_path(
78
+ MODEL_REPO, label2id_path
79
+ )
80
 
81
+ model = JiTModel.from_pretrained(
82
+ config=config,
83
+ checkpoint_path=get_file_path(MODEL_REPO, model_path),
84
+ )
85
+ model.eval()
86
+ model.requires_grad_(False)
87
+ model.to(device=device)
88
+ model_map[model_path] = model # cache
89
+
90
+ label2id = load_label2id(get_file_path(MODEL_REPO, label2id_path))
91
+ label2id_map[label2id_path] = label2id # cache
92
+
93
+ return model, label2id
94
+
95
+
96
+ @spaces.GPU(duration=5)
97
+ def generate_images(
98
+ prompt: str,
99
+ negative_prompt: str,
100
+ num_steps: int,
101
+ cfg_scale: float,
102
+ batch_size: int,
103
+ size: int,
104
+ seed: int,
105
+ #
106
+ model_path: str = MODEL_PATH,
107
+ label2id_path: str = LABEL2ID_PATH,
108
+ config_path: str = CONFIG_PATH,
109
+ progress=gr.Progress(track_tqdm=True),
110
+ ):
111
+ model, _label2id = load_model(
112
+ model_path=model_path,
113
+ label2id_path=label2id_path,
114
+ config_path=config_path,
115
+ device=DEVICE,
116
+ )
117
 
118
+ with torch.inference_mode():
119
+ images = model.generate(
120
+ prompt=[prompt] * batch_size,
121
+ negative_prompt=negative_prompt,
122
+ num_inference_steps=num_steps,
123
+ cfg_scale=cfg_scale,
124
+ height=size,
125
+ width=size,
126
+ max_token_length=MAX_TOKEN_LENGTH,
127
+ cfg_time_range=[0.1, 1.0],
128
+ seed=seed if seed >= 0 else None,
129
+ device=DEVICE,
130
+ execution_dtype=model.config.torch_dtype,
131
+ )
132
+
133
+ return images
134
+
135
+
136
+ def demo():
137
+ with gr.Blocks() as ui:
138
+ gr.Markdown(f"""
139
+ # JiT-AnimeFace Demo
140
+ Pixel-space x-prediction flow-matching model for anime face generation, trained from scratch.
141
+
142
+ See full supported tags: [label2id.json](https://huggingface.co/{MODEL_REPO}/blob/main/{LABEL2ID_PATH}).
143
+ """)
144
 
145
+ with gr.Row():
146
+ with gr.Column():
147
+ prompt = gr.TextArea(
148
+ label="Prompt",
149
+ info="Space-separated tags. Not all of danbooru tags are supported. See the link above for full supported tags.",
150
+ value="general 1girl solo portrait looking_at_viewer blue_hair short_hair blush cat_ears open_mouth cat_ears animal_ears red_eyes white_background",
151
+ placeholder="e.g.: general 1girl solo portrait looking_at_viewer",
152
+ )
153
+ negative_prompt = gr.TextArea(
154
+ label="Negative Prompt",
155
+ value="retro_artstyle 1990s_(style) sketch",
156
+ lines=2,
157
+ placeholder="e.g.: retro_artstyle 1990s_(style) sketch",
158
+ )
159
+ num_steps = gr.Slider(
160
+ minimum=1,
161
+ maximum=100,
162
+ value=25,
163
+ step=4,
164
+ label="Number of Steps",
165
+ )
166
+ cfg_scale = gr.Slider(
167
+ minimum=1.0,
168
  maximum=10.0,
169
+ value=3.0,
170
+ step=0.25,
171
+ label="CFG Scale",
172
  )
173
+ batch_size = gr.Slider(
 
 
174
  minimum=1,
175
+ maximum=64,
176
+ value=16,
177
  step=1,
178
+ label="Batch Size",
179
+ )
180
+ size = gr.Slider(
181
+ minimum=64,
182
+ maximum=320,
183
+ value=256,
184
+ step=64,
185
+ label="Image Size",
186
+ )
187
+ seed = gr.Number(
188
+ value=-1,
189
+ label="Seed (-1 for random)",
190
  )
191
 
192
+ with gr.Column(scale=2):
193
+ generate_button = gr.Button("Generate Images", variant="primary")
194
+ output_gallery = gr.Gallery(
195
+ label="Generated Images",
196
+ columns=4,
197
+ height="768px",
198
+ preview=False,
199
+ show_label=True,
200
+ )
201
+
202
+ gr.Examples(
203
+ examples=[
204
+ [
205
+ "general 1girl solo portrait looking_at_viewer blue_hair short_hair blush cat_ears open_mouth cat_ears animal_ears red_eyes white_background",
206
+ "retro_artstyle 1990s_(style) sketch",
207
+ ]
208
+ ],
209
+ inputs=[prompt, negative_prompt],
210
+ )
211
+
212
+ gr.on(
213
+ triggers=[generate_button.click, prompt.submit],
214
+ fn=generate_images,
215
+ inputs=[
216
+ prompt,
217
+ negative_prompt,
218
+ num_steps,
219
+ cfg_scale,
220
+ batch_size,
221
+ size,
222
+ seed,
223
+ ],
224
+ outputs=output_gallery,
225
+ )
226
+
227
+ return ui
228
+
229
 
230
  if __name__ == "__main__":
231
+ load_model(
232
+ model_path=MODEL_PATH,
233
+ label2id_path=LABEL2ID_PATH,
234
+ config_path=CONFIG_PATH,
235
+ device=DEVICE,
236
+ )
237
+
238
+ demo().launch()
model/class_encoder.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import NamedTuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ PromptType = str | list[str]
8
+
9
+
10
+ class ClassTokenizerOutput(NamedTuple):
11
+ class_ids: torch.Tensor
12
+ attention_mask: torch.Tensor
13
+
14
+
15
+ class ClassTokenizer:
16
+ def __init__(
17
+ self,
18
+ label2id: dict[str, int],
19
+ splitter: str = " ",
20
+ ) -> None:
21
+ self.label2id = label2id
22
+ self.id2label = {v: k for k, v in label2id.items()}
23
+ self.splitter = splitter
24
+
25
+ self.pad_token_id = len(label2id)
26
+
27
+ assert all([id < len(label2id) for id in label2id.values()]), (
28
+ "All label IDs must be less than the number of classes."
29
+ )
30
+
31
+ def normalize_prompts(
32
+ self,
33
+ class_names: PromptType,
34
+ ) -> list[str]:
35
+ _class_names: list[str] = (
36
+ class_names if isinstance(class_names, list) else [class_names]
37
+ )
38
+ return _class_names
39
+
40
+ def tokenize(
41
+ self,
42
+ prompts: PromptType,
43
+ max_length: int = 32,
44
+ ) -> ClassTokenizerOutput:
45
+ # 1. Normalize class names
46
+ _prompts = self.normalize_prompts(prompts)
47
+
48
+ # 2. Convert to IDs
49
+ class_ids = []
50
+ masks = []
51
+ for text in _prompts:
52
+ ids = []
53
+
54
+ for label in text.split(self.splitter):
55
+ if label.strip() == "":
56
+ continue
57
+ id = self.label2id.get(label.strip())
58
+ if id is not None: # 0 is OK
59
+ ids.append(id)
60
+ masks.append(1)
61
+ else:
62
+ warnings.warn(f"Label '{label}' not found in label2id mapping.")
63
+ class_ids.append(ids)
64
+
65
+ # 3. Pad to max_length
66
+ padded_class_ids = []
67
+ padded_masks = []
68
+
69
+ for _i, ids in enumerate(class_ids):
70
+ if len(ids) < max_length:
71
+ mask = [1] * len(ids) + [0] * (max_length - len(ids))
72
+ ids = ids + [self.pad_token_id] * (max_length - len(ids)) # padding idx
73
+ else:
74
+ mask = [1] * max_length
75
+ ids = ids[:max_length]
76
+
77
+ padded_class_ids.append(ids)
78
+ padded_masks.append(mask)
79
+
80
+ return ClassTokenizerOutput(
81
+ class_ids=torch.tensor(padded_class_ids, dtype=torch.long),
82
+ attention_mask=torch.tensor(padded_masks, dtype=torch.long),
83
+ )
84
+
85
+
86
+ class ClassEncoderOutput(NamedTuple):
87
+ embeddings: torch.Tensor
88
+ attention_mask: torch.Tensor
89
+
90
+
91
+ class ClassEncoder(nn.Module):
92
+ def __init__(
93
+ self,
94
+ label2id: dict[str, int],
95
+ embedding_dim: int,
96
+ ):
97
+ super().__init__()
98
+
99
+ self.num_classes = len(label2id)
100
+
101
+ self.pad_token_id = self.num_classes # padding idx
102
+
103
+ self.embedding = nn.Embedding(
104
+ self.num_classes + 1, # +1 for padding idx
105
+ embedding_dim,
106
+ padding_idx=self.num_classes,
107
+ )
108
+
109
+ self.tokenizer = ClassTokenizer(label2id)
110
+
111
+ def initialize_weights(self):
112
+ nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
113
+
114
+ def encode_prompts(
115
+ self,
116
+ prompts: PromptType,
117
+ max_token_length: int = 32,
118
+ ):
119
+ # 1. Tokenize prompts
120
+ class_ids, attention_mask = self.tokenizer.tokenize(
121
+ prompts,
122
+ max_length=max_token_length,
123
+ )
124
+
125
+ # 3. Get embeddings
126
+ embeddings = self.embedding(class_ids.to(self.embedding.weight.device))
127
+
128
+ return ClassEncoderOutput(
129
+ embeddings=embeddings,
130
+ attention_mask=attention_mask,
131
+ )
model/config.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+
4
+ from typing import Literal
5
+ from pydantic import BaseModel
6
+
7
+
8
+ FP32_STR = ["float32", "fp32"]
9
+ FP16_STR = ["float16", "fp16", "half"]
10
+ BF16_STR = ["bfloat16", "bf16"]
11
+
12
+
13
+ def str_to_dtype(dtype_str: str) -> torch.dtype:
14
+ dtype_str = dtype_str.lower()
15
+ if dtype_str in FP32_STR:
16
+ return torch.float32
17
+ elif dtype_str in FP16_STR:
18
+ return torch.float16
19
+ elif dtype_str in BF16_STR:
20
+ return torch.bfloat16
21
+ else:
22
+ raise ValueError(f"Unsupported dtype string: {dtype_str}")
23
+
24
+
25
+ class DenoiserConfig(BaseModel):
26
+ patch_size: int = 16
27
+ in_channels: int = 3
28
+ out_channels: int = 3
29
+ hidden_size: int = 1024
30
+ depth: int = 24
31
+ num_heads: int = 16
32
+ mlp_ratio: float = 4.0
33
+ attn_dropout: float = 0.0
34
+ proj_dropout: float = 0.0
35
+
36
+ bottleneck_dim: int = 128
37
+ num_time_tokens: int = 4
38
+
39
+ rope_theta: float = 256.0
40
+ rope_axes_dims: list[int] = [16, 24, 24]
41
+ rope_axes_lens: list[int] = [256, 128, 128]
42
+ rope_zero_centered: list[bool] = [False, True, True]
43
+
44
+ context_dim: int
45
+
46
+
47
+ class JiT_B_16_Config(DenoiserConfig):
48
+ patch_size: int = 16
49
+
50
+ depth: int = 12
51
+ hidden_size: int = 768
52
+ num_heads: int = 12
53
+ bottleneck_dim: int = 128
54
+
55
+ context_dim: int = 768
56
+
57
+ rope_axes_dims: list[int] = [16, 24, 24] # sum = 64 = 768 / 12
58
+ rope_axes_lens: list[int] = [
59
+ 256, # max 256 token text
60
+ 128, # 2048x2048 image size
61
+ 128,
62
+ ]
63
+
64
+
65
+ ContextType = Literal["class", "text"]
66
+
67
+
68
+ class ClassContextConfig(BaseModel):
69
+ type: Literal["class"] = "class"
70
+ label2id_map_path: str
71
+
72
+ @property
73
+ def label2id(self) -> dict[str, int]:
74
+ with open(self.label2id_map_path, "r") as f:
75
+ label2id = json.load(f)
76
+
77
+ return label2id
78
+
79
+
80
+ class TextContextConfig(BaseModel):
81
+ type: Literal["text"] = "text"
82
+ pretrained_model: str = "p1atdev/Qwen3-VL-2B-Instruct-Text-Only"
83
+
84
+
85
+ ContextConfig = ClassContextConfig | TextContextConfig
86
+
87
+
88
+ class JiTConfig(BaseModel):
89
+ dtype: str = "float32"
90
+
91
+ context_encoder: ContextConfig
92
+ denoiser: DenoiserConfig = JiT_B_16_Config()
93
+
94
+ @property
95
+ def torch_dtype(self) -> torch.dtype:
96
+ return str_to_dtype(self.dtype)
model/denoiser.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/LTH14/JiT/blob/main/model_jit.py
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint as checkpoint
8
+
9
+ import torch.nn.functional as F
10
+
11
+ from .config import DenoiserConfig
12
+
13
+
14
+ # https://github.com/huggingface/diffusers/blob/66bf7ea5be7099c8a47b9cba135f276d55247447/src/diffusers/models/embeddings.py#L27
15
+ def get_timestep_embedding(
16
+ timesteps: torch.Tensor,
17
+ embedding_dim: int,
18
+ flip_sin_to_cos: bool = False,
19
+ downscale_freq_shift: float = 1,
20
+ scale: float = 1,
21
+ max_period: int = 10000,
22
+ ):
23
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
24
+
25
+ half_dim = embedding_dim // 2
26
+ exponent = -math.log(max_period) * torch.arange(
27
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
28
+ )
29
+ exponent = exponent / (half_dim - downscale_freq_shift)
30
+
31
+ emb = torch.exp(exponent)
32
+ emb = timesteps[:, None].float() * emb[None, :]
33
+
34
+ # scale embeddings
35
+ emb = scale * emb
36
+
37
+ # concat sine and cosine embeddings
38
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
39
+
40
+ # flip sine and cosine embeddings
41
+ if flip_sin_to_cos:
42
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
43
+
44
+ # zero pad
45
+ if embedding_dim % 2 == 1:
46
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
47
+
48
+ return emb
49
+
50
+
51
+ class FP32RMSNorm(nn.RMSNorm):
52
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
53
+ return F.rms_norm(
54
+ hidden_states.to(torch.float32),
55
+ self.normalized_shape,
56
+ weight=self.weight,
57
+ eps=self.eps,
58
+ ).to(hidden_states.dtype)
59
+
60
+
61
+ class BottleneckPatchEmbed(nn.Module):
62
+ """Image to Patch Embedding"""
63
+
64
+ def __init__(
65
+ self,
66
+ patch_size: int = 16,
67
+ in_channels: int = 3,
68
+ bottleneck_dim: int = 128,
69
+ hidden_dim: int = 768,
70
+ bias: bool = True,
71
+ ):
72
+ super().__init__()
73
+
74
+ self.patch_size = patch_size
75
+ self.in_channels = in_channels
76
+ self.bottleneck_dim = bottleneck_dim
77
+ self.hidden_dim = hidden_dim
78
+ self.bias = bias
79
+
80
+ self.proj_1 = nn.Conv2d(
81
+ in_channels,
82
+ bottleneck_dim,
83
+ kernel_size=patch_size,
84
+ stride=patch_size,
85
+ bias=False,
86
+ )
87
+ self.proj_2 = nn.Conv2d(
88
+ bottleneck_dim,
89
+ hidden_dim,
90
+ kernel_size=1,
91
+ stride=1,
92
+ bias=bias,
93
+ )
94
+
95
+ def forward(self, image: torch.Tensor) -> torch.Tensor:
96
+ # B, C, H, W = image.shape
97
+
98
+ # [B, C, H, W]
99
+ # -> [B, bottleneck_dim, H/patch_size, W/patch_size] (proj_1)
100
+ # -> [B, hidden_dim, H/patch_size, W/patch_size] (proj_2)
101
+ # -> [B, hidden_dim, num_patches] (flatten)
102
+ # -> [B, num_patches, hidden_dim] (transpose)
103
+ patches = (
104
+ self.proj_2(
105
+ self.proj_1(image),
106
+ )
107
+ .flatten(2)
108
+ .transpose(1, 2)
109
+ )
110
+
111
+ return patches
112
+
113
+
114
+ class TimestepEmbedder(nn.Module):
115
+ def __init__(
116
+ self,
117
+ hidden_dim: int,
118
+ freq_embedding_size: int = 256,
119
+ ):
120
+ super().__init__()
121
+
122
+ self.freq_embedding_size = freq_embedding_size
123
+
124
+ self.mlp = nn.Sequential(
125
+ nn.Linear(freq_embedding_size, hidden_dim, bias=True),
126
+ nn.SiLU(),
127
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
128
+ )
129
+
130
+ def forward(self, timestep: torch.Tensor) -> torch.Tensor:
131
+ freq_emb = get_timestep_embedding(
132
+ timestep,
133
+ embedding_dim=self.freq_embedding_size,
134
+ flip_sin_to_cos=True,
135
+ downscale_freq_shift=0,
136
+ )
137
+ time_embed = self.mlp(freq_emb.to(dtype=self.mlp[0].weight.dtype))
138
+
139
+ return time_embed
140
+
141
+
142
+ def apply_rope(
143
+ inputs: torch.Tensor, # (batch_size, num_heads, seq_len, dim)
144
+ freqs_cis: torch.Tensor, # (batch_size, seq_len, dim//2) complex64
145
+ ) -> torch.Tensor:
146
+ batch_size, num_heads, seq_len, dim = inputs.shape
147
+
148
+ with torch.autocast(device_type="cuda", enabled=False):
149
+ inputs_cis = torch.view_as_complex(
150
+ inputs.float().view(batch_size, num_heads, seq_len, dim // 2, 2)
151
+ )
152
+ freqs_cis = freqs_cis.unsqueeze(1) # (batch_size, 1, seq_len, dim//2)
153
+ output = torch.view_as_real(inputs_cis * freqs_cis).flatten(3)
154
+
155
+ return output.type_as(inputs)
156
+
157
+
158
+ class RopeEmbedder:
159
+ def __init__(
160
+ self,
161
+ rope_theta: float = 256.0, # ref: Z-Image
162
+ axes_dims: list[int] = [32, 64, 64], # text, height, width
163
+ axes_lens: list[int] = [256, 128, 128], # text, height, width
164
+ zero_centered: list[bool] = [False, True, True],
165
+ ):
166
+ self.rope_theta = rope_theta
167
+ self.axes_dims = axes_dims
168
+ self.axes_lens = axes_lens
169
+ self.zero_centered = zero_centered
170
+
171
+ # text starts with 0, image axes are zero-centered
172
+
173
+ self.freqs_cis = self.precompute_freqs_cis(
174
+ theta=self.rope_theta,
175
+ dims=self.axes_dims,
176
+ lens=self.axes_lens,
177
+ zero_centered=self.zero_centered,
178
+ )
179
+
180
+ @staticmethod
181
+ def get_rope_freqs(
182
+ dim: int,
183
+ min_position: int = 0,
184
+ max_position: int = 128,
185
+ theta: float = 10000.0,
186
+ ) -> torch.Tensor:
187
+ freqs = 1.0 / (
188
+ theta
189
+ ** (
190
+ torch.arange(0, dim, 2, dtype=torch.float64, device=torch.device("cpu"))
191
+ / dim
192
+ )
193
+ )
194
+ positions = torch.arange(
195
+ start=min_position,
196
+ end=max_position,
197
+ dtype=torch.float64,
198
+ device=torch.device("cpu"),
199
+ )
200
+
201
+ freqs = torch.outer(positions, freqs).float() # (max_position, dim//2)
202
+ # โ†“pos, โ†’ dim//2
203
+ # [ min_position * [1/ฮธ^(0/dim), 1/ฮธ^(2/dim), 1/ฮธ^(4/dim), ..., 1/ฮธ^((dim-2)/dim)]
204
+ # ...
205
+ # 0 * [1/ฮธ^(0/dim), 1/ฮธ^(2/dim), 1/ฮธ^(4/dim), ..., 1/ฮธ^((dim-2)/dim)]
206
+ # 1 * [1/ฮธ^(0/dim), 1/ฮธ^(2/dim), 1/ฮธ^(4/dim), ..., 1/ฮธ^((dim-2)/dim)]
207
+ # 2 * [1/ฮธ^(0/dim), 1/ฮธ^(2/dim), 1/ฮธ^(4/dim), ..., 1/ฮธ^((dim-2)/dim)]
208
+ # ...
209
+ # max_position * [1/ฮธ^(0/dim), 1/ฮธ^(2/dim), 1/ฮธ^(4/dim), ..., 1/ฮธ^((dim-2)/dim)] ]
210
+
211
+ freqs_cis = torch.polar(
212
+ abs=torch.ones_like(freqs),
213
+ angle=freqs,
214
+ ).to(torch.complex64) # (min_position~max_position, dim//2) complex64
215
+
216
+ # ๅคงใใ•ใฏๅค‰ใˆใšใซๅ›ž่ปขใ‚’่กจใ™่ค‡็ด ๆ•ฐ
217
+ return freqs_cis
218
+
219
+ @staticmethod
220
+ def precompute_freqs_cis(
221
+ theta: float,
222
+ dims: list[int],
223
+ lens: list[int],
224
+ zero_centered: list[bool],
225
+ ):
226
+ freqs_cis = []
227
+
228
+ for i, (dim, len_) in enumerate(zip(dims, lens)):
229
+ freq_cis = RopeEmbedder.get_rope_freqs(
230
+ dim=dim,
231
+ min_position=(len_ // 2) - len_ if zero_centered[i] else 0,
232
+ max_position=len_ // 2 if zero_centered[i] else len_,
233
+ theta=theta,
234
+ ) # (len_, dim//2) complex64
235
+
236
+ freqs_cis.append(freq_cis)
237
+
238
+ return freqs_cis
239
+
240
+ # get frequencies for given position ids
241
+ def __call__(self, position_ids: torch.Tensor):
242
+ # move to device
243
+ freqs_cis = [fc.to(position_ids.device) for fc in self.freqs_cis]
244
+
245
+ result = []
246
+ for i in range(len(self.axes_dims)):
247
+ index = (
248
+ position_ids[..., i : i + 1]
249
+ .repeat(
250
+ # match dimensions for each axis
251
+ 1, # batch size?
252
+ 1, # sequence length?
253
+ freqs_cis[i].shape[-1],
254
+ )
255
+ .to(torch.int64)
256
+ )
257
+ result.append(
258
+ torch.gather(
259
+ freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1),
260
+ dim=1,
261
+ index=index,
262
+ )
263
+ )
264
+
265
+ return torch.cat(result, dim=-1)
266
+
267
+
268
+ class Attention(nn.Module):
269
+ def __init__(
270
+ self,
271
+ dim: int,
272
+ num_heads: int = 8,
273
+ qkv_bias: bool = True,
274
+ qk_norm: bool = True,
275
+ attn_dropout: float = 0.0,
276
+ proj_dropout: float = 0.0,
277
+ ):
278
+ super().__init__()
279
+
280
+ self.num_heads = num_heads
281
+ self.head_dim = dim // num_heads
282
+
283
+ self.q_norm = FP32RMSNorm(self.head_dim) if qk_norm else nn.Identity()
284
+ self.k_norm = FP32RMSNorm(self.head_dim) if qk_norm else nn.Identity()
285
+
286
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
287
+ self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
288
+ self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
289
+ self.attn_dropout = nn.Dropout(attn_dropout)
290
+
291
+ self.to_o = nn.Linear(dim, dim)
292
+ self.proj_dropout = nn.Dropout(proj_dropout)
293
+
294
+ def _pre_attn_reshape(self, x: torch.Tensor):
295
+ batch_size, seq_len, dim = x.shape
296
+
297
+ # [B, N, D] -> [B, N, num_heads, D/num_heads] -> [B, num_heads, N, D/num_heads]
298
+ x = x.view(
299
+ batch_size,
300
+ seq_len,
301
+ self.num_heads,
302
+ self.head_dim,
303
+ ).permute(0, 2, 1, 3) # [B, num_heads, N, head_dim]
304
+
305
+ return x
306
+
307
+ def _post_attn_reshape(self, x: torch.Tensor):
308
+ batch_size, num_heads, seq_len, head_dim = x.shape
309
+
310
+ # [B, num_heads, N, head_dim] -> [B, N, num_heads, head_dim] -> [B, N, D]
311
+ x = (
312
+ x.permute(0, 2, 1, 3)
313
+ .contiguous()
314
+ .view(batch_size, seq_len, num_heads * head_dim)
315
+ )
316
+
317
+ return x
318
+
319
+ def forward(
320
+ self,
321
+ hidden_states: torch.Tensor,
322
+ rope_freqs: torch.Tensor,
323
+ mask: torch.Tensor | None = None, # 1: attend, 0: ignore
324
+ ) -> torch.Tensor:
325
+ batch_size, seq_len, _dim = hidden_states.shape
326
+
327
+ # QKV
328
+ q = self.to_q(hidden_states)
329
+ k = self.to_k(hidden_states)
330
+ v = self.to_v(hidden_states)
331
+
332
+ q = self._pre_attn_reshape(q) # [B, num_heads, N, head_dim]
333
+ k = self._pre_attn_reshape(k)
334
+ v = self._pre_attn_reshape(v)
335
+
336
+ # QKNorm
337
+ q = self.q_norm(q)
338
+ k = self.k_norm(k)
339
+
340
+ q = apply_rope(q, rope_freqs)
341
+ k = apply_rope(k, rope_freqs)
342
+
343
+ if mask is not None:
344
+ # mask: (batch_size, seq_len) -> (batch_size, num_heads, seq_len, seq_len)
345
+ mask = (
346
+ mask.bool()
347
+ .view(batch_size, 1, 1, seq_len)
348
+ .expand(-1, self.num_heads, seq_len, -1)
349
+ )
350
+
351
+ attn = F.scaled_dot_product_attention(
352
+ q,
353
+ k,
354
+ v,
355
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
356
+ attn_mask=mask,
357
+ is_causal=False,
358
+ ).to(hidden_states.dtype)
359
+ attn = self._post_attn_reshape(attn)
360
+
361
+ # output
362
+ out = self.to_o(attn)
363
+ out = self.proj_dropout(out)
364
+
365
+ return out
366
+
367
+
368
+ class SwiGLU(nn.Module):
369
+ def __init__(
370
+ self,
371
+ dim: int,
372
+ hidden_dim: int,
373
+ dropout: float = 0.0,
374
+ bias: bool = True,
375
+ ):
376
+ super().__init__()
377
+
378
+ hidden_dim = int(hidden_dim * 2 / 3)
379
+
380
+ self.w_1 = nn.Linear(dim, hidden_dim, bias=bias)
381
+ self.w_2 = nn.Linear(dim, hidden_dim, bias=bias)
382
+ self.w_3 = nn.Linear(hidden_dim, dim, bias=bias)
383
+
384
+ self.ffn_dropout = nn.Dropout(dropout)
385
+
386
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
387
+ x_1 = self.w_1(hidden_states)
388
+ x_2 = self.w_2(hidden_states)
389
+
390
+ x = F.silu(x_1) * x_2
391
+
392
+ x = self.w_3(self.ffn_dropout(x))
393
+
394
+ return x
395
+
396
+
397
+ class FinalLayer(nn.Module):
398
+ def __init__(
399
+ self,
400
+ hidden_dim: int,
401
+ mlp_ratio: float,
402
+ patch_size: int,
403
+ out_channels: int,
404
+ ):
405
+ super().__init__()
406
+
407
+ self.norm_final = FP32RMSNorm(hidden_dim)
408
+
409
+ self.mlp = SwiGLU(
410
+ dim=hidden_dim,
411
+ hidden_dim=int(hidden_dim * mlp_ratio),
412
+ dropout=0.0,
413
+ bias=True,
414
+ )
415
+
416
+ self.linear = nn.Linear(
417
+ hidden_dim,
418
+ patch_size * patch_size * out_channels,
419
+ bias=True,
420
+ )
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ ) -> torch.Tensor:
426
+ x = self.norm_final(hidden_states)
427
+ x = self.mlp(x)
428
+ x = self.linear(x)
429
+
430
+ return x
431
+
432
+
433
+ class JiTBlock(nn.Module):
434
+ def __init__(
435
+ self,
436
+ hidden_dim: int,
437
+ num_heads: int,
438
+ mlp_ratio: float = 4.0,
439
+ attn_dropout: float = 0.0,
440
+ proj_dropout: float = 0.0,
441
+ ffn_dropout: float = 0.0,
442
+ qkv_bias: bool = True,
443
+ qk_norm: bool = True,
444
+ bias: bool = True,
445
+ ):
446
+ super().__init__()
447
+
448
+ self.norm1 = FP32RMSNorm(hidden_dim, eps=1e-6)
449
+ self.attn = Attention(
450
+ dim=hidden_dim,
451
+ num_heads=num_heads,
452
+ qkv_bias=qkv_bias,
453
+ qk_norm=qk_norm,
454
+ attn_dropout=attn_dropout,
455
+ proj_dropout=proj_dropout,
456
+ )
457
+
458
+ self.norm2 = FP32RMSNorm(hidden_dim)
459
+ self.mlp = SwiGLU(
460
+ dim=hidden_dim,
461
+ hidden_dim=int(hidden_dim * mlp_ratio),
462
+ dropout=ffn_dropout,
463
+ bias=bias,
464
+ )
465
+
466
+ def forward(
467
+ self,
468
+ hidden_states: torch.Tensor,
469
+ rope_freqs: torch.Tensor,
470
+ mask: torch.Tensor | None = None,
471
+ ):
472
+ # attn
473
+ hidden_states = hidden_states + self.attn(
474
+ self.norm1(hidden_states),
475
+ rope_freqs,
476
+ mask=mask,
477
+ )
478
+
479
+ # mlp
480
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
481
+
482
+ return hidden_states
483
+
484
+
485
+ class JiT(nn.Module):
486
+ def __init__(self, config: DenoiserConfig):
487
+ super().__init__()
488
+
489
+ self.config = config
490
+
491
+ assert (config.hidden_size // config.num_heads) == sum(config.rope_axes_dims), (
492
+ "The sum of rope_axes_dims must equal to hidden_size / num_heads = head_dim."
493
+ )
494
+
495
+ self.num_axes = len(
496
+ config.rope_axes_dims
497
+ ) # 0: image_index, 1: height, 2: width
498
+
499
+ # image patch embedder
500
+ self.patch_embedder = BottleneckPatchEmbed(
501
+ patch_size=config.patch_size,
502
+ in_channels=config.in_channels,
503
+ bottleneck_dim=config.bottleneck_dim,
504
+ hidden_dim=config.hidden_size,
505
+ bias=True,
506
+ )
507
+
508
+ # timestep embedder
509
+ self.time_embedder = TimestepEmbedder(
510
+ hidden_dim=config.hidden_size,
511
+ freq_embedding_size=256,
512
+ )
513
+ self.time_position_embeds = nn.Parameter(
514
+ torch.randn(
515
+ config.num_time_tokens,
516
+ config.hidden_size,
517
+ ),
518
+ requires_grad=True,
519
+ )
520
+
521
+ # RoPE embedder
522
+ self.rope_embedder = RopeEmbedder(
523
+ rope_theta=config.rope_theta,
524
+ axes_dims=config.rope_axes_dims,
525
+ axes_lens=config.rope_axes_lens,
526
+ zero_centered=config.rope_zero_centered,
527
+ )
528
+
529
+ # class condition or text embedding
530
+ self.context_embedder = nn.Linear(
531
+ config.context_dim,
532
+ config.hidden_size,
533
+ bias=True,
534
+ )
535
+
536
+ self.blocks = nn.ModuleList(
537
+ [
538
+ JiTBlock(
539
+ hidden_dim=config.hidden_size,
540
+ num_heads=config.num_heads,
541
+ mlp_ratio=config.mlp_ratio,
542
+ attn_dropout=config.attn_dropout,
543
+ proj_dropout=config.proj_dropout,
544
+ ffn_dropout=0.0,
545
+ qkv_bias=True,
546
+ qk_norm=True,
547
+ bias=True,
548
+ )
549
+ for _ in range(config.depth)
550
+ ]
551
+ )
552
+
553
+ self.final_layer = FinalLayer(
554
+ hidden_dim=config.hidden_size,
555
+ mlp_ratio=config.mlp_ratio,
556
+ patch_size=config.patch_size,
557
+ out_channels=config.in_channels,
558
+ )
559
+
560
+ self.gradient_checkpointing = False
561
+
562
+ def initialize_weights(self):
563
+ # Initialize weights
564
+ for m in self.modules():
565
+ if isinstance(m, nn.Linear):
566
+ nn.init.xavier_uniform_(m.weight)
567
+ if m.bias is not None:
568
+ nn.init.zeros_(m.bias)
569
+ elif isinstance(m, nn.RMSNorm):
570
+ nn.init.ones_(m.weight)
571
+
572
+ # patch embed
573
+ w_1 = self.patch_embedder.proj_1.weight
574
+ nn.init.xavier_uniform_(w_1.view([w_1.shape[0], -1]))
575
+ w_2 = self.patch_embedder.proj_2.weight
576
+ nn.init.xavier_uniform_(w_2.view([w_2.shape[0], -1]))
577
+ if self.patch_embedder.proj_2.bias is not None:
578
+ nn.init.zeros_(self.patch_embedder.proj_2.bias)
579
+
580
+ # time position embeds
581
+ nn.init.normal_(
582
+ self.time_position_embeds,
583
+ std=0.02,
584
+ )
585
+
586
+ # time embedder
587
+ nn.init.normal_(
588
+ self.time_embedder.mlp[0].weight, # type: ignore
589
+ std=0.02,
590
+ )
591
+ nn.init.normal_(
592
+ self.time_embedder.mlp[2].weight, # type: ignore
593
+ std=0.02,
594
+ )
595
+
596
+ def set_gradient_checkpointing(self, enable: bool = True):
597
+ self.gradient_checkpointing = enable
598
+
599
+ def prepare_image_position_ids(
600
+ self,
601
+ height: int,
602
+ width: int,
603
+ image_index: int,
604
+ ) -> torch.Tensor:
605
+ # [H/patch_size, W/patch_size]
606
+
607
+ patch_size = self.config.patch_size
608
+ h_patches = height // patch_size
609
+ w_patches = width // patch_size
610
+
611
+ position_ids = torch.zeros(
612
+ h_patches,
613
+ w_patches,
614
+ self.num_axes,
615
+ )
616
+
617
+ # image_index
618
+ position_ids[:, :, 0] = image_index # image
619
+
620
+ # height (y-index)
621
+ position_ids[:, :, 1] = (
622
+ torch.arange(
623
+ h_patches,
624
+ )
625
+ .unsqueeze(1)
626
+ .repeat(1, w_patches)
627
+ )
628
+ # width (x-index)
629
+ position_ids[:, :, 2] = (
630
+ torch.arange(
631
+ w_patches,
632
+ )
633
+ .unsqueeze(0)
634
+ .repeat(h_patches, 1)
635
+ )
636
+
637
+ return position_ids.view(-1, self.num_axes) # (num_patches, n_axes)
638
+
639
+ def prepare_context_position_ids(
640
+ self,
641
+ seq_len: int,
642
+ context_start_index: int = 0,
643
+ xy_position: int = 0,
644
+ ) -> torch.Tensor:
645
+ position_ids = torch.zeros(
646
+ seq_len,
647
+ self.num_axes,
648
+ )
649
+
650
+ # context_index (0, ..., seq_len-1)
651
+ position_ids[:, 0] = torch.arange(
652
+ context_start_index,
653
+ context_start_index + seq_len,
654
+ ) # text
655
+
656
+ # token indices are (0, 0)...(0, 0)
657
+ position_ids[:, 1] = xy_position
658
+ position_ids[:, 2] = xy_position
659
+
660
+ return position_ids
661
+
662
+ def prepare_time_position_ids(
663
+ self,
664
+ seq_len: int,
665
+ time_start_index: int,
666
+ xy_position: int = 0,
667
+ ) -> torch.Tensor:
668
+ position_ids = torch.zeros(
669
+ seq_len,
670
+ self.num_axes,
671
+ )
672
+
673
+ # time_index
674
+ position_ids[:, 0] = torch.arange(
675
+ time_start_index, time_start_index + seq_len
676
+ ) # time
677
+
678
+ # token indices are (0, 0)...(0, 0)
679
+ position_ids[:, 1] = xy_position
680
+ position_ids[:, 2] = xy_position
681
+
682
+ return position_ids
683
+
684
+ def unpatchify(
685
+ self,
686
+ patches: torch.Tensor,
687
+ height: int,
688
+ width: int,
689
+ ) -> torch.Tensor:
690
+ batch_size, num_patches, _patch_dim = patches.shape
691
+
692
+ patch_size = self.config.patch_size
693
+ out_channels = self.config.out_channels
694
+
695
+ h_patches = height // patch_size
696
+ w_patches = width // patch_size
697
+
698
+ assert num_patches == h_patches * w_patches, "Mismatch in number of patches"
699
+
700
+ # [B, N, patch_size*patch_size*C] -> [B, H_patch, W_patch, patch_size, patch_size, C]
701
+ patches = patches.view(
702
+ batch_size,
703
+ h_patches,
704
+ w_patches,
705
+ patch_size,
706
+ patch_size,
707
+ out_channels,
708
+ )
709
+
710
+ # [B, H_patch, W_patch, patch_size, patch_size, C]
711
+ # -> [B, C, H_patch, patch_size, W_patch, patch_size]
712
+ patches = patches.permute(0, 5, 1, 3, 2, 4)
713
+ # -> [B, C, H_img, W_img]
714
+ images = patches.reshape(batch_size, out_channels, height, width)
715
+
716
+ return images
717
+
718
+ def forward(
719
+ self,
720
+ image: torch.Tensor, # [B, C, H, W]
721
+ timestep: torch.Tensor, # [B]
722
+ context: torch.Tensor, # [B, context_len, context_dim]
723
+ context_mask: torch.Tensor | None = None, # [B, context_len]
724
+ ):
725
+ batch_size, _in_channels, height, width = image.shape
726
+
727
+ time_embed: torch.Tensor = self.time_embedder(timestep) # [B, hidden_dim]
728
+ time_tokens = time_embed.unsqueeze(1).repeat( # add seq_len dim
729
+ 1,
730
+ self.time_position_embeds.shape[0], # num_time_tokens
731
+ 1,
732
+ ) + self.time_position_embeds.unsqueeze(0).repeat( # add batch dim
733
+ batch_size,
734
+ 1,
735
+ 1,
736
+ ) # [B, num_time_tokens, hidden_dim]
737
+ num_time_tokens = time_tokens.shape[1]
738
+
739
+ context_embed = self.context_embedder(context)
740
+ context_len = context_embed.shape[1]
741
+
742
+ patches = self.patch_embedder(image) # [B, N, hidden_dim]]
743
+ patches_len = patches.shape[1]
744
+
745
+ # context -> time -> patches
746
+ context_position_ids = self.prepare_context_position_ids(
747
+ seq_len=context_len,
748
+ context_start_index=0,
749
+ )
750
+ time_position_ids = self.prepare_time_position_ids(
751
+ seq_len=num_time_tokens,
752
+ time_start_index=context_len,
753
+ )
754
+ patches_position_ids = self.prepare_image_position_ids(
755
+ height=height,
756
+ width=width,
757
+ image_index=context_len + num_time_tokens, # after context and time tokens
758
+ )
759
+
760
+ # actually: patches -> time -> context
761
+ position_ids = torch.cat(
762
+ [
763
+ patches_position_ids,
764
+ time_position_ids,
765
+ context_position_ids,
766
+ ],
767
+ dim=0,
768
+ ).view(1, -1, self.num_axes) # (1, total_seq_len, n_axes)
769
+
770
+ # prepare RoPE
771
+ freqs_cis = (
772
+ self.rope_embedder(position_ids=position_ids)
773
+ .repeat(
774
+ batch_size,
775
+ 1,
776
+ 1,
777
+ )
778
+ .to(device=image.device)
779
+ )
780
+
781
+ # attention mask
782
+ if context_mask is not None:
783
+ patches_mask = torch.ones(batch_size, patches_len, device=image.device)
784
+ time_mask = torch.ones(batch_size, num_time_tokens, device=image.device)
785
+ mask = torch.cat(
786
+ [
787
+ patches_mask,
788
+ time_mask,
789
+ context_mask.to(image.device),
790
+ ],
791
+ dim=1,
792
+ )
793
+ else:
794
+ # attend all
795
+ mask = torch.ones(
796
+ batch_size,
797
+ patches_len + num_time_tokens + context_len,
798
+ device=image.device,
799
+ )
800
+
801
+ for _i, block in enumerate(self.blocks):
802
+ tokens = torch.cat(
803
+ [
804
+ patches, # 16x16
805
+ time_tokens, # 4
806
+ context_embed, # 64
807
+ ],
808
+ dim=1, # cat in seq_len dimension
809
+ )
810
+
811
+ if self.gradient_checkpointing and self.training:
812
+ patches = checkpoint.checkpoint( # type: ignore
813
+ block,
814
+ tokens,
815
+ freqs_cis,
816
+ mask,
817
+ )[:, :patches_len, :]
818
+ else:
819
+ patches = block(
820
+ tokens,
821
+ rope_freqs=freqs_cis,
822
+ mask=mask,
823
+ )[:, :patches_len, :] # only keep patch tokens
824
+
825
+ patches = self.final_layer(patches)
826
+
827
+ pred_image = self.unpatchify(
828
+ patches,
829
+ height=height,
830
+ width=width,
831
+ )
832
+
833
+ return pred_image
model/pipeline.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from PIL import Image
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+ from accelerate import init_empty_weights
9
+ from safetensors.torch import load_file
10
+
11
+
12
+ from .denoiser import JiT
13
+ from .class_encoder import ClassEncoder
14
+ from .config import JiTConfig, ClassContextConfig
15
+ # from .text_encoder import TextEncoder
16
+
17
+ # from ...modules.quant import replace_by_prequantized_weights
18
+ # from ...utils import tensor as tensor_utils
19
+
20
+
21
+ def tensor_to_images(
22
+ tensor: torch.Tensor,
23
+ ) -> list[Image.Image]:
24
+ # -1~1 -> 0~255
25
+
26
+ # denormalize
27
+ tensor = tensor.clamp(-1.0, 1.0)
28
+ tensor = (tensor + 1.0) / 2.0 * 255.0
29
+
30
+ # permute
31
+ tensor = tensor.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
32
+
33
+ # convert to numpy array
34
+ image_array = tensor.cpu().float().numpy().astype(np.uint8)
35
+
36
+ return [Image.fromarray(image) for image in image_array]
37
+
38
+
39
+ class JiTModel(nn.Module):
40
+ denoiser: JiT
41
+ denoiser_class: type[JiT] = JiT
42
+
43
+ class_encoder: ClassEncoder
44
+
45
+ def __init__(
46
+ self,
47
+ config: JiTConfig,
48
+ ):
49
+ super().__init__()
50
+
51
+ self.config = config
52
+
53
+ self.denoiser = self.denoiser_class(config.denoiser)
54
+
55
+ if isinstance(config.context_encoder, ClassContextConfig):
56
+ self.class_encoder = ClassEncoder(
57
+ label2id=config.context_encoder.label2id,
58
+ embedding_dim=config.denoiser.context_dim,
59
+ )
60
+ else:
61
+ raise NotImplementedError(
62
+ "Only ClassContextConfig is supported in this version."
63
+ )
64
+
65
+ self.progress_bar = tqdm
66
+
67
+ def _load_checkpoint(
68
+ self,
69
+ checkpoint_path: str,
70
+ strict: bool = True,
71
+ ):
72
+ state_dict = load_file(checkpoint_path)
73
+
74
+ # replace_by_prequantized_weights(self, state_dict)
75
+
76
+ self.denoiser.load_state_dict(
77
+ {
78
+ key[len("denoiser.") :]: value
79
+ for key, value in state_dict.items()
80
+ if key.startswith("denoiser.")
81
+ },
82
+ strict=strict,
83
+ assign=True,
84
+ )
85
+ if self.class_encoder is not None:
86
+ self.class_encoder.load_state_dict(
87
+ {
88
+ key[len("class_encoder.") :]: value
89
+ for key, value in state_dict.items()
90
+ if key.startswith("class_encoder.")
91
+ },
92
+ strict=strict,
93
+ assign=True,
94
+ )
95
+ # if self.text_encoder is not None:
96
+ # self.text_encoder.model.load_state_dict(
97
+ # {
98
+ # key[len("text_encoder.") :]: value
99
+ # for key, value in state_dict.items()
100
+ # if key.startswith("text_encoder.")
101
+ # },
102
+ # strict=strict,
103
+ # assign=True,
104
+ # )
105
+
106
+ @classmethod
107
+ def from_pretrained(
108
+ cls,
109
+ config: JiTConfig,
110
+ checkpoint_path: str,
111
+ ) -> "JiTModel":
112
+ with init_empty_weights():
113
+ model = cls(config)
114
+
115
+ model._load_checkpoint(checkpoint_path)
116
+
117
+ return model
118
+
119
+ @classmethod
120
+ def new_with_config(
121
+ cls,
122
+ config: JiTConfig,
123
+ ) -> "JiTModel":
124
+ with init_empty_weights():
125
+ model = cls(config)
126
+
127
+ model.denoiser.to_empty(device="cpu")
128
+ model.denoiser.initialize_weights()
129
+
130
+ if isinstance(config.context_encoder, ClassContextConfig):
131
+ model.class_encoder.to_empty(device="cpu")
132
+ model.class_encoder.initialize_weights()
133
+ else:
134
+ # model.text_encoder = TextEncoder.from_remote(
135
+ # repo_id=config.context_encoder.pretrained_model,
136
+ # )
137
+ raise NotImplementedError(
138
+ "Only ClassContextConfig is supported in this version."
139
+ )
140
+
141
+ return model
142
+
143
+ def prepare_noisy_image(
144
+ self,
145
+ batch_size: int,
146
+ height: int,
147
+ width: int,
148
+ dtype: torch.dtype,
149
+ device: torch.device,
150
+ seed: int | None = None,
151
+ ):
152
+ if seed is not None:
153
+ generator = torch.Generator(device=device)
154
+ generator.manual_seed(seed)
155
+ noise = torch.randn(
156
+ (batch_size, 3, height, width),
157
+ dtype=dtype,
158
+ device=device,
159
+ generator=generator,
160
+ )
161
+ else:
162
+ noise = torch.randn(
163
+ (batch_size, 3, height, width),
164
+ dtype=dtype,
165
+ device=device,
166
+ )
167
+
168
+ return noise
169
+
170
+ def prepare_timesteps(
171
+ self,
172
+ num_inference_steps: int,
173
+ device: torch.device,
174
+ ):
175
+ timesteps = torch.linspace(
176
+ 0.0,
177
+ 1.0,
178
+ num_inference_steps + 1,
179
+ device=device,
180
+ )
181
+
182
+ return timesteps
183
+
184
+ def prepare_context_embeddings(
185
+ self,
186
+ prompts: str | list[str],
187
+ negative_prompt: str | list[str],
188
+ max_token_length: int = 64,
189
+ do_cfg: bool = False,
190
+ ):
191
+ # if self.text_encoder is not None:
192
+ # encoder_output = self.text_encoder.encode_prompts(
193
+ # prompts,
194
+ # negative_prompts=negative_prompt,
195
+ # use_negative_prompts=do_cfg,
196
+ # max_token_length=max_token_length,
197
+ # )
198
+ # if do_cfg:
199
+ # prompt_embeddings = torch.cat(
200
+ # [
201
+ # encoder_output.positive_embeddings,
202
+ # encoder_output.negative_embeddings,
203
+ # ]
204
+ # )
205
+ # attention_mask = torch.cat(
206
+ # [
207
+ # encoder_output.positive_attention_mask,
208
+ # encoder_output.negative_attention_mask,
209
+ # ]
210
+ # )
211
+ # else:
212
+ # prompt_embeddings = encoder_output.positive_embeddings
213
+ # attention_mask = encoder_output.positive_attention_mask
214
+
215
+ if self.class_encoder is not None:
216
+ embeddings, attention_mask = self.class_encoder.encode_prompts(
217
+ prompts,
218
+ max_token_length=max_token_length,
219
+ )
220
+ negative_embeddings, _ = self.class_encoder.encode_prompts(
221
+ negative_prompt,
222
+ max_token_length=max_token_length,
223
+ )
224
+ if do_cfg:
225
+ prompt_embeddings = torch.cat(
226
+ [
227
+ embeddings,
228
+ negative_embeddings,
229
+ ],
230
+ dim=0,
231
+ )
232
+ attention_mask = torch.cat(
233
+ [
234
+ attention_mask,
235
+ attention_mask,
236
+ ],
237
+ dim=0,
238
+ )
239
+ else:
240
+ prompt_embeddings = embeddings
241
+ else:
242
+ raise NotImplementedError("Only ClassEncoder is supported in this version.")
243
+
244
+ return prompt_embeddings, attention_mask
245
+
246
+ def to_pil_images(self, tensor: torch.Tensor) -> list[Image.Image]:
247
+ return tensor_to_images(tensor)
248
+
249
+ def image_to_velocity(
250
+ self,
251
+ image: torch.Tensor,
252
+ noisy: torch.Tensor,
253
+ timestep: torch.Tensor,
254
+ clamp_eps: float = 1e-5,
255
+ ):
256
+ return (image - noisy) / (1 - timestep.view(-1, 1, 1, 1)).clamp_min_(clamp_eps)
257
+
258
+ def renorm_cfg(
259
+ self,
260
+ positive_velocity: torch.Tensor,
261
+ cfg_velocity: torch.Tensor,
262
+ ) -> torch.Tensor:
263
+ positive_norm = torch.norm(positive_velocity, dim=-1, keepdim=True)
264
+ cfg_norm = torch.norm(cfg_velocity, dim=-1, keepdim=True)
265
+
266
+ new_cfg_velocity = cfg_velocity * (positive_norm / cfg_norm)
267
+
268
+ return new_cfg_velocity
269
+
270
+ def dynamic_thresholding(
271
+ self,
272
+ images: torch.Tensor,
273
+ percentile: float = 0.995,
274
+ ) -> torch.Tensor:
275
+ """
276
+ Apply dynamic thresholding to the images.
277
+ Args:
278
+ images (torch.Tensor): The input images tensor.
279
+ percentile (float): The percentile value for thresholding.
280
+ Returns:
281
+ torch.Tensor: The thresholded images tensor.
282
+ """
283
+ batch_size = images.shape[0]
284
+ flattened_images = images.view(batch_size, -1)
285
+ abs_images = torch.abs(flattened_images)
286
+
287
+ s = torch.quantile(abs_images, percentile, dim=1, keepdim=True)
288
+ s = torch.clamp(s, min=1.0).view(batch_size, 1, 1, 1)
289
+
290
+ thresholded_images = torch.clamp(images, -s, s) / s
291
+
292
+ return thresholded_images
293
+
294
+ def normalize_prompts(
295
+ self,
296
+ prompt: str | list[str],
297
+ ) -> list[str]:
298
+ return prompt if isinstance(prompt, list) else [prompt]
299
+
300
+ @torch.inference_mode()
301
+ def generate(
302
+ self,
303
+ prompt: str | list[str],
304
+ negative_prompt: str | list[str] | None = None,
305
+ width: int = 256,
306
+ height: int = 256,
307
+ num_inference_steps: int = 20,
308
+ cfg_scale: float = 2.0,
309
+ max_token_length: int = 64,
310
+ seed: int | None = None,
311
+ execution_dtype: torch.dtype = torch.bfloat16,
312
+ device: torch.device | str = torch.device("cuda"),
313
+ do_cfg_renorm: bool = False,
314
+ do_dynamic_thresholding: bool = False,
315
+ cfg_time_range: list[float] = [0.0, 1.0],
316
+ # do_offloading: bool = False,
317
+ ):
318
+ # 1. Prepare args
319
+ execution_device: torch.device = (
320
+ torch.device(device) if isinstance(device, str) else device
321
+ )
322
+ do_cfg = cfg_scale > 1.0
323
+ timesteps = self.prepare_timesteps(
324
+ num_inference_steps=num_inference_steps,
325
+ device=execution_device,
326
+ )
327
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
328
+
329
+ # 3. prepare noise
330
+ noisy_image = self.prepare_noisy_image(
331
+ batch_size=batch_size,
332
+ height=height,
333
+ width=width,
334
+ dtype=execution_dtype,
335
+ device=execution_device,
336
+ seed=seed,
337
+ )
338
+
339
+ negative_prompts = [""] if negative_prompt is None else negative_prompt
340
+ negative_prompts = self.normalize_prompts(negative_prompts)
341
+ if len(negative_prompts) != batch_size and len(negative_prompts) == 1:
342
+ negative_prompts = negative_prompts * batch_size
343
+
344
+ prompt_embeddings, attention_mask = self.prepare_context_embeddings(
345
+ prompts=prompt,
346
+ negative_prompt=negative_prompts,
347
+ max_token_length=max_token_length,
348
+ do_cfg=do_cfg,
349
+ )
350
+
351
+ # 4. Denoising loop
352
+ with self.progress_bar(total=num_inference_steps) as pbar:
353
+ for i, timestep in enumerate(timesteps[:-1]):
354
+ image_input = torch.cat([noisy_image] * 2) if do_cfg else noisy_image
355
+
356
+ batch_timestep = timestep.expand(image_input.shape[0])
357
+
358
+ model_pred = self.denoiser(
359
+ image=image_input,
360
+ timestep=batch_timestep,
361
+ context=prompt_embeddings,
362
+ context_mask=attention_mask,
363
+ )
364
+
365
+ if do_cfg and cfg_time_range[0] <= float(timestep) <= cfg_time_range[1]:
366
+ image_pred_positive, image_pred_negative = model_pred.chunk(2)
367
+ v_pred_positive = self.image_to_velocity(
368
+ image=image_pred_positive,
369
+ noisy=noisy_image,
370
+ timestep=timestep.expand(batch_size),
371
+ )
372
+ v_pred_negative = self.image_to_velocity(
373
+ image=image_pred_negative,
374
+ noisy=noisy_image,
375
+ timestep=timestep.expand(batch_size),
376
+ )
377
+ velocity = v_pred_positive + cfg_scale * (
378
+ v_pred_positive - v_pred_negative
379
+ )
380
+ if do_cfg_renorm:
381
+ velocity = self.renorm_cfg(
382
+ positive_velocity=v_pred_positive,
383
+ cfg_velocity=velocity,
384
+ )
385
+ if do_dynamic_thresholding:
386
+ # re-calculate the image prediction after cfg
387
+ image_pred = noisy_image + velocity * (1 - timestep)
388
+ image_pred = self.dynamic_thresholding(image_pred)
389
+ velocity = self.image_to_velocity(
390
+ image=image_pred,
391
+ noisy=noisy_image,
392
+ timestep=timestep.expand(batch_size),
393
+ )
394
+ else:
395
+ velocity = self.image_to_velocity(
396
+ image=model_pred[:batch_size],
397
+ noisy=noisy_image,
398
+ timestep=timestep.expand(batch_size),
399
+ )
400
+
401
+ # new noisy image
402
+ noisy_image = noisy_image + velocity * (timesteps[i + 1] - timestep)
403
+
404
+ pbar.update()
405
+
406
+ # now it should be clean
407
+ clean_image = noisy_image
408
+
409
+ # to PIL images
410
+ pil_images = self.to_pil_images(clean_image.cpu())
411
+
412
+ return pil_images
requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
  torch
5
  transformers
6
- xformers
 
1
+ spaces
 
 
2
  torch
3
  transformers
4
+ huggingface-hub