File size: 3,248 Bytes
ec0c335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gc
from threading import Thread
import torch
from diffusers import DDIMScheduler
import transformers
from transformers import (
    GenerationConfig,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer,
)
from fastchat.utils import build_logger

logger = build_logger("diffusion_infer", 'diffusion_infer.log')

@torch.inference_mode()
def generate_stream_imagen(
    model,
    tokenizer,
    params,
    device,
    context_len=256,
    stream_interval=2,
):
    prompt = params["prompt"]
    encoding = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = encoding.input_ids
    # encoding["decoder_input_ids"] = encoding["input_ids"].clone()
    input_echo_len = len(input_ids)
    #
    # generation_config = GenerationConfig(
    #     max_new_tokens=max_new_tokens,
    #     do_sample=temperature >= 1e-5,
    #     temperature=temperature,
    #     repetition_penalty=repetition_penalty,
    #     no_repeat_ngram_size=10,
    #     top_p=top_p,
    #     top_k=top_k,
    #     eos_token_id=stop_token_ids,
    # )
    #
    # class CodeBlockStopper(StoppingCriteria):
    #     def __call__(
    #         self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
    #     ) -> bool:
    #         # Code-completion is open-end generation.
    #         # We check \n\n to stop at end of a code block.
    #         if list(input_ids[0][-2:]) == [628, 198]:
    #             return True
    #         return False

    # gen_kwargs = dict(
    #     **encoding,
    #     streamer=streamer,
    #     generation_config=generation_config,
    #     stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]),
    # )
    # generation_kwargs = {"prompt": prompt}
    #
    # model.pipe.scheduler = DDIMScheduler.from_config(model.pipe.scheduler.config)
    # thread = Thread(target=model.infer_one_image, kwargs=generation_kwargs)
    # thread.start()
    # i = 0
    # output = ""
    # for new_text in streamer:
    #     i += 1
    #     output += new_text
    #     if i % stream_interval == 0 or i == max_new_tokens - 1:
    #         yield {
    #             "text": output,
    #             "usage": {
    #                 "prompt_tokens": input_echo_len,
    #                 "completion_tokens": i,
    #                 "total_tokens": input_echo_len + i,
    #             },
    #             "finish_reason": None,
    #         }
    #     if i >= max_new_tokens:
    #         break
    #
    # if i >= max_new_tokens:
    #     finish_reason = "length"
    # else:
    #     finish_reason = "stop"
    logger.info(f"prompt: {prompt}")
    logger.info(f"model.scheduler: {model.pipe.scheduler}")
    logger.info(f"model.type: {type(model)}")
    # logger.info(f"prompt: {prompt}")
    output = model.infer_one_image(prompt=prompt, seed=42)

    yield {
        "text": output,
        "usage": {
            "prompt_tokens": input_echo_len,
            "completion_tokens": 0,
            "total_tokens": input_echo_len,
        },
        "finish_reason": "stop",
    }
    # thread.join()

    # clean
    gc.collect()
    torch.cuda.empty_cache()
    if device == "xpu":
        torch.xpu.empty_cache()
    if device == "npu":
        torch.npu.empty_cache()