Suchinthana commited on
Commit
ecb6154
·
1 Parent(s): 922274d

Logging, prompt, pipe update

Browse files
Files changed (1) hide show
  1. app.py +406 -181
app.py CHANGED
@@ -12,23 +12,43 @@ from diffusers import StableDiffusionInpaintPipeline
12
  import spaces
13
  import logging
14
  import math
15
- from typing import List, Union
16
 
17
  # Set up logging
18
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
19
  logger = logging.getLogger(__name__)
20
 
 
 
21
  # Initialize APIs
22
- openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
23
- geolocator = Nominatim(user_agent="geoapi")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Function to fetch coordinates
26
  @spaces.GPU
27
  def get_geo_coordinates(location_name):
 
28
  try:
29
- location = geolocator.geocode(location_name)
30
  if location:
 
31
  return [location.longitude, location.latitude]
 
32
  return None
33
  except Exception as e:
34
  logger.error(f"Error fetching coordinates for {location_name}: {e}")
@@ -37,12 +57,14 @@ def get_geo_coordinates(location_name):
37
  # Function to process OpenAI chat response
38
  @spaces.GPU
39
  def process_openai_response(query):
40
- response = openai_client.chat.completions.create(
41
- model="gpt-4o-mini",
42
- messages=[
43
- {
44
- "role": "system",
45
- "content": """
 
 
46
  You are an assistant that generates structured JSON output for geographical queries with city names. Your task is to generate a JSON object containing information about geographical features and their representation based on the user's query. Follow these rules:
47
 
48
  1. The JSON should always have the following structure:
@@ -93,207 +115,410 @@ You are an assistant that generates structured JSON output for geographical quer
93
 
94
  Generate similar JSON for the following query:
95
  """
96
- },
97
- {
98
- "role": "user",
99
- "content": query
100
- }
101
- ],
102
- temperature=1,
103
- max_tokens=2048,
104
- top_p=1,
105
- frequency_penalty=0,
106
- presence_penalty=0,
107
- response_format={"type": "json_object"}
108
- )
109
- return json.loads(response.choices[0].message.content)
 
 
 
 
 
 
 
 
110
 
111
  # Generate GeoJSON from OpenAI response
112
  @spaces.GPU
113
- def generate_geojson(response):
114
- logger.info(f"OpenAI response: {response}")
115
- feature_type = response['output']['feature_representation']['type']
116
- city_names = response['output']['feature_representation']['cities']
117
- properties = response['output']['feature_representation']['properties']
118
-
119
- coordinates = []
120
 
121
- # Fetch coordinates for cities
122
- for city in city_names:
123
- try:
124
  coord = get_geo_coordinates(city)
125
  if coord:
126
  coordinates.append(coord)
127
  else:
128
- logger.warning(f"Coordinates not found for city: {city}")
129
- except Exception as e:
130
- logger.error(f"Error fetching coordinates for {city}: {e}")
131
-
132
- if feature_type == "Polygon":
133
- if len(coordinates) < 3:
134
- raise ValueError("Polygon requires at least 3 coordinates.")
135
- # Close the polygon by appending the first point at the end
136
- coordinates.append(coordinates[0])
137
- coordinates = [coordinates] # Nest coordinates for Polygon
138
-
139
- # Create the GeoJSON object
140
- geojson_data = {
141
- "type": "FeatureCollection",
142
- "features": [
143
- {
144
- "type": "Feature",
145
- "properties": properties,
146
- "geometry": {
147
- "type": feature_type,
148
- "coordinates": coordinates,
149
- },
150
- }
151
- ],
152
- }
153
-
154
- return geojson_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # Sort coordinates for a simple polygon (Reduce intersection points)
157
  def sort_coordinates_for_simple_polygon(geojson):
158
- # Extract coordinates from the GeoJSON
159
- coordinates = geojson['features'][0]['geometry']['coordinates'][0]
160
-
161
- # Remove the last point if it duplicates the first (GeoJSON convention for polygons)
162
- if coordinates[0] == coordinates[-1]:
163
- coordinates = coordinates[:-1]
164
-
165
- # Calculate the centroid of the points
166
- centroid_x = sum(point[0] for point in coordinates) / len(coordinates)
167
- centroid_y = sum(point[1] for point in coordinates) / len(coordinates)
168
-
169
- # Define a function to calculate the angle relative to the centroid
170
- def angle_from_centroid(point):
171
- dx = point[0] - centroid_x
172
- dy = point[1] - centroid_y
173
- return math.atan2(dy, dx)
174
-
175
- # Sort points by their angle from the centroid
176
- sorted_coordinates = sorted(coordinates, key=angle_from_centroid)
177
-
178
- # Close the polygon by appending the first point to the end
179
- sorted_coordinates.append(sorted_coordinates[0])
180
-
181
- # Update the GeoJSON with sorted coordinates
182
- geojson['features'][0]['geometry']['coordinates'][0] = sorted_coordinates
183
-
184
- return geojson
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  # Generate static map image
187
  @spaces.GPU
