from __future__ import annotations import shutil import subprocess from pathlib import Path from textwrap import dedent import torch import streamlit as st import numpy as np from PIL import Image from transformers import CLIPTokenizer def hex_to_rgb(s: str) -> tuple[int, int, int]: value = s.lstrip("#") return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16)) col1, col2 = st.columns([15, 85]) color = col1.color_picker("Pick a color", "#00f900") col2.text_input("", color, disabled=True) emb_name = st.text_input("Embedding name", color.lstrip("#").upper()) rgb = hex_to_rgb(color) img_array = np.zeros((128, 128, 3), dtype=np.uint8) for i in range(3): img_array[..., i] = rgb[i] dataset_path = Path("dataset") output_path = Path("output") if dataset_path.exists(): shutil.rmtree(dataset_path) if output_path.exists(): shutil.rmtree(output_path) dataset_path.mkdir() img_path = dataset_path / f"{emb_name}.png" Image.fromarray(img_array).save(img_path) tokenizer = CLIPTokenizer.from_pretrained( "Linaqruf/anything-v3.0", subfolder="tokenizer" ) with st.sidebar: init_text = st.text_input("Initializer", "init token name") steps = st.slider("Steps", 1, 100, 30, step=1) learning_rate = st.text_input("Learning rate", "0.005") learning_rate = float(learning_rate) # case 1: init_text is not a single token token = tokenizer.tokenize(init_text) if len(token) > 1: st.warning("init_text must be a single token") st.stop() # case 2: init_text already exists in the tokenizer num_added_tokens = tokenizer.add_tokens(emb_name) if num_added_tokens == 0: st.warning(f"The tokenizer already contains the token {emb_name}") st.stop() cmd = """ accelerate launch textual_inversion.py \ --pretrained_model_name_or_path="Linaqruf/anything-v3.0" \ --train_data_dir="dataset" \ --learnable_property="style" \ --placeholder_token="{emb_name}" \ --initializer_token="{init}" \ --resolution=128 \ --train_batch_size=1 \ --repeats=1 \ --gradient_accumulation_steps=1 \ --max_train_steps={steps} \ --learning_rate={lr} \ --output_dir="output" \ --only_save_embeds """.strip() cmd = dedent(cmd).format( emb_name=emb_name, init=init_text, lr=learning_rate, steps=steps ) if st.button("Start"): with st.spinner("Training..."): subprocess.run(cmd, shell=True) result_path = Path("output") / "learned_embeds.bin" if not result_path.exists(): st.stop() # fix unknown error trained_emb = torch.load(result_path, map_location="cpu") for k, v in trained_emb.items(): trained_emb[k] = torch.from_numpy(v.numpy()) torch.save(trained_emb, result_path) file = result_path.read_bytes() st.download_button("Download", file, f"{emb_name}.pt")