Spaces:
Runtime error
Runtime error
import gradio as gr | |
import re | |
import datetime | |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, MBartTokenizer | |
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, DDIMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler | |
import torch | |
import random | |
# token for SDXL | |
access_token="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" | |
# output path | |
path = "./output" | |
#openai settings | |
messages=[{ | |
'role' : 'system', | |
'content' : 'You are a helpful assistant for organizing prompt for generating images' | |
}] | |
# mBart settings | |
article_kr = "μ μμ λνλ μ리μμ κ΅°μ¬μ μΈ ν΄κ²°μ± μ΄ μλ€κ³ λ§ν©λλ€." #example article | |
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") | |
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") | |
tokenizer.src_lang = "ko_KR" | |
def translate_mBart(article_kr, lnaguage_code): | |
if lnaguage_code == "Korean": | |
lnaguage_code = "ko_KR" | |
elif lnaguage_code == "Japanese": | |
lnaguage_code = "ja_XX" | |
elif lnaguage_code == "Chinese": | |
lnaguage_code = "zh_CN" | |
tokenizer.src_lang = lnaguage_code | |
encoded_ar = tokenizer(article_kr, return_tensors="pt") | |
generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"]) | |
result = (tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)) | |
return result[0] | |
# diffusers settings | |
lms = LMSDiscreteScheduler( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear" | |
) | |
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
refine_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0" | |
#pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, scheduler=lms ,use_auth_token=access_token) | |
pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16,use_auth_token=access_token) | |
#pipeline.load_lora_weights(".", weight_name="fashigirl-v6-sdxl-5ep-resize.safetensors") | |
#pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config, rescale_beta_zero_snr=True, timestep_respacing="training") | |
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True, timestep_respacing="linspace") | |
pipeline.to("cuda") | |
#pipeline.enable_model_cpu_offload() | |
refine = StableDiffusionXLImg2ImgPipeline.from_pretrained(refine_model_id, torch_dtype=torch.float16, use_safetensors=True, use_auth_token=access_token) | |
refine.to("cuda") | |
#refine.enable_model_cpu_offload() | |
prompt = "1girl, solo, long hair, shirt, looking at viewer, white shirt, collared shirt, black eyes, smile, bow, black bow, closed mouth, portrait, brown hair, black hair, straight-on, bowtie, black bowtie, upper body, cloud, sky,huge breasts,shiny,shiny skin,milf,(mature female:1.2),<lora:fashigirl-v6-sdxl-5ep-resize:0.7>" | |
negative_prompt = "(low quality:1.3), (worst quality:1.3),(monochrome:0.8),(deformed:1.3),(malformed hands:1.4),(poorly drawn hands:1.4),(mutated fingers:1.4),(bad anatomy:1.3),(extra limbs:1.35),(poorly drawn face:1.4),(watermark:1.3),long neck,text,watermark,signature,logo" | |
seed = random.randint(0, 999999) | |
generator = torch.manual_seed(seed) | |
num_inference_steps = 60 | |
guidance_scale = 7 | |
def text2img(language_code, prompt, negative_prompt, x, y, isRandom, fixedRandom ,num_inference_steps, guidance_scale, refine): | |
seed = 0 | |
if isRandom: | |
seed = random.randint(0, 999999) | |
else: | |
seed = int(fixedRandom) | |
generator = torch.manual_seed(seed) | |
# Check Koran prompts | |
allPrompt = ["",""] | |
''' | |
if prompt.upper() != prompt.lower(): | |
print("prompt is an alphabet") | |
allPrompt[0] = prompt | |
else: | |
print("prompt is not an alphabet") | |
allPrompt[0] = translate_mBart(prompt) | |
if negative_prompt != "": | |
if negative_prompt.upper() != negative_prompt.lower(): | |
print("negative prompt is an alphabet") | |
allPrompt[1] = negative_prompt | |
else: | |
print("negative prompt is not an alphabet") | |
allPrompt[1] = translate_mBart(negative_prompt) | |
else: | |
negative_prompt = "" | |
''' | |
allPrompt = [prompt, negative_prompt] | |
if language_code != "English": | |
allPrompt[0] = translate_mBart(prompt, language_code) | |
else: | |
allPrompt[0] = prompt | |
if negative_prompt != "": | |
if language_code != "English": | |
allPrompt[1] = translate_mBart(negative_prompt, language_code) | |
else: | |
allPrompt[1] = negative_prompt | |
else: | |
allPrompt[1] = "" | |
print("prompts length : "+ str(len(allPrompt))) | |
print("prompt : " + allPrompt[0]) | |
print("negative prompt : " + allPrompt[1]) | |
_prompt = allPrompt[0] | |
_negative_prompt = allPrompt[1] | |
#_negative_prompt = translate(negative_prompt) | |
image = pipeline( | |
prompt=_prompt, negative_prompt=_negative_prompt, width=int(x), height=int(y), num_inference_steps=int(num_inference_steps), generator=generator, guidance_scale=int(guidance_scale) | |
).images[0] | |
_seed = str(seed) | |
#_prompt = re.sub(r"[^\uAC00-\uD7A30-9a-zA-Z\s]", "", _prompt) | |
timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S") | |
#image.save( "./output/" + "sdxl_base_" + "_seed_" + _seed+ "_time_" + timestamp +".png") | |
#image.save("sdxl_prompt_" + "_seed_" + _seed + ".png") | |
print(seed) | |
_info = "prompt : " + _prompt + " / negative prompt : " + _negative_prompt + " / seed : " + _seed + " / refine : " +str(refine) +" / time : " + timestamp | |
if refine: | |
image = img2img(prompt=_prompt, negative_prompt=_negative_prompt, image=image) | |
return image, _info | |
return image, _info | |
def img2img(prompt, negative_prompt, image): | |
image = refine(prompt=prompt, negative_prompt=negative_prompt, image=image).images[0] | |
#timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S") | |
#image.save( "./output/" + "sdxl_refine_" + "_seed_" + timestamp +".png") | |
return image | |
demo = gr.Interface( | |
fn=text2img, | |
inputs=[gr.Radio(["Korean", "Japanese", "Chinese", "English"], label="language", value="Korean"), | |
gr.Text(label=("prompt"), value=""), | |
gr.Text(label=("negative prompt"), value=""), | |
gr.Slider(128, 2048, value=1024, label="width"), | |
gr.Slider(128, 2048, value=1024, label="height"), | |
gr.Checkbox(label="auto random seed", value=True), | |
gr.Slider(1, 999999, label="fixed random seed"), | |
gr.Slider(1, 50, value=20, label="inference steps"), | |
gr.Slider(1, 20, value=7, label="guidance scale"), | |
gr.Checkbox(["refine"]) | |
], | |
outputs=[gr.Image(label="generated image"), | |
gr.Textbox(label="properties")], | |
title ="λ€κ΅μ΄ SDXL", | |
#description="νκΈλ‘ νλ SDXL", | |
#article="test" | |
) | |
demo.launch() |