Bingsu commited on
Commit
d9da3b4
1 Parent(s): c36f73b

feat: main feature

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import shutil
4
+ import subprocess
5
+ from pathlib import Path
6
+ from textwrap import dedent
7
+
8
+ import torch
9
+ import streamlit as st
10
+ import numpy as np
11
+ from PIL import Image
12
+ from transformers import CLIPTokenizer
13
+
14
+
15
+ def hex_to_rgb(s: str) -> tuple[int, int, int]:
16
+ value = s.lstrip("#")
17
+ return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16))
18
+
19
+
20
+ col1, col2 = st.columns([15, 85])
21
+ color = col1.color_picker("Pick a color", "#00f900")
22
+ col2.text_input("", color, disabled=True)
23
+
24
+ emb_name = st.text_input("Embedding name", color.lstrip("#").upper())
25
+ rgb = hex_to_rgb(color)
26
+
27
+ img_array = np.zeros((128, 128, 3), dtype=np.uint8)
28
+ for i in range(3):
29
+ img_array[..., i] = rgb[i]
30
+
31
+ dataset_path = Path("dataset")
32
+ output_path = Path("output")
33
+ if dataset_path.exists():
34
+ shutil.rmtree(dataset_path)
35
+ if output_path.exists():
36
+ shutil.rmtree(output_path)
37
+
38
+ dataset_path.mkdir()
39
+ img_path = dataset_path / f"{emb_name}.png"
40
+ Image.fromarray(img_array).save(img_path)
41
+ tokenizer = CLIPTokenizer.from_pretrained(
42
+ "Linaqruf/anything-v3.0", subfolder="tokenizer"
43
+ )
44
+
45
+ with st.sidebar:
46
+ init_text = st.text_input("Initializer", "init token name")
47
+ steps = st.slider("Steps", 1, 100, 30, step=1)
48
+ learning_rate = st.text_input("Learning rate", "0.005")
49
+ learning_rate = float(learning_rate)
50
+
51
+ # case 1: init_text is not a single token
52
+ token = tokenizer.tokenize(init_text)
53
+ if len(token) > 1:
54
+ st.warning("init_text must be a single token")
55
+ st.stop()
56
+
57
+ # case 2: init_text already exists in the tokenizer
58
+ num_added_tokens = tokenizer.add_tokens(emb_name)
59
+ if num_added_tokens == 0:
60
+ st.warning(f"The tokenizer already contains the token {emb_name}")
61
+ st.stop()
62
+
63
+ cmd = """
64
+ accelerate launch textual_inversion.py \
65
+ --pretrained_model_name_or_path="Linaqruf/anything-v3.0" \
66
+ --train_data_dir="dataset" \
67
+ --learnable_property="style" \
68
+ --placeholder_token="{emb_name}" \
69
+ --initializer_token="{init}" \
70
+ --resolution=128 \
71
+ --train_batch_size=1 \
72
+ --repeats=1 \
73
+ --gradient_accumulation_steps=1 \
74
+ --max_train_steps={steps} \
75
+ --learning_rate={lr} \
76
+ --output_dir="output" \
77
+ --only_save_embeds
78
+ """.strip()
79
+
80
+ cmd = dedent(cmd).format(
81
+ emb_name=emb_name, init=init_text, lr=learning_rate, steps=steps
82
+ )
83
+
84
+ if st.button("Start"):
85
+ with st.spinner("Training..."):
86
+ subprocess.run(cmd, shell=True)
87
+
88
+ result_path = Path("output") / "learned_embeds.bin"
89
+ if not result_path.exists():
90
+ st.stop()
91
+
92
+ # fix unknown error
93
+ trained_emb = torch.load(result_path, map_location="cpu")
94
+ for k, v in trained_emb.items():
95
+ trained_emb[k] = torch.from_numpy(v.numpy())
96
+ torch.save(trained_emb, result_path)
97
+
98
+ file = result_path.read_bytes()
99
+ st.download_button("Download", file, f"{emb_name}.pt")