mgokg commited on
Commit
d0e9a20
·
verified ·
1 Parent(s): 6d108cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -95
app.py CHANGED
@@ -1,19 +1,41 @@
1
- import os
2
  import gradio as gr
 
3
  from google import genai
4
  from google.genai import types
5
  from gradio_client import Client
6
 
7
- # 1. Initialize the client for the external DB Timetable App
8
- # We use the Hugging Face Space ID provided in your documentation
9
- db_client = Client("mgokg/db-timetable-api")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def get_train_connection(dep: str, dest: str):
12
  """
13
  Fetches the train timetable between two cities using the external API.
14
  """
 
 
 
 
 
 
 
15
  try:
16
- # Calling the specific endpoint mentioned in the MCP docs: db_timetable_api_ui_wrapper
17
  result = db_client.predict(
18
  dep=dep,
19
  dest=dest,
@@ -23,8 +45,7 @@ def get_train_connection(dep: str, dest: str):
23
  except Exception as e:
24
  return f"Error fetching timetable: {str(e)}"
25
 
26
- # 2. Define the tool for Gemini
27
- # This tells the model how to use the Python function above
28
  train_tool = types.FunctionDeclaration(
29
  name="get_train_connection",
30
  description="Find train connections and timetables between a start location (dep) and a destination (dest).",
@@ -38,130 +59,117 @@ train_tool = types.FunctionDeclaration(
38
  )
39
  )
40
 
41
- # Map the string name to the actual python function
42
  tools_map = {
43
  "get_train_connection": get_train_connection
44
  }
45
 
46
- def generate(input_text, history):
47
- # Initialize Gemini Client
 
 
 
 
 
48
  try:
49
  client = genai.Client(
50
  api_key=os.environ.get("GEMINI_API_KEY"),
51
  )
52
  except Exception as e:
53
- yield f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set.", history
54
  return
55
 
56
- model = "gemini-2.0-flash-exp" # Or "gemini-2.0-flash" depending on availability
57
 
58
- # Prepare the conversation history for context
59
- # (Optional: You can add previous history here if you want multi-turn chat)
60
- contents = [
61
- types.Content(
62
- role="user",
63
- parts=[types.Part.from_text(text=input_text)],
64
- ),
65
- ]
66
-
67
- # 3. Configure tools (Google Search + Our Custom DB Tool)
68
  tools = [
69
- types.Tool(google_search=types.GoogleSearch()),
70
- types.Tool(function_declarations=[train_tool]),
 
 
71
  ]
72
 
73
  generate_content_config = types.GenerateContentConfig(
74
  temperature=0.4,
75
  tools=tools,
76
- # Automatic function calling allows the SDK to handle the loop,
77
- # but for granular control in Gradio, we often handle it manually below
78
- # or rely on the model to return a function call part.
79
  )
80
 
 
 
 
 
 
 
 
81
  response_text = ""
82
 
83
- # First API Call: Ask the model what to do
84
  try:
 
85
  response = client.models.generate_content(
86
  model=model,
87
  contents=contents,
88
  config=generate_content_config,
89
  )
90
- except Exception as e:
91
- yield f"Error during generation: {e}", history
92
- return
93
 
94
- # 4. Check if the model wants to call a function
95
- # We look at the first candidate's first part
96
- if response.candidates and response.candidates[0].content.parts:
97
- first_part = response.candidates[0].content.parts[0]
98
-
99
- # If it's a function call
100
- if first_part.function_call:
101
- fn_name = first_part.function_call.name
102
- fn_args = first_part.function_call.args
103
 
104
- # Execute the tool
105
- if fn_name in tools_map:
106
- status_msg = f"🔄 Checking trains from {fn_args.get('dep')} to {fn_args.get('dest')}..."
107
- yield status_msg, history
108
-
109
- api_result = tools_map[fn_name](**fn_args)
110
 
111
- # Send the result back to Gemini
112
- # We append the model's function call and our function response to history
113
- contents.append(response.candidates[0].content)
114
- contents.append(
115
- types.Content(
116
- role="tool",
117
- parts=[
118
- types.Part.from_function_response(
119
- name=fn_name,
120
- response={"result": api_result}
121
- )
122
- ]
 
 
 
 
123
  )
124
- )
125
-
126
- # Second API Call: Get the final natural language answer
127
- stream = client.models.generate_content_stream(
128
- model=model,
129
- contents=contents,
130
- config=generate_content_config # Keep tools enabled just in case
131
- )
132
-
133
- final_text = ""
134
- for chunk in stream:
135
- if chunk.text:
136
- final_text += chunk.text
137
- yield final_text, history
138
- return
 
 
 
 
139
 
140
- # If no function call, just return the text (e.g., normal chat or Google Search result)
141
- if response.text:
142
- yield response.text, history
143
 
144
  if __name__ == '__main__':
145
  with gr.Blocks() as demo:
146
- gr.Markdown("# Gemini 2.0 Flash + DB Timetable Tool")
 
 
 
147
 
148
- chatbot = gr.Chatbot(label="Conversation", height=400)
149
- msg = gr.Textbox(lines=1, label="Ask about trains (e.g., 'Train from Berlin to Munich')", placeholder="Enter message here...")
150
- clear = gr.Button("Clear")
151
-
152
- def user(user_message, history):
153
- return "", history + [[user_message, None]]
154
-
155
- def bot(history):
156
- user_message = history[-1][0]
157
- # Call generate and update the last message in history
158
- for partial_response, _ in generate(user_message, history):
159
- history[-1][1] = partial_response
160
- yield history
161
-
162
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
163
- bot, chatbot, chatbot
164
  )
165
- clear.click(lambda: None, None, chatbot, queue=False)
166
-
167
  demo.launch(show_error=True)
 
 
