Spaces:
Running
Running
koji
commited on
Commit
•
1a94bc1
1
Parent(s):
9b9398a
streamlit app を作成
Browse files- .gitignore +1 -0
- app.py +152 -0
- requirements.txt +3 -1
- stable_diffusion_cpu.ipynb +0 -0
.gitignore
CHANGED
@@ -2,5 +2,6 @@
|
|
2 |
*.png
|
3 |
.ipynb_checkpoints/
|
4 |
__pycache__/
|
|
|
5 |
|
6 |
!/sample_output_images/**
|
|
|
2 |
*.png
|
3 |
.ipynb_checkpoints/
|
4 |
__pycache__/
|
5 |
+
.env
|
6 |
|
7 |
!/sample_output_images/**
|
app.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import streamlit as st
|
9 |
+
import torch
|
10 |
+
from diffusers import StableDiffusionPipeline
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
from huggingface_hub import notebook_login
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
load_dotenv(".env")
|
16 |
+
|
17 |
+
ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN")
|
18 |
+
|
19 |
+
sys.path.append("./")
|
20 |
+
from simulation import *
|
21 |
+
|
22 |
+
# シード値の固定
|
23 |
+
SEED = 42
|
24 |
+
np.random.seed(seed=SEED)
|
25 |
+
random.seed(SEED)
|
26 |
+
|
27 |
+
|
28 |
+
def main():
|
29 |
+
|
30 |
+
# 生息地を表すワード
|
31 |
+
HABITAT_WORDS = " Alien from Mars"
|
32 |
+
|
33 |
+
# パラメーター
|
34 |
+
GENOMS_SIZE = 4 # 遺伝配列 0, 1 のどちらかを要素とした配列のサイズ
|
35 |
+
TOUNAMENT_NUM = 10 # トーナメント方式で競わせる数
|
36 |
+
CROSSOVER_PB = 0.8 # cross over(交差) する確率
|
37 |
+
MUTATION_PB = 0.5 # mutation(突然変異)する確率
|
38 |
+
|
39 |
+
# グローバル変数
|
40 |
+
global best
|
41 |
+
|
42 |
+
POPURATIONS = st.slider(
|
43 |
+
label="人口数",
|
44 |
+
min_value=3,
|
45 |
+
max_value=3000,
|
46 |
+
value=500,
|
47 |
+
)
|
48 |
+
|
49 |
+
NUM_GENERATION = st.slider(
|
50 |
+
label="世代数",
|
51 |
+
min_value=10,
|
52 |
+
max_value=10000,
|
53 |
+
value=1000,
|
54 |
+
)
|
55 |
+
|
56 |
+
# キーワード候補
|
57 |
+
word_dict = {
|
58 |
+
"body_size": ["Fingertip sized", "Palm sized", "", "Tall", "Giant"],
|
59 |
+
"body_hair": ["Bald", "Smooth", "", "Furry", "Very Furry"],
|
60 |
+
"herd_num": ["Lone", "Pair", "", "Herd of", "Swarm of"],
|
61 |
+
"eating": ["No teeth", "Herbivorous", "Omnivorous", "Carnivorous", "Fang"],
|
62 |
+
"body_color": [
|
63 |
+
"Lightest skin",
|
64 |
+
"Lighter skin",
|
65 |
+
"",
|
66 |
+
"Darker skin",
|
67 |
+
"Darkest skin",
|
68 |
+
],
|
69 |
+
"ferocity": ["Peaceful", "Gentle", "", "Ferocious", "Tyrannical"],
|
70 |
+
}
|
71 |
+
|
72 |
+
if st.button("実行", key="ga"):
|
73 |
+
|
74 |
+
st.write("遺伝アルゴリズムの実行")
|
75 |
+
|
76 |
+
progress_bar_ga = st.progress(0)
|
77 |
+
|
78 |
+
# create first genetarion
|
79 |
+
generation = create_generation(POPURATIONS, GENOMS_SIZE)
|
80 |
+
|
81 |
+
progress_bar_ga.progress(50)
|
82 |
+
|
83 |
+
# アルゴリズムの実行
|
84 |
+
best, worst = ga_solve(
|
85 |
+
generation,
|
86 |
+
NUM_GENERATION,
|
87 |
+
POPURATIONS,
|
88 |
+
TOUNAMENT_NUM,
|
89 |
+
CROSSOVER_PB,
|
90 |
+
MUTATION_PB,
|
91 |
+
)
|
92 |
+
|
93 |
+
progress_bar_ga.progress(100)
|
94 |
+
|
95 |
+
st.write("遺伝アルゴリズム処理の終了")
|
96 |
+
|
97 |
+
st.write("画像生成の実行")
|
98 |
+
|
99 |
+
progress_bar_image = st.progress(0)
|
100 |
+
|
101 |
+
progress_bar_image.progress(0)
|
102 |
+
|
103 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
104 |
+
"CompVis/stable-diffusion-v1-4", use_auth_token=ACCESS_TOKEN
|
105 |
+
)
|
106 |
+
pipe.enable_attention_slicing()
|
107 |
+
|
108 |
+
progress_bar_image.progress(7)
|
109 |
+
|
110 |
+
device = "gpu" if torch.cuda.is_available() else "cpu"
|
111 |
+
|
112 |
+
print("used device is", device)
|
113 |
+
pipe.to(device)
|
114 |
+
|
115 |
+
# NSFWフィルターの回避
|
116 |
+
def null_safety(images, **kwargs):
|
117 |
+
return images, False
|
118 |
+
|
119 |
+
pipe.safety_checker = null_safety
|
120 |
+
|
121 |
+
last_generation = NUM_GENERATION - 1
|
122 |
+
|
123 |
+
plt.figure(figsize=(8, 8))
|
124 |
+
plt.rcParams["font.size"] = 9
|
125 |
+
|
126 |
+
words = (
|
127 |
+
get_word_for_image_generate(word_dict, best, last_generation)
|
128 |
+
+ HABITAT_WORDS
|
129 |
+
)
|
130 |
+
|
131 |
+
image = pipe(words)["sample"][0]
|
132 |
+
|
133 |
+
plt.title(f"{last_generation + 1}th\n{words}.")
|
134 |
+
plt.xticks([])
|
135 |
+
plt.yticks([])
|
136 |
+
plt.imshow(image)
|
137 |
+
|
138 |
+
progress_bar_image.progress(100)
|
139 |
+
|
140 |
+
plt.tight_layout()
|
141 |
+
buf = io.BytesIO()
|
142 |
+
plt.savefig(buf, format="png")
|
143 |
+
|
144 |
+
buf.seek(0)
|
145 |
+
im = Image.open(buf)
|
146 |
+
numpy_image = np.array(im)
|
147 |
+
st.image(numpy_image)
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
|
152 |
+
main()
|
requirements.txt
CHANGED
@@ -5,4 +5,6 @@ pandas==1.3.5
|
|
5 |
ftfy==6.1.1
|
6 |
spacy==3.4.1
|
7 |
matplotlib==3.5.3
|
8 |
-
notebook==6.4.12
|
|
|
|
|
|
5 |
ftfy==6.1.1
|
6 |
spacy==3.4.1
|
7 |
matplotlib==3.5.3
|
8 |
+
notebook==6.4.12
|
9 |
+
streamlit==1.13.0
|
10 |
+
python-dotenv==0.21.0
|
stable_diffusion_cpu.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|