import base64
import datetime
import gradio as gr
import numpy as np
import os
import pytz
import psutil
import re
import random
import torch
import time
import time

from PIL import Image
from io import BytesIO
from PIL import Image
from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny

try:
    import intel_extension_for_pytorch as ipex
except:
    pass

SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# check if MPS is available OSX only M1/M2/M3 chips
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
device = torch.device(
    "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
)
torch_device = device
torch_dtype = torch.float16

print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
print(f"TORCH_COMPILE: {TORCH_COMPILE}")
print(f"device: {device}")

if mps_available:
    device = torch.device("mps")
    torch_device = "cpu"
    torch_dtype = torch.float32

if SAFETY_CHECKER == "True":
    pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7")
else:
    pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7", safety_checker=None)

pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.to(device=torch_device, dtype=torch_dtype).to(device)
pipe.unet.to(memory_format=torch.channels_last)
pipe.set_progress_bar_config(disable=True)

# check if computer has less than 64GB of RAM using sys or os
if psutil.virtual_memory().total < 64 * 1024**3:
    pipe.enable_attention_slicing()

if TORCH_COMPILE:
    pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
    pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
    pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)

# Load LCM LoRA
pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
pipe.fuse_lora()

def safe_filename(text):
    """Generate a safe filename from a string."""
    safe_text = re.sub(r'\W+', '_', text)
    timestamp = datetime.datetime.now().strftime("%Y%m%d")
    return f"{safe_text}_{timestamp}.png"
    
def encode_image(image):
    """Encode image to base64."""
    buffered = BytesIO()
    #image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

def fake_gan():
    base_dir = os.getcwd()  # Get the current base directory
    img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))]  # List all files ending with ".jpg" or ".jpeg"
    images = [(random.choice(img_files), os.path.splitext(file)[0]) for file in img_files]
    return images
    
def predict(prompt, guidance, steps, seed=1231231):
    generator = torch.manual_seed(seed)
    last_time = time.time()
    results = pipe(
        prompt=prompt,
        generator=generator,
        num_inference_steps=steps,
        guidance_scale=guidance,
        width=512,
        height=512,
        # original_inference_steps=params.lcm_steps,
        output_type="pil",
    )
    print(f"Pipe took {time.time() - last_time} seconds")
    nsfw_content_detected = (
        results.nsfw_content_detected[0]
        if "nsfw_content_detected" in results
        else False
    )
    if nsfw_content_detected:
        nsfw=gr.Button("🕹️NSFW🎨", scale=1)

    central = pytz.timezone('US/Central')
    safe_date_time = datetime.datetime.now().strftime("%Y%m%d")
    replaced_prompt = prompt.replace(" ", "_").replace("\n", "_")
    safe_prompt = "".join(x for x in replaced_prompt if x.isalnum() or x == "_")[:90]
    filename = f"{safe_date_time}_{safe_prompt}.png"
    

    # Save the image
    if len(results.images) > 0:
        image_path = os.path.join("", filename)  # Specify your directory
        results.images[0].save(image_path)
        print(f"#Image saved as {image_path}")
        encoded_image = encode_image(image)
        html_link = f'<a href="data:image/png;base64,{encoded_image}" download="{filename}">Download Image</a>'
        gr.Markdown(html_link)
    


    return results.images[0] if len(results.images) > 0 else None


css = """
#container{
    margin: 0 auto;
    max-width: 40rem;
}
#intro{
    max-width: 100%;
    text-align: center;
    margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="container"):
        gr.Markdown(
            """## 🕹️ Stable Diffusion 1.5 - Real Time 🎨 Image Generation Using 🌐 Latent Consistency LoRAs""",
            elem_id="intro",
        )
        with gr.Row():
            with gr.Row():
                prompt = gr.Textbox(
                    placeholder="Insert your prompt here:", scale=5, container=False
                )
                generate_bt = gr.Button("Generate", scale=1)

        # Image Result from last prompt
        image = gr.Image(type="filepath")

        # Gallery of Generated Images with Image Names in Random Set to Download
        with gr.Row(variant="compact"):
            text = gr.Textbox(
                label="Image Sets",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
            )
            btn = gr.Button("Generate Gallery of Saved Images")
        gallery = gr.Gallery(
            label="Generated Images", show_label=False, elem_id="gallery"
        )

        # Advanced Generate Options
        with gr.Accordion("Advanced options", open=False):
            guidance = gr.Slider(
                label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001
            )
            steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=10, step=1)
            seed = gr.Slider(
                randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
            )

        # Diffusers
        with gr.Accordion("Run with diffusers"):
            gr.Markdown(
                """## Running LCM-LoRAs it with `diffusers`
            ```bash
            pip install diffusers==0.23.0
            ```
            
            ```py
            from diffusers import DiffusionPipeline, LCMScheduler

            pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7").to("cuda") 
            pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
            pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") #yes, it's a normal LoRA

            results = pipe(
                prompt="ImageEditor",
                num_inference_steps=4,
                guidance_scale=0.0,
            )
            results.images[0]
            ```
            """
            )

        # Function IO Eventing and Controls
        inputs = [prompt, guidance, steps, seed]
        generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
        btn.click(fake_gan, None, gallery)
        prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
        guidance.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
        steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
        seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)

demo.queue()
demo.launch()