korean_sdxl / app.py
anteaterho's picture
edit language settings
28dd7a4
import gradio as gr
import re
import datetime
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, MBartTokenizer
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, DDIMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler
import torch
import random
# token for SDXL
access_token="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"
# output path
path = "./output"
#openai settings
messages=[{
'role' : 'system',
'content' : 'You are a helpful assistant for organizing prompt for generating images'
}]
# mBart settings
article_kr = "μœ μ—”μ˜ λŒ€ν‘œλŠ” μ‹œλ¦¬μ•„μ— ꡰ사적인 해결책이 μ—†λ‹€κ³  λ§ν•©λ‹ˆλ‹€." #example article
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer.src_lang = "ko_KR"
def translate_mBart(article_kr, lnaguage_code):
if lnaguage_code == "Korean":
lnaguage_code = "ko_KR"
elif lnaguage_code == "Japanese":
lnaguage_code = "ja_XX"
elif lnaguage_code == "Chinese":
lnaguage_code = "zh_CN"
tokenizer.src_lang = lnaguage_code
encoded_ar = tokenizer(article_kr, return_tensors="pt")
generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
result = (tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
return result[0]
# diffusers settings
lms = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear"
)
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
refine_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
#pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, scheduler=lms ,use_auth_token=access_token)
pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16,use_auth_token=access_token)
#pipeline.load_lora_weights(".", weight_name="fashigirl-v6-sdxl-5ep-resize.safetensors")
#pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config, rescale_beta_zero_snr=True, timestep_respacing="training")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True, timestep_respacing="linspace")
pipeline.to("cuda")
#pipeline.enable_model_cpu_offload()
refine = StableDiffusionXLImg2ImgPipeline.from_pretrained(refine_model_id, torch_dtype=torch.float16, use_safetensors=True, use_auth_token=access_token)
refine.to("cuda")
#refine.enable_model_cpu_offload()
prompt = "1girl, solo, long hair, shirt, looking at viewer, white shirt, collared shirt, black eyes, smile, bow, black bow, closed mouth, portrait, brown hair, black hair, straight-on, bowtie, black bowtie, upper body, cloud, sky,huge breasts,shiny,shiny skin,milf,(mature female:1.2),<lora:fashigirl-v6-sdxl-5ep-resize:0.7>"
negative_prompt = "(low quality:1.3), (worst quality:1.3),(monochrome:0.8),(deformed:1.3),(malformed hands:1.4),(poorly drawn hands:1.4),(mutated fingers:1.4),(bad anatomy:1.3),(extra limbs:1.35),(poorly drawn face:1.4),(watermark:1.3),long neck,text,watermark,signature,logo"
seed = random.randint(0, 999999)
generator = torch.manual_seed(seed)
num_inference_steps = 60
guidance_scale = 7
def text2img(language_code, prompt, negative_prompt, x, y, isRandom, fixedRandom ,num_inference_steps, guidance_scale, refine):
seed = 0
if isRandom:
seed = random.randint(0, 999999)
else:
seed = int(fixedRandom)
generator = torch.manual_seed(seed)
# Check Koran prompts
allPrompt = ["",""]
'''
if prompt.upper() != prompt.lower():
print("prompt is an alphabet")
allPrompt[0] = prompt
else:
print("prompt is not an alphabet")
allPrompt[0] = translate_mBart(prompt)
if negative_prompt != "":
if negative_prompt.upper() != negative_prompt.lower():
print("negative prompt is an alphabet")
allPrompt[1] = negative_prompt
else:
print("negative prompt is not an alphabet")
allPrompt[1] = translate_mBart(negative_prompt)
else:
negative_prompt = ""
'''
allPrompt = [prompt, negative_prompt]
if language_code != "English":
allPrompt[0] = translate_mBart(prompt, language_code)
else:
allPrompt[0] = prompt
if negative_prompt != "":
if language_code != "English":
allPrompt[1] = translate_mBart(negative_prompt, language_code)
else:
allPrompt[1] = negative_prompt
else:
allPrompt[1] = ""
print("prompts length : "+ str(len(allPrompt)))
print("prompt : " + allPrompt[0])
print("negative prompt : " + allPrompt[1])
_prompt = allPrompt[0]
_negative_prompt = allPrompt[1]
#_negative_prompt = translate(negative_prompt)
image = pipeline(
prompt=_prompt, negative_prompt=_negative_prompt, width=int(x), height=int(y), num_inference_steps=int(num_inference_steps), generator=generator, guidance_scale=int(guidance_scale)
).images[0]
_seed = str(seed)
#_prompt = re.sub(r"[^\uAC00-\uD7A30-9a-zA-Z\s]", "", _prompt)
timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
#image.save( "./output/" + "sdxl_base_" + "_seed_" + _seed+ "_time_" + timestamp +".png")
#image.save("sdxl_prompt_" + "_seed_" + _seed + ".png")
print(seed)
_info = "prompt : " + _prompt + " / negative prompt : " + _negative_prompt + " / seed : " + _seed + " / refine : " +str(refine) +" / time : " + timestamp
if refine:
image = img2img(prompt=_prompt, negative_prompt=_negative_prompt, image=image)
return image, _info
return image, _info
def img2img(prompt, negative_prompt, image):
image = refine(prompt=prompt, negative_prompt=negative_prompt, image=image).images[0]
#timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
#image.save( "./output/" + "sdxl_refine_" + "_seed_" + timestamp +".png")
return image
demo = gr.Interface(
fn=text2img,
inputs=[gr.Radio(["Korean", "Japanese", "Chinese", "English"], label="language", value="Korean"),
gr.Text(label=("prompt"), value=""),
gr.Text(label=("negative prompt"), value=""),
gr.Slider(128, 2048, value=1024, label="width"),
gr.Slider(128, 2048, value=1024, label="height"),
gr.Checkbox(label="auto random seed", value=True),
gr.Slider(1, 999999, label="fixed random seed"),
gr.Slider(1, 50, value=20, label="inference steps"),
gr.Slider(1, 20, value=7, label="guidance scale"),
gr.Checkbox(["refine"])
],
outputs=[gr.Image(label="generated image"),
gr.Textbox(label="properties")],
title ="λ‹€κ΅­μ–΄ SDXL",
#description="ν•œκΈ€λ‘œ ν•˜λŠ” SDXL",
#article="test"
)
demo.launch()