Bingsu's picture
feat: main feature
d9da3b4
raw
history blame
No virus
2.75 kB
from __future__ import annotations
import shutil
import subprocess
from pathlib import Path
from textwrap import dedent
import torch
import streamlit as st
import numpy as np
from PIL import Image
from transformers import CLIPTokenizer
def hex_to_rgb(s: str) -> tuple[int, int, int]:
value = s.lstrip("#")
return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16))
col1, col2 = st.columns([15, 85])
color = col1.color_picker("Pick a color", "#00f900")
col2.text_input("", color, disabled=True)
emb_name = st.text_input("Embedding name", color.lstrip("#").upper())
rgb = hex_to_rgb(color)
img_array = np.zeros((128, 128, 3), dtype=np.uint8)
for i in range(3):
img_array[..., i] = rgb[i]
dataset_path = Path("dataset")
output_path = Path("output")
if dataset_path.exists():
shutil.rmtree(dataset_path)
if output_path.exists():
shutil.rmtree(output_path)
dataset_path.mkdir()
img_path = dataset_path / f"{emb_name}.png"
Image.fromarray(img_array).save(img_path)
tokenizer = CLIPTokenizer.from_pretrained(
"Linaqruf/anything-v3.0", subfolder="tokenizer"
)
with st.sidebar:
init_text = st.text_input("Initializer", "init token name")
steps = st.slider("Steps", 1, 100, 30, step=1)
learning_rate = st.text_input("Learning rate", "0.005")
learning_rate = float(learning_rate)
# case 1: init_text is not a single token
token = tokenizer.tokenize(init_text)
if len(token) > 1:
st.warning("init_text must be a single token")
st.stop()
# case 2: init_text already exists in the tokenizer
num_added_tokens = tokenizer.add_tokens(emb_name)
if num_added_tokens == 0:
st.warning(f"The tokenizer already contains the token {emb_name}")
st.stop()
cmd = """
accelerate launch textual_inversion.py \
--pretrained_model_name_or_path="Linaqruf/anything-v3.0" \
--train_data_dir="dataset" \
--learnable_property="style" \
--placeholder_token="{emb_name}" \
--initializer_token="{init}" \
--resolution=128 \
--train_batch_size=1 \
--repeats=1 \
--gradient_accumulation_steps=1 \
--max_train_steps={steps} \
--learning_rate={lr} \
--output_dir="output" \
--only_save_embeds
""".strip()
cmd = dedent(cmd).format(
emb_name=emb_name, init=init_text, lr=learning_rate, steps=steps
)
if st.button("Start"):
with st.spinner("Training..."):
subprocess.run(cmd, shell=True)
result_path = Path("output") / "learned_embeds.bin"
if not result_path.exists():
st.stop()
# fix unknown error
trained_emb = torch.load(result_path, map_location="cpu")
for k, v in trained_emb.items():
trained_emb[k] = torch.from_numpy(v.numpy())
torch.save(trained_emb, result_path)
file = result_path.read_bytes()
st.download_button("Download", file, f"{emb_name}.pt")