p1atdev's picture
chore: remove dotenv
67ef2c1
import time
import os
import torch
from dartrs.v2 import (
V2Model,
MixtralModel,
MistralModel,
compose_prompt,
LengthTag,
AspectRatioTag,
RatingTag,
IdentityTag,
)
from dartrs.dartrs import DartTokenizer
from dartrs.utils import get_generation_config
import gradio as gr
from gradio.components import Component
try:
import spaces
except ImportError:
class spaces:
def GPU(*args, **kwargs):
return lambda x: x
from output import UpsamplingOutput
from utils import ASPECT_RATIO_OPTIONS, RATING_OPTIONS, LENGTH_OPTIONS, IDENTITY_OPTIONS
HF_TOKEN = os.getenv("HF_TOKEN", None)
ALL_MODELS = {
"dart-v2-moe-sft": {
"repo": "p1atdev/dart-v2-moe-sft",
"type": "sft",
"class": MixtralModel,
},
"dart-v2-sft": {
"repo": "p1atdev/dart-v2-sft",
"type": "sft",
"class": MistralModel,
},
}
def prepare_models(model_config: dict):
model_name = model_config["repo"]
tokenizer = DartTokenizer.from_pretrained(model_name, auth_token=HF_TOKEN)
model = model_config["class"].from_pretrained(model_name, auth_token=HF_TOKEN)
return {
"tokenizer": tokenizer,
"model": model,
}
def normalize_tags(tokenizer: DartTokenizer, tags: str):
"""Just remove unk tokens."""
return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
@torch.no_grad()
def generate_tags(
model: V2Model,
tokenizer: DartTokenizer,
prompt: str,
ban_token_ids: list[int],
):
output = model.generate(
get_generation_config(
prompt,
tokenizer=tokenizer,
temperature=1,
top_p=0.9,
top_k=100,
max_new_tokens=256,
ban_token_ids=ban_token_ids,
),
)
return output
class V2UI:
model_name: str | None = None
model: V2Model
tokenizer: DartTokenizer
input_components: list[Component] = []
generate_btn: gr.Button
def on_generate(
self,
model_name: str,
copyright_tags: str,
character_tags: str,
general_tags: str,
rating_tag: RatingTag,
aspect_ratio_tag: AspectRatioTag,
length_tag: LengthTag,
identity_tag: IdentityTag,
ban_tags: str,
*args,
) -> UpsamplingOutput:
if self.model_name is None or self.model_name != model_name:
models = prepare_models(ALL_MODELS[model_name])
self.model = models["model"]
self.tokenizer = models["tokenizer"]
self.model_name = model_name
# normalize tags
# copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
# character_tags = normalize_tags(self.tokenizer, character_tags)
# general_tags = normalize_tags(self.tokenizer, general_tags)
ban_token_ids = self.tokenizer.encode(ban_tags.strip())
prompt = compose_prompt(
prompt=general_tags,
copyright=copyright_tags,
character=character_tags,
rating=rating_tag,
aspect_ratio=aspect_ratio_tag,
length=length_tag,
identity=identity_tag,
)
start = time.time()
upsampled_tags = generate_tags(
self.model,
self.tokenizer,
prompt,
ban_token_ids,
)
elapsed_time = time.time() - start
return UpsamplingOutput(
upsampled_tags=upsampled_tags,
copyright_tags=copyright_tags,
character_tags=character_tags,
general_tags=general_tags,
rating_tag=rating_tag,
aspect_ratio_tag=aspect_ratio_tag,
length_tag=length_tag,
identity_tag=identity_tag,
elapsed_time=elapsed_time,
)
def ui(self):
input_copyright = gr.Textbox(
label="Copyright tags",
placeholder="vocaloid",
)
input_character = gr.Textbox(
label="Character tags",
placeholder="hatsune miku",
)
input_general = gr.TextArea(
label="General tags",
lines=4,
placeholder="1girl, ...",
value="1girl, solo",
)
input_rating = gr.Radio(
label="Rating",
choices=list(RATING_OPTIONS),
value="general",
)
input_aspect_ratio = gr.Radio(
label="Aspect ratio",
info="The aspect ratio of the image.",
choices=["ultra_wide", "wide", "square", "tall", "ultra_tall"],
value="tall",
)
input_length = gr.Radio(
label="Length",
info="The total length of the tags.",
choices=list(LENGTH_OPTIONS),
value="long",
)
input_identity = gr.Radio(
label="Keep identity",
info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.",
choices=list(IDENTITY_OPTIONS),
value="none",
)
with gr.Accordion(label="Advanced options", open=False):
input_ban_tags = gr.Textbox(
label="Ban tags",
info="Tags to ban from the output.",
placeholder="alternate costumen, ...",
)
model_name = gr.Dropdown(
label="Model",
choices=list(ALL_MODELS.keys()),
value=list(ALL_MODELS.keys())[0],
)
self.input_components = [
model_name,
input_copyright,
input_character,
input_general,
input_rating,
input_aspect_ratio,
input_length,
input_identity,
input_ban_tags,
]
def get_inputs(self) -> list[Component]:
return self.input_components