import time import pathlib import kgen.models as models from kgen.formatter import seperate_tags, apply_format, apply_dtg_prompt from kgen.metainfo import TARGET from kgen.generate import tag_gen from kgen.logging import logger SEED_MAX = 2**31 - 1 DEFAULT_FORMAT = """<|special|>, <|characters|>, <|copyrights|>, <|artist|>, <|general|>, <|quality|>, <|meta|>, <|rating|>""" def process( prompt: str, aspect_ratio: float, seed: int, tag_length: str, ban_tags: str, format: str, temperature: float, ): propmt_preview = prompt.replace("\n", " ")[:40] logger.info(f"Processing propmt: {propmt_preview}...") logger.info(f"Processing with seed: {seed}") black_list = [tag.strip() for tag in ban_tags.split(",") if tag.strip()] all_tags = [tag.strip() for tag in prompt.strip().split(",") if tag.strip()] tag_length = tag_length.replace(" ", "_") len_target = TARGET[tag_length] tag_map = seperate_tags(all_tags) dtg_prompt = apply_dtg_prompt(tag_map, tag_length, aspect_ratio) for _, extra_tokens, iter_count in tag_gen( models.text_model, models.tokenizer, dtg_prompt, tag_map["special"] + tag_map["general"], len_target, black_list, temperature=temperature, top_p=0.8, top_k=80, max_new_tokens=512, max_retry=10, max_same_output=5, seed=seed % SEED_MAX, ): pass tag_map["general"] += extra_tokens prompt_by_dtg = apply_format(tag_map, format) logger.info( "Prompt processing done. General Tags Count: " f"{len(tag_map['general'] + tag_map['special'])}" f" | Total iterations: {iter_count}" ) return prompt_by_dtg if __name__ == "__main__": models.model_dir = pathlib.Path(__file__).parent / "models" file = models.download_gguf() files = models.list_gguf() file = files[-1] logger.info(f"Use gguf model from local file: {file}") models.load_model(file, gguf=True) prompt = """ 1girl, ask (askzy), masterpiece """ t0 = time.time_ns() result = process( prompt, aspect_ratio=1.0, seed=1, tag_length="long", ban_tags="", format=DEFAULT_FORMAT, temperature=1.35, ) t1 = time.time_ns() logger.info(f"Result:\n{result}") logger.info(f"Time cost: {(t1 - t0) / 10**6:.1f}ms")