188
  def generate_static_map(geojson_data, invisible=False):
189
- m = StaticMap(600, 600)
190
- logger.info(f"GeoJSON data: {geojson_data}")
191
-
192
- for feature in geojson_data["features"]:
193
- geom_type = feature["geometry"]["type"]
194
- coords = feature["geometry"]["coordinates"]
195
-
196
- if geom_type == "Point":
197
- m.add_marker(CircleMarker((coords[0][0], coords[0][1]), '#1C00ff00' if invisible else '#42445A85', 100))
198
- elif geom_type in ["MultiPoint", "LineString"]:
199
- for coord in coords:
200
- m.add_marker(CircleMarker((coord[0], coord[1]), '#1C00ff00' if invisible else '#42445A85', 100))
201
- elif geom_type in ["Polygon", "MultiPolygon"]:
202
- for polygon in coords:
203
- m.add_polygon(Polygon([(c[0], c[1]) for c in polygon], '#1C00ff00' if invisible else '#42445A85', 3))
204
-
205
- return m.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  # ControlNet pipeline setup
208
- # controlnet = ControlNetModel.from_pretrained("stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16)
209
- # pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
210
- # "stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
211
- # )
212
- # pipeline.to('cuda')
213
-
214
- pipeline = StableDiffusionInpaintPipeline.from_pretrained(
215
- "stabilityai/stable-diffusion-2-inpainting",
216
- torch_dtype=torch.float16,
217
- )
218
- pipeline.to("cuda")
219
-
 
 
 
 
 
 
220
  @spaces.GPU
221
- def make_inpaint_condition(init_image, mask_image):
222
- init_image = np.array(init_image.convert("RGB")).astype(np.float32) / 255.0
223
- mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- assert init_image.shape[0:1] == mask_image.shape[0:1], "image and image_mask must have the same image size"
226
- init_image[mask_image > 0.5] = -1.0 # set as masked pixel
227
- init_image = np.expand_dims(init_image, 0).transpose(0, 3, 1, 2)
228
- init_image = torch.from_numpy(init_image)
229
- return init_image
230
 
231
  @spaces.GPU
232
- def generate_satellite_image(init_image, mask_image, prompt):
233
- control_image = make_inpaint_condition(init_image, mask_image)
234
- result = pipeline(
235
- prompt=prompt,
236
- image=init_image,
237
- mask_image=mask_image,
238
- control_image=control_image,
239
- strength=0.47,
240
- guidance_scale=95,
241
- num_inference_steps=250
242
- )
243
- return result.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
  # Gradio UI
246
  @spaces.GPU