1
  import gradio as gr
2
+ import os
3
  from google import genai
4
  from google.genai import types
5
  from gradio_client import Client
6
 
7
+ # --- 1. Robust Client Initialization ---
8
+ # We initialize this globally but handle errors so the app doesn't crash on startup.
9
+ db_client = None
10
+
11
+ def init_db_client():
12
+ global db_client
13
+ try:
14
+ # Use the DIRECT URL to avoid DNS resolution issues with the Hub API
15
+ print("Connecting to DB Timetable API...")
16
+ db_client = Client("https://mgokg-db-timetable-api.hf.space/")
17
+ print("Successfully connected.")
18
+ except Exception as e:
19
+ print(f"Warning: Could not connect to DB Timetable API: {e}")
20
+
21
+ # Attempt connection on startup
22
+ init_db_client()
23
+
24
+ # --- 2. Tool Definition ---
25
 
26
  def get_train_connection(dep: str, dest: str):
27
  """
28
  Fetches the train timetable between two cities using the external API.
29
  """
30
+ global db_client
31
+ # If client failed to load initially, try one more time
32
+ if db_client is None:
33
+ init_db_client()
34
+ if db_client is None:
35
+ return "Error: The train database is currently unreachable. Please check your network connection."
36
+
37
  try:
38
+ # Calling the specific endpoint mentioned in the MCP docs
39
  result = db_client.predict(
40
  dep=dep,
41
  dest=dest,
 
45
  except Exception as e:
46
  return f"Error fetching timetable: {str(e)}"
47
 
48
+ # Define the tool schema for Gemini
 
49
  train_tool = types.FunctionDeclaration(
50
  name="get_train_connection",
51
  description="Find train connections and timetables between a start location (dep) and a destination (dest).",
 
59
  )
60
  )
61
 
62
+ # Map string name to the actual function
63
  tools_map = {
64
  "get_train_connection": get_train_connection
65
  }
66
 
67
+ # --- 3. Generation Logic ---
68
+
69
+ def generate(input_text):
70
+ if not input_text:
71
+ yield "", ""
72
+ return
73
+
74
  try:
75
  client = genai.Client(
76
  api_key=os.environ.get("GEMINI_API_KEY"),
77
  )
78
  except Exception as e:
79
+ yield f"Error initializing client: {e}. Make sure GEMINI_API_KEY is set.", input_text
80
  return
81
 
82
+ model = "gemini-2.0-flash-exp" # Ensure you use a model version that supports tools
83
 
84
+ # Configure tools (Google Search + Our Custom DB Tool)
 
 
 
 
 
 
 
 
 
85
  tools = [
86
+ types.Tool(
87
+ google_search=types.GoogleSearch(),
88
+ function_declarations=[train_tool]
89
+ )
90
  ]
91
 
92
  generate_content_config = types.GenerateContentConfig(
93
  temperature=0.4,
94
  tools=tools,
 
 
 
95
  )
96
 
97
+ contents = [
98
+ types.Content(
99
+ role="user",
100
+ parts=[types.Part.from_text(text=input_text)],
101
+ ),
102
+ ]
103
+
104
  response_text = ""
105
 
 
106
  try:
107
+ # First API Call: Ask Gemini what to do
108
  response = client.models.generate_content(
109
  model=model,
110
  contents=contents,
111
  config=generate_content_config,
112
  )
 
 
 
113
 
114
+ # Check if Gemini wants to call a function
115
+ if response.candidates and response.candidates[0].content.parts:
116
+ part = response.candidates[0].content.parts[0]
 
 
 
 
 
 
117
 
118
+ # If it's a function call
119
+ if part.function_call:
120
+ fn_name = part.function_call.name
121
+ fn_args = part.function_call.args
 
 
122
 
123
+ if fn_name in tools_map:
124
+ # 1. Execute the Python function (calls the external Gradio app)
125
+ api_result = tools_map[fn_name](**fn_args)
126
+
127
+ # 2. Feed the result back to Gemini
128
+ contents.append(response.candidates[0].content) # Add the model's call to history
129
+ contents.append(
130
+ types.Content(
131
+ role="tool",
132
+ parts=[
133
+ types.Part.from_function_response(
134
+ name=fn_name,
135
+ response={"result": api_result}
136
+ )
137
+ ]
138
+ )
139
  )
140
+
141
+ # 3. Get the final natural language answer
142
+ stream = client.models.generate_content_stream(
143
+ model=model,
144
+ contents=contents,
145
+ config=generate_content_config
146
+ )
147
+
148
+ for chunk in stream:
149
+ response_text += chunk.text
150
+ yield response_text, ""
151
+ return
152
+
153
+ # If no function call, just return the text (e.g. normal chat or Google Search)
154
+ if response.text:
155
+ yield response.text, ""
156
+
157
+ except Exception as e:
158
+ yield f"Error during generation: {e}", input_text
159
 
160
+ # --- 4. UI Setup ---
 
 
161
 
162
  if __name__ == '__main__':
163
  with gr.Blocks() as demo:
164
+ title = gr.Markdown("# Gemini 2.0 Flash + DB Timetable Tool")
165
+ output_textbox = gr.Markdown()
166
+ input_textbox = gr.Textbox(lines=3, label="", placeholder="Ask for a train connection (e.g., 'Train from Berlin to Frankfurt')...")
167
+ submit_button = gr.Button("Send")
168
 
169
+ submit_button.click(
170
+ fn=generate,
171
+ inputs=input_textbox,
172
+ outputs=[output_textbox, input_textbox]
 
 
 
 
 
 
 
 
 
 
 
 
173
  )
174
+
 
175
  demo.launch(show_error=True)