import gradio as gr
import os
hf_token = os.environ.get("HF_TOKEN")
import spaces
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler, AutoencoderKL
import torch
import time

class Dummy():
    pass

resolutions = ["1024 1024","1280 768","1344 768","768 1344","768 1280" ]

# Load pipeline
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
unet = UNet2DConditionModel.from_pretrained("briaai/BRIA-2.2-FAST", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained("briaai/BRIA-2.2", torch_dtype=torch.float16, unet=unet, vae=vae)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)'cuda')
del unet
del vae
pipe.force_zeros_for_empty_prompt = False
print("Optimizing BRIA 2.2 FAST - this could take a while")
t=time.time()
pipe.unet = torch.compile(
    pipe.unet, mode="reduce-overhead", fullgraph=True # 600 secs compilation
)
with torch.no_grad():
    outputs = pipe(
        prompt="an apple",
        num_inference_steps=8,
    )

# This will avoid future compilations on different shapes
unet_compiled =
unet_compiled.config=pipe.unet.config
unet_compiled.add_embedding = Dummy()
unet_compiled.add_embedding.linear_1 = Dummy()
unet_compiled.add_embedding.linear_1.in_features = pipe.unet.add_embedding.linear_1.in_features
pipe.unet = unet_compiled
print(f"Optimizing finished successfully after {time.time()-t} secs")

@spaces.GPU(enable_queue=True)
def infer(prompt,seed,resolution):
    print(f"""
        —/n {prompt}
    """)
    # generator = torch.Generator("cuda").manual_seed(555)
    t=time.time()
    if seed=="-1":
        generator=None
    else:
        try:
            seed=int(seed)
            generator = torch.Generator("cuda").manual_seed(seed)
        except:
            generator=None
    w,h = resolution.split()
    w,h = int(w),int(h)
    image = pipe(prompt,num_inference_steps=8,generator=generator,width=w,height=h).images[0]
    print(f'gen time is {time.time()-t} secs')
    # Future
    # Add amound of steps
    # if nsfw:
    #     raise gr.Error("Generated image is NSFW")
    return image

css = """
#col-container{
    margin: 0 auto;
    max-width: 580px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("## BRIA 2.2 FAST")