247
- def handle_query(query):
248
- response = process_openai_response(query)
249
- geojson_data = generate_geojson(response)
250
-
251
- if geojson_data["features"][0]["geometry"]["type"] == 'Polygon':
252
- geojson_data_coords = sort_coordinates_for_simple_polygon(geojson_data)
253
- map_image = generate_static_map(geojson_data_coords)
254
- else:
255
- map_image = generate_static_map(geojson_data)
256
- empty_map_image = generate_static_map(geojson_data, invisible=True)
257
-
258
- difference = np.abs(np.array(map_image.convert("RGB")) - np.array(empty_map_image.convert("RGB")))
259
- threshold = 10
260
- mask = (np.sum(difference, axis=-1) > threshold).astype(np.uint8) * 255
261
-
262
- mask_image = Image.fromarray(mask, mode="L")
263
- satellite_image = generate_satellite_image(
264
- empty_map_image, mask_image, response['output']['feature_representation']['properties']['description']
265
- )
266
-
267
- #return map_image, satellite_image, empty_map_image, mask_image, response
268
- return map_image, satellite_image, empty_map_image, mask_image, response['output']['feature_representation']['properties']['description']
269
-
270
- def update_query(selected_query):
271
- return selected_query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  query_options = [
274
  "Area covering south asian subcontinent",
275
- "Mark a triangular area using New York, Boston, and Texas",
276
  "Mark cities in India",
277
  "Show me Lotus Tower in a Map",
278
  "Mark the area of west germany",
279
  "Mark the area of the Amazon rainforest",
280
  "Mark the area of the Sahara desert"
281
- ]
282
-
283
- with gr.Blocks() as demo:
284
- with gr.Row():
285
- selected_query = gr.Dropdown(label="Select Query", choices=query_options, value=query_options[-1])
286
- query_input = gr.Textbox(label="Enter Query", value=query_options[-1])
287
- selected_query.change(update_query, inputs=selected_query, outputs=query_input)
288
- submit_btn = gr.Button("Submit")
289
- with gr.Row():
290
- map_output = gr.Image(label="Map Visualization")
291
- satellite_output = gr.Image(label="Generated Map Image")
292
- with gr.Row():
293
- empty_map_output = gr.Image(label="Empty Visualization")
294
- mask_output = gr.Image(label="Mask")
295
- image_prompt = gr.Textbox(label="Image Prompt Used")
296
- submit_btn.click(handle_query, inputs=[query_input], outputs=[map_output, satellite_output, empty_map_output, mask_output, image_prompt])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  if __name__ == "__main__":
299
- demo.launch()
 
 
 
 
 
 
 
12
  import spaces
13
  import logging
14
  import math
15
+ from typing import List, Union # Make sure these are actually used or remove them
16
 
17
  # Set up logging
18
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
19
  logger = logging.getLogger(__name__)
20
 
21
+ logger.info("Script starting. Initializing APIs and models.")
22
+
23
  # Initialize APIs
24
+ try:
25
+ openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
26
+ logger.info("OpenAI client initialized.")
27
+ except KeyError:
28
+ logger.error("OPENAI_API_KEY environment variable not set!")
29
+ # Handle this critical error, perhaps exit or raise
30
+ raise
31
+ except Exception as e:
32
+ logger.error(f"Error initializing OpenAI client: {e}")
33
+ raise
34
+
35
+ try:
36
+ geolocator = Nominatim(user_agent="geoapi_visualizemap") # More specific user agent
37
+ logger.info("Geolocator initialized.")
38
+ except Exception as e:
39
+ logger.error(f"Error initializing Geolocator: {e}")
40
+ raise
41
 
42
  # Function to fetch coordinates
43
  @spaces.GPU
44
  def get_geo_coordinates(location_name):
45
+ logger.info(f"Attempting to fetch coordinates for: {location_name}")
46
  try:
47
+ location = geolocator.geocode(location_name, timeout=10) # Added timeout
48
  if location:
49
+ logger.info(f"Coordinates found for {location_name}: {[location.longitude, location.latitude]}")
50
  return [location.longitude, location.latitude]
51
+ logger.warning(f"No location data returned for {location_name}")
52
  return None
53
  except Exception as e:
54
  logger.error(f"Error fetching coordinates for {location_name}: {e}")
 
57
  # Function to process OpenAI chat response
58
  @spaces.GPU
59
  def process_openai_response(query):
60
+ logger.info(f"Processing OpenAI query: {query}")
61
+ try:
62
+ response = openai_client.chat.completions.create(
63
+ model="gpt-4o-mini",
64
+ messages=[
65
+ {
66
+ "role": "system",
67
+ "content": """
68
  You are an assistant that generates structured JSON output for geographical queries with city names. Your task is to generate a JSON object containing information about geographical features and their representation based on the user's query. Follow these rules:
69
 
70
  1. The JSON should always have the following structure:
 
115
 
116
  Generate similar JSON for the following query:
117
  """
118
+ },
119
+ {
120
+ "role": "user",
121
+ "content": query
122
+ }
123
+ ],
124
+ temperature=1,
125
+ max_tokens=2048,
126
+ top_p=1,
127
+ frequency_penalty=0,
128
+ presence_penalty=0,
129
+ response_format={"type": "json_object"}
130
+ )
131
+ content = response.choices[0].message.content
132
+ logger.info(f"Raw OpenAI response content: {content}")
133
+ parsed_response = json.loads(content)
134
+ logger.info(f"Parsed OpenAI response: {json.dumps(parsed_response, indent=2)}")
135
+ return parsed_response
136
+ except Exception as e:
137
+ logger.error(f"Error processing OpenAI response for query '{query}': {e}")
138
+ # Consider returning a default error structure or re-raising
139
+ raise
140
 
