Spaces:
Running
on
Zero
Running
on
Zero
Suchinthana
commited on
Commit
·
ecb6154
1
Parent(s):
922274d
Logging, prompt, pipe update
Browse files
app.py
CHANGED
@@ -12,23 +12,43 @@ from diffusers import StableDiffusionInpaintPipeline
|
|
12 |
import spaces
|
13 |
import logging
|
14 |
import math
|
15 |
-
from typing import List, Union
|
16 |
|
17 |
# Set up logging
|
18 |
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
|
|
|
|
21 |
# Initialize APIs
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
# Function to fetch coordinates
|
26 |
@spaces.GPU
|
27 |
def get_geo_coordinates(location_name):
|
|
|
28 |
try:
|
29 |
-
location = geolocator.geocode(location_name)
|
30 |
if location:
|
|
|
31 |
return [location.longitude, location.latitude]
|
|
|
32 |
return None
|
33 |
except Exception as e:
|
34 |
logger.error(f"Error fetching coordinates for {location_name}: {e}")
|
@@ -37,12 +57,14 @@ def get_geo_coordinates(location_name):
|
|
37 |
# Function to process OpenAI chat response
|
38 |
@spaces.GPU
|
39 |
def process_openai_response(query):
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
46 |
You are an assistant that generates structured JSON output for geographical queries with city names. Your task is to generate a JSON object containing information about geographical features and their representation based on the user's query. Follow these rules:
|
47 |
|
48 |
1. The JSON should always have the following structure:
|
@@ -93,207 +115,410 @@ You are an assistant that generates structured JSON output for geographical quer
|
|
93 |
|
94 |
Generate similar JSON for the following query:
|
95 |
"""
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
# Generate GeoJSON from OpenAI response
|
112 |
@spaces.GPU
|
113 |
-
def generate_geojson(
|
114 |
-
logger.info(f"OpenAI
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
try:
|
124 |
coord = get_geo_coordinates(city)
|
125 |
if coord:
|
126 |
coordinates.append(coord)
|
127 |
else:
|
128 |
-
logger.warning(f"Coordinates not found for city: {city}")
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
if
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
"
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
# Sort coordinates for a simple polygon (Reduce intersection points)
|
157 |
def sort_coordinates_for_simple_polygon(geojson):
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
coordinates
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
# Generate static map image
|
187 |
@spaces.GPU
|
188 |
def generate_static_map(geojson_data, invisible=False):
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
# ControlNet pipeline setup
|
208 |
-
|
209 |
-
|
210 |
-
#
|
211 |
-
#
|
212 |
-
#
|
213 |
-
|
214 |
-
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
215 |
-
|
216 |
-
|
217 |
-
)
|
218 |
-
pipeline.to("cuda")
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
@spaces.GPU
|
221 |
-
def make_inpaint_condition(
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
-
assert init_image.shape[0:1] == mask_image.shape[0:1], "image and image_mask must have the same image size"
|
226 |
-
init_image[mask_image > 0.5] = -1.0 # set as masked pixel
|
227 |
-
init_image = np.expand_dims(init_image, 0).transpose(0, 3, 1, 2)
|
228 |
-
init_image = torch.from_numpy(init_image)
|
229 |
-
return init_image
|
230 |
|
231 |
@spaces.GPU
|
232 |
-
def generate_satellite_image(
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
mask_image
|
238 |
-
control_image
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
# Gradio UI
|
246 |
@spaces.GPU
|
247 |
-
def handle_query(query):
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
empty_map_image
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
query_options = [
|
274 |
"Area covering south asian subcontinent",
|
275 |
-
"Mark a triangular area using New York, Boston, and Texas",
|
276 |
"Mark cities in India",
|
277 |
"Show me Lotus Tower in a Map",
|
278 |
"Mark the area of west germany",
|
279 |
"Mark the area of the Amazon rainforest",
|
280 |
"Mark the area of the Sahara desert"
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
if __name__ == "__main__":
|
299 |
-
demo.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import spaces
|
13 |
import logging
|
14 |
import math
|
15 |
+
from typing import List, Union # Make sure these are actually used or remove them
|
16 |
|
17 |
# Set up logging
|
18 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
+
logger.info("Script starting. Initializing APIs and models.")
|
22 |
+
|
23 |
# Initialize APIs
|
24 |
+
try:
|
25 |
+
openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
|
26 |
+
logger.info("OpenAI client initialized.")
|
27 |
+
except KeyError:
|
28 |
+
logger.error("OPENAI_API_KEY environment variable not set!")
|
29 |
+
# Handle this critical error, perhaps exit or raise
|
30 |
+
raise
|
31 |
+
except Exception as e:
|
32 |
+
logger.error(f"Error initializing OpenAI client: {e}")
|
33 |
+
raise
|
34 |
+
|
35 |
+
try:
|
36 |
+
geolocator = Nominatim(user_agent="geoapi_visualizemap") # More specific user agent
|
37 |
+
logger.info("Geolocator initialized.")
|
38 |
+
except Exception as e:
|
39 |
+
logger.error(f"Error initializing Geolocator: {e}")
|
40 |
+
raise
|
41 |
|
42 |
# Function to fetch coordinates
|
43 |
@spaces.GPU
|
44 |
def get_geo_coordinates(location_name):
|
45 |
+
logger.info(f"Attempting to fetch coordinates for: {location_name}")
|
46 |
try:
|
47 |
+
location = geolocator.geocode(location_name, timeout=10) # Added timeout
|
48 |
if location:
|
49 |
+
logger.info(f"Coordinates found for {location_name}: {[location.longitude, location.latitude]}")
|
50 |
return [location.longitude, location.latitude]
|
51 |
+
logger.warning(f"No location data returned for {location_name}")
|
52 |
return None
|
53 |
except Exception as e:
|
54 |
logger.error(f"Error fetching coordinates for {location_name}: {e}")
|
|
|
57 |
# Function to process OpenAI chat response
|
58 |
@spaces.GPU
|
59 |
def process_openai_response(query):
|
60 |
+
logger.info(f"Processing OpenAI query: {query}")
|
61 |
+
try:
|
62 |
+
response = openai_client.chat.completions.create(
|
63 |
+
model="gpt-4o-mini",
|
64 |
+
messages=[
|
65 |
+
{
|
66 |
+
"role": "system",
|
67 |
+
"content": """
|
68 |
You are an assistant that generates structured JSON output for geographical queries with city names. Your task is to generate a JSON object containing information about geographical features and their representation based on the user's query. Follow these rules:
|
69 |
|
70 |
1. The JSON should always have the following structure:
|
|
|
115 |
|
116 |
Generate similar JSON for the following query:
|
117 |
"""
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"role": "user",
|
121 |
+
"content": query
|
122 |
+
}
|
123 |
+
],
|
124 |
+
temperature=1,
|
125 |
+
max_tokens=2048,
|
126 |
+
top_p=1,
|
127 |
+
frequency_penalty=0,
|
128 |
+
presence_penalty=0,
|
129 |
+
response_format={"type": "json_object"}
|
130 |
+
)
|
131 |
+
content = response.choices[0].message.content
|
132 |
+
logger.info(f"Raw OpenAI response content: {content}")
|
133 |
+
parsed_response = json.loads(content)
|
134 |
+
logger.info(f"Parsed OpenAI response: {json.dumps(parsed_response, indent=2)}")
|
135 |
+
return parsed_response
|
136 |
+
except Exception as e:
|
137 |
+
logger.error(f"Error processing OpenAI response for query '{query}': {e}")
|
138 |
+
# Consider returning a default error structure or re-raising
|
139 |
+
raise
|
140 |
|
141 |
# Generate GeoJSON from OpenAI response
|
142 |
@spaces.GPU
|
143 |
+
def generate_geojson(response_data): # Renamed to avoid confusion with http response
|
144 |
+
logger.info(f"Generating GeoJSON from OpenAI response_data: {json.dumps(response_data, indent=2)}")
|
145 |
+
try:
|
146 |
+
feature_type = response_data['output']['feature_representation']['type']
|
147 |
+
city_names = response_data['output']['feature_representation']['cities']
|
148 |
+
properties = response_data['output']['feature_representation']['properties']
|
149 |
+
logger.info(f"Feature type: {feature_type}, Cities: {city_names}")
|
150 |
|
151 |
+
coordinates = []
|
152 |
+
for city in city_names:
|
|
|
153 |
coord = get_geo_coordinates(city)
|
154 |
if coord:
|
155 |
coordinates.append(coord)
|
156 |
else:
|
157 |
+
logger.warning(f"Coordinates not found for city: {city}. Skipping.")
|
158 |
+
|
159 |
+
logger.info(f"Collected coordinates: {coordinates}")
|
160 |
+
|
161 |
+
# Ensure coordinates has the correct structure for each geometry type
|
162 |
+
if feature_type == "Point":
|
163 |
+
if not coordinates:
|
164 |
+
raise ValueError("Point type requires at least one coordinate.")
|
165 |
+
# GeoJSON Point expects a single coordinate pair, not a list of pairs
|
166 |
+
final_coordinates = coordinates[0] if coordinates else []
|
167 |
+
elif feature_type == "MultiPoint":
|
168 |
+
final_coordinates = coordinates # List of coordinate pairs
|
169 |
+
elif feature_type == "LineString":
|
170 |
+
if len(coordinates) < 2:
|
171 |
+
raise ValueError("LineString requires at least 2 coordinates.")
|
172 |
+
final_coordinates = coordinates # List of coordinate pairs
|
173 |
+
elif feature_type == "Polygon":
|
174 |
+
if len(coordinates) < 3:
|
175 |
+
raise ValueError("Polygon requires at least 3 coordinates.")
|
176 |
+
# Close the polygon by appending the first point at the end
|
177 |
+
if coordinates[0] != coordinates[-1]: # Check if already closed
|
178 |
+
coordinates.append(coordinates[0])
|
179 |
+
final_coordinates = [coordinates] # Nest coordinates for Polygon
|
180 |
+
else: # MultiLineString, MultiPolygon, GeometryCollection
|
181 |
+
logger.warning(f"Unsupported or complex feature_type: {feature_type}. Using raw coordinates.")
|
182 |
+
final_coordinates = coordinates # Or handle more specifically
|
183 |
+
|
184 |
+
geojson_data = {
|
185 |
+
"type": "FeatureCollection",
|
186 |
+
"features": [
|
187 |
+
{
|
188 |
+
"type": "Feature",
|
189 |
+
"properties": properties,
|
190 |
+
"geometry": {
|
191 |
+
"type": feature_type,
|
192 |
+
"coordinates": final_coordinates,
|
193 |
+
},
|
194 |
+
}
|
195 |
+
],
|
196 |
+
}
|
197 |
+
logger.info(f"Generated GeoJSON: {json.dumps(geojson_data, indent=2)}")
|
198 |
+
return geojson_data
|
199 |
+
except KeyError as e:
|
200 |
+
logger.error(f"KeyError while generating GeoJSON: {e}. Response data: {json.dumps(response_data, indent=2)}")
|
201 |
+
raise
|
202 |
+
except ValueError as e:
|
203 |
+
logger.error(f"ValueError while generating GeoJSON: {e}. Coordinates: {coordinates if 'coordinates' in locals() else 'N/A'}")
|
204 |
+
raise
|
205 |
+
except Exception as e:
|
206 |
+
logger.error(f"Unexpected error in generate_geojson: {e}")
|
207 |
+
raise
|
208 |
|
209 |
# Sort coordinates for a simple polygon (Reduce intersection points)
|
210 |
def sort_coordinates_for_simple_polygon(geojson):
|
211 |
+
logger.info("Attempting to sort polygon coordinates.")
|
212 |
+
try:
|
213 |
+
coordinates = geojson['features'][0]['geometry']['coordinates'][0]
|
214 |
+
logger.info(f"Original polygon coordinates: {coordinates}")
|
215 |
+
|
216 |
+
if not coordinates or len(coordinates) < 3:
|
217 |
+
logger.warning("Not enough coordinates to sort for a polygon.")
|
218 |
+
return geojson
|
219 |
+
|
220 |
+
# Remove the last point if it duplicates the first (GeoJSON convention for polygons)
|
221 |
+
if coordinates[0] == coordinates[-1] and len(coordinates) > 1:
|
222 |
+
plot_coordinates = coordinates[:-1]
|
223 |
+
else:
|
224 |
+
plot_coordinates = coordinates
|
225 |
+
|
226 |
+
if not plot_coordinates or len(plot_coordinates) < 3: # Check again after potentially removing last point
|
227 |
+
logger.warning("Not enough unique coordinates to sort for a polygon after de-duplication.")
|
228 |
+
return geojson
|
229 |
+
|
230 |
+
# Calculate the centroid of the points
|
231 |
+
centroid_x = sum(point[0] for point in plot_coordinates) / len(plot_coordinates)
|
232 |
+
centroid_y = sum(point[1] for point in plot_coordinates) / len(plot_coordinates)
|
233 |
+
logger.info(f"Calculated centroid: ({centroid_x}, {centroid_y})")
|
234 |
+
|
235 |
+
def angle_from_centroid(point):
|
236 |
+
dx = point[0] - centroid_x
|
237 |
+
dy = point[1] - centroid_y
|
238 |
+
return math.atan2(dy, dx)
|
239 |
+
|
240 |
+
sorted_plot_coordinates = sorted(plot_coordinates, key=angle_from_centroid)
|
241 |
+
sorted_plot_coordinates.append(sorted_plot_coordinates[0]) # Close the polygon
|
242 |
+
|
243 |
+
geojson['features'][0]['geometry']['coordinates'][0] = sorted_plot_coordinates
|
244 |
+
logger.info(f"Sorted polygon coordinates: {sorted_plot_coordinates}")
|
245 |
+
return geojson
|
246 |
+
except Exception as e:
|
247 |
+
logger.error(f"Error sorting polygon coordinates: {e}")
|
248 |
+
return geojson # Return original on error
|
249 |
|
250 |
# Generate static map image
|
251 |
@spaces.GPU
|
252 |
def generate_static_map(geojson_data, invisible=False):
|
253 |
+
logger.info(f"Generating static map. Invisible: {invisible}. GeoJSON: {json.dumps(geojson_data, indent=2)}")
|
254 |
+
try:
|
255 |
+
m = StaticMap(600, 600)
|
256 |
+
color = '#1C00ff00' if invisible else '#42445A85' # Transparent if invisible, else semi-transparent blue/grey
|
257 |
+
|
258 |
+
for feature in geojson_data["features"]:
|
259 |
+
geom_type = feature["geometry"]["type"]
|
260 |
+
coords = feature["geometry"]["coordinates"]
|
261 |
+
logger.info(f"Processing feature type: {geom_type} with coords: {coords}")
|
262 |
+
|
263 |
+
if geom_type == "Point":
|
264 |
+
# Coords for Point is a single [lon, lat]
|
265 |
+
if coords and len(coords) == 2 and isinstance(coords[0], (int, float)):
|
266 |
+
m.add_marker(CircleMarker((coords[0], coords[1]), color, 20 if invisible else 10)) # Adjusted size
|
267 |
+
else:
|
268 |
+
logger.warning(f"Skipping Point due to invalid coordinate structure: {coords}")
|
269 |
+
elif geom_type == "MultiPoint":
|
270 |
+
# Coords for MultiPoint is a list of [lon, lat]
|
271 |
+
for coord_pair in coords:
|
272 |
+
if coord_pair and len(coord_pair) == 2 and isinstance(coord_pair[0], (int, float)):
|
273 |
+
m.add_marker(CircleMarker((coord_pair[0], coord_pair[1]), color, 20 if invisible else 10))
|
274 |
+
else:
|
275 |
+
logger.warning(f"Skipping point in MultiPoint due to invalid coordinate structure: {coord_pair}")
|
276 |
+
elif geom_type == "LineString":
|
277 |
+
# Coords for LineString is a list of [lon, lat]
|
278 |
+
if len(coords) >=2:
|
279 |
+
m.add_line(Polygon([(c[0], c[1]) for c in coords], "blue", 3)) # For LineString, use add_line or thicker Polygon outline
|
280 |
+
else:
|
281 |
+
logger.warning(f"Skipping LineString, not enough points: {coords}")
|
282 |
+
elif geom_type == "Polygon":
|
283 |
+
# Coords for Polygon is a list containing one list of [lon, lat] (the exterior ring)
|
284 |
+
for polygon_ring in coords: # Should be only one for simple polygon
|
285 |
+
if len(polygon_ring) >= 3:
|
286 |
+
m.add_polygon(Polygon([(c[0], c[1]) for c in polygon_ring], color, '#0000AA' if not invisible else '#1C00ff00', 3 if not invisible else 0))
|
287 |
+
else:
|
288 |
+
logger.warning(f"Skipping polygon ring, not enough points: {polygon_ring}")
|
289 |
+
# Add handling for MultiLineString, MultiPolygon if your OpenAI might produce them
|
290 |
+
else:
|
291 |
+
logger.warning(f"Unsupported geometry type for static map: {geom_type}")
|
292 |
+
|
293 |
+
rendered_map = m.render(center=None, zoom=None) # Let it auto-center and zoom
|
294 |
+
logger.info(f"Static map rendered successfully. Invisible: {invisible}")
|
295 |
+
return rendered_map
|
296 |
+
except Exception as e:
|
297 |
+
logger.error(f"Error generating static map (invisible={invisible}): {e}")
|
298 |
+
# Return a placeholder or re-raise
|
299 |
+
return Image.new("RGB", (600, 600), color="grey") # Placeholder
|
300 |
|
301 |
# ControlNet pipeline setup
|
302 |
+
logger.info("Initializing Stable Diffusion Inpaint Pipeline.")
|
303 |
+
try:
|
304 |
+
# controlnet = ControlNetModel.from_pretrained("stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16)
|
305 |
+
# pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
306 |
+
# "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16 # Changed base model
|
307 |
+
# )
|
308 |
+
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
309 |
+
"stabilityai/stable-diffusion-2-inpainting", # This is a full inpainting pipeline, not just a controlnet
|
310 |
+
torch_dtype=torch.float16,
|
311 |
+
)
|
312 |
+
pipeline.to("cuda")
|
313 |
+
logger.info("Stable Diffusion Inpaint Pipeline loaded to CUDA.")
|
314 |
+
except Exception as e:
|
315 |
+
logger.error(f"Error initializing Stable Diffusion pipeline: {e}")
|
316 |
+
raise
|
317 |
+
|
318 |
+
# This function was for ControlNet, may not be needed as-is for StableDiffusionInpaintPipeline
|
319 |
+
# It expects init_image to be a NumPy array, and mask_image a NumPy array
|
320 |
@spaces.GPU
|
321 |
+
def make_inpaint_condition(init_image_pil, mask_image_pil):
|
322 |
+
logger.info("Preparing inpaint condition (ControlNet specific, may need adjustment).")
|
323 |
+
# Ensure PIL Images are converted to NumPy arrays correctly
|
324 |
+
init_image_np = np.array(init_image_pil.convert("RGB")).astype(np.float32) / 255.0
|
325 |
+
mask_image_np = np.array(mask_image_pil.convert("L")).astype(np.float32) / 255.0 # Ensure mask is L
|
326 |
+
|
327 |
+
logger.info(f"Init image shape: {init_image_np.shape}, Mask image shape: {mask_image_np.shape}")
|
328 |
+
|
329 |
+
if init_image_np.shape[:2] != mask_image_np.shape[:2]:
|
330 |
+
logger.error(f"Image and mask dimensions mismatch: {init_image_np.shape[:2]} vs {mask_image_np.shape[:2]}")
|
331 |
+
# Resize mask to match image if necessary, or raise error
|
332 |
+
# For now, let's assume they should match and this is an error state
|
333 |
+
raise ValueError("Image and mask_image must have the same height and width.")
|
334 |
+
|
335 |
+
# This operation is specific to how some ControlNet inpainting expects masked areas.
|
336 |
+
# Standard SDInpaintPipeline might not need this.
|
337 |
+
# init_image_np[mask_image_np > 0.5] = -1.0 # set as masked pixel
|
338 |
+
|
339 |
+
# init_image_np = np.expand_dims(init_image_np, 0).transpose(0, 3, 1, 2)
|
340 |
+
# init_image_tensor = torch.from_numpy(init_image_np)
|
341 |
+
# logger.info(f"Processed init_image tensor shape: {init_image_tensor.shape}")
|
342 |
+
# return init_image_tensor
|
343 |
+
|
344 |
+
# For StableDiffusionInpaintPipeline, `image` and `mask_image` are passed directly as PIL Images or tensors.
|
345 |
+
# The `make_inpaint_condition` might be redundant if you are not using a ControlNet that specifically requires this format.
|
346 |
+
# If you were using ControlNet, this would be the control_image.
|
347 |
+
# For now, let's assume it's meant to be the 'image' input for SD Inpaint, preprocessed.
|
348 |
+
return init_image_pil # Or init_image_tensor if pipeline expects tensor
|
349 |
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
@spaces.GPU
|
352 |
+
def generate_satellite_image(base_image_pil, mask_image_pil, prompt):
|
353 |
+
logger.info(f"Generating satellite image with prompt: '{prompt}'")
|
354 |
+
logger.info(f"Base image type: {type(base_image_pil)}, Mask image type: {type(mask_image_pil)}")
|
355 |
+
|
356 |
+
try:
|
357 |
+
# StableDiffusionInpaintPipeline expects PIL Images or tensors for image and mask_image
|
358 |
+
# The `control_image` argument is not standard for StableDiffusionInpaintPipeline.
|
359 |
+
# It's specific to StableDiffusionControlNetInpaintPipeline.
|
360 |
+
|
361 |
+
# If you were using the ControlNet variant:
|
362 |
+
# control_image_tensor = make_inpaint_condition(base_image_pil, mask_image_pil)
|
363 |
+
# result = pipeline(
|
364 |
+
# prompt=prompt,
|
365 |
+
# image=base_image_pil, # or tensor version if pipeline prefers
|
366 |
+
# mask_image=mask_image_pil, # or tensor version
|
367 |
+
# control_image=control_image_tensor, # This is for ControlNet
|
368 |
+
# strength=0.47, # strength might be called differently or not used in SD Inpaint
|
369 |
+
# guidance_scale=9.5, # Adjusted scale
|
370 |
+
# num_inference_steps=50 # Adjusted steps
|
371 |
+
# ).images[0]
|
372 |
+
|
373 |
+
# For StableDiffusionInpaintPipeline:
|
374 |
+
result = pipeline(
|
375 |
+
prompt=prompt,
|
376 |
+
image=base_image_pil, # PIL Image or PyTorch tensor
|
377 |
+
mask_image=mask_image_pil, # PIL Image or PyTorch tensor
|
378 |
+
guidance_scale=9.5, # More reasonable default
|
379 |
+
num_inference_steps=50 # More reasonable default
|
380 |
+
).images[0]
|
381 |
+
|
382 |
+
logger.info("Satellite image generated successfully.")
|
383 |
+
return result
|
384 |
+
except Exception as e:
|
385 |
+
logger.error(f"Error generating satellite image: {e}")
|
386 |
+
return Image.new("RGB", base_image_pil.size, color="red") # Placeholder
|
387 |
|
388 |
# Gradio UI
|
389 |
@spaces.GPU
|
390 |
+
def handle_query(query: str):
|
391 |
+
logger.info(f"--- Handling query: {query} ---")
|
392 |
+
try:
|
393 |
+
openai_response = process_openai_response(query)
|
394 |
+
logger.info(f"handle_query: OpenAI response received: type={type(openai_response)}")
|
395 |
+
|
396 |
+
geojson_data = generate_geojson(openai_response)
|
397 |
+
logger.info(f"handle_query: GeoJSON data generated: type={type(geojson_data)}")
|
398 |
+
|
399 |
+
processed_geojson_data = geojson_data
|
400 |
+
if geojson_data["features"][0]["geometry"]["type"] == 'Polygon':
|
401 |
+
logger.info("handle_query: Detected Polygon, attempting to sort coordinates.")
|
402 |
+
processed_geojson_data = sort_coordinates_for_simple_polygon(geojson_data)
|
403 |
+
|
404 |
+
map_image = generate_static_map(processed_geojson_data, invisible=False)
|
405 |
+
logger.info(f"handle_query: Visible map_image generated: type={type(map_image)}")
|
406 |
+
|
407 |
+
empty_map_image = generate_static_map(processed_geojson_data, invisible=True) # Use processed_geojson_data here too
|
408 |
+
logger.info(f"handle_query: Invisible empty_map_image generated: type={type(empty_map_image)}")
|
409 |
+
|
410 |
+
# Ensure images are PIL for diff
|
411 |
+
map_array = np.array(map_image.convert("RGB"))
|
412 |
+
empty_map_array = np.array(empty_map_image.convert("RGB"))
|
413 |
+
|
414 |
+
difference = np.abs(map_array - empty_map_array)
|
415 |
+
threshold = 10 # May need adjustment
|
416 |
+
mask_array = (np.sum(difference, axis=-1) > threshold).astype(np.uint8) * 255
|
417 |
+
mask_image = Image.fromarray(mask_array, mode="L")
|
418 |
+
logger.info(f"handle_query: Mask image generated: type={type(mask_image)}")
|
419 |
+
|
420 |
+
prompt_for_image = openai_response['output']['feature_representation']['properties']['description']
|
421 |
+
logger.info(f"handle_query: Prompt for satellite image: '{prompt_for_image}', type={type(prompt_for_image)}")
|
422 |
+
|
423 |
+
# Pass empty_map_image (which is the base map without visible markers)
|
424 |
+
# and the derived mask_image to the inpainting function
|
425 |
+
satellite_image = generate_satellite_image(
|
426 |
+
empty_map_image, mask_image, prompt_for_image
|
427 |
+
)
|
428 |
+
logger.info(f"handle_query: Satellite image generated: type={type(satellite_image)}")
|
429 |
+
|
430 |
+
# Ensure all returned image types are PIL Images
|
431 |
+
final_map_image = map_image if isinstance(map_image, Image.Image) else Image.new("RGB", (600,600), "grey")
|
432 |
+
final_satellite_image = satellite_image if isinstance(satellite_image, Image.Image) else Image.new("RGB", (600,600), "red")
|
433 |
+
final_empty_map_image = empty_map_image if isinstance(empty_map_image, Image.Image) else Image.new("RGB", (600,600), "grey")
|
434 |
+
final_mask_image = mask_image if isinstance(mask_image, Image.Image) else Image.new("L", (600,600), 0)
|
435 |
+
|
436 |
+
logger.info(f"handle_query: Returning types: {type(final_map_image)}, {type(final_satellite_image)}, {type(final_empty_map_image)}, {type(final_mask_image)}, {type(prompt_for_image)}")
|
437 |
+
return final_map_image, final_satellite_image, final_empty_map_image, final_mask_image, prompt_for_image
|
438 |
|
439 |
+
except Exception as e:
|
440 |
+
logger.error(f"--- Error in handle_query for query '{query}': {e} ---", exc_info=True)
|
441 |
+
# Return placeholder/error images and message
|
442 |
+
error_img = Image.new("RGB", (600, 600), "black")
|
443 |
+
error_text_img = ImageDraw.Draw(error_img)
|
444 |
+
error_text_img.text((10,10), f"Error: {e}", fill="white")
|
445 |
+
return error_img, error_img, error_img, error_img, f"Error processing query: {e}"
|
446 |
+
|
447 |
+
def update_query(selected_query_value: str) -> str: # Added type hints
|
448 |
+
logger.info(f"Dropdown changed. Selected query: '{selected_query_value}', type: {type(selected_query_value)}")
|
449 |
+
return selected_query_value
|
450 |
+
|
451 |
+
logger.info("Defining Gradio UI components.")
|
452 |
query_options = [
|
453 |
"Area covering south asian subcontinent",
|
454 |
+
"Mark a triangular area using New York, Boston, and Texas", # Texas is a state, might cause issues with geocoding as a city point
|
455 |
"Mark cities in India",
|
456 |
"Show me Lotus Tower in a Map",
|
457 |
"Mark the area of west germany",
|
458 |
"Mark the area of the Amazon rainforest",
|
459 |
"Mark the area of the Sahara desert"
|
460 |
+
]
|
461 |
+
logger.info(f"Query options: {query_options}")
|
462 |
+
|
463 |
+
# It's crucial that the `value` parameters for components are of the type Gradio expects
|
464 |
+
# for their schema generation, even before any function is called.
|
465 |
+
# For gr.Textbox, `value` should be a string.
|
466 |
+
# For gr.Dropdown, `value` should be one of the `choices` or None.
|
467 |
+
|
468 |
+
try:
|
469 |
+
with gr.Blocks() as demo:
|
470 |
+
logger.info("Inside gr.Blocks() context manager.")
|
471 |
+
with gr.Row():
|
472 |
+
logger.info("Defining first gr.Row.")
|
473 |
+
selected_query = gr.Dropdown(label="Select Query", choices=query_options, value=query_options[-1], type="value") # Ensure type="value" if not default
|
474 |
+
logger.info(f"selected_query Dropdown defined. Initial value: '{query_options[-1]}', type: {type(query_options[-1])}")
|
475 |
+
|
476 |
+
query_input = gr.Textbox(label="Enter Query", value=str(query_options[-1])) # Ensure value is string
|
477 |
+
logger.info(f"query_input Textbox defined. Initial value: '{query_options[-1]}', type: {type(query_options[-1])}")
|
478 |
+
|
479 |
+
# The `change` event should not cause the schema error, but good to log
|
480 |
+
selected_query.change(fn=update_query, inputs=selected_query, outputs=query_input)
|
481 |
+
logger.info("selected_query.change event defined.")
|
482 |
+
|
483 |
+
submit_btn = gr.Button("Submit")
|
484 |
+
logger.info("submit_btn Button defined.")
|
485 |
+
|
486 |
+
with gr.Row():
|
487 |
+
logger.info("Defining second gr.Row for image outputs.")
|
488 |
+
map_output = gr.Image(label="Map Visualization") # No initial value needed here, will be populated by function
|
489 |
+
logger.info("map_output Image defined.")
|
490 |
+
satellite_output = gr.Image(label="Generated Map Image")
|
491 |
+
logger.info("satellite_output Image defined.")
|
492 |
+
|
493 |
+
with gr.Row():
|
494 |
+
logger.info("Defining third gr.Row for debug outputs.")
|
495 |
+
empty_map_output = gr.Image(label="Empty Visualization")
|
496 |
+
logger.info("empty_map_output Image defined.")
|
497 |
+
mask_output = gr.Image(label="Mask")
|
498 |
+
logger.info("mask_output Image defined.")
|
499 |
+
# For image_prompt, provide a default string value or None. An empty string is fine.
|
500 |
+
image_prompt_output = gr.Textbox(label="Image Prompt Used", value="") # Changed name to avoid conflict, ensure string value
|
501 |
+
logger.info(f"image_prompt_output Textbox defined. Initial value: '', type: str")
|
502 |
+
|
503 |
+
# The outputs list must match the number and expected types of what handle_query returns.
|
504 |
+
# handle_query returns: PIL.Image, PIL.Image, PIL.Image, PIL.Image, str
|
505 |
+
# Gradio components: gr.Image, gr.Image, gr.Image, gr.Image, gr.Textbox
|
506 |
+
# This mapping looks correct.
|
507 |
+
submit_btn.click(fn=handle_query,
|
508 |
+
inputs=[query_input],
|
509 |
+
outputs=[map_output, satellite_output, empty_map_output, mask_output, image_prompt_output])
|
510 |
+
logger.info("submit_btn.click event defined.")
|
511 |
+
logger.info("Gradio Blocks defined successfully.")
|
512 |
+
|
513 |
+
except Exception as e:
|
514 |
+
logger.error(f"Error during Gradio UI definition: {e}", exc_info=True)
|
515 |
+
raise
|
516 |
|
517 |
if __name__ == "__main__":
|
518 |
+
logger.info("Launching Gradio demo.")
|
519 |
+
try:
|
520 |
+
demo.launch() # debug=True can sometimes give more frontend info, but not for this backend error
|
521 |
+
logger.info("Gradio demo launched.")
|
522 |
+
except Exception as e:
|
523 |
+
logger.error(f"Error launching Gradio demo: {e}", exc_info=True)
|
524 |
+
raise
|