import io import os import random import sys import matplotlib.pyplot as plt import numpy as np import streamlit as st import torch from diffusers import StableDiffusionPipeline # from dotenv import load_dotenv from huggingface_hub import notebook_login from PIL import Image # ローカル実行用 .envファイルから環境変数読み込み # load_dotenv(".env") # ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN") # Hugging SpaceのSecret Repoから環境変数読み取り ACCESS_TOKEN = st.secrets["ACCESS_TOKEN"] sys.path.append("./") from simulation import * # シード値の固定 SEED = 42 np.random.seed(seed=SEED) random.seed(SEED) def main(): # 生息地を表すワード HABITAT_WORDS = " Alien from Mars" # パラメーター GENOMS_SIZE = 4 # 遺伝配列 0, 1 のどちらかを要素とした配列のサイズ TOUNAMENT_NUM = 10 # トーナメント方式で競わせる数 CROSSOVER_PB = 0.8 # cross over(交差) する確率 MUTATION_PB = 0.5 # mutation(突然変異)する確率 # グローバル変数 global best POPURATIONS = st.slider( label="人口数", min_value=3, max_value=3000, value=500, ) NUM_GENERATION = st.slider( label="世代数", min_value=10, max_value=10000, value=1000, ) # キーワード候補 word_dict = { "body_size": ["Fingertip sized", "Palm sized", "", "Tall", "Giant"], "body_hair": ["Bald", "Smooth", "", "Furry", "Very Furry"], "herd_num": ["Lone", "Pair", "", "Herd of", "Swarm of"], "eating": ["No teeth", "Herbivorous", "Omnivorous", "Carnivorous", "Fang"], "body_color": [ "Lightest skin", "Lighter skin", "", "Darker skin", "Darkest skin", ], "ferocity": ["Peaceful", "Gentle", "", "Ferocious", "Tyrannical"], } if st.button("実行", key="ga"): st.write("遺伝アルゴリズムの実行") progress_bar_ga = st.progress(0) # create first genetarion generation = create_generation(POPURATIONS, GENOMS_SIZE) progress_bar_ga.progress(50) # アルゴリズムの実行 best, worst = ga_solve( generation, NUM_GENERATION, POPURATIONS, TOUNAMENT_NUM, CROSSOVER_PB, MUTATION_PB, ) progress_bar_ga.progress(100) st.write("遺伝アルゴリズム処理の終了") st.write("画像生成の実行") progress_bar_image = st.progress(0) progress_bar_image.progress(0) pipe = StableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", use_auth_token=ACCESS_TOKEN ) pipe.enable_attention_slicing() progress_bar_image.progress(7) device = "cuda" if torch.cuda.is_available() else "cpu" print("used device is", device) pipe.to(device) # NSFWフィルターの回避 def null_safety(images, **kwargs): return images, False pipe.safety_checker = null_safety last_generation = NUM_GENERATION - 1 plt.figure(figsize=(8, 8)) plt.rcParams["font.size"] = 9 words = ( get_word_for_image_generate(word_dict, best, last_generation) + HABITAT_WORDS ) image = pipe(words)["images"][0] plt.title(f"{last_generation + 1}th\n{words}.") plt.xticks([]) plt.yticks([]) plt.imshow(image) progress_bar_image.progress(100) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png") buf.seek(0) im = Image.open(buf) numpy_image = np.array(im) st.image(numpy_image) if __name__ == "__main__": main()