ehristoforu's picture
Duplicate from rinna/japanese-stable-diffusion
c781418
import random
import gradio as gr
import os
import torch
from torch import autocast
from diffusers import LMSDiscreteScheduler
from japanese_stable_diffusion import JapaneseStableDiffusionPipeline
from PIL import Image
from dotenv import load_dotenv
load_dotenv()
ACCESS_TOKEN = os.getenv("ACCESS_TOKEN")
model_id = "rinna/japanese-stable-diffusion"
device = "cuda" if torch.cuda.is_available() else "cpu"
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
pipe = JapaneseStableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, use_auth_token=ACCESS_TOKEN)
pipe.to(device)
pipe.unet.half()
pipe.text_encoder.half()
#torch.backends.cudnn.benchmark = True
def infer(
prompt,
n_samples=4,
guidance_scale=7.5,
steps=50,
seed="random",
):
if seed == "random":
generator = torch.Generator(device=device).manual_seed(int(random.randint(0, 2 ** 32)))
else:
generator = torch.Generator(device=device).manual_seed(int(seed))
with autocast("cuda"):
images_list = pipe(
prompt=[prompt] * int(n_samples),
guidance_scale=guidance_scale,
num_inference_steps=int(steps),
generator=generator
)
images = []
safe_image = Image.open(r"nsfw.png")
for i, image in enumerate(images_list.images):
if (images_list["nsfw_content_detected"][i]):
images.append(safe_image)
else:
images.append(image)
return images
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
color: white;
border-color: black;
background: black;
}
input[type='range'] {
accent-color: black;
}
.dark input[type='range'] {
accent-color: #dfdfdf;
}
.container {
max-width: 730px;
margin: auto;
padding-top: 1.5rem;
}
#gallery {
min-height: 22rem;
margin-bottom: 15px;
margin-left: auto;
margin-right: auto;
border-bottom-right-radius: .5rem !important;
border-bottom-left-radius: .5rem !important;
}
#gallery>div>.h-full {
min-height: 20rem;
}
.details:hover {
text-decoration: underline;
}
.gr-button {
white-space: nowrap;
}
.gr-button:focus {
border-color: rgb(147 197 253 / var(--tw-border-opacity));
outline: none;
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
--tw-border-opacity: 1;
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
--tw-ring-opacity: .5;
}
#advanced-btn {
font-size: .7rem !important;
line-height: 19px;
margin-top: 12px;
margin-bottom: 12px;
padding: 2px 8px;
border-radius: 14px !important;
}
#advanced-options {
display: none;
margin-bottom: 20px;
}
.footer {
margin-bottom: 45px;
margin-top: 35px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
.acknowledgments h4{
margin: 1.25em 0 .25em 0;
font-weight: bold;
font-size: 115%;
}
"""
block = gr.Blocks(css=css)
examples = [
["サラリーマン 油絵", 2, 7.5, 50, "random"],
["キラキラ瞳の猫", 2, 7.5, 50, "random"],
["夕暮れの神社の夏祭りを描いた水彩画", 2, 7.5, 50, "random"]
]
with block:
gr.HTML(
"""
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<svg
width="0.65em"
height="0.65em"
viewBox="0 0 115 115"
fill="none"
xmlns="http://www.w3.org/2000/svg"
>
<rect width="23" height="23" fill="white"></rect>
<rect y="69" width="23" height="23" fill="white"></rect>
<rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="46" width="23" height="23" fill="white"></rect>
<rect x="46" y="69" width="23" height="23" fill="white"></rect>
<rect x="69" width="23" height="23" fill="black"></rect>
<rect x="69" y="69" width="23" height="23" fill="black"></rect>
<rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="115" y="46" width="23" height="23" fill="white"></rect>
<rect x="115" y="115" width="23" height="23" fill="white"></rect>
<rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="92" y="69" width="23" height="23" fill="white"></rect>
<rect x="69" y="46" width="23" height="23" fill="white"></rect>
<rect x="69" y="115" width="23" height="23" fill="white"></rect>
<rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="46" y="46" width="23" height="23" fill="black"></rect>
<rect x="46" y="115" width="23" height="23" fill="black"></rect>
<rect x="46" y="69" width="23" height="23" fill="black"></rect>
<rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="23" y="69" width="23" height="23" fill="black"></rect>
</svg>
<h1 style="font-weight: 900; margin-bottom: 7px;">
Japanese Stable Diffusion Demo
</h1>
</div>
<p style="margin-bottom: 10px; font-size: 94%">
<a
href="https://github.com/rinnakk/japanese-stable-diffusion/"
style="text-decoration: underline;"
target="_blank"
>Japanese Stable Diffusion</a
>
is a Japanese-language specific latent text-to-image diffusion model.
</p>
</div>
"""
)
with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Enter your prompt in Japanese",
show_label=False,
max_lines=1,
placeholder="Enter your prompt in Japanese",
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Generate image").style(
margin=False,
rounded=(False, True, True, False),
)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")
advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
with gr.Row(elem_id="advanced-options"):
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
steps = gr.Slider(label="Steps", minimum=1, maximum=200, value=50, step=1)
scale = gr.Slider(
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
)
seed = gr.Textbox(value='random',
placeholder="If you fix seed, you get same outputs all the time. You can set as integer like 42.",
label="seed")
ex = gr.Examples(
examples=examples, fn=infer, inputs=[text, samples, scale, steps, seed], outputs=gallery, cache_examples=True
)
ex.dataset.headers = [""]
text.submit(infer, inputs=[text, samples, scale, steps, seed], outputs=gallery)
btn.click(infer, inputs=[text, samples, scale, steps, seed], outputs=gallery)
advanced_button.click(
None,
[],
text,
_js="""
() => {
const options = document.querySelector("body > gradio-app").querySelector("#advanced-options");
options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none";
}""",
)
gr.HTML(
"""
<div class="footer">
<p>Model by <a href="https://huggingface.co/rinna" style="text-decoration: underline;" target="_blank">rinna</a> - Gradio Demo by 🤗 Hugging Face
</p>
</div>
<div class="acknowledgments">
<p><h4>LICENSE</h4>
The model is licensed with a <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" style="text-decoration: underline;" target="_blank">CreativeML Open RAIL-M</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a>.</p>
<p><h4>Limitations and Bias</h4>
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases. Japanese Stable Diffusion was trained on Japanese datasets including LAION-5B with Japanese captions, which consists of images that are primarily limited to Japanese descriptions. Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for. This affects the overall output of the model. Further, the ability of the model to generate content with non-Japanese prompts is significantly worse than with Japanese-language prompts. You can read more in the <a href="https://huggingface.co/rinna/japanese-stable-diffusion#limitations-and-bias" style="text-decoration: underline;" target="_blank">model card</a>.</p>
</div>
<br> 
<br>
<i>This demo is based on the
<a
href="https://huggingface.co/spaces/stabilityai/stable-diffusion/"
style="text-decoration: underline;"
target="_blank"
>Stable Diffusion Demo</a
>.</i>
"""
)
block.queue(max_size=25).launch()