TroglodyteDerivations commited on
Commit
06a4769
1 Parent(s): ae445de

Updated line 414 with: dispersion_values = np.array([self.Dispersion() for _ in positions]) [plot_dispersion_heatmap] method

Browse files
Files changed (1) hide show
  1. app.py +28 -27
app.py CHANGED
@@ -403,6 +403,34 @@ class GWO:
403
  dy = y.max() - y.min()
404
  return (dx + dy) / 2.0
405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  def plot_dispersion(self):
407
  """Plot the dispersion over time"""
408
  # Assuming self.giter stores the iteration number at which each best position was found
@@ -450,33 +478,6 @@ class GWO:
450
  #plt.close() # Close the figure to free up memory
451
  #return Image.open(buf)
452
 
453
- def plot_dispersion_heatmap(self, x_range, y_range, resolution=100):
454
- # Create a grid of points within the specified range
455
- x = np.linspace(*x_range, resolution)
456
- y = np.linspace(*y_range, resolution)
457
- X, Y = np.meshgrid(x, y)
458
- positions = np.vstack([X.ravel(), Y.ravel()]).T
459
-
460
- # Calculate the dispersion for each position in the grid
461
- dispersion_values = np.array([self.Dispersion(pos) for pos in positions])
462
- Z = dispersion_values.reshape(X.shape)
463
-
464
- # Plot the dispersion heatmap
465
- plt.figure(figsize=(10, 8))
466
- plt.pcolormesh(X, Y, Z, cmap='viridis', norm=LogNorm())
467
- plt.colorbar(label='Dispersion')
468
-
469
- # Set plot title and labels
470
- plt.title('Dispersion Heatmap')
471
- plt.xlabel('x')
472
- plt.ylabel('y')
473
-
474
- # Convert the plot to a PIL Image and return it
475
- buf = BytesIO()
476
- plt.savefig(buf, format='png')
477
- buf.seek(0)
478
- plt.close() # Close the figure to free up memory
479
- return Image.open(buf)
480
 
481
 
482
  def optimize(npart, ndim, max_iter):
 
403
  dy = y.max() - y.min()
404
  return (dx + dy) / 2.0
405
 
406
+ def plot_dispersion_heatmap(self, x_range, y_range, resolution=100):
407
+ # Create a grid of points within the specified range
408
+ x = np.linspace(*x_range, resolution)
409
+ y = np.linspace(*y_range, resolution)
410
+ X, Y = np.meshgrid(x, y)
411
+ positions = np.vstack([X.ravel(), Y.ravel()]).T
412
+
413
+ # Calculate the dispersion for each position in the grid
414
+ dispersion_values = np.array([self.Dispersion() for _ in positions])
415
+ Z = dispersion_values.reshape(X.shape)
416
+
417
+ # Plot the dispersion heatmap
418
+ plt.figure(figsize=(10, 8))
419
+ plt.pcolormesh(X, Y, Z, cmap='viridis')
420
+ plt.colorbar(label='Dispersion')
421
+
422
+ # Set plot title and labels
423
+ plt.title('Dispersion Heatmap')
424
+ plt.xlabel('x')
425
+ plt.ylabel('y')
426
+
427
+ # Convert the plot to a PIL Image and return it
428
+ buf = BytesIO()
429
+ plt.savefig(buf, format='png')
430
+ buf.seek(0)
431
+ plt.close() # Close the figure to free up memory
432
+ return Image.open(buf)
433
+
434
  def plot_dispersion(self):
435
  """Plot the dispersion over time"""
436
  # Assuming self.giter stores the iteration number at which each best position was found
 
478
  #plt.close() # Close the figure to free up memory
479
  #return Image.open(buf)
480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
 
482
 
483
  def optimize(npart, ndim, max_iter):