Spaces:
Runtime error
Runtime error
File size: 7,029 Bytes
0501d0a 18e509a 0501d0a 8ed1edc 0501d0a 28dd7a4 0501d0a 84cf441 0501d0a 84cf441 0501d0a 28dd7a4 0501d0a e43a89c 0501d0a 28dd7a4 e43a89c 28dd7a4 0501d0a 28dd7a4 e43a89c 0501d0a e43a89c 0501d0a e43a89c 0501d0a 79ae37c 0501d0a e43a89c 0501d0a e43a89c 0501d0a e43a89c 0501d0a 79ae37c 0501d0a 28dd7a4 e43a89c 28dd7a4 e43a89c 28dd7a4 e43a89c 0501d0a 1b3d053 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gradio as gr
import re
import datetime
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, MBartTokenizer
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, DDIMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler
import torch
import random
# token for SDXL
access_token="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"
# output path
path = "./output"
#openai settings
messages=[{
'role' : 'system',
'content' : 'You are a helpful assistant for organizing prompt for generating images'
}]
# mBart settings
article_kr = "μ μμ λνλ μ리μμ κ΅°μ¬μ μΈ ν΄κ²°μ±
μ΄ μλ€κ³ λ§ν©λλ€." #example article
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer.src_lang = "ko_KR"
def translate_mBart(article_kr, lnaguage_code):
if lnaguage_code == "Korean":
lnaguage_code = "ko_KR"
elif lnaguage_code == "Japanese":
lnaguage_code = "ja_XX"
elif lnaguage_code == "Chinese":
lnaguage_code = "zh_CN"
tokenizer.src_lang = lnaguage_code
encoded_ar = tokenizer(article_kr, return_tensors="pt")
generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
result = (tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
return result[0]
# diffusers settings
lms = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear"
)
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
refine_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
#pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, scheduler=lms ,use_auth_token=access_token)
pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16,use_auth_token=access_token)
#pipeline.load_lora_weights(".", weight_name="fashigirl-v6-sdxl-5ep-resize.safetensors")
#pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config, rescale_beta_zero_snr=True, timestep_respacing="training")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True, timestep_respacing="linspace")
pipeline.to("cuda")
#pipeline.enable_model_cpu_offload()
refine = StableDiffusionXLImg2ImgPipeline.from_pretrained(refine_model_id, torch_dtype=torch.float16, use_safetensors=True, use_auth_token=access_token)
refine.to("cuda")
#refine.enable_model_cpu_offload()
prompt = "1girl, solo, long hair, shirt, looking at viewer, white shirt, collared shirt, black eyes, smile, bow, black bow, closed mouth, portrait, brown hair, black hair, straight-on, bowtie, black bowtie, upper body, cloud, sky,huge breasts,shiny,shiny skin,milf,(mature female:1.2),<lora:fashigirl-v6-sdxl-5ep-resize:0.7>"
negative_prompt = "(low quality:1.3), (worst quality:1.3),(monochrome:0.8),(deformed:1.3),(malformed hands:1.4),(poorly drawn hands:1.4),(mutated fingers:1.4),(bad anatomy:1.3),(extra limbs:1.35),(poorly drawn face:1.4),(watermark:1.3),long neck,text,watermark,signature,logo"
seed = random.randint(0, 999999)
generator = torch.manual_seed(seed)
num_inference_steps = 60
guidance_scale = 7
def text2img(language_code, prompt, negative_prompt, x, y, isRandom, fixedRandom ,num_inference_steps, guidance_scale, refine):
seed = 0
if isRandom:
seed = random.randint(0, 999999)
else:
seed = int(fixedRandom)
generator = torch.manual_seed(seed)
# Check Koran prompts
allPrompt = ["",""]
'''
if prompt.upper() != prompt.lower():
print("prompt is an alphabet")
allPrompt[0] = prompt
else:
print("prompt is not an alphabet")
allPrompt[0] = translate_mBart(prompt)
if negative_prompt != "":
if negative_prompt.upper() != negative_prompt.lower():
print("negative prompt is an alphabet")
allPrompt[1] = negative_prompt
else:
print("negative prompt is not an alphabet")
allPrompt[1] = translate_mBart(negative_prompt)
else:
negative_prompt = ""
'''
allPrompt = [prompt, negative_prompt]
if language_code != "English":
allPrompt[0] = translate_mBart(prompt, language_code)
else:
allPrompt[0] = prompt
if negative_prompt != "":
if language_code != "English":
allPrompt[1] = translate_mBart(negative_prompt, language_code)
else:
allPrompt[1] = negative_prompt
else:
allPrompt[1] = ""
print("prompts length : "+ str(len(allPrompt)))
print("prompt : " + allPrompt[0])
print("negative prompt : " + allPrompt[1])
_prompt = allPrompt[0]
_negative_prompt = allPrompt[1]
#_negative_prompt = translate(negative_prompt)
image = pipeline(
prompt=_prompt, negative_prompt=_negative_prompt, width=int(x), height=int(y), num_inference_steps=int(num_inference_steps), generator=generator, guidance_scale=int(guidance_scale)
).images[0]
_seed = str(seed)
#_prompt = re.sub(r"[^\uAC00-\uD7A30-9a-zA-Z\s]", "", _prompt)
timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
#image.save( "./output/" + "sdxl_base_" + "_seed_" + _seed+ "_time_" + timestamp +".png")
#image.save("sdxl_prompt_" + "_seed_" + _seed + ".png")
print(seed)
_info = "prompt : " + _prompt + " / negative prompt : " + _negative_prompt + " / seed : " + _seed + " / refine : " +str(refine) +" / time : " + timestamp
if refine:
image = img2img(prompt=_prompt, negative_prompt=_negative_prompt, image=image)
return image, _info
return image, _info
def img2img(prompt, negative_prompt, image):
image = refine(prompt=prompt, negative_prompt=negative_prompt, image=image).images[0]
#timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
#image.save( "./output/" + "sdxl_refine_" + "_seed_" + timestamp +".png")
return image
demo = gr.Interface(
fn=text2img,
inputs=[gr.Radio(["Korean", "Japanese", "Chinese", "English"], label="language", value="Korean"),
gr.Text(label=("prompt"), value=""),
gr.Text(label=("negative prompt"), value=""),
gr.Slider(128, 2048, value=1024, label="width"),
gr.Slider(128, 2048, value=1024, label="height"),
gr.Checkbox(label="auto random seed", value=True),
gr.Slider(1, 999999, label="fixed random seed"),
gr.Slider(1, 50, value=20, label="inference steps"),
gr.Slider(1, 20, value=7, label="guidance scale"),
gr.Checkbox(["refine"])
],
outputs=[gr.Image(label="generated image"),
gr.Textbox(label="properties")],
title ="λ€κ΅μ΄ SDXL",
#description="νκΈλ‘ νλ SDXL",
#article="test"
)
demo.launch() |