anycoder-37529124 / model_handler.py
WaveCut's picture
Upload folder using huggingface_hub
ea4fe4b verified
raw
history blame
2.52 kB
import torch
from diffusers import AutoPipelineForTextToImage
import os
class ModelHandler:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers"
self.pipeline = None
self.load_model()
def load_model(self):
"""
Loads the model pipeline. Uses float16 for GPU to save memory.
"""
try:
print(f"Loading model: {self.model_id} on {self.device}...")
dtype = torch.float16 if self.device == "cuda" else torch.float32
# AutoPipeline handles the architecture detection automatically
self.pipeline = AutoPipelineForTextToImage.from_pretrained(
self.model_id,
torch_dtype=dtype,
use_safetensors=True
)
if self.device == "cuda":
self.pipeline.to("cuda")
# Optional: Enable CPU offload if VRAM is limited (e.g. < 8GB)
# self.pipeline.enable_model_cpu_offload()
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# Fallback or re-raise depending on deployment needs
raise e
def infer(self, prompt, negative_prompt, width, height, num_inference_steps, guidance_scale, seed, progress_callback=None):
"""
Runs inference on the loaded pipeline.
"""
if self.pipeline is None:
self.load_model()
generator = torch.Generator(device=self.device).manual_seed(int(seed))
# Progress bar handling
def callback_dynamic(step, timestep, latents):
if progress_callback:
progress_callback((step, num_inference_steps))
# Depending on the specific diffusers version or pipeline type,
# callback usage might vary slightly, but this is standard for recent versions.
image = self.pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
# callback=callback_dynamic, # Optional: enable for granular progress updates
# callback_steps=1
).images[0]
return image