Bingsu's picture
feat: capture progressbar
9225658
raw
history blame
3.23 kB
from __future__ import annotations
import shlex
import shutil
import subprocess
from pathlib import Path
from textwrap import dedent
import numpy as np
import streamlit as st
import torch
from PIL import Image
from transformers import CLIPTokenizer
def hex_to_rgb(s: str) -> tuple[int, int, int]:
value = s.lstrip("#")
return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16))
col1, col2 = st.columns([15, 85])
color = col1.color_picker("Pick a color", "#00f900")
col2.text_input("", color, disabled=True)
emb_name = st.text_input("Embedding name", color.lstrip("#").upper())
init_token = st.text_input("Initializer token", "init token name")
rgb = hex_to_rgb(color)
img_array = np.zeros((128, 128, 3), dtype=np.uint8)
for i in range(3):
img_array[..., i] = rgb[i]
dataset_path = Path("dataset")
output_path = Path("output")
if dataset_path.exists():
shutil.rmtree(dataset_path)
if output_path.exists():
shutil.rmtree(output_path)
dataset_path.mkdir()
img_path = dataset_path / f"{emb_name}.png"
Image.fromarray(img_array).save(img_path)
with st.sidebar:
model_name = st.text_input("Model name", "Linaqruf/anything-v3.0")
steps = st.slider("Steps", 1, 100, 30, step=1)
learning_rate = st.text_input("Learning rate", "0.005")
learning_rate = float(learning_rate)
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
# case 1: init_token is not a single token
token = tokenizer.tokenize(init_token)
if len(token) > 1:
st.warning("Initializer token must be a single token")
st.stop()
# case 2: init_token already exists in the tokenizer
num_added_tokens = tokenizer.add_tokens(emb_name)
if num_added_tokens == 0:
st.warning(f"The tokenizer already contains the token {emb_name}")
st.stop()
cmd = """
accelerate launch textual_inversion.py \
--pretrained_model_name_or_path={model_name} \
--train_data_dir="dataset" \
--learnable_property="style" \
--placeholder_token="{emb_name}" \
--initializer_token="{init}" \
--resolution=128 \
--train_batch_size=1 \
--repeats=1 \
--gradient_accumulation_steps=1 \
--max_train_steps={steps} \
--learning_rate={lr} \
--output_dir="output" \
--only_save_embeds
""".strip()
cmd = dedent(cmd).format(
model_name=model_name,
emb_name=emb_name,
init=init_token,
lr=learning_rate,
steps=steps,
)
cmd = shlex.split(cmd)
result_path = output_path / "learned_embeds.bin"
captured = ""
start_button = st.button("Start")
download_button = st.empty()
if start_button:
with st.spinner("Training..."):
placeholder = st.empty()
p = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8"
)
while line := p.stderr.readline():
captured += line
placeholder.code(captured, language="bash")
if not result_path.exists():
st.stop()
# fix unknown file volume bug
trained_emb = torch.load(result_path, map_location="cpu")
for k, v in trained_emb.items():
trained_emb[k] = torch.from_numpy(v.numpy())
torch.save(trained_emb, result_path)
file = result_path.read_bytes()
download_button.download_button(f"Download {emb_name}.pt", file, f"{emb_name}.pt")