fudii0921 commited on
Commit
96a52ff
·
verified ·
1 Parent(s): 8a42c81

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +247 -0
app.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from pathlib import Path
4
+ from typing import Annotated
5
+ from autogen import AssistantAgent, UserProxyAgent
6
+ from autogen.coding import LocalCommandLineCodeExecutor
7
+ import gradio as gr
8
+ from autogen import ConversableAgent
9
+ from autogen import register_function
10
+ import mysql.connector
11
+ import random
12
+ import requests
13
+ from groq import Groq
14
+ from dotenv import load_dotenv
15
+
16
+ tool_resp = ""
17
+
18
+ js = """
19
+ function createGradioAnimation() {
20
+ var container = document.createElement('div');
21
+ container.id = 'gradio-animation';
22
+ container.style.fontSize = '2em';
23
+ container.style.fontWeight = 'bold';
24
+ container.style.textAlign = 'center';
25
+ container.style.marginBottom = '20px';
26
+
27
+ var text = '部門収益分析';
28
+ for (var i = 0; i < text.length; i++) {
29
+ (function(i){
30
+ setTimeout(function(){
31
+ var letter = document.createElement('span');
32
+ var randomColor = "#" + Math.floor(Math.random() * 16777215).toString(16);
33
+ letter.style.color = randomColor;
34
+ letter.style.opacity = '0';
35
+ letter.style.transition = 'opacity 0.5s';
36
+ letter.innerText = text[i];
37
+
38
+ container.appendChild(letter);
39
+
40
+ setTimeout(function() {
41
+ letter.style.opacity = '1';
42
+ }, 50);
43
+
44
+ // Blink the text 3 times
45
+ for (var j = 0; j < 3; j++) {
46
+ setTimeout(function() {
47
+ letter.style.opacity = '0';
48
+ }, 500 + j * 1000);
49
+ setTimeout(function() {
50
+ letter.style.opacity = '1';
51
+ }, 1000 + j * 1000);
52
+ }
53
+ }, i * 250);
54
+ })(i);
55
+ }
56
+
57
+ var gradioContainer = document.querySelector('.gradio-container');
58
+ gradioContainer.insertBefore(container, gradioContainer.firstChild);
59
+
60
+ return 'Animation created';
61
+ }
62
+ """
63
+
64
+ load_dotenv(verbose=True)
65
+
66
+ conn = mysql.connector.connect(
67
+ host=os.environ.get("HOST"),
68
+ user=os.environ.get("USER_NAME"),
69
+ password=os.environ.get("PASSWORD"),
70
+ port=os.environ.get("PORT"),
71
+ database=os.environ.get("DB"),
72
+ ssl_disabled=True
73
+ )
74
+
75
+ cursor = conn.cursor(dictionary=True)
76
+
77
+ def get_rounrobin():
78
+ select_one_data_query = "select api from agentic_apis_count order by counts ASC"
79
+ cursor.execute(select_one_data_query)
80
+ result = cursor.fetchall()
81
+ first_api = result[0]['api']
82
+ return first_api
83
+
84
+ # MySQLに接続
85
+ def get_api_keys():
86
+ token = get_rounrobin()
87
+ os.environ["GROQ_API_KEY"] = token
88
+
89
+ return token
90
+
91
+ # Configure Groq
92
+ config_list = [{
93
+ "model": "llama-3.3-70b-versatile",
94
+ "api_key": os.environ["GROQ_API_KEY"],
95
+ "api_type": "groq"
96
+ }]
97
+
98
+
99
+ # Create a directory to store code files from code executor
100
+ work_dir = Path("coding")
101
+ work_dir.mkdir(exist_ok=True)
102
+ code_executor = LocalCommandLineCodeExecutor(work_dir=work_dir)
103
+
104
+ # Define revenue tool
105
+ #def get_current_revenue(location, unit="yen"):
106
+ def get_current_revenue(location):
107
+ """Get the revenue for some location"""
108
+ data = requests.get('https://www.ryhintl.com/dbjson/getjson?sqlcmd=select `title` as country,`snippet` as revenue from cohere_documents_auto')
109
+ # 元のデータ
110
+ data = json.loads(data.content)
111
+
112
+ # 指定された形式に変換
113
+ revenue_data = {item["country"]: {"revenue": item["revenue"]} for item in data}
114
+ #print("revenue data:",revenue_data)
115
+ tmp = json.dumps({
116
+ "location": location.title(),
117
+ "revenue": revenue_data[location]["revenue"],
118
+ "unit": ""
119
+ #"unit": unit
120
+ })
121
+ #print("tmp:",tmp)
122
+
123
+ return json.dumps({
124
+ "location": location.title(),
125
+ "revenue": revenue_data[location]["revenue"],
126
+ "unit": ""
127
+ })
128
+
129
+
130
+ #return json.dumps({"location": location, "revenue": "unknown"})
131
+
132
+ # Create an AI assistant that uses the kpi tool
133
+ assistant = AssistantAgent(
134
+ #assistant = ConversableAgent(
135
+ name="groq_assistant",
136
+ system_message="""あなたは、次のことができる役に立つAIアシスタントです。
137
+ - 情報検索ツールを使用する
138
+ - 結果を分析して自然言語のみで説明する""",
139
+ llm_config={"config_list": config_list}
140
+ )
141
+
142
+ # Create a user proxy agent that only handles code execution
143
+ user_proxy = UserProxyAgent(
144
+ #user_proxy = ConversableAgent(
145
+ name="user_proxy",
146
+ human_input_mode="NEVER",
147
+ code_execution_config={"work_dir":"coding", "use_docker":False},
148
+ max_consecutive_auto_reply=2,
149
+ #llm_config={"config_list": config_list}
150
+ )
151
+
152
+ '''user_proxy.register_function(
153
+ function_map={
154
+ "get_current_revenue": get_current_revenue
155
+ }
156
+ )'''
157
+
158
+
159
+
160
+
161
+ # Register weather tool with the assistant
162
+ @user_proxy.register_for_execution()
163
+ @assistant.register_for_llm(description="snippetの内容")
164
+ #@user_proxy.register_for_llm(description="Weather forecast for cities.")
165
+ def revenue_analysis(
166
+ location: Annotated[str, "title"]
167
+ #unit: Annotated[str, "Revenue unit (dollar/yen)"] = "yen"
168
+ ) -> str:
169
+ #revenue_details = get_current_revenue(location=location, unit=unit)
170
+ revenue_details = get_current_revenue(location=location)
171
+ revenues = json.loads(revenue_details)
172
+ #print("resp:",f"{revenues['location']}の内容は{revenues['revenue']}")
173
+ global tool_resp
174
+ tool_resp = tool_resp + f"\n\n{location}\n{revenues['location']}の内容は{revenues['revenue']}"
175
+
176
+ return f"{revenues['location']}の内容は{revenues['revenue']}"
177
+
178
+ def get_revenue_and_plot(div1, div2, div3):
179
+ get_api_keys()
180
+
181
+ # Start the conversation
182
+ resp = user_proxy.initiate_chat(
183
+ assistant,
184
+ message=f"""3つのことをやってみましょう:
185
+ 1. {div1}、{div2}、{div3}の内容をtoolを利用して抽出します。
186
+ 2. toolを利用して抽出された内容を詳しく分析します。
187
+ 3. 日本語で説明してください。
188
+ """
189
+ )
190
+
191
+ total_tokens = resp.cost['usage_including_cached_inference']['llama-3.3-70b-versatile']['total_tokens']
192
+
193
+ #update counts
194
+ select_one_data_query = "SELECT counts FROM agentic_apis_count where api = '"+os.environ["GROQ_API_KEY"]+"'"
195
+ cursor.execute(select_one_data_query)
196
+ ext_key = cursor.fetchall()
197
+ key = [item['counts'] for item in ext_key]
198
+ calculated = key[0]+total_tokens/10000
199
+
200
+ update_counts_query = "UPDATE agentic_apis_count SET counts = "+str(calculated)+" WHERE api = '"+os.environ["GROQ_API_KEY"]+"'"
201
+
202
+ cursor.execute(update_counts_query)
203
+ conn.commit()
204
+
205
+ groq_assistant_contents = [entry['content'] for entry in resp.chat_history if entry['role'] == 'user' and entry['name'] == 'groq_assistant']
206
+
207
+ global tool_resp
208
+ client = Groq(api_key=os.environ["GROQ_API_KEY"])
209
+ system_prompt = {
210
+ "role": "system",
211
+ "content": "You are a helpful assistant, answer questions concisely."
212
+ }
213
+
214
+ # Set the user prompt
215
+ user_input = tool_resp+"を要約してください。"
216
+ user_prompt = {
217
+ "role": "user", "content": user_input
218
+ }
219
+
220
+ # Initialize the chat history
221
+ chat_history = [system_prompt, user_prompt]
222
+
223
+ response = client.chat.completions.create(
224
+ model="llama-3.3-70b-versatile",
225
+ messages=chat_history,
226
+ max_tokens=1024,
227
+ temperature=0)
228
+
229
+ kekka = response.choices[0].message.content
230
+
231
+ usages = "使用トークン数: "+str(total_tokens)+ " \n"+kekka
232
+ return groq_assistant_contents,usages
233
+
234
+ # Create Gradio interface
235
+ iface = gr.Interface(
236
+ js=js,
237
+ fn=get_revenue_and_plot,
238
+ inputs=[gr.Dropdown(choices=["上期経営会議議事録", "セキュリティー会議資料", "コーポレートガバナンス会議資料"], label="上期経営会議議事録",value="上期経営会議議事録"), gr.Dropdown(choices=["上期経営会議議事録", "セキュリティー会議資料", "コーポレートガバナンス会議資料"], label="セキュリティー会議資料",value="セキュリティー会議資料"), gr.Dropdown(choices=["上期経営会議議事録", "セキュリティー会議資料", "コーポレートガバナンス会議資料"], label="コーポレートガバナンス会議資料",value="コーポレートガバナンス会議資料")],
239
+ outputs=[gr.Textbox(label="結果"),gr.Textbox(label="Usageデータとツール結果")],
240
+ title="資料の分析",
241
+ description="プロンプトを入力してデータを取得し、内容を分析します。",
242
+ submit_btn="実行",
243
+ clear_btn="クリア",
244
+ flagging_mode="never"
245
+ )
246
+
247
+ iface.launch()