import gradio as gr import re import datetime from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, MBartTokenizer from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, DDIMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler import torch import random # token for SDXL access_token="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" # output path path = "./output" #openai settings messages=[{ 'role' : 'system', 'content' : 'You are a helpful assistant for organizing prompt for generating images' }] # mBart settings article_kr = "유엔의 대표는 시리아에 군사적인 해결책이 없다고 말합니다." #example article model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") tokenizer.src_lang = "ko_KR" def translate_mBart(article_kr, lnaguage_code): if lnaguage_code == "Korean": lnaguage_code = "ko_KR" elif lnaguage_code == "Japanese": lnaguage_code = "ja_XX" elif lnaguage_code == "Chinese": lnaguage_code = "zh_CN" tokenizer.src_lang = lnaguage_code encoded_ar = tokenizer(article_kr, return_tensors="pt") generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"]) result = (tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)) return result[0] # diffusers settings lms = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" refine_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0" #pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, scheduler=lms ,use_auth_token=access_token) pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16,use_auth_token=access_token) #pipeline.load_lora_weights(".", weight_name="fashigirl-v6-sdxl-5ep-resize.safetensors") #pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config, rescale_beta_zero_snr=True, timestep_respacing="training") pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True, timestep_respacing="linspace") pipeline.to("cuda") #pipeline.enable_model_cpu_offload() refine = StableDiffusionXLImg2ImgPipeline.from_pretrained(refine_model_id, torch_dtype=torch.float16, use_safetensors=True, use_auth_token=access_token) refine.to("cuda") #refine.enable_model_cpu_offload() prompt = "1girl, solo, long hair, shirt, looking at viewer, white shirt, collared shirt, black eyes, smile, bow, black bow, closed mouth, portrait, brown hair, black hair, straight-on, bowtie, black bowtie, upper body, cloud, sky,huge breasts,shiny,shiny skin,milf,(mature female:1.2)," negative_prompt = "(low quality:1.3), (worst quality:1.3),(monochrome:0.8),(deformed:1.3),(malformed hands:1.4),(poorly drawn hands:1.4),(mutated fingers:1.4),(bad anatomy:1.3),(extra limbs:1.35),(poorly drawn face:1.4),(watermark:1.3),long neck,text,watermark,signature,logo" seed = random.randint(0, 999999) generator = torch.manual_seed(seed) num_inference_steps = 60 guidance_scale = 7 def text2img(language_code, prompt, negative_prompt, x, y, isRandom, fixedRandom ,num_inference_steps, guidance_scale, refine): seed = 0 if isRandom: seed = random.randint(0, 999999) else: seed = int(fixedRandom) generator = torch.manual_seed(seed) # Check Koran prompts allPrompt = ["",""] ''' if prompt.upper() != prompt.lower(): print("prompt is an alphabet") allPrompt[0] = prompt else: print("prompt is not an alphabet") allPrompt[0] = translate_mBart(prompt) if negative_prompt != "": if negative_prompt.upper() != negative_prompt.lower(): print("negative prompt is an alphabet") allPrompt[1] = negative_prompt else: print("negative prompt is not an alphabet") allPrompt[1] = translate_mBart(negative_prompt) else: negative_prompt = "" ''' allPrompt = [prompt, negative_prompt] if language_code != "English": allPrompt[0] = translate_mBart(prompt, language_code) else: allPrompt[0] = prompt if negative_prompt != "": if language_code != "English": allPrompt[1] = translate_mBart(negative_prompt, language_code) else: allPrompt[1] = negative_prompt else: allPrompt[1] = "" print("prompts length : "+ str(len(allPrompt))) print("prompt : " + allPrompt[0]) print("negative prompt : " + allPrompt[1]) _prompt = allPrompt[0] _negative_prompt = allPrompt[1] #_negative_prompt = translate(negative_prompt) image = pipeline( prompt=_prompt, negative_prompt=_negative_prompt, width=int(x), height=int(y), num_inference_steps=int(num_inference_steps), generator=generator, guidance_scale=int(guidance_scale) ).images[0] _seed = str(seed) #_prompt = re.sub(r"[^\uAC00-\uD7A30-9a-zA-Z\s]", "", _prompt) timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S") #image.save( "./output/" + "sdxl_base_" + "_seed_" + _seed+ "_time_" + timestamp +".png") #image.save("sdxl_prompt_" + "_seed_" + _seed + ".png") print(seed) _info = "prompt : " + _prompt + " / negative prompt : " + _negative_prompt + " / seed : " + _seed + " / refine : " +str(refine) +" / time : " + timestamp if refine: image = img2img(prompt=_prompt, negative_prompt=_negative_prompt, image=image) return image, _info return image, _info def img2img(prompt, negative_prompt, image): image = refine(prompt=prompt, negative_prompt=negative_prompt, image=image).images[0] #timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S") #image.save( "./output/" + "sdxl_refine_" + "_seed_" + timestamp +".png") return image demo = gr.Interface( fn=text2img, inputs=[gr.Radio(["Korean", "Japanese", "Chinese", "English"], label="language", value="Korean"), gr.Text(label=("prompt"), value=""), gr.Text(label=("negative prompt"), value=""), gr.Slider(128, 2048, value=1024, label="width"), gr.Slider(128, 2048, value=1024, label="height"), gr.Checkbox(label="auto random seed", value=True), gr.Slider(1, 999999, label="fixed random seed"), gr.Slider(1, 50, value=20, label="inference steps"), gr.Slider(1, 20, value=7, label="guidance scale"), gr.Checkbox(["refine"]) ], outputs=[gr.Image(label="generated image"), gr.Textbox(label="properties")], title ="다국어 SDXL", #description="한글로 하는 SDXL", #article="test" ) demo.launch()