koji commited on
Commit
1a94bc1
1 Parent(s): 9b9398a

streamlit app を作成

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.py +152 -0
  3. requirements.txt +3 -1
  4. stable_diffusion_cpu.ipynb +0 -0
.gitignore CHANGED
@@ -2,5 +2,6 @@
2
  *.png
3
  .ipynb_checkpoints/
4
  __pycache__/
 
5
 
6
  !/sample_output_images/**
 
2
  *.png
3
  .ipynb_checkpoints/
4
  __pycache__/
5
+ .env
6
 
7
  !/sample_output_images/**
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import random
4
+ import sys
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import streamlit as st
9
+ import torch
10
+ from diffusers import StableDiffusionPipeline
11
+ from dotenv import load_dotenv
12
+ from huggingface_hub import notebook_login
13
+ from PIL import Image
14
+
15
+ load_dotenv(".env")
16
+
17
+ ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN")
18
+
19
+ sys.path.append("./")
20
+ from simulation import *
21
+
22
+ # シード値の固定
23
+ SEED = 42
24
+ np.random.seed(seed=SEED)
25
+ random.seed(SEED)
26
+
27
+
28
+ def main():
29
+
30
+ # 生息地を表すワード
31
+ HABITAT_WORDS = " Alien from Mars"
32
+
33
+ # パラメーター
34
+ GENOMS_SIZE = 4 # 遺伝配列 0, 1 のどちらかを要素とした配列のサイズ
35
+ TOUNAMENT_NUM = 10 # トーナメント方式で競わせる数
36
+ CROSSOVER_PB = 0.8 # cross over(交差) する確率
37
+ MUTATION_PB = 0.5 # mutation(突然変異)する確率
38
+
39
+ # グローバル変数
40
+ global best
41
+
42
+ POPURATIONS = st.slider(
43
+ label="人口数",
44
+ min_value=3,
45
+ max_value=3000,
46
+ value=500,
47
+ )
48
+
49
+ NUM_GENERATION = st.slider(
50
+ label="世代数",
51
+ min_value=10,
52
+ max_value=10000,
53
+ value=1000,
54
+ )
55
+
56
+ # キーワード候補
57
+ word_dict = {
58
+ "body_size": ["Fingertip sized", "Palm sized", "", "Tall", "Giant"],
59
+ "body_hair": ["Bald", "Smooth", "", "Furry", "Very Furry"],
60
+ "herd_num": ["Lone", "Pair", "", "Herd of", "Swarm of"],
61
+ "eating": ["No teeth", "Herbivorous", "Omnivorous", "Carnivorous", "Fang"],
62
+ "body_color": [
63
+ "Lightest skin",
64
+ "Lighter skin",
65
+ "",
66
+ "Darker skin",
67
+ "Darkest skin",
68
+ ],
69
+ "ferocity": ["Peaceful", "Gentle", "", "Ferocious", "Tyrannical"],
70
+ }
71
+
72
+ if st.button("実行", key="ga"):
73
+
74
+ st.write("遺伝アルゴリズムの実行")
75
+
76
+ progress_bar_ga = st.progress(0)
77
+
78
+ # create first genetarion
79
+ generation = create_generation(POPURATIONS, GENOMS_SIZE)
80
+
81
+ progress_bar_ga.progress(50)
82
+
83
+ # アルゴリズムの実行
84
+ best, worst = ga_solve(
85
+ generation,
86
+ NUM_GENERATION,
87
+ POPURATIONS,
88
+ TOUNAMENT_NUM,
89
+ CROSSOVER_PB,
90
+ MUTATION_PB,
91
+ )
92
+
93
+ progress_bar_ga.progress(100)
94
+
95
+ st.write("遺伝アルゴリズム処理の終了")
96
+
97
+ st.write("画像生成の実行")
98
+
99
+ progress_bar_image = st.progress(0)
100
+
101
+ progress_bar_image.progress(0)
102
+
103
+ pipe = StableDiffusionPipeline.from_pretrained(
104
+ "CompVis/stable-diffusion-v1-4", use_auth_token=ACCESS_TOKEN
105
+ )
106
+ pipe.enable_attention_slicing()
107
+
108
+ progress_bar_image.progress(7)
109
+
110
+ device = "gpu" if torch.cuda.is_available() else "cpu"
111
+
112
+ print("used device is", device)
113
+ pipe.to(device)
114
+
115
+ # NSFWフィルターの回避
116
+ def null_safety(images, **kwargs):
117
+ return images, False
118
+
119
+ pipe.safety_checker = null_safety
120
+
121
+ last_generation = NUM_GENERATION - 1
122
+
123
+ plt.figure(figsize=(8, 8))
124
+ plt.rcParams["font.size"] = 9
125
+
126
+ words = (
127
+ get_word_for_image_generate(word_dict, best, last_generation)
128
+ + HABITAT_WORDS
129
+ )
130
+
131
+ image = pipe(words)["sample"][0]
132
+
133
+ plt.title(f"{last_generation + 1}th\n{words}.")
134
+ plt.xticks([])
135
+ plt.yticks([])
136
+ plt.imshow(image)
137
+
138
+ progress_bar_image.progress(100)
139
+
140
+ plt.tight_layout()
141
+ buf = io.BytesIO()
142
+ plt.savefig(buf, format="png")
143
+
144
+ buf.seek(0)
145
+ im = Image.open(buf)
146
+ numpy_image = np.array(im)
147
+ st.image(numpy_image)
148
+
149
+
150
+ if __name__ == "__main__":
151
+
152
+ main()
requirements.txt CHANGED
@@ -5,4 +5,6 @@ pandas==1.3.5
5
  ftfy==6.1.1
6
  spacy==3.4.1
7
  matplotlib==3.5.3
8
- notebook==6.4.12
 
 
 
5
  ftfy==6.1.1
6
  spacy==3.4.1
7
  matplotlib==3.5.3
8
+ notebook==6.4.12
9
+ streamlit==1.13.0
10
+ python-dotenv==0.21.0
stable_diffusion_cpu.ipynb CHANGED
The diff for this file is too large to render. See raw diff