Spaces:
Running
on
Zero
Running
on
Zero
import time | |
import os | |
import torch | |
from typing import Callable | |
from dartrs.v2 import ( | |
V2Model, | |
MixtralModel, | |
MistralModel, | |
compose_prompt, | |
LengthTag, | |
AspectRatioTag, | |
RatingTag, | |
IdentityTag, | |
) | |
from dartrs.dartrs import DartTokenizer | |
from dartrs.utils import get_generation_config | |
import gradio as gr | |
from gradio.components import Component | |
try: | |
import spaces | |
except ImportError: | |
class spaces: | |
def GPU(*args, **kwargs): | |
return lambda x: x | |
from output import UpsamplingOutput | |
HF_TOKEN = os.getenv("HF_TOKEN", None) | |
V2_ALL_MODELS = { | |
"dart-v2-moe-sft": { | |
"repo": "p1atdev/dart-v2-moe-sft", | |
"type": "sft", | |
"class": MixtralModel, | |
}, | |
"dart-v2-sft": { | |
"repo": "p1atdev/dart-v2-sft", | |
"type": "sft", | |
"class": MistralModel, | |
}, | |
} | |
def prepare_models(model_config: dict): | |
model_name = model_config["repo"] | |
tokenizer = DartTokenizer.from_pretrained(model_name, auth_token=HF_TOKEN) | |
model = model_config["class"].from_pretrained(model_name, auth_token=HF_TOKEN) | |
return { | |
"tokenizer": tokenizer, | |
"model": model, | |
} | |
def normalize_tags(tokenizer: DartTokenizer, tags: str): | |
"""Just remove unk tokens.""" | |
return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"]) | |
def generate_tags( | |
model: V2Model, | |
tokenizer: DartTokenizer, | |
prompt: str, | |
ban_token_ids: list[int], | |
): | |
output = model.generate( | |
get_generation_config( | |
prompt, | |
tokenizer=tokenizer, | |
temperature=1, | |
top_p=0.9, | |
top_k=100, | |
max_new_tokens=256, | |
ban_token_ids=ban_token_ids, | |
), | |
) | |
return output | |
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5): | |
return ( | |
[f"1{noun}"] | |
+ [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)] | |
+ [f"{maximum+1}+{noun}s"] | |
) | |
PEOPLE_TAGS = ( | |
_people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"] | |
) | |
def gen_prompt_text(output: UpsamplingOutput): | |
# separate people tags (e.g. 1girl) | |
people_tags = [] | |
other_general_tags = [] | |
for tag in output.general_tags.split(","): | |
tag = tag.strip() | |
if tag in PEOPLE_TAGS: | |
people_tags.append(tag) | |
else: | |
other_general_tags.append(tag) | |
return ", ".join( | |
[ | |
part.strip() | |
for part in [ | |
*people_tags, | |
output.character_tags, | |
output.copyright_tags, | |
*other_general_tags, | |
output.upsampled_tags, | |
output.rating_tag, | |
] | |
if part.strip() != "" | |
] | |
) | |
def elapsed_time_format(elapsed_time: float) -> str: | |
return f"Elapsed: {elapsed_time:.2f} seconds" | |
def parse_upsampling_output( | |
upsampler: Callable[..., UpsamplingOutput], | |
): | |
def _parse_upsampling_output(*args) -> tuple[str, str, dict]: | |
output = upsampler(*args) | |
return ( | |
gen_prompt_text(output), | |
elapsed_time_format(output.elapsed_time), | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
) | |
return _parse_upsampling_output | |
class V2UI: | |
model_name: str | None = None | |
model: V2Model | |
tokenizer: DartTokenizer | |
input_components: list[Component] = [] | |
generate_btn: gr.Button | |
def on_generate( | |
self, | |
model_name: str, | |
copyright_tags: str, | |
character_tags: str, | |
general_tags: str, | |
rating_tag: RatingTag, | |
aspect_ratio_tag: AspectRatioTag, | |
length_tag: LengthTag, | |
identity_tag: IdentityTag, | |
ban_tags: str, | |
*args, | |
) -> UpsamplingOutput: | |
if self.model_name is None or self.model_name != model_name: | |
models = prepare_models(V2_ALL_MODELS[model_name]) | |
self.model = models["model"] | |
self.tokenizer = models["tokenizer"] | |
self.model_name = model_name | |
# normalize tags | |
# copyright_tags = normalize_tags(self.tokenizer, copyright_tags) | |
# character_tags = normalize_tags(self.tokenizer, character_tags) | |
# general_tags = normalize_tags(self.tokenizer, general_tags) | |
ban_token_ids = self.tokenizer.encode(ban_tags.strip()) | |
prompt = compose_prompt( | |
prompt=general_tags, | |
copyright=copyright_tags, | |
character=character_tags, | |
rating=rating_tag, | |
aspect_ratio=aspect_ratio_tag, | |
length=length_tag, | |
identity=identity_tag, | |
) | |
start = time.time() | |
upsampled_tags = generate_tags( | |
self.model, | |
self.tokenizer, | |
prompt, | |
ban_token_ids, | |
) | |
elapsed_time = time.time() - start | |
return UpsamplingOutput( | |
upsampled_tags=upsampled_tags, | |
copyright_tags=copyright_tags, | |
character_tags=character_tags, | |
general_tags=general_tags, | |
rating_tag=rating_tag, | |
aspect_ratio_tag=aspect_ratio_tag, | |
length_tag=length_tag, | |
identity_tag=identity_tag, | |
elapsed_time=elapsed_time, | |
) | |