baxin commited on
Commit
8836e4a
·
verified ·
1 Parent(s): 3927af4

Create app.py

Browse files
Files changed (1) hide show
  1. src/app.py +129 -0
src/app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ from cerebras.cloud.sdk import Cerebras
4
+ import openai
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+ # --- Assuming config.py and utils.py exist ---
9
+ import config
10
+ import utils
11
+
12
+ # --- BASE_PROMPT のインポート ---
13
+ try:
14
+ from prompt import BASE_PROMPT
15
+ except ImportError:
16
+ st.error(
17
+ "Error: 'prompt.py' not found or 'BASE_PROMPT' is not defined within it.")
18
+ st.stop()
19
+
20
+ # --- Import column rendering functions ---
21
+ from chat_column import render_chat_column
22
+
23
+ # --- 環境変数読み込み ---
24
+ load_dotenv()
25
+
26
+ # --- Streamlit ページ設定 ---
27
+ st.set_page_config(page_icon="🤖x🎬", layout="wide",
28
+ page_title="Veo3 JSON Creator")
29
+
30
+ # --- UI 表示 ---
31
+ utils.display_icon("🤖x🎬")
32
+ st.title("Veo3 JSON Creator")
33
+ st.subheader("Generate json for Veo3",
34
+ divider="blue", anchor=False)
35
+
36
+ # --- APIキーの処理 ---
37
+ # (API Key logic remains the same)
38
+ api_key_from_env = os.getenv("CEREBRAS_API_KEY")
39
+ show_api_key_input = not bool(api_key_from_env)
40
+ cerebras_api_key = None
41
+
42
+ # --- サイドバーの設定 ---
43
+ # (Sidebar logic remains the same)
44
+ with st.sidebar:
45
+ st.title("Settings")
46
+ # Cerebras Key Input
47
+ if show_api_key_input:
48
+ st.markdown("### :red[Enter your Cerebras API Key below]")
49
+ api_key_input = st.text_input(
50
+ "Cerebras API Key:", type="password", key="cerebras_api_key_input_field")
51
+ if api_key_input:
52
+ cerebras_api_key = api_key_input
53
+ else:
54
+ cerebras_api_key = api_key_from_env
55
+ st.success("✓ Cerebras API Key loaded from environment")
56
+
57
+ # Model selection
58
+ model_option = st.selectbox(
59
+ "Choose a LLM model:",
60
+ options=list(config.MODELS.keys()),
61
+ format_func=lambda x: config.MODELS[x]["name"],
62
+ key="model_select"
63
+ )
64
+ # Max tokens slider
65
+ max_tokens_range = config.MODELS[model_option]["tokens"]
66
+ default_tokens = min(2048, max_tokens_range)
67
+ max_tokens = st.slider(
68
+ "Max Tokens (LLM):",
69
+ min_value=512,
70
+ max_value=max_tokens_range,
71
+ value=default_tokens,
72
+ step=512,
73
+ help="Max tokens for the LLM's text prompt response."
74
+ )
75
+
76
+
77
+ # --- メインアプリケーションロジック ---
78
+ # Re-check Cerebras API key
79
+ if not cerebras_api_key and show_api_key_input and 'cerebras_api_key_input_field' in st.session_state and st.session_state.cerebras_api_key_input_field:
80
+ cerebras_api_key = st.session_state.cerebras_api_key_input_field
81
+
82
+ if not cerebras_api_key:
83
+ st.error("Cerebras API Key is required. Please enter it in the sidebar or set the CEREBRAS_API_KEY environment variable.", icon="🚨")
84
+ st.stop()
85
+
86
+ # APIクライアント初期化
87
+ # (Client initialization remains the same)
88
+ llm_client = None
89
+ image_client = None
90
+ try:
91
+ llm_client = Cerebras(api_key=cerebras_api_key)
92
+ except Exception as e:
93
+ st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨")
94
+ st.stop()
95
+
96
+
97
+ # --- Session State Initialization ---
98
+ # Initialize state variables if they don't exist
99
+ if "messages" not in st.session_state:
100
+ st.session_state.messages = []
101
+ if "current_image_prompt_text" not in st.session_state:
102
+ st.session_state.current_image_prompt_text = ""
103
+ # --- MODIFICATION START ---
104
+ # Replace single image state with a list to store multiple images and their prompts
105
+ if "generated_images_list" not in st.session_state:
106
+ st.session_state.generated_images_list = [] # Initialize as empty list
107
+ # Remove old state variable if it exists (optional cleanup)
108
+ if "latest_generated_image" in st.session_state:
109
+ del st.session_state["latest_generated_image"]
110
+ # --- MODIFICATION END ---
111
+ if "selected_model" not in st.session_state:
112
+ st.session_state.selected_model = None
113
+
114
+
115
+ # --- Track selected model, but do not clear chat or image state on model change ---
116
+ if st.session_state.selected_model != model_option:
117
+ st.session_state.selected_model = model_option
118
+ # Optionally rerun to update UI, but do not clear messages or images
119
+ st.rerun()
120
+
121
+ # --- Define Main Columns ---
122
+ chat_col, image_col = st.columns([2, 1])
123
+
124
+ # --- Render Columns using imported functions ---
125
+ with chat_col:
126
+ render_chat_column(st, llm_client, model_option, max_tokens, BASE_PROMPT)
127
+
128
+ # --- Footer ---
129
+ st.markdown('<div style="text-align: center; margin-top: 2em; color: #888; font-size: 1.1em;">made with 💙 by baxin</div>', unsafe_allow_html=True)