File size: 3,676 Bytes
ec0c335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gc
from threading import Thread
import torch
from diffusers import DDIMScheduler
import transformers
from transformers import (
    GenerationConfig,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer,
)
from fastchat.utils import build_logger

logger = build_logger("diffusion_infer", 'diffusion_infer.log')

@torch.inference_mode()
def generate_stream_sde(
    model,
    tokenizer,
    params,
    device,
    context_len=256,
    stream_interval=2,
):
    prompt = params["prompt"]
    # temperature = float(params.get("temperature", 1.0))
    # repetition_penalty = float(params.get("repetition_penalty", 1.0))
    # top_p = float(params.get("top_p", 1.0))
    # top_k = int(params.get("top_k", 50))  # -1 means disable
    # max_new_tokens = int(params.get("max_new_tokens", 1024))
    # stop_token_ids = params.get("stop_token_ids", None) or []
    # stop_token_ids.append(tokenizer.eos_token_id)
    #
    # decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
    # streamer = TextIteratorStreamer(tokenizer, **decode_config)
    encoding = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = encoding.input_ids
    # encoding["decoder_input_ids"] = encoding["input_ids"].clone()
    input_echo_len = len(input_ids)
    #
    # generation_config = GenerationConfig(
    #     max_new_tokens=max_new_tokens,
    #     do_sample=temperature >= 1e-5,
    #     temperature=temperature,
    #     repetition_penalty=repetition_penalty,
    #     no_repeat_ngram_size=10,
    #     top_p=top_p,
    #     top_k=top_k,
    #     eos_token_id=stop_token_ids,
    # )
    #
    # class CodeBlockStopper(StoppingCriteria):
    #     def __call__(
    #         self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
    #     ) -> bool:
    #         # Code-completion is open-end generation.
    #         # We check \n\n to stop at end of a code block.
    #         if list(input_ids[0][-2:]) == [628, 198]:
    #             return True
    #         return False

    # gen_kwargs = dict(
    #     **encoding,
    #     streamer=streamer,
    #     generation_config=generation_config,
    #     stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]),
    # )
    generation_kwargs = {"prompt": prompt}
    model.scheduler = DDIMScheduler.from_config(model.scheduler.config)
    logger.info(f"model.scheduler: {model.scheduler}")
    thread = Thread(target=model, kwargs=generation_kwargs)
    thread.start()
    # i = 0
    # output = ""
    # for new_text in streamer:
    #     i += 1
    #     output += new_text
    #     if i % stream_interval == 0 or i == max_new_tokens - 1:
    #         yield {
    #             "text": output,
    #             "usage": {
    #                 "prompt_tokens": input_echo_len,
    #                 "completion_tokens": i,
    #                 "total_tokens": input_echo_len + i,
    #             },
    #             "finish_reason": None,
    #         }
    #     if i >= max_new_tokens:
    #         break
    #
    # if i >= max_new_tokens:
    #     finish_reason = "length"
    # else:
    #     finish_reason = "stop"
    logger.info(f"prompt: {prompt}")
    output = model(prompt=prompt).images[0]

    yield {
        "text": output,
        "usage": {
            "prompt_tokens": input_echo_len,
            "completion_tokens": 0,
            "total_tokens": input_echo_len,
        },
        "finish_reason": "stop",
    }
    thread.join()

    # clean
    gc.collect()
    torch.cuda.empty_cache()
    if device == "xpu":
        torch.xpu.empty_cache()
    if device == "npu":
        torch.npu.empty_cache()