visual-arena / fastchat /model /model_imagenhub.py
tianleliphoebe's picture
Upload folder using huggingface_hub
ec0c335 verified
import gc
from threading import Thread
import torch
from diffusers import DDIMScheduler
import transformers
from transformers import (
GenerationConfig,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
)
from fastchat.utils import build_logger
logger = build_logger("diffusion_infer", 'diffusion_infer.log')
@torch.inference_mode()
def generate_stream_imagen(
model,
tokenizer,
params,
device,
context_len=256,
stream_interval=2,
):
prompt = params["prompt"]
encoding = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = encoding.input_ids
# encoding["decoder_input_ids"] = encoding["input_ids"].clone()
input_echo_len = len(input_ids)
#
# generation_config = GenerationConfig(
# max_new_tokens=max_new_tokens,
# do_sample=temperature >= 1e-5,
# temperature=temperature,
# repetition_penalty=repetition_penalty,
# no_repeat_ngram_size=10,
# top_p=top_p,
# top_k=top_k,
# eos_token_id=stop_token_ids,
# )
#
# class CodeBlockStopper(StoppingCriteria):
# def __call__(
# self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
# ) -> bool:
# # Code-completion is open-end generation.
# # We check \n\n to stop at end of a code block.
# if list(input_ids[0][-2:]) == [628, 198]:
# return True
# return False
# gen_kwargs = dict(
# **encoding,
# streamer=streamer,
# generation_config=generation_config,
# stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]),
# )
# generation_kwargs = {"prompt": prompt}
#
# model.pipe.scheduler = DDIMScheduler.from_config(model.pipe.scheduler.config)
# thread = Thread(target=model.infer_one_image, kwargs=generation_kwargs)
# thread.start()
# i = 0
# output = ""
# for new_text in streamer:
# i += 1
# output += new_text
# if i % stream_interval == 0 or i == max_new_tokens - 1:
# yield {
# "text": output,
# "usage": {
# "prompt_tokens": input_echo_len,
# "completion_tokens": i,
# "total_tokens": input_echo_len + i,
# },
# "finish_reason": None,
# }
# if i >= max_new_tokens:
# break
#
# if i >= max_new_tokens:
# finish_reason = "length"
# else:
# finish_reason = "stop"
logger.info(f"prompt: {prompt}")
logger.info(f"model.scheduler: {model.pipe.scheduler}")
logger.info(f"model.type: {type(model)}")
# logger.info(f"prompt: {prompt}")
output = model.infer_one_image(prompt=prompt, seed=42)
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": 0,
"total_tokens": input_echo_len,
},
"finish_reason": "stop",
}
# thread.join()
# clean
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()