import gradio as gr import os hf_token = os.environ.get("HF_TOKEN") import spaces from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler, AutoencoderKL import torch import time class Dummy(): pass resolutions = ["1024 1024","1280 768","1344 768","768 1344","768 1280" ] # Load pipeline vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) unet = UNet2DConditionModel.from_pretrained("briaai/BRIA-2.2-FAST", torch_dtype=torch.float16) pipe = DiffusionPipeline.from_pretrained("briaai/BRIA-2.2", torch_dtype=torch.float16, unet=unet, vae=vae) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) pipe.to('cuda') del unet del vae pipe.force_zeros_for_empty_prompt = False print("Optimizing BRIA 2.2 FAST - this could take a while") t=time.time() pipe.unet = torch.compile( pipe.unet, mode="reduce-overhead", fullgraph=True # 600 secs compilation ) with torch.no_grad(): outputs = pipe( prompt="an apple", num_inference_steps=8, ) # This will avoid future compilations on different shapes unet_compiled = torch._dynamo.run(pipe.unet) unet_compiled.config=pipe.unet.config unet_compiled.add_embedding = Dummy() unet_compiled.add_embedding.linear_1 = Dummy() unet_compiled.add_embedding.linear_1.in_features = pipe.unet.add_embedding.linear_1.in_features pipe.unet = unet_compiled print(f"Optimizing finished successfully after {time.time()-t} secs") @spaces.GPU(enable_queue=True) def infer(prompt,seed,resolution): print(f""" —/n {prompt} """) # generator = torch.Generator("cuda").manual_seed(555) t=time.time() if seed=="-1": generator=None else: try: seed=int(seed) generator = torch.Generator("cuda").manual_seed(seed) except: generator=None w,h = resolution.split() w,h = int(w),int(h) image = pipe(prompt,num_inference_steps=8,generator=generator,width=w,height=h).images[0] print(f'gen time is {time.time()-t} secs') # Future # Add amound of steps # if nsfw: # raise gr.Error("Generated image is NSFW") return image css = """ #col-container{ margin: 0 auto; max-width: 580px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("## BRIA 2.2 FAST") gr.HTML('''
This is a demo for BRIA 2.2 FAST . This is a fast version of BRIA 2.2 text-to-image model, still trained on licensed data, and so provides full legal liability coverage for copyright and privacy infringement. Try it for free in our webapp demo . Are you a startup or a student? We encourage you to apply for our Startup Plan