141
  # Generate GeoJSON from OpenAI response
142
  @spaces.GPU
143
+ def generate_geojson(response_data): # Renamed to avoid confusion with http response
144
+ logger.info(f"Generating GeoJSON from OpenAI response_data: {json.dumps(response_data, indent=2)}")
145
+ try:
146
+ feature_type = response_data['output']['feature_representation']['type']
147
+ city_names = response_data['output']['feature_representation']['cities']
148
+ properties = response_data['output']['feature_representation']['properties']
149
+ logger.info(f"Feature type: {feature_type}, Cities: {city_names}")
150
 
151
+ coordinates = []
152
+ for city in city_names:
 
153
  coord = get_geo_coordinates(city)
154
  if coord:
155
  coordinates.append(coord)
156
  else:
157
+ logger.warning(f"Coordinates not found for city: {city}. Skipping.")
158
+
159
+ logger.info(f"Collected coordinates: {coordinates}")
160
+
161
+ # Ensure coordinates has the correct structure for each geometry type
162
+ if feature_type == "Point":
163
+ if not coordinates:
164
+ raise ValueError("Point type requires at least one coordinate.")
165
+ # GeoJSON Point expects a single coordinate pair, not a list of pairs
166
+ final_coordinates = coordinates[0] if coordinates else []
167
+ elif feature_type == "MultiPoint":
168
+ final_coordinates = coordinates # List of coordinate pairs
169
+ elif feature_type == "LineString":
170
+ if len(coordinates) < 2:
171
+ raise ValueError("LineString requires at least 2 coordinates.")
172
+ final_coordinates = coordinates # List of coordinate pairs
173
+ elif feature_type == "Polygon":
174
+ if len(coordinates) < 3:
175
+ raise ValueError("Polygon requires at least 3 coordinates.")
176
+ # Close the polygon by appending the first point at the end
177
+ if coordinates[0] != coordinates[-1]: # Check if already closed
178
+ coordinates.append(coordinates[0])
179
+ final_coordinates = [coordinates] # Nest coordinates for Polygon
180
+ else: # MultiLineString, MultiPolygon, GeometryCollection
181
+ logger.warning(f"Unsupported or complex feature_type: {feature_type}. Using raw coordinates.")
182
+ final_coordinates = coordinates # Or handle more specifically
183
+
184
+ geojson_data = {
185
+ "type": "FeatureCollection",
186
+ "features": [
187
+ {
188
+ "type": "Feature",
189
+ "properties": properties,
190
+ "geometry": {
191
+ "type": feature_type,
192
+ "coordinates": final_coordinates,
193
+ },
194
+ }
195
+ ],
196
+ }
197
+ logger.info(f"Generated GeoJSON: {json.dumps(geojson_data, indent=2)}")
198
+ return geojson_data
199
+ except KeyError as e:
200
+ logger.error(f"KeyError while generating GeoJSON: {e}. Response data: {json.dumps(response_data, indent=2)}")
201
+ raise
202
+ except ValueError as e:
203
+ logger.error(f"ValueError while generating GeoJSON: {e}. Coordinates: {coordinates if 'coordinates' in locals() else 'N/A'}")
204
+ raise
205
+ except Exception as e:
206
+ logger.error(f"Unexpected error in generate_geojson: {e}")
207
+ raise
208
 
209
  # Sort coordinates for a simple polygon (Reduce intersection points)
210
  def sort_coordinates_for_simple_polygon(geojson):
