Umang-Bansal commited on
Commit
9ac0ba6
1 Parent(s): d0d13a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -14
app.py CHANGED
@@ -8,14 +8,18 @@ import scipy
8
  from scipy import signal
9
  import pickle
10
 
 
 
11
  def get_data_preview(file):
12
- data = pd.read_csv(file.name)
13
- return data.head()
 
14
 
15
- def label_data(file, start, end, label):
16
- data = pd.read_csv(file.name)
17
- data.loc[start:end, 'label'] = label # Label the specified range
18
- return data
 
19
 
20
  def preprocess_data(data):
21
  data.drop(columns=data.columns[0], axis=1, inplace=True)
@@ -86,7 +90,9 @@ def preprocess_data(data):
86
  df_features['label'] = labels
87
  return df_features
88
 
89
- def train_model(data):
 
 
90
  scaler = StandardScaler()
91
  X = data.drop('label', axis=1)
92
  y = data['label']
@@ -108,15 +114,17 @@ def train_model(data):
108
  with open(scaler_filename, 'wb') as file:
109
  pickle.dump(scaler, file)
110
 
111
- return f"Training complete! Model and scaler saved.", gr.File(model_filename), gr.File(scaler_filename)
112
 
113
 
114
  with gr.Blocks() as demo:
115
  file_input = gr.File(label="Upload CSV File")
116
  data_preview = gr.Dataframe(label="Data Preview", interactive=False)
117
- start_input = gr.Number(label="Start Index", value=0)
118
- end_input = gr.Number(label="End Index", value=100)
119
- label_input = gr.Number(label="Label Value", value=1)
 
 
120
  labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
121
  training_status = gr.Textbox(label="Training Status")
122
  model_file = gr.File(label="Download Trained Model")
@@ -124,8 +132,8 @@ with gr.Blocks() as demo:
124
 
125
  file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
126
  label_button = gr.Button("Label Data")
127
- label_button.click(label_data, inputs=[file_input, start_input, end_input, label_input], outputs=labeled_data_preview)
128
  train_button = gr.Button("Train Model")
129
- train_button.click(train_model, inputs=labeled_data_preview, outputs=[training_status, model_file, scaler_file])
130
 
131
- demo.launch()
 
8
  from scipy import signal
9
  import pickle
10
 
11
+ global_data = None
12
+
13
  def get_data_preview(file):
14
+ global global_data
15
+ global_data = pd.read_csv(file.name)
16
+ return global_data.head()
17
 
18
+ def label_data(ranges):
19
+ global global_data
20
+ for start, end, label in ranges:
21
+ global_data.loc[start:end, 'label'] = label
22
+ return global_data.head()
23
 
24
  def preprocess_data(data):
25
  data.drop(columns=data.columns[0], axis=1, inplace=True)
 
90
  df_features['label'] = labels
91
  return df_features
92
 
93
+ def train_model():
94
+ global global_data
95
+ data = preprocess_data(global_data)
96
  scaler = StandardScaler()
97
  X = data.drop('label', axis=1)
98
  y = data['label']
 
114
  with open(scaler_filename, 'wb') as file:
115
  pickle.dump(scaler, file)
116
 
117
+ return "Training complete! Model and scaler saved.", model_filename, scaler_filename
118
 
119
 
120
  with gr.Blocks() as demo:
121
  file_input = gr.File(label="Upload CSV File")
122
  data_preview = gr.Dataframe(label="Data Preview", interactive=False)
123
+ ranges_input = gr.Dataframe(headers=["Start Index", "End Index", "Label"], label="Ranges for Labeling")
124
+
125
+ #start_input = gr.Number(label="Start Index", value=0)
126
+ #end_input = gr.Number(label="End Index", value=100)
127
+ #label_input = gr.Number(label="Label Value", value=1)
128
  labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
129
  training_status = gr.Textbox(label="Training Status")
130
  model_file = gr.File(label="Download Trained Model")
 
132
 
133
  file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
134
  label_button = gr.Button("Label Data")
135
+ label_button.click(label_data, inputs=[ranges_input], outputs=labeled_data_preview)
136
  train_button = gr.Button("Train Model")
137
+ train_button.click(train_model, outputs=[training_status, model_file, scaler_file])
138
 
139
+ demo.launch()