File size: 2,419 Bytes
7d4afe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import time
import pathlib

import kgen.models as models
from kgen.formatter import seperate_tags, apply_format, apply_dtg_prompt
from kgen.metainfo import TARGET
from kgen.generate import tag_gen
from kgen.logging import logger


SEED_MAX = 2**31 - 1
DEFAULT_FORMAT = """<|special|>, 
<|characters|>, <|copyrights|>, 
<|artist|>, 

<|general|>, 

<|quality|>, <|meta|>, <|rating|>"""


def process(
    prompt: str,
    aspect_ratio: float,
    seed: int,
    tag_length: str,
    ban_tags: str,
    format: str,
    temperature: float,
):
    propmt_preview = prompt.replace("\n", " ")[:40]
    logger.info(f"Processing propmt: {propmt_preview}...")
    logger.info(f"Processing with seed: {seed}")
    black_list = [tag.strip() for tag in ban_tags.split(",") if tag.strip()]
    all_tags = [tag.strip() for tag in prompt.strip().split(",") if tag.strip()]

    tag_length = tag_length.replace(" ", "_")
    len_target = TARGET[tag_length]

    tag_map = seperate_tags(all_tags)
    dtg_prompt = apply_dtg_prompt(tag_map, tag_length, aspect_ratio)
    for _, extra_tokens, iter_count in tag_gen(
        models.text_model,
        models.tokenizer,
        dtg_prompt,
        tag_map["special"] + tag_map["general"],
        len_target,
        black_list,
        temperature=temperature,
        top_p=0.8,
        top_k=80,
        max_new_tokens=512,
        max_retry=10,
        max_same_output=5,
        seed=seed % SEED_MAX,
    ):
        pass
    tag_map["general"] += extra_tokens
    prompt_by_dtg = apply_format(tag_map, format)
    logger.info(
        "Prompt processing done. General Tags Count: "
        f"{len(tag_map['general'] + tag_map['special'])}"
        f" | Total iterations: {iter_count}"
    )
    return prompt_by_dtg


if __name__ == "__main__":
    models.model_dir = pathlib.Path(__file__).parent / "models"

    file = models.download_gguf()
    files = models.list_gguf()
    file = files[-1]
    logger.info(f"Use gguf model from local file: {file}")
    models.load_model(file, gguf=True)

    prompt = """
1girl, ask (askzy), masterpiece
"""

    t0 = time.time_ns()
    result = process(
        prompt,
        aspect_ratio=1.0,
        seed=1,
        tag_length="long",
        ban_tags="",
        format=DEFAULT_FORMAT,
        temperature=1.35,
    )
    t1 = time.time_ns()
    logger.info(f"Result:\n{result}")
    logger.info(f"Time cost: {(t1 - t0) / 10**6:.1f}ms")