import gradio as gr import re import datetime from transformers import MBartForConditionalGeneration, MBart50TokenizerFast from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, DDIMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler import torch import random # token for SDXL access_token="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" # output path path = "./output" #openai settings messages=[{ 'role' : 'system', 'content' : 'You are a helpful assistant for organizing prompt for generating images' }] # mBart settings article_kr = "유엔의 대표는 시리아에 군사적인 해결책이 없다고 말합니다." #example article model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") tokenizer.src_lang = "ko_KR" def translate_mBart(article_kr): encoded_ar = tokenizer(article_kr, return_tensors="pt") generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"]) result = (tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)) return result[0] # diffusers settings lms = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" refine_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0" #pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, scheduler=lms ,use_auth_token=access_token) pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16,use_auth_token=access_token) #pipeline.load_lora_weights(".", weight_name="fashigirl-v6-sdxl-5ep-resize.safetensors") #pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config, rescale_beta_zero_snr=True, timestep_respacing="training") pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True, timestep_respacing="linspace") pipeline.to("cuda") refine = StableDiffusionXLImg2ImgPipeline.from_pretrained(refine_model_id, torch_dtype=torch.float16, use_safetensors=True, use_auth_token=access_token) refine.to("cuda") prompt = "1girl, solo, long hair, shirt, looking at viewer, white shirt, collared shirt, black eyes, smile, bow, black bow, closed mouth, portrait, brown hair, black hair, straight-on, bowtie, black bowtie, upper body, cloud, sky,huge breasts,shiny,shiny skin,milf,(mature female:1.2)," negative_prompt = "(low quality:1.3), (worst quality:1.3),(monochrome:0.8),(deformed:1.3),(malformed hands:1.4),(poorly drawn hands:1.4),(mutated fingers:1.4),(bad anatomy:1.3),(extra limbs:1.35),(poorly drawn face:1.4),(watermark:1.3),long neck,text,watermark,signature,logo" seed = random.randint(0, 999999) generator = torch.manual_seed(seed) num_inference_steps = 60 guidance_scale = 7 def text2img(prompt, negative_prompt, x, y, isRandom, fixedRandom ,num_inference_steps, guidance_scale, refine): seed = 0 if isRandom: seed = random.randint(0, 999999) else: seed = int(fixedRandom) generator = torch.manual_seed(seed) # Translate prompt with negative prompt #allPrompt = (translate(prompt +"/"+ negative_prompt).split("/")) allPrompt = ["",""] allPrompt[0] = translate_mBart(prompt) allPrompt[1] = translate_mBart(negative_prompt) print(len(allPrompt)) print("prompt : " + allPrompt[0]) print("negative prompt : " + allPrompt[1]) _prompt = allPrompt[0] if len(allPrompt) > 1: _negative_prompt = allPrompt[1] else: _negative_prompt = " " # Check about it is English if _prompt.upper() != _prompt.lower(): print(" it is an alphabet") else: print(" it is not an alphabet") _prompt = "traffic sign of stop says SDXL" #_negative_prompt = translate(negative_prompt) image = pipeline( prompt=_prompt, negative_prompt=_negative_prompt, width=int(x), height=int(y), num_inference_steps=int(num_inference_steps), generator=generator, guidance_scale=int(guidance_scale) ).images[0] _seed = str(seed) _prompt = re.sub(r"[^\uAC00-\uD7A30-9a-zA-Z\s]", "", _prompt) timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S") image.save( "./output/" + "sdxl_base_" + "_seed_" + _seed+ "_time_" + timestamp +".png") #image.save("sdxl_prompt_" + "_seed_" + _seed + ".png") print(seed) if refine: image = img2img(prompt=_prompt, negative_prompt=_negative_prompt, image=image) return image return image def img2img(prompt, negative_prompt, image): image = refine(prompt=prompt, negative_prompt=negative_prompt, image=image).images[0] timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S") image.save( "./output/" + "sdxl_refine_" + "_seed_" + timestamp +".png") return image demo = gr.Interface( fn=text2img, inputs=["text", "text", gr.Slider(0,2048), gr.Slider(0,2048), gr.Checkbox(["random"]), "number", "number", "number", gr.Checkbox(["refine"])], outputs=["image"], title ="한글로 하는 SDXL", ) demo.launch(share=True, debug=True) #torch.cuda.empty_cache()