211
+ logger.info("Attempting to sort polygon coordinates.")
212
+ try:
213
+ coordinates = geojson['features'][0]['geometry']['coordinates'][0]
214
+ logger.info(f"Original polygon coordinates: {coordinates}")
215
+
216
+ if not coordinates or len(coordinates) < 3:
217
+ logger.warning("Not enough coordinates to sort for a polygon.")
218
+ return geojson
219
+
220
+ # Remove the last point if it duplicates the first (GeoJSON convention for polygons)
221
+ if coordinates[0] == coordinates[-1] and len(coordinates) > 1:
222
+ plot_coordinates = coordinates[:-1]
223
+ else:
224
+ plot_coordinates = coordinates
225
+
226
+ if not plot_coordinates or len(plot_coordinates) < 3: # Check again after potentially removing last point
227
+ logger.warning("Not enough unique coordinates to sort for a polygon after de-duplication.")
228
+ return geojson
229
+
230
+ # Calculate the centroid of the points
231
+ centroid_x = sum(point[0] for point in plot_coordinates) / len(plot_coordinates)
232
+ centroid_y = sum(point[1] for point in plot_coordinates) / len(plot_coordinates)
233
+ logger.info(f"Calculated centroid: ({centroid_x}, {centroid_y})")
234
+
235
+ def angle_from_centroid(point):
236
+ dx = point[0] - centroid_x
237
+ dy = point[1] - centroid_y
238
+ return math.atan2(dy, dx)
239
+
240
+ sorted_plot_coordinates = sorted(plot_coordinates, key=angle_from_centroid)
241
+ sorted_plot_coordinates.append(sorted_plot_coordinates[0]) # Close the polygon
242
+
243
+ geojson['features'][0]['geometry']['coordinates'][0] = sorted_plot_coordinates
244
+ logger.info(f"Sorted polygon coordinates: {sorted_plot_coordinates}")
245
+ return geojson
246
+ except Exception as e:
247
+ logger.error(f"Error sorting polygon coordinates: {e}")
248
+ return geojson # Return original on error
249
 
250
  # Generate static map image
251
  @spaces.GPU
252
  def generate_static_map(geojson_data, invisible=False):
253
+ logger.info(f"Generating static map. Invisible: {invisible}. GeoJSON: {json.dumps(geojson_data, indent=2)}")
254
+ try:
255
+ m = StaticMap(600, 600)
256
+ color = '#1C00ff00' if invisible else '#42445A85' # Transparent if invisible, else semi-transparent blue/grey
257
+
258
+ for feature in geojson_data["features"]:
259
+ geom_type = feature["geometry"]["type"]
260
+ coords = feature["geometry"]["coordinates"]
261
+ logger.info(f"Processing feature type: {geom_type} with coords: {coords}")
262
+
263
+ if geom_type == "Point":
264
+ # Coords for Point is a single [lon, lat]
265
+ if coords and len(coords) == 2 and isinstance(coords[0], (int, float)):
266
+ m.add_marker(CircleMarker((coords[0], coords[1]), color, 20 if invisible else 10)) # Adjusted size
267
+ else:
268
+ logger.warning(f"Skipping Point due to invalid coordinate structure: {coords}")
269
+ elif geom_type == "MultiPoint":
270
+ # Coords for MultiPoint is a list of [lon, lat]
271
+ for coord_pair in coords:
272
+ if coord_pair and len(coord_pair) == 2 and isinstance(coord_pair[0], (int, float)):
273
+ m.add_marker(CircleMarker((coord_pair[0], coord_pair[1]), color, 20 if invisible else 10))
274
+ else:
275
+ logger.warning(f"Skipping point in MultiPoint due to invalid coordinate structure: {coord_pair}")
276
+ elif geom_type == "LineString":
277
+ # Coords for LineString is a list of [lon, lat]
278
+ if len(coords) >=2:
279
+ m.add_line(Polygon([(c[0], c[1]) for c in coords], "blue", 3)) # For LineString, use add_line or thicker Polygon outline
280
+ else:
281
+ logger.warning(f"Skipping LineString, not enough points: {coords}")
282
+ elif geom_type == "Polygon":
283
+ # Coords for Polygon is a list containing one list of [lon, lat] (the exterior ring)
284
+ for polygon_ring in coords: # Should be only one for simple polygon
285
+ if len(polygon_ring) >= 3:
286
+ m.add_polygon(Polygon([(c[0], c[1]) for c in polygon_ring], color, '#0000AA' if not invisible else '#1C00ff00', 3 if not invisible else 0))
287
+ else:
288
+ logger.warning(f"Skipping polygon ring, not enough points: {polygon_ring}")
289
+ # Add handling for MultiLineString, MultiPolygon if your OpenAI might produce them
290
+ else:
291
+ logger.warning(f"Unsupported geometry type for static map: {geom_type}")
292
+
293
+ rendered_map = m.render(center=None, zoom=None) # Let it auto-center and zoom
294
+ logger.info(f"Static map rendered successfully. Invisible: {invisible}")
295
+ return rendered_map
296
+ except Exception as e:
297
+ logger.error(f"Error generating static map (invisible={invisible}): {e}")
298
+ # Return a placeholder or re-raise
299
+ return Image.new("RGB", (600, 600), color="grey") # Placeholder
300
 
