Spaces:
Runtime error
Runtime error
| import torch | |
| from diffusers import AutoPipelineForTextToImage | |
| import os | |
| class ModelHandler: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers" | |
| self.pipeline = None | |
| self.load_model() | |
| def load_model(self): | |
| """ | |
| Loads the model pipeline. Uses float16 for GPU to save memory. | |
| """ | |
| try: | |
| print(f"Loading model: {self.model_id} on {self.device}...") | |
| dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| # AutoPipeline handles the architecture detection automatically | |
| self.pipeline = AutoPipelineForTextToImage.from_pretrained( | |
| self.model_id, | |
| torch_dtype=dtype, | |
| use_safetensors=True | |
| ) | |
| if self.device == "cuda": | |
| self.pipeline.to("cuda") | |
| # Optional: Enable CPU offload if VRAM is limited (e.g. < 8GB) | |
| # self.pipeline.enable_model_cpu_offload() | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Fallback or re-raise depending on deployment needs | |
| raise e | |
| def infer(self, prompt, negative_prompt, width, height, num_inference_steps, guidance_scale, seed, progress_callback=None): | |
| """ | |
| Runs inference on the loaded pipeline. | |
| """ | |
| if self.pipeline is None: | |
| self.load_model() | |
| generator = torch.Generator(device=self.device).manual_seed(int(seed)) | |
| # Progress bar handling | |
| def callback_dynamic(step, timestep, latents): | |
| if progress_callback: | |
| progress_callback((step, num_inference_steps)) | |
| # Depending on the specific diffusers version or pipeline type, | |
| # callback usage might vary slightly, but this is standard for recent versions. | |
| image = self.pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| # callback=callback_dynamic, # Optional: enable for granular progress updates | |
| # callback_steps=1 | |
| ).images[0] | |
| return image |