import gc from threading import Thread import torch from diffusers import DDIMScheduler import transformers from transformers import ( GenerationConfig, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, ) from fastchat.utils import build_logger logger = build_logger("diffusion_infer", 'diffusion_infer.log') @torch.inference_mode() def generate_stream_sde( model, tokenizer, params, device, context_len=256, stream_interval=2, ): prompt = params["prompt"] # temperature = float(params.get("temperature", 1.0)) # repetition_penalty = float(params.get("repetition_penalty", 1.0)) # top_p = float(params.get("top_p", 1.0)) # top_k = int(params.get("top_k", 50)) # -1 means disable # max_new_tokens = int(params.get("max_new_tokens", 1024)) # stop_token_ids = params.get("stop_token_ids", None) or [] # stop_token_ids.append(tokenizer.eos_token_id) # # decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) # streamer = TextIteratorStreamer(tokenizer, **decode_config) encoding = tokenizer(prompt, return_tensors="pt").to(device) input_ids = encoding.input_ids # encoding["decoder_input_ids"] = encoding["input_ids"].clone() input_echo_len = len(input_ids) # # generation_config = GenerationConfig( # max_new_tokens=max_new_tokens, # do_sample=temperature >= 1e-5, # temperature=temperature, # repetition_penalty=repetition_penalty, # no_repeat_ngram_size=10, # top_p=top_p, # top_k=top_k, # eos_token_id=stop_token_ids, # ) # # class CodeBlockStopper(StoppingCriteria): # def __call__( # self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs # ) -> bool: # # Code-completion is open-end generation. # # We check \n\n to stop at end of a code block. # if list(input_ids[0][-2:]) == [628, 198]: # return True # return False # gen_kwargs = dict( # **encoding, # streamer=streamer, # generation_config=generation_config, # stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), # ) generation_kwargs = {"prompt": prompt} model.scheduler = DDIMScheduler.from_config(model.scheduler.config) logger.info(f"model.scheduler: {model.scheduler}") thread = Thread(target=model, kwargs=generation_kwargs) thread.start() # i = 0 # output = "" # for new_text in streamer: # i += 1 # output += new_text # if i % stream_interval == 0 or i == max_new_tokens - 1: # yield { # "text": output, # "usage": { # "prompt_tokens": input_echo_len, # "completion_tokens": i, # "total_tokens": input_echo_len + i, # }, # "finish_reason": None, # } # if i >= max_new_tokens: # break # # if i >= max_new_tokens: # finish_reason = "length" # else: # finish_reason = "stop" logger.info(f"prompt: {prompt}") output = model(prompt=prompt).images[0] yield { "text": output, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": 0, "total_tokens": input_echo_len, }, "finish_reason": "stop", } thread.join() # clean gc.collect() torch.cuda.empty_cache() if device == "xpu": torch.xpu.empty_cache() if device == "npu": torch.npu.empty_cache()