|
import time |
|
import pathlib |
|
|
|
import kgen.models as models |
|
from kgen.formatter import seperate_tags, apply_format, apply_dtg_prompt |
|
from kgen.metainfo import TARGET |
|
from kgen.generate import tag_gen |
|
from kgen.logging import logger |
|
|
|
|
|
SEED_MAX = 2**31 - 1 |
|
DEFAULT_FORMAT = """<|special|>, |
|
<|characters|>, <|copyrights|>, |
|
<|artist|>, |
|
|
|
<|general|>, |
|
|
|
<|quality|>, <|meta|>, <|rating|>""" |
|
|
|
|
|
def process( |
|
prompt: str, |
|
aspect_ratio: float, |
|
seed: int, |
|
tag_length: str, |
|
ban_tags: str, |
|
format: str, |
|
temperature: float, |
|
): |
|
propmt_preview = prompt.replace("\n", " ")[:40] |
|
logger.info(f"Processing propmt: {propmt_preview}...") |
|
logger.info(f"Processing with seed: {seed}") |
|
black_list = [tag.strip() for tag in ban_tags.split(",") if tag.strip()] |
|
all_tags = [tag.strip() for tag in prompt.strip().split(",") if tag.strip()] |
|
|
|
tag_length = tag_length.replace(" ", "_") |
|
len_target = TARGET[tag_length] |
|
|
|
tag_map = seperate_tags(all_tags) |
|
dtg_prompt = apply_dtg_prompt(tag_map, tag_length, aspect_ratio) |
|
for _, extra_tokens, iter_count in tag_gen( |
|
models.text_model, |
|
models.tokenizer, |
|
dtg_prompt, |
|
tag_map["special"] + tag_map["general"], |
|
len_target, |
|
black_list, |
|
temperature=temperature, |
|
top_p=0.8, |
|
top_k=80, |
|
max_new_tokens=512, |
|
max_retry=10, |
|
max_same_output=5, |
|
seed=seed % SEED_MAX, |
|
): |
|
pass |
|
tag_map["general"] += extra_tokens |
|
prompt_by_dtg = apply_format(tag_map, format) |
|
logger.info( |
|
"Prompt processing done. General Tags Count: " |
|
f"{len(tag_map['general'] + tag_map['special'])}" |
|
f" | Total iterations: {iter_count}" |
|
) |
|
return prompt_by_dtg |
|
|
|
|
|
if __name__ == "__main__": |
|
models.model_dir = pathlib.Path(__file__).parent / "models" |
|
|
|
file = models.download_gguf() |
|
files = models.list_gguf() |
|
file = files[-1] |
|
logger.info(f"Use gguf model from local file: {file}") |
|
models.load_model(file, gguf=True) |
|
|
|
prompt = """ |
|
1girl, ask (askzy), masterpiece |
|
""" |
|
|
|
t0 = time.time_ns() |
|
result = process( |
|
prompt, |
|
aspect_ratio=1.0, |
|
seed=1, |
|
tag_length="long", |
|
ban_tags="", |
|
format=DEFAULT_FORMAT, |
|
temperature=1.35, |
|
) |
|
t1 = time.time_ns() |
|
logger.info(f"Result:\n{result}") |
|
logger.info(f"Time cost: {(t1 - t0) / 10**6:.1f}ms") |
|
|