k4d3 commited on
Commit
13c6de4
1 Parent(s): c29c862

Signed-off-by: Balazs Horvath <acsipont@gmail.com>

jtp2.py DELETED
@@ -1,161 +0,0 @@
1
- import os
2
- import json
3
- from PIL import Image
4
- import safetensors.torch
5
- import timm
6
- from timm.models import VisionTransformer
7
- import torch
8
- from torchvision.transforms import transforms
9
- from torchvision.transforms import InterpolationMode
10
- import torchvision.transforms.functional as TF
11
- import argparse
12
- import pillow_jxl
13
-
14
- torch.set_grad_enabled(False)
15
-
16
- class Fit(torch.nn.Module):
17
- def __init__(self, bounds: tuple[int, int] | int, interpolation=InterpolationMode.LANCZOS, grow: bool = True, pad: float | None = None):
18
- super().__init__()
19
- self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds
20
- self.interpolation = interpolation
21
- self.grow = grow
22
- self.pad = pad
23
-
24
- def forward(self, img: Image) -> Image:
25
- wimg, himg = img.size
26
- hbound, wbound = self.bounds
27
- hscale = hbound / himg
28
- wscale = wbound / wimg
29
- if not self.grow:
30
- hscale = min(hscale, 1.0)
31
- wscale = min(wscale, 1.0)
32
- scale = min(hscale, wscale)
33
- if scale == 1.0:
34
- return img
35
- hnew = min(round(himg * scale), hbound)
36
- wnew = min(round(wimg * scale), wbound)
37
- img = TF.resize(img, (hnew, wnew), self.interpolation)
38
- if self.pad is None:
39
- return img
40
- hpad = hbound - hnew
41
- wpad = wbound - wnew
42
- tpad = hpad // 2
43
- bpad = hpad - tpad
44
- lpad = wpad // 2
45
- rpad = wpad - lpad
46
- return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad)
47
-
48
- def __repr__(self) -> str:
49
- return f"{self.__class__.__name__}(bounds={self.bounds}, interpolation={self.interpolation.value}, grow={self.grow}, pad={self.pad})"
50
-
51
- class CompositeAlpha(torch.nn.Module):
52
- def __init__(self, background: tuple[float, float, float] | float):
53
- super().__init__()
54
- self.background = (background, background, background) if isinstance(background, float) else background
55
- self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2)
56
-
57
- def forward(self, img: torch.Tensor) -> torch.Tensor:
58
- if img.shape[-3] == 3:
59
- return img
60
- alpha = img[..., 3, None, :, :]
61
- img[..., :3, :, :] *= alpha
62
- background = self.background.expand(-1, img.shape[-2], img.shape[-1])
63
- if background.ndim == 1:
64
- background = background[:, None, None]
65
- elif background.ndim == 2:
66
- background = background[None, :, :]
67
- img[..., :3, :, :] += (1.0 - alpha) * background
68
- return img[..., :3, :, :]
69
-
70
- def __repr__(self) -> str:
71
- return f"{self.__class__.__name__}(background={self.background})"
72
-
73
- transform = transforms.Compose([
74
- Fit((384, 384)),
75
- transforms.ToTensor(),
76
- CompositeAlpha(0.5),
77
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
78
- transforms.CenterCrop((384, 384)),
79
- ])
80
-
81
- model = timm.create_model("vit_so400m_patch14_siglip_384.webli", pretrained=False, num_classes=9083) # type: VisionTransformer
82
-
83
- class GatedHead(torch.nn.Module):
84
- def __init__(self, num_features: int, num_classes: int):
85
- super().__init__()
86
- self.num_classes = num_classes
87
- self.linear = torch.nn.Linear(num_features, num_classes * 2)
88
- self.act = torch.nn.Sigmoid()
89
- self.gate = torch.nn.Sigmoid()
90
-
91
- def forward(self, x: torch.Tensor) -> torch.Tensor:
92
- x = self.linear(x)
93
- x = self.act(x[:, :self.num_classes]) * self.gate(x[:, self.num_classes:])
94
- return x
95
-
96
- model.head = GatedHead(min(model.head.weight.shape), 9083)
97
- safetensors.torch.load_model(model, "JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors")
98
-
99
- if torch.cuda.is_available():
100
- model.cuda()
101
- if torch.cuda.get_device_capability()[0] >= 7: # tensor cores
102
- model.to(dtype=torch.float16, memory_format=torch.channels_last)
103
-
104
- model.eval()
105
-
106
- with open("tags.json", "r") as file:
107
- tags = json.load(file) # type: dict
108
- allowed_tags = list(tags.keys())
109
-
110
- for idx, tag in enumerate(allowed_tags):
111
- allowed_tags[idx] = tag.replace("_", " ")
112
-
113
- sorted_tag_score = {}
114
-
115
- def run_classifier(image, threshold):
116
- global sorted_tag_score
117
- img = image.convert('RGBA')
118
- tensor = transform(img).unsqueeze(0)
119
- if torch.cuda.is_available():
120
- tensor = tensor.cuda()
121
- if torch.cuda.get_device_capability()[0] >= 7: # tensor cores
122
- tensor = tensor.to(dtype=torch.float16, memory_format=torch.channels_last)
123
- with torch.no_grad():
124
- probits = model(tensor)[0].cpu()
125
- values, indices = probits.topk(250)
126
- tag_score = dict()
127
- for i in range(indices.size(0)):
128
- tag_score[allowed_tags[indices[i]]] = values[i].item()
129
- sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
130
- return create_tags(threshold)
131
-
132
- def create_tags(threshold):
133
- global sorted_tag_score
134
- filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
135
- text_no_impl = ", ".join(filtered_tag_score.keys())
136
- return text_no_impl, filtered_tag_score
137
-
138
- def process_directory(directory, threshold):
139
- results = {}
140
- for root, _, files in os.walk(directory):
141
- for file in files:
142
- if file.lower().endswith(('.jpg', '.jpeg', '.png', '.jxl')):
143
- image_path = os.path.join(root, file)
144
- image = Image.open(image_path)
145
- tags, _ = run_classifier(image, threshold)
146
- results[image_path] = tags
147
- # Save tags to a text file with the same name as the image
148
- text_file_path = os.path.splitext(image_path)[0] + ".txt"
149
- with open(text_file_path, "w") as text_file:
150
- text_file.write(tags)
151
- return results
152
-
153
- if __name__ == "__main__":
154
- parser = argparse.ArgumentParser(description="Run inference on a directory of images.")
155
- parser.add_argument("directory", type=str, help="Target directory containing images.")
156
- parser.add_argument("--threshold", type=float, default=0.2, help="Threshold for tag filtering.")
157
- args = parser.parse_args()
158
-
159
- results = process_directory(args.directory, args.threshold)
160
- for image_path, tags in results.items():
161
- print(f"{image_path}: {tags}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/by_dagasi-v220240731004448/network_train/events.out.tfevents.1722379588.berilia.2834308.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9ccee82a99379d3deeea818c60394392931cfff3fe19fa7483dc6ad0fe3f5568
3
- size 403320
 
 
 
 
metrics/by_hax-v1e400-20240804130227/network_train/events.out.tfevents.1722769396.berilia.3289669.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:034baf9175755f1a4d71e099eef11c13eaaf4f20f784c7b9ff121342ef5c7ddb
3
- size 396905
 
 
 
 
metrics/by_jinxit-v2e400-20240729235422/network_train/events.out.tfevents.1722290163.berilia.3199627.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2c46a98e965052fabfbc1b65bfd9290e33e1358a3cf7c4b26addd2a627114595
3
- size 843874
 
 
 
 
metrics/magic-normalized-v2e400-20240730013158/network_train/events.out.tfevents.1722295954.berilia.3268268.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a24bc14ad339847b995826600261af14c07c923a4fc04248b7124d25fadce7c8
3
- size 544162
 
 
 
 
metrics/realistic-v7e400-20240802163709/network_train/events.out.tfevents.1722609499.berilia.1601635.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:55b3897bb017c365e498cb9545186be2243474821719180fb2fb5c519273d193
3
- size 820965
 
 
 
 
metrics/space-v2e200-20240730174030/network_train/events.out.tfevents.1722354067.berilia.1646768.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b2288178ab681b4b4a20b12a65fe2627426add1bc4562530230952833ebb7950
3
- size 620514
 
 
 
 
metrics/stoat-v7e400-20240729181527/network_train/events.out.tfevents.1722269758.berilia.2962309.0 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8d5a441c2f829509ff64f2ef9ee575c93fef551944b85f81214fef7bb98a06ed
3
- size 806521
 
 
 
 
train-pony.sh DELETED
@@ -1,103 +0,0 @@
1
- #/usr/bin/env zsh
2
-
3
- NAME="stoat-v2s400"
4
-
5
- # alpha=1 @ dim=16 is the same lr than alpha=4 @ dim=256
6
- # --min_snr_gamma=1
7
-
8
- args=(
9
- --pretrained_model_name_or_path=/home/kade/ComfyUI/models/checkpoints/ponyDiffusionV6XL_v6StartWithThisOne.safetensors
10
- # Output, logging
11
- --output_dir="/home/kade/output_dir/$NAME"
12
- --output_name="$NAME"
13
- --log_prefix="$NAME-"
14
- --log_with=tensorboard
15
- --logging_dir=/home/kade/output_dir/logs
16
- --seed=1728871242
17
-
18
- # Dataset
19
- --train_data_dir=/home/kade/training_dir
20
- --dataset_repeats=1
21
- --resolution="1024,1024"
22
- --enable_bucket
23
- --bucket_reso_steps=32
24
- --min_bucket_reso=256
25
- --max_bucket_reso=2048
26
- --flip_aug
27
- --shuffle_caption
28
- --cache_latents
29
- --cache_latents_to_disk
30
- --max_data_loader_n_workers=8
31
- --persistent_data_loader_workers
32
-
33
- # Network config
34
- --network_dim=8
35
- --network_alpha=4
36
- --network_module="lycoris.kohya"
37
- --network_args
38
- "preset=full"
39
- "conv_dim=256"
40
- "conv_alpha=4"
41
- "rank_dropout=0"
42
- "module_dropout=0"
43
- "use_tucker=False"
44
- "use_scalar=False"
45
- "rank_dropout_scale=False"
46
- "algo=locon"
47
- "dora_wd=False"
48
- "train_norm=False"
49
- --network_dropout=0
50
-
51
- # Optimizer config
52
- --optimizer_type=ClybW
53
- --train_batch_size=8
54
- --gradient_accumulation_steps=6
55
- --max_grad_norm=1
56
- --gradient_checkpointing
57
- #--lr_warmup_steps=6
58
- #--scale_weight_norms=1
59
-
60
- # LR Scheduling
61
- --max_train_steps=400
62
- --learning_rate=0.0002
63
- --unet_lr=0.0002
64
- --text_encoder_lr=0.0001
65
- --lr_scheduler="cosine"
66
- --lr_scheduler_args="num_cycles=0.375"
67
-
68
- # Noise
69
- --multires_noise_iterations=12
70
- --multires_noise_discount=0.4
71
- #--min_snr_gamma=1
72
-
73
- # Optimization, details
74
- --no_half_vae
75
- --sdpa
76
- --mixed_precision="bf16"
77
-
78
- # Saving
79
- --save_model_as="safetensors"
80
- --save_precision="fp16"
81
- --save_every_n_steps=20
82
- --save_state
83
- # Either resume from a saved state
84
- #--resume="$HOME/output_dir/wolflink-vfucks400" # Resume from saved state
85
- #--skip_until_initial_step
86
- # Or from a checkpoint
87
- #--network_weights="$HOME/output_dir/wolflink-vfucks400/wolflink-vfucks400-step00000120.safetensors" # Resume from checkpoint (not needed with state, i think)
88
- #--initial_step=120
89
-
90
- # Sampling
91
- --sample_every_n_steps=20
92
- --sample_prompts=/home/kade/training_dir/sample-prompts.txt
93
- --sample_sampler="euler_a"
94
- --caption_extension=".txt"
95
- )
96
-
97
-
98
- cd ~/source/repos/sd-scripts
99
-
100
- #accelerate launch --num_cpu_threads_per_process=2
101
- python "./sdxl_train_network.py" "${args[@]}"
102
-
103
- cd ~