301
  # ControlNet pipeline setup
302
+ logger.info("Initializing Stable Diffusion Inpaint Pipeline.")
303
+ try:
304
+ # controlnet = ControlNetModel.from_pretrained("stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16)
305
+ # pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
306
+ # "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16 # Changed base model
307
+ # )
308
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
309
+ "stabilityai/stable-diffusion-2-inpainting", # This is a full inpainting pipeline, not just a controlnet
310
+ torch_dtype=torch.float16,
311
+ )
312
+ pipeline.to("cuda")
313
+ logger.info("Stable Diffusion Inpaint Pipeline loaded to CUDA.")
314
+ except Exception as e:
315
+ logger.error(f"Error initializing Stable Diffusion pipeline: {e}")
316
+ raise
317
+
318
+ # This function was for ControlNet, may not be needed as-is for StableDiffusionInpaintPipeline
319
+ # It expects init_image to be a NumPy array, and mask_image a NumPy array
320
  @spaces.GPU
321
+ def make_inpaint_condition(init_image_pil, mask_image_pil):
322
+ logger.info("Preparing inpaint condition (ControlNet specific, may need adjustment).")
323
+ # Ensure PIL Images are converted to NumPy arrays correctly
324
+ init_image_np = np.array(init_image_pil.convert("RGB")).astype(np.float32) / 255.0
325
+ mask_image_np = np.array(mask_image_pil.convert("L")).astype(np.float32) / 255.0 # Ensure mask is L
326
+
327
+ logger.info(f"Init image shape: {init_image_np.shape}, Mask image shape: {mask_image_np.shape}")
328
+
329
+ if init_image_np.shape[:2] != mask_image_np.shape[:2]:
330
+ logger.error(f"Image and mask dimensions mismatch: {init_image_np.shape[:2]} vs {mask_image_np.shape[:2]}")
331
+ # Resize mask to match image if necessary, or raise error
332
+ # For now, let's assume they should match and this is an error state
333
+ raise ValueError("Image and mask_image must have the same height and width.")
334
+
335
+ # This operation is specific to how some ControlNet inpainting expects masked areas.
336
+ # Standard SDInpaintPipeline might not need this.
337
+ # init_image_np[mask_image_np > 0.5] = -1.0 # set as masked pixel
338
+
339
+ # init_image_np = np.expand_dims(init_image_np, 0).transpose(0, 3, 1, 2)
340
+ # init_image_tensor = torch.from_numpy(init_image_np)
341
+ # logger.info(f"Processed init_image tensor shape: {init_image_tensor.shape}")
342
+ # return init_image_tensor
343
+
344
+ # For StableDiffusionInpaintPipeline, `image` and `mask_image` are passed directly as PIL Images or tensors.
345
+ # The `make_inpaint_condition` might be redundant if you are not using a ControlNet that specifically requires this format.
346
+ # If you were using ControlNet, this would be the control_image.
347
+ # For now, let's assume it's meant to be the 'image' input for SD Inpaint, preprocessed.
348
+ return init_image_pil # Or init_image_tensor if pipeline expects tensor
349
 
 
 
 
 
 
350
 
351
  @spaces.GPU
352
+ def generate_satellite_image(base_image_pil, mask_image_pil, prompt):
353
+ logger.info(f"Generating satellite image with prompt: '{prompt}'")
354
+ logger.info(f"Base image type: {type(base_image_pil)}, Mask image type: {type(mask_image_pil)}")
355
+
356
+ try:
357
+ # StableDiffusionInpaintPipeline expects PIL Images or tensors for image and mask_image
358
+ # The `control_image` argument is not standard for StableDiffusionInpaintPipeline.
359
+ # It's specific to StableDiffusionControlNetInpaintPipeline.
360
+
361
+ # If you were using the ControlNet variant:
362
+ # control_image_tensor = make_inpaint_condition(base_image_pil, mask_image_pil)
363
+ # result = pipeline(
364
+ # prompt=prompt,
365
+ # image=base_image_pil, # or tensor version if pipeline prefers
366
+ # mask_image=mask_image_pil, # or tensor version
367
+ # control_image=control_image_tensor, # This is for ControlNet
368
+ # strength=0.47, # strength might be called differently or not used in SD Inpaint
369
+ # guidance_scale=9.5, # Adjusted scale
370
+ # num_inference_steps=50 # Adjusted steps
371
+ # ).images[0]
372
+
373
+ # For StableDiffusionInpaintPipeline:
374
+ result = pipeline(
375
+ prompt=prompt,
376
+ image=base_image_pil, # PIL Image or PyTorch tensor
377
+ mask_image=mask_image_pil, # PIL Image or PyTorch tensor
378
+ guidance_scale=9.5, # More reasonable default
379
+ num_inference_steps=50 # More reasonable default
380
+ ).images[0]
381
+
382
+ logger.info("Satellite image generated successfully.")
383
+ return result
384
+ except Exception as e:
385
+ logger.error(f"Error generating satellite image: {e}")
386
+ return Image.new("RGB", base_image_pil.size, color="red") # Placeholder
387
 
