Spaces:
Runtime error
Runtime error
anteaterho
commited on
Commit
β’
0501d0a
1
Parent(s):
9728fd4
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import openai
|
3 |
+
import re
|
4 |
+
import datetime
|
5 |
+
|
6 |
+
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
7 |
+
|
8 |
+
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, DDIMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
|
12 |
+
# token for SDXL
|
13 |
+
access_token="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"
|
14 |
+
# token for OpenAI
|
15 |
+
openai.api_key = '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'
|
16 |
+
|
17 |
+
# output path
|
18 |
+
path = "./output"
|
19 |
+
|
20 |
+
#openai settings
|
21 |
+
messages=[{
|
22 |
+
'role' : 'system',
|
23 |
+
'content' : 'You are a helpful assistant for organizing prompt for generating images'
|
24 |
+
}]
|
25 |
+
|
26 |
+
def translate(msg):
|
27 |
+
messages.append({
|
28 |
+
'role' : 'assistant',
|
29 |
+
'content' : msg
|
30 |
+
})
|
31 |
+
|
32 |
+
messages.append({
|
33 |
+
'role' : 'user',
|
34 |
+
'content' : 'Translate the sentence into English only. Keep the symbols intact when translating.'
|
35 |
+
})
|
36 |
+
|
37 |
+
res = openai.ChatCompletion.create(
|
38 |
+
model='gpt-3.5-turbo',
|
39 |
+
messages=messages
|
40 |
+
)
|
41 |
+
|
42 |
+
_msg = res['choices'][0]['message']['content']
|
43 |
+
|
44 |
+
print(_msg)
|
45 |
+
|
46 |
+
return _msg
|
47 |
+
|
48 |
+
# mBart settings
|
49 |
+
article_kr = "μ μμ λνλ μ리μμ κ΅°μ¬μ μΈ ν΄κ²°μ±
μ΄ μλ€κ³ λ§ν©λλ€." #example article
|
50 |
+
|
51 |
+
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
52 |
+
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
53 |
+
tokenizer.src_lang = "ko_KR"
|
54 |
+
|
55 |
+
def translate_mBart(article_kr):
|
56 |
+
encoded_ar = tokenizer(article_kr, return_tensors="pt")
|
57 |
+
generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
|
58 |
+
result = (tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
|
59 |
+
|
60 |
+
return result[0]
|
61 |
+
|
62 |
+
# diffusers settings
|
63 |
+
lms = LMSDiscreteScheduler(
|
64 |
+
beta_start=0.00085,
|
65 |
+
beta_end=0.012,
|
66 |
+
beta_schedule="scaled_linear"
|
67 |
+
)
|
68 |
+
|
69 |
+
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
70 |
+
refine_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
71 |
+
#pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, scheduler=lms ,use_auth_token=access_token)
|
72 |
+
pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16,use_auth_token=access_token)
|
73 |
+
#pipeline.load_lora_weights(".", weight_name="fashigirl-v6-sdxl-5ep-resize.safetensors")
|
74 |
+
#pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config, rescale_beta_zero_snr=True, timestep_respacing="training")
|
75 |
+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True, timestep_respacing="linspace")
|
76 |
+
|
77 |
+
pipeline.to("cuda")
|
78 |
+
|
79 |
+
refine = StableDiffusionXLImg2ImgPipeline.from_pretrained(refine_model_id, torch_dtype=torch.float16, use_safetensors=True, use_auth_token=access_token)
|
80 |
+
refine.to("cuda")
|
81 |
+
|
82 |
+
prompt = "1girl, solo, long hair, shirt, looking at viewer, white shirt, collared shirt, black eyes, smile, bow, black bow, closed mouth, portrait, brown hair, black hair, straight-on, bowtie, black bowtie, upper body, cloud, sky,huge breasts,shiny,shiny skin,milf,(mature female:1.2),<lora:fashigirl-v6-sdxl-5ep-resize:0.7>"
|
83 |
+
negative_prompt = "(low quality:1.3), (worst quality:1.3),(monochrome:0.8),(deformed:1.3),(malformed hands:1.4),(poorly drawn hands:1.4),(mutated fingers:1.4),(bad anatomy:1.3),(extra limbs:1.35),(poorly drawn face:1.4),(watermark:1.3),long neck,text,watermark,signature,logo"
|
84 |
+
seed = random.randint(0, 999999)
|
85 |
+
generator = torch.manual_seed(seed)
|
86 |
+
num_inference_steps = 60
|
87 |
+
guidance_scale = 7
|
88 |
+
|
89 |
+
def text2img(prompt, negative_prompt, x, y, isRandom, fixedRandom ,num_inference_steps, guidance_scale, refine):
|
90 |
+
seed = 0
|
91 |
+
if isRandom:
|
92 |
+
seed = random.randint(0, 999999)
|
93 |
+
else:
|
94 |
+
seed = int(fixedRandom)
|
95 |
+
generator = torch.manual_seed(seed)
|
96 |
+
|
97 |
+
# Translate prompt with negative prompt
|
98 |
+
#allPrompt = (translate(prompt +"/"+ negative_prompt).split("/"))
|
99 |
+
allPrompt = ["",""]
|
100 |
+
allPrompt[0] = translate_mBart(prompt)
|
101 |
+
allPrompt[1] = translate_mBart(negative_prompt)
|
102 |
+
|
103 |
+
print(len(allPrompt))
|
104 |
+
print("prompt : " + allPrompt[0])
|
105 |
+
print("negative prompt : " + allPrompt[1])
|
106 |
+
_prompt = allPrompt[0]
|
107 |
+
if len(allPrompt) > 1:
|
108 |
+
_negative_prompt = allPrompt[1]
|
109 |
+
else:
|
110 |
+
_negative_prompt = " "
|
111 |
+
|
112 |
+
# Check about it is English
|
113 |
+
if _prompt.upper() != _prompt.lower():
|
114 |
+
print(" it is an alphabet")
|
115 |
+
else:
|
116 |
+
print(" it is not an alphabet")
|
117 |
+
_prompt = "traffic sign of stop says SDXL"
|
118 |
+
|
119 |
+
#_negative_prompt = translate(negative_prompt)
|
120 |
+
image = pipeline(
|
121 |
+
prompt=_prompt, negative_prompt=_negative_prompt, width=int(x), height=int(y), num_inference_steps=int(num_inference_steps), generator=generator, guidance_scale=int(guidance_scale)
|
122 |
+
).images[0]
|
123 |
+
|
124 |
+
_seed = str(seed)
|
125 |
+
_prompt = re.sub(r"[^\uAC00-\uD7A30-9a-zA-Z\s]", "", _prompt)
|
126 |
+
timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
|
127 |
+
image.save( "./output/" + "sdxl_base_" + "_seed_" + _seed+ "_time_" + timestamp +".png")
|
128 |
+
#image.save("sdxl_prompt_" + "_seed_" + _seed + ".png")
|
129 |
+
print(seed)
|
130 |
+
|
131 |
+
if refine:
|
132 |
+
image = img2img(prompt=_prompt, negative_prompt=_negative_prompt, image=image)
|
133 |
+
|
134 |
+
return image
|
135 |
+
|
136 |
+
return image
|
137 |
+
|
138 |
+
|
139 |
+
def img2img(prompt, negative_prompt, image):
|
140 |
+
|
141 |
+
image = refine(prompt=prompt, negative_prompt=negative_prompt, image=image).images[0]
|
142 |
+
timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
|
143 |
+
image.save( "./output/" + "sdxl_refine_" + "_seed_" + timestamp +".png")
|
144 |
+
|
145 |
+
return image
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
demo = gr.Interface(
|
150 |
+
fn=text2img,
|
151 |
+
inputs=["text", "text", gr.Slider(0,2048), gr.Slider(0,2048), gr.Checkbox(["random"]), "number", "number", "number", gr.Checkbox(["refine"])],
|
152 |
+
outputs=["image"],
|
153 |
+
title ="νκΈλ‘ νλ SDXL",
|
154 |
+
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
demo.launch(share=True, debug=True)
|
160 |
+
|
161 |
+
#torch.cuda.empty_cache()
|