penscola commited on
Commit
53752d9
·
1 Parent(s): 8d1fe99

Upload 3 files

Browse files
Files changed (3) hide show
  1. Gradient Boosting_pipeline.pkl +3 -0
  2. predict.py +95 -0
  3. requirments.txt +19 -0
Gradient Boosting_pipeline.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56cd9c9ba8e2588f9dbfbf9e9759d04f41dec108bc3582bb5463d4f31d9a34f5
3
+ size 1597216
predict.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import pickle
4
+ import numpy as np
5
+ from sklearn.preprocessing import StandardScaler
6
+ from sklearn.compose import ColumnTransformer
7
+ from sklearn.pipeline import Pipeline
8
+ from sklearn.impute import SimpleImputer
9
+ from sklearn.ensemble import GradientBoostingClassifier
10
+
11
+ # Load the saved full pipeline from the file
12
+ # Load the saved full pipeline from the file
13
+ full_pipeline_path = '../../model/Gradient Boosting_pipeline.pkl'
14
+
15
+ with open(full_pipeline_path, 'rb') as f_in:
16
+ full_pipeline = pickle.load(f_in)
17
+
18
+ # Define the predict function
19
+ # Define the predict function
20
+ def predict(danceability, energy, key, loudness, speechiness, acousticness, instrumentalness,
21
+ liveness, valence, tempo, duration_ms, mode):
22
+ # Create a DataFrame from the input data
23
+ input_data = pd.DataFrame({
24
+ 'danceability': [danceability] if danceability else [0], # Replace None with default value
25
+ 'energy': [energy] if energy is not None else [0], # Replace None with default value
26
+ 'key': [key] if key else [0], # Replace None with default value
27
+ 'loudness': [loudness] if loudness else [0], # Replace None with default value
28
+ 'speechiness': [speechiness] if speechiness else [0], # Replace None with default value
29
+ 'acousticness': [acousticness] if acousticness else [0], # Replace None with default value
30
+ 'instrumentalness': [instrumentalness] if instrumentalness else [0], # Replace None with default value
31
+ 'liveness': [liveness] if liveness else [0], # Replace None with default value
32
+ 'valence': [valence] if valence else [0], # Replace None with default value
33
+ 'tempo': [tempo] if tempo else [0], # Replace None with default value
34
+ 'duration_ms': [duration_ms] if duration_ms else [0], # Replace None with default value
35
+ 'mode': [mode] if mode else [0], # Replace None with default value
36
+ })
37
+
38
+ # Make predictions using the loaded logistic regression model
39
+ #predict probabilities
40
+ predictions = full_pipeline.predict_proba(input_data)
41
+ #take the index of the maximum probability
42
+ index=np.argmax(predictions)
43
+ higher_pred_prob=round((predictions[0][index])*100)
44
+
45
+
46
+ #return predictions[0]
47
+ print(f'[Info] Predicted probabilities{predictions},{full_pipeline.classes_}')
48
+
49
+ return f'{full_pipeline.classes_[index]} with {higher_pred_prob}% confidence'
50
+
51
+ # Setting Gradio App Interface
52
+ with gr.Blocks(css=".gradio-container {background-color:grey }",theme=gr.themes.Base(primary_hue='blue'),title='Uriel') as demo:
53
+ gr.Markdown("# Spotify Genre Prediction #\n*This App allows the user to predict genre by entering values in the given fields. Any field left blank takes the default value.*")
54
+
55
+ # Receiving ALL Input Data here
56
+ gr.Markdown("**Demographic Data**")
57
+ with gr.Row():
58
+ danceability = gr.Number(label="Danceability ~ describes how suitable a track is for dancing based on musical elements.")
59
+ energy = gr.Number(label="Energy ~ measure from 0.0 to 1.0 and represents a perceptual measure of intensity and activity.")
60
+ key = gr.Number(label="Key ~ The estimated overall key of the track, If no key was detected, the value is -1")
61
+ loudness = gr.Number(label="Loudness ~ Overall loudness of a track in decibels (dB), range between -60 and 0 db.")
62
+
63
+ with gr.Row():
64
+ speechiness = gr.Number(label="Speechiness ~ indicates the modality (major or minor), represented by 1 and minor is 0")
65
+ acousticness = gr.Number(label="Acousticness ~ A confidence measure from 0.0 to 1.0 of whether the track is acoustic")
66
+ instrumentalness = gr.Number(label="Instrumentalness ~ Predicts whether a track contains no vocals, Rap or spoken word tracks")
67
+ liveness = gr.Number(label="Liveness ~ Detects the presence of an audience in the recording, measure from 0.0 to 1.0")
68
+
69
+ with gr.Row():
70
+ valence = gr.Number(label="Valence ~ A measure from 0.0 to 1.0 describing the musical positiveness conveyed by a track.")
71
+ tempo = gr.Number(label="Tempo ~ The overall estimated tempo of a track in beats per minute (BPM)")
72
+ duration_ms = gr.Number(label="Duration_ms ~ double Duration of song in milliseconds")
73
+ mode = gr.Number(label="Mode ~ Mode indicates the modality (major or minor) of a track, represented by 1 and minor is 0")
74
+
75
+ # Output Prediction
76
+ output = gr.Text(label="Outcome")
77
+ submit_button = gr.Button("Predict")
78
+
79
+ submit_button.click(fn= predict,
80
+ outputs= output,
81
+ inputs=[danceability, energy, key, loudness, speechiness, acousticness, instrumentalness, liveness, valence, tempo, duration_ms, mode]
82
+
83
+ ),
84
+
85
+ # Add the reset and flag buttons
86
+
87
+ def clear():
88
+ output.value = ""
89
+ return 'Predicted values have been reset'
90
+
91
+ clear_btn = gr.Button("Reset", variant="primary")
92
+ clear_btn.click(fn=clear, inputs=None, outputs=output)
93
+
94
+
95
+ demo.launch(inbrowser = True)
requirments.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ contourpy==1.2.0
2
+ cycler==0.12.1
3
+ fonttools==4.47.0
4
+ joblib==1.3.2
5
+ kiwisolver==1.4.5
6
+ matplotlib==3.8.2
7
+ numpy==1.26.2
8
+ packaging==23.2
9
+ pandas==2.1.4
10
+ Pillow==10.1.0
11
+ pyparsing==3.1.1
12
+ python-dateutil==2.8.2
13
+ pytz==2023.3.post1
14
+ scikit-learn==1.3.2
15
+ scipy==1.11.4
16
+ seaborn==0.13.0
17
+ six==1.16.0
18
+ threadpoolctl==3.2.0
19
+ tzdata==2023.3