BATUTO_imagen / app.py
ivanoctaviogaitansantos's picture
Update app.py
78f791e verified
import gradio as gr
import torch
import gc
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AIImageGeneratorNSFW:
def __init__(self):
self.pipeline = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_id = "segmind/Segmind-DE-XL"
self.lora_id = "urn:air:sdxl:lora:civitai:141300@341068"
self.is_model_loaded = False
logger.info(f"Inicializando en dispositivo: {self.device}")
def load_model(self):
if self.is_model_loaded:
return True
try:
logger.info("Cargando modelo base NSFW con LoRA y optimizaci贸n...")
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
tokenizer_1 = CLIPTokenizer.from_pretrained(self.model_id, subfolder="tokenizer", use_fast=False)
tokenizer_2 = CLIPTokenizer.from_pretrained(self.model_id, subfolder="tokenizer_2", use_fast=False)
text_encoder_1 = CLIPTextModel.from_pretrained(self.model_id, subfolder="text_encoder", torch_dtype=torch_dtype, low_cpu_mem_usage=True)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(self.model_id, subfolder="text_encoder_2", torch_dtype=torch_dtype, low_cpu_mem_usage=True)
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
self.model_id,
tokenizer=[tokenizer_1, tokenizer_2],
text_encoder=[text_encoder_1, text_encoder_2],
torch_dtype=torch_dtype,
scheduler=EulerDiscreteScheduler.from_pretrained(self.model_id, subfolder="scheduler"),
safety_checker=None,
use_safetensors=True,
variant="fp16" if self.device == "cuda" else None
)
# Carga LoRA con m茅todo actual
self.pipeline.load_lora(self.lora_id, weight=1.0)
if self.device == "cuda":
self.pipeline.enable_model_cpu_offload()
self.pipeline.enable_vae_slicing()
self.pipeline.unet = torch.compile(self.pipeline.unet, mode='reduce-overhead')
self.pipeline.text_encoder_1 = torch.compile(text_encoder_1, mode='reduce-overhead')
self.pipeline.text_encoder_2 = torch.compile(text_encoder_2, mode='reduce-overhead')
self.is_model_loaded = True
logger.info("Modelo NSFW con LoRA cargado y optimizado correctamente.")
return True
except Exception as e:
logger.error(f"Error cargando modelo NSFW con LoRA: {e}")
return False
def generate_image(self, prompt, width=1024, height=576, steps=35, guidance_scale=12.0):
if not self.is_model_loaded and not self.load_model():
return None
try:
with torch.inference_mode():
generator = torch.Generator(self.device).manual_seed(torch.randint(0, 2**32, (1,)).item())
result = self.pipeline(
prompt=prompt,
width=(width // 8) * 8,
height=(height // 8) * 8,
num_inference_steps=steps,
guidance_scale=guidance_scale,
generator=generator,
output_type="pil"
)
gc.collect()
if self.device == "cuda":
torch.cuda.empty_cache()
return result.images[0]
except Exception as e:
logger.error(f"Error generando imagen NSFW: {e}")
gc.collect()
if self.device == "cuda":
torch.cuda.empty_cache()
return None
def initialize_generator_nsfw():
global generator_nsfw
if 'generator_nsfw' not in globals():
globals()['generator_nsfw'] = AIImageGeneratorNSFW()
return globals()['generator_nsfw']
def generate_image_nsfw(prompt, width, height, steps, guidance_scale):
gen = initialize_generator_nsfw()
if not prompt.strip():
return None
return gen.generate_image(
prompt=prompt,
width=int(width),
height=int(height),
steps=int(steps),
guidance_scale=float(guidance_scale)
)
def create_nsfw_interface():
with gr.Blocks(title="Generador de Im谩genes NSFW con IA - Stable Diffusion XL") as iface:
gr.Markdown("# 馃帹 Generador NSFW basado en Stable Diffusion XL\n_Uso responsable y solo para adultos_")
prompt = gr.Textbox(label="Prompt para la imagen NSFW", placeholder="Describe el contenido expl铆cito...", lines=3)
width = gr.Slider(512, 1536, value=1024, step=8, label="Ancho (pixeles)")
height = gr.Slider(512, 1536, value=576, step=8, label="Alto (pixeles)")
steps = gr.Slider(10, 50, value=35, step=1, label="Pasos de inferencia")
guidance_scale = gr.Slider(1.0, 20.0, value=12.0, step=0.1, label="Escala de gu铆a")
btn_generate = gr.Button("Generar Imagen NSFW")
img_output = gr.Image(label="Imagen generada")
btn_generate.click(
fn=generate_image_nsfw,
inputs=[prompt, width, height, steps, guidance_scale],
outputs=img_output
)
return iface
# Declarar interfaz global para Hugging Face Spaces
nsfw_app = create_nsfw_interface()
if __name__ == "__main__":
nsfw_app.launch()