Spaces:
Paused
Paused
p@localhost
commited on
Commit
•
ee589a1
1
Parent(s):
872f4b4
Add application files
Browse files- app.py +51 -0
- model.py +48 -0
- requirements.txt +9 -0
app.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
def generate(prompt, n_prompt, modelName):
|
4 |
+
return models[modelName].process(prompt, n_prompt)
|
5 |
+
|
6 |
+
def create_demo():
|
7 |
+
with gr.Blocks() as demo:
|
8 |
+
with gr.Column():
|
9 |
+
prompt = gr.Textbox(label='Prompt')
|
10 |
+
n_prompt = gr.Textbox( label='Negative Prompt', value= 'low res')
|
11 |
+
modelName = gr.Dropdown(choices = list(models.keys()), label = "Model", value=list(models.keys())[0])
|
12 |
+
|
13 |
+
run_button = gr.Button('Generar')
|
14 |
+
|
15 |
+
gr.Markdown("lunarnaut txt2img demo")
|
16 |
+
result = gr.Gallery(label='Output', show_label=False, elem_id='gallery').style(columns=1, rows=1, preview=True)
|
17 |
+
|
18 |
+
inputs = [
|
19 |
+
prompt,
|
20 |
+
n_prompt,
|
21 |
+
modelName,
|
22 |
+
]
|
23 |
+
|
24 |
+
prompt.submit(
|
25 |
+
fn=generate,
|
26 |
+
inputs=inputs,
|
27 |
+
outputs=result
|
28 |
+
)
|
29 |
+
n_prompt.submit(
|
30 |
+
fn=generate,
|
31 |
+
inputs=inputs,
|
32 |
+
outputs=result
|
33 |
+
)
|
34 |
+
|
35 |
+
run_button.click(
|
36 |
+
fn=generate,
|
37 |
+
inputs=inputs,
|
38 |
+
outputs=result
|
39 |
+
)
|
40 |
+
return demo
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
from model import Model
|
45 |
+
models = {
|
46 |
+
"Stable Diffusion v1.5": Model("runwayml/stable-diffusion-v1-5"),
|
47 |
+
"Anything v3.0": Model("Linaqruf/anything-v3.0"),
|
48 |
+
"Realistic Vision v2.0": Model("SG161222/Realistic_Vision_V2.0"),
|
49 |
+
}
|
50 |
+
demo = create_demo()
|
51 |
+
demo.queue().launch()
|
model.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
|
4 |
+
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
|
5 |
+
from diffusers import DPMSolverMultistepScheduler
|
6 |
+
import torch
|
7 |
+
import PIL.Image
|
8 |
+
import numpy as np
|
9 |
+
import datetime
|
10 |
+
|
11 |
+
# Check environment
|
12 |
+
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
13 |
+
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
14 |
+
|
15 |
+
device = "cuda"
|
16 |
+
|
17 |
+
class Model:
|
18 |
+
def __init__(self, modelID):
|
19 |
+
|
20 |
+
self.modelID = modelID
|
21 |
+
self.pipe = StableDiffusionPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
|
22 |
+
self.pipe = self.pipe.to(device)
|
23 |
+
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
|
24 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
25 |
+
|
26 |
+
def process(self,
|
27 |
+
prompt: str,
|
28 |
+
negative_prompt: str,
|
29 |
+
guidance_scale:int = 7,
|
30 |
+
num_images:int = 1,
|
31 |
+
num_steps:int = 20,
|
32 |
+
):
|
33 |
+
seed = np.random.randint(0, np.iinfo(np.int32).max)
|
34 |
+
generator = torch.Generator(device).manual_seed(seed)
|
35 |
+
now = datetime.datetime.now()
|
36 |
+
print(now)
|
37 |
+
print(self.modelID)
|
38 |
+
print(prompt)
|
39 |
+
print(negative_prompt)
|
40 |
+
with torch.inference_mode():
|
41 |
+
images = self.pipe(prompt=prompt,
|
42 |
+
negative_prompt=negative_prompt,
|
43 |
+
guidance_scale=guidance_scale,
|
44 |
+
num_images_per_prompt=num_images,
|
45 |
+
num_inference_steps=num_steps,
|
46 |
+
generator=generator).images
|
47 |
+
|
48 |
+
return images
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
+
torch
|
3 |
+
accelerate==0.18.0
|
4 |
+
diffusers==0.16.0
|
5 |
+
gradio==3.30.0
|
6 |
+
safetensors==0.3.0
|
7 |
+
torchvision==0.15.1
|
8 |
+
transformers==4.28.1
|
9 |
+
xformers==0.0.18
|