DGSpitzer commited on
Commit
33c29ab
1 Parent(s): 159f3bb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline
2
+ import gradio as gr
3
+ import torch
4
+
5
+ models = [
6
+ "DGSpitzer/Cyberpunk-Anime-Diffusion"
7
+ ]
8
+
9
+ prompt_prefixes = {
10
+ models[0]: "cyberpunk anime style"
11
+ }
12
+
13
+ current_model = models[0]
14
+ pipe = StableDiffusionPipeline.from_pretrained(current_model, torch_dtype=torch.float16)
15
+ if torch.cuda.is_available():
16
+ pipe = pipe.to("cuda")
17
+
18
+ device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
19
+
20
+ def on_model_change(model):
21
+
22
+ global current_model
23
+ global pipe
24
+ if model != current_model:
25
+ current_model = model
26
+ pipe = StableDiffusionPipeline.from_pretrained(current_model, torch_dtype=torch.float16)
27
+ if torch.cuda.is_available():
28
+ pipe = pipe.to("cuda")
29
+
30
+ def inference(prompt, guidance, steps):
31
+
32
+ prompt = prompt_prefixes[current_model] + prompt
33
+ image = pipe(prompt, num_inference_steps=int(steps), guidance_scale=guidance, width=512, height=512).images[0]
34
+ return image
35
+
36
+ with gr.Blocks() as demo:
37
+ gr.HTML(
38
+ """
39
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
40
+ <div
41
+ style="
42
+ display: inline-flex;
43
+ align-items: center;
44
+ gap: 0.8rem;
45
+ font-size: 1.75rem;
46
+ "
47
+ >
48
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
49
+ DGS Diffusion Space
50
+ </h1>
51
+ </div>
52
+ <p style="margin-bottom: 10px; font-size: 94%">
53
+ Demo for Cyberpunk Anime Diffusion. Based of Finetuned Diffusion by anzorq <a href="https://twitter.com/hahahahohohe">
54
+ </p>
55
+ </div>
56
+ """
57
+ )
58
+ with gr.Row():
59
+
60
+ with gr.Column():
61
+ model = gr.Dropdown(label="Model", choices=models, value=models[0])
62
+ prompt = gr.Textbox(label="Prompt", placeholder="{} is added automatically".format(prompt_prefixes[current_model]))
63
+ guidance = gr.Slider(label="Guidance scale", value=7.5, maximum=15)
64
+ steps = gr.Slider(label="Steps", value=50, maximum=100, minimum=2)
65
+ run = gr.Button(value="Run")
66
+ gr.Markdown(f"Running on: {device}")
67
+ with gr.Column():
68
+ image_out = gr.Image(height=512)
69
+
70
+ model.change(on_model_change, inputs=model, outputs=[])
71
+ run.click(inference, inputs=[prompt, guidance, steps], outputs=image_out)
72
+ gr.Examples([
73
+ ["portrait of a girl in dgs illustration style, Anime girl, female soldier working in a cyberpunk city, cleavage, ((perfect femine face)), intricate, 8k, highly detailed, shy, digital painting, intense, sharp focus", 7, 20],
74
+ ["a photo of muscular beard soldier male in dgs illustration style, half-body, holding robot arms, strong chest", 7.0, 20],
75
+ ["portrait of ((Harry Potter)) muscular ((male)) in dgs illustration style, photorealistic painting, soldier working in a cyberpunk city, cleavage, intricate, 8k, highly detailed, digital painting, intense, sharp focus", 7, 20],
76
+ ["portrait of (liu yifei) girl in dgs illustration style, soldier working in a cyberpunk city, cleavage, intricate, 8k, highly detailed, digital painting, intense, sharp focus", 7, 20],
77
+ ["portrait of in dgs illustration style, soldier working in a cyberpunk city, cleavage, intricate, 8k, highly detailed, digital painting, intense, sharp focus", 7, 20],
78
+ ], [prompt, guidance, steps], image_out, inference, cache_examples=torch.cuda.is_available())
79
+ gr.HTML('''
80
+ <div>
81
+ <p>Model by <a href="https://huggingface.co/DGSpitzer" style="text-decoration: underline;" target="_blank">@dgspitzer</a> ❤️</p>
82
+ </div>
83
+ <div>Space by
84
+ <a href="https://twitter.com/DGSpitzer">
85
+ <img alt="Twitter Follow" src="https://img.shields.io/twitter/follow/DGSpitzer?label=%40DGSpitzer&style=social">
86
+ </a>
87
+ </div>
88
+ ''')
89
+
90
+ demo.queue()
91
+ demo.launch()