caliex commited on
Commit
24d7e1c
1 Parent(s): 616f933

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -4
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import matplotlib as mpl
3
  import matplotlib.pyplot as plt
4
  import numpy as np
 
5
  from sklearn import datasets
6
  from sklearn.mixture import GaussianMixture
7
  from sklearn.model_selection import StratifiedKFold
@@ -98,10 +99,24 @@ def classify_iris(cov_type):
98
  return output_path
99
 
100
 
 
 
 
 
 
 
 
 
 
101
  iface = gr.Interface(
102
- fn=classify_iris,
103
- inputs=gr.inputs.Radio(["spherical", "diag", "tied", "full"], label="Covariance Type"),
104
- outputs="image",
 
 
 
 
 
105
  title="Gaussian Mixture Model Covariance",
106
  description="Explore different covariance types for Gaussian mixture models (GMMs) in this demonstration. GMMs are commonly used for clustering, but in this example, we compare the obtained clusters with the actual classes from the dataset. By initializing the means of the Gaussians with the means of the classes in the training set, we ensure a valid comparison. The plots show the predicted labels on both training and test data using GMMs with spherical, diagonal, full, and tied covariance matrices. Interestingly, while full covariance is expected to perform best, it may overfit small datasets and struggle to generalize to held out test data. See the original scikit-learn example for more information: https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_covariances.html",
107
  examples=[
@@ -112,4 +127,4 @@ iface = gr.Interface(
112
  ],
113
  )
114
 
115
- iface.launch()
 
2
  import matplotlib as mpl
3
  import matplotlib.pyplot as plt
4
  import numpy as np
5
+ from PIL import Image
6
  from sklearn import datasets
7
  from sklearn.mixture import GaussianMixture
8
  from sklearn.model_selection import StratifiedKFold
 
99
  return output_path
100
 
101
 
102
+ def update_plot(cov_type):
103
+ # image_path = classify_iris(cov_type)
104
+ # return gr.outputs.Image(image_path)
105
+ image_path = classify_iris(cov_type)
106
+ image = Image.open(image_path)
107
+ image_array = np.array(image)
108
+ return image_array
109
+
110
+
111
  iface = gr.Interface(
112
+ fn=update_plot,
113
+ inputs=gr.inputs.Radio(
114
+ ["spherical", "diag", "tied", "full"],
115
+ label="Covariance Type",
116
+ default="spherical"
117
+ ),
118
+ outputs=gr.outputs.Image(type="numpy"),
119
+ live=True,
120
  title="Gaussian Mixture Model Covariance",
121
  description="Explore different covariance types for Gaussian mixture models (GMMs) in this demonstration. GMMs are commonly used for clustering, but in this example, we compare the obtained clusters with the actual classes from the dataset. By initializing the means of the Gaussians with the means of the classes in the training set, we ensure a valid comparison. The plots show the predicted labels on both training and test data using GMMs with spherical, diagonal, full, and tied covariance matrices. Interestingly, while full covariance is expected to perform best, it may overfit small datasets and struggle to generalize to held out test data. See the original scikit-learn example for more information: https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_covariances.html",
122
  examples=[
 
127
  ],
128
  )
129
 
130
+ iface.launch()