anteaterho commited on
Commit
0501d0a
β€’
1 Parent(s): 9728fd4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import openai
3
+ import re
4
+ import datetime
5
+
6
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
7
+
8
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, DDIMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler
9
+ import torch
10
+ import random
11
+
12
+ # token for SDXL
13
+ access_token="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"
14
+ # token for OpenAI
15
+ openai.api_key = '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'
16
+
17
+ # output path
18
+ path = "./output"
19
+
20
+ #openai settings
21
+ messages=[{
22
+ 'role' : 'system',
23
+ 'content' : 'You are a helpful assistant for organizing prompt for generating images'
24
+ }]
25
+
26
+ def translate(msg):
27
+ messages.append({
28
+ 'role' : 'assistant',
29
+ 'content' : msg
30
+ })
31
+
32
+ messages.append({
33
+ 'role' : 'user',
34
+ 'content' : 'Translate the sentence into English only. Keep the symbols intact when translating.'
35
+ })
36
+
37
+ res = openai.ChatCompletion.create(
38
+ model='gpt-3.5-turbo',
39
+ messages=messages
40
+ )
41
+
42
+ _msg = res['choices'][0]['message']['content']
43
+
44
+ print(_msg)
45
+
46
+ return _msg
47
+
48
+ # mBart settings
49
+ article_kr = "μœ μ—”μ˜ λŒ€ν‘œλŠ” μ‹œλ¦¬μ•„μ— ꡰ사적인 해결책이 μ—†λ‹€κ³  λ§ν•©λ‹ˆλ‹€." #example article
50
+
51
+ model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
52
+ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
53
+ tokenizer.src_lang = "ko_KR"
54
+
55
+ def translate_mBart(article_kr):
56
+ encoded_ar = tokenizer(article_kr, return_tensors="pt")
57
+ generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
58
+ result = (tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
59
+
60
+ return result[0]
61
+
62
+ # diffusers settings
63
+ lms = LMSDiscreteScheduler(
64
+ beta_start=0.00085,
65
+ beta_end=0.012,
66
+ beta_schedule="scaled_linear"
67
+ )
68
+
69
+ base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
70
+ refine_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
71
+ #pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, scheduler=lms ,use_auth_token=access_token)
72
+ pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16,use_auth_token=access_token)
73
+ #pipeline.load_lora_weights(".", weight_name="fashigirl-v6-sdxl-5ep-resize.safetensors")
74
+ #pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config, rescale_beta_zero_snr=True, timestep_respacing="training")
75
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True, timestep_respacing="linspace")
76
+
77
+ pipeline.to("cuda")
78
+
79
+ refine = StableDiffusionXLImg2ImgPipeline.from_pretrained(refine_model_id, torch_dtype=torch.float16, use_safetensors=True, use_auth_token=access_token)
80
+ refine.to("cuda")
81
+
82
+ prompt = "1girl, solo, long hair, shirt, looking at viewer, white shirt, collared shirt, black eyes, smile, bow, black bow, closed mouth, portrait, brown hair, black hair, straight-on, bowtie, black bowtie, upper body, cloud, sky,huge breasts,shiny,shiny skin,milf,(mature female:1.2),<lora:fashigirl-v6-sdxl-5ep-resize:0.7>"
83
+ negative_prompt = "(low quality:1.3), (worst quality:1.3),(monochrome:0.8),(deformed:1.3),(malformed hands:1.4),(poorly drawn hands:1.4),(mutated fingers:1.4),(bad anatomy:1.3),(extra limbs:1.35),(poorly drawn face:1.4),(watermark:1.3),long neck,text,watermark,signature,logo"
84
+ seed = random.randint(0, 999999)
85
+ generator = torch.manual_seed(seed)
86
+ num_inference_steps = 60
87
+ guidance_scale = 7
88
+
89
+ def text2img(prompt, negative_prompt, x, y, isRandom, fixedRandom ,num_inference_steps, guidance_scale, refine):
90
+ seed = 0
91
+ if isRandom:
92
+ seed = random.randint(0, 999999)
93
+ else:
94
+ seed = int(fixedRandom)
95
+ generator = torch.manual_seed(seed)
96
+
97
+ # Translate prompt with negative prompt
98
+ #allPrompt = (translate(prompt +"/"+ negative_prompt).split("/"))
99
+ allPrompt = ["",""]
100
+ allPrompt[0] = translate_mBart(prompt)
101
+ allPrompt[1] = translate_mBart(negative_prompt)
102
+
103
+ print(len(allPrompt))
104
+ print("prompt : " + allPrompt[0])
105
+ print("negative prompt : " + allPrompt[1])
106
+ _prompt = allPrompt[0]
107
+ if len(allPrompt) > 1:
108
+ _negative_prompt = allPrompt[1]
109
+ else:
110
+ _negative_prompt = " "
111
+
112
+ # Check about it is English
113
+ if _prompt.upper() != _prompt.lower():
114
+ print(" it is an alphabet")
115
+ else:
116
+ print(" it is not an alphabet")
117
+ _prompt = "traffic sign of stop says SDXL"
118
+
119
+ #_negative_prompt = translate(negative_prompt)
120
+ image = pipeline(
121
+ prompt=_prompt, negative_prompt=_negative_prompt, width=int(x), height=int(y), num_inference_steps=int(num_inference_steps), generator=generator, guidance_scale=int(guidance_scale)
122
+ ).images[0]
123
+
124
+ _seed = str(seed)
125
+ _prompt = re.sub(r"[^\uAC00-\uD7A30-9a-zA-Z\s]", "", _prompt)
126
+ timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
127
+ image.save( "./output/" + "sdxl_base_" + "_seed_" + _seed+ "_time_" + timestamp +".png")
128
+ #image.save("sdxl_prompt_" + "_seed_" + _seed + ".png")
129
+ print(seed)
130
+
131
+ if refine:
132
+ image = img2img(prompt=_prompt, negative_prompt=_negative_prompt, image=image)
133
+
134
+ return image
135
+
136
+ return image
137
+
138
+
139
+ def img2img(prompt, negative_prompt, image):
140
+
141
+ image = refine(prompt=prompt, negative_prompt=negative_prompt, image=image).images[0]
142
+ timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
143
+ image.save( "./output/" + "sdxl_refine_" + "_seed_" + timestamp +".png")
144
+
145
+ return image
146
+
147
+
148
+
149
+ demo = gr.Interface(
150
+ fn=text2img,
151
+ inputs=["text", "text", gr.Slider(0,2048), gr.Slider(0,2048), gr.Checkbox(["random"]), "number", "number", "number", gr.Checkbox(["refine"])],
152
+ outputs=["image"],
153
+ title ="ν•œκΈ€λ‘œ ν•˜λŠ” SDXL",
154
+
155
+ )
156
+
157
+
158
+
159
+ demo.launch(share=True, debug=True)
160
+
161
+ #torch.cuda.empty_cache()