Spaces:
Runtime error
Runtime error
import gc | |
from threading import Thread | |
import torch | |
from diffusers import DDIMScheduler | |
import transformers | |
from transformers import ( | |
GenerationConfig, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
TextIteratorStreamer, | |
) | |
from fastchat.utils import build_logger | |
logger = build_logger("diffusion_infer", 'diffusion_infer.log') | |
def generate_stream_imagen( | |
model, | |
tokenizer, | |
params, | |
device, | |
context_len=256, | |
stream_interval=2, | |
): | |
prompt = params["prompt"] | |
encoding = tokenizer(prompt, return_tensors="pt").to(device) | |
input_ids = encoding.input_ids | |
# encoding["decoder_input_ids"] = encoding["input_ids"].clone() | |
input_echo_len = len(input_ids) | |
# | |
# generation_config = GenerationConfig( | |
# max_new_tokens=max_new_tokens, | |
# do_sample=temperature >= 1e-5, | |
# temperature=temperature, | |
# repetition_penalty=repetition_penalty, | |
# no_repeat_ngram_size=10, | |
# top_p=top_p, | |
# top_k=top_k, | |
# eos_token_id=stop_token_ids, | |
# ) | |
# | |
# class CodeBlockStopper(StoppingCriteria): | |
# def __call__( | |
# self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs | |
# ) -> bool: | |
# # Code-completion is open-end generation. | |
# # We check \n\n to stop at end of a code block. | |
# if list(input_ids[0][-2:]) == [628, 198]: | |
# return True | |
# return False | |
# gen_kwargs = dict( | |
# **encoding, | |
# streamer=streamer, | |
# generation_config=generation_config, | |
# stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), | |
# ) | |
# generation_kwargs = {"prompt": prompt} | |
# | |
# model.pipe.scheduler = DDIMScheduler.from_config(model.pipe.scheduler.config) | |
# thread = Thread(target=model.infer_one_image, kwargs=generation_kwargs) | |
# thread.start() | |
# i = 0 | |
# output = "" | |
# for new_text in streamer: | |
# i += 1 | |
# output += new_text | |
# if i % stream_interval == 0 or i == max_new_tokens - 1: | |
# yield { | |
# "text": output, | |
# "usage": { | |
# "prompt_tokens": input_echo_len, | |
# "completion_tokens": i, | |
# "total_tokens": input_echo_len + i, | |
# }, | |
# "finish_reason": None, | |
# } | |
# if i >= max_new_tokens: | |
# break | |
# | |
# if i >= max_new_tokens: | |
# finish_reason = "length" | |
# else: | |
# finish_reason = "stop" | |
logger.info(f"prompt: {prompt}") | |
logger.info(f"model.scheduler: {model.pipe.scheduler}") | |
logger.info(f"model.type: {type(model)}") | |
# logger.info(f"prompt: {prompt}") | |
output = model.infer_one_image(prompt=prompt, seed=42) | |
yield { | |
"text": output, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": 0, | |
"total_tokens": input_echo_len, | |
}, | |
"finish_reason": "stop", | |
} | |
# thread.join() | |
# clean | |
gc.collect() | |
torch.cuda.empty_cache() | |
if device == "xpu": | |
torch.xpu.empty_cache() | |
if device == "npu": | |
torch.npu.empty_cache() | |