388
  # Gradio UI
389
  @spaces.GPU
390
+ def handle_query(query: str):
391
+ logger.info(f"--- Handling query: {query} ---")
392
+ try:
393
+ openai_response = process_openai_response(query)
394
+ logger.info(f"handle_query: OpenAI response received: type={type(openai_response)}")
395
+
396
+ geojson_data = generate_geojson(openai_response)
397
+ logger.info(f"handle_query: GeoJSON data generated: type={type(geojson_data)}")
398
+
399
+ processed_geojson_data = geojson_data
400
+ if geojson_data["features"][0]["geometry"]["type"] == 'Polygon':
401
+ logger.info("handle_query: Detected Polygon, attempting to sort coordinates.")
402
+ processed_geojson_data = sort_coordinates_for_simple_polygon(geojson_data)
403
+
404
+ map_image = generate_static_map(processed_geojson_data, invisible=False)
405
+ logger.info(f"handle_query: Visible map_image generated: type={type(map_image)}")
406
+
407
+ empty_map_image = generate_static_map(processed_geojson_data, invisible=True) # Use processed_geojson_data here too
408
+ logger.info(f"handle_query: Invisible empty_map_image generated: type={type(empty_map_image)}")
409
+
410
+ # Ensure images are PIL for diff
411
+ map_array = np.array(map_image.convert("RGB"))
412
+ empty_map_array = np.array(empty_map_image.convert("RGB"))
413
+
414
+ difference = np.abs(map_array - empty_map_array)
415
+ threshold = 10 # May need adjustment
416
+ mask_array = (np.sum(difference, axis=-1) > threshold).astype(np.uint8) * 255
417
+ mask_image = Image.fromarray(mask_array, mode="L")
418
+ logger.info(f"handle_query: Mask image generated: type={type(mask_image)}")
419
+
420
+ prompt_for_image = openai_response['output']['feature_representation']['properties']['description']
421
+ logger.info(f"handle_query: Prompt for satellite image: '{prompt_for_image}', type={type(prompt_for_image)}")
422
+
423
+ # Pass empty_map_image (which is the base map without visible markers)
424
+ # and the derived mask_image to the inpainting function
425
+ satellite_image = generate_satellite_image(
426
+ empty_map_image, mask_image, prompt_for_image
427
+ )
428
+ logger.info(f"handle_query: Satellite image generated: type={type(satellite_image)}")
429
+
430
+ # Ensure all returned image types are PIL Images
431
+ final_map_image = map_image if isinstance(map_image, Image.Image) else Image.new("RGB", (600,600), "grey")
432
+ final_satellite_image = satellite_image if isinstance(satellite_image, Image.Image) else Image.new("RGB", (600,600), "red")
433
+ final_empty_map_image = empty_map_image if isinstance(empty_map_image, Image.Image) else Image.new("RGB", (600,600), "grey")
434
+ final_mask_image = mask_image if isinstance(mask_image, Image.Image) else Image.new("L", (600,600), 0)
435
+
436
+ logger.info(f"handle_query: Returning types: {type(final_map_image)}, {type(final_satellite_image)}, {type(final_empty_map_image)}, {type(final_mask_image)}, {type(prompt_for_image)}")
437
+ return final_map_image, final_satellite_image, final_empty_map_image, final_mask_image, prompt_for_image
438
 
439
+ except Exception as e:
440
+ logger.error(f"--- Error in handle_query for query '{query}': {e} ---", exc_info=True)
441
+ # Return placeholder/error images and message
442
+ error_img = Image.new("RGB", (600, 600), "black")
443
+ error_text_img = ImageDraw.Draw(error_img)
444
+ error_text_img.text((10,10), f"Error: {e}", fill="white")
445
+ return error_img, error_img, error_img, error_img, f"Error processing query: {e}"
446
+
447
+ def update_query(selected_query_value: str) -> str: # Added type hints
448
+ logger.info(f"Dropdown changed. Selected query: '{selected_query_value}', type: {type(selected_query_value)}")
449
+ return selected_query_value
450
+
451
+ logger.info("Defining Gradio UI components.")
452
  query_options = [
453
  "Area covering south asian subcontinent",
454
+ "Mark a triangular area using New York, Boston, and Texas", # Texas is a state, might cause issues with geocoding as a city point
455
  "Mark cities in India",
456
  "Show me Lotus Tower in a Map",
457
  "Mark the area of west germany",
458
  "Mark the area of the Amazon rainforest",
459
  "Mark the area of the Sahara desert"
460
+ ]
461
+ logger.info(f"Query options: {query_options}")
462
+
463
+ # It's crucial that the `value` parameters for components are of the type Gradio expects
464
+ # for their schema generation, even before any function is called.
465
+ # For gr.Textbox, `value` should be a string.
466
+ # For gr.Dropdown, `value` should be one of the `choices` or None.
467
+
468
+ try:
469
+ with gr.Blocks() as demo:
470
+ logger.info("Inside gr.Blocks() context manager.")
471
+ with gr.Row():
472
+ logger.info("Defining first gr.Row.")
473
+ selected_query = gr.Dropdown(label="Select Query", choices=query_options, value=query_options[-1], type="value") # Ensure type="value" if not default
474
+ logger.info(f"selected_query Dropdown defined. Initial value: '{query_options[-1]}', type: {type(query_options[-1])}")
475
+
476
+ query_input = gr.Textbox(label="Enter Query", value=str(query_options[-1])) # Ensure value is string
477
+ logger.info(f"query_input Textbox defined. Initial value: '{query_options[-1]}', type: {type(query_options[-1])}")
478
+
479
+ # The `change` event should not cause the schema error, but good to log
480
+ selected_query.change(fn=update_query, inputs=selected_query, outputs=query_input)
481
+ logger.info("selected_query.change event defined.")
482
+
483
+ submit_btn = gr.Button("Submit")
484
+ logger.info("submit_btn Button defined.")
485
+
486
+ with gr.Row():
487
+ logger.info("Defining second gr.Row for image outputs.")
488
+ map_output = gr.Image(label="Map Visualization") # No initial value needed here, will be populated by function
489
+ logger.info("map_output Image defined.")
490
+ satellite_output = gr.Image(label="Generated Map Image")
491
+ logger.info("satellite_output Image defined.")
492
+
493
+ with gr.Row():
494
+ logger.info("Defining third gr.Row for debug outputs.")
495
+ empty_map_output = gr.Image(label="Empty Visualization")
496
+ logger.info("empty_map_output Image defined.")
497
+ mask_output = gr.Image(label="Mask")
498
+ logger.info("mask_output Image defined.")
499
+ # For image_prompt, provide a default string value or None. An empty string is fine.
500
+ image_prompt_output = gr.Textbox(label="Image Prompt Used", value="") # Changed name to avoid conflict, ensure string value
501
+ logger.info(f"image_prompt_output Textbox defined. Initial value: '', type: str")
502
+
503
+ # The outputs list must match the number and expected types of what handle_query returns.
504
+ # handle_query returns: PIL.Image, PIL.Image, PIL.Image, PIL.Image, str
505
+ # Gradio components: gr.Image, gr.Image, gr.Image, gr.Image, gr.Textbox
506
+ # This mapping looks correct.
507
+ submit_btn.click(fn=handle_query,
508
+ inputs=[query_input],
509
+ outputs=[map_output, satellite_output, empty_map_output, mask_output, image_prompt_output])
510
+ logger.info("submit_btn.click event defined.")
511
+ logger.info("Gradio Blocks defined successfully.")
512
+
513
+ except Exception as e:
514
+ logger.error(f"Error during Gradio UI definition: {e}", exc_info=True)
515
+ raise
516
 
517
  if __name__ == "__main__":
518
+ logger.info("Launching Gradio demo.")
519
+ try:
520
+ demo.launch() # debug=True can sometimes give more frontend info, but not for this backend error
521
+ logger.info("Gradio demo launched.")
522
+ except Exception as e:
523
+ logger.error(f"Error launching Gradio demo: {e}", exc_info=True)
524
+ raise