VincentCroft commited on
Commit
ea59ebe
·
verified ·
1 Parent(s): 2f71709

Upload 3 files

Browse files
Files changed (3) hide show
  1. lstm_cnn_app.py +98 -0
  2. requirements.txt +7 -0
  3. tcn_app.py +107 -0
lstm_cnn_app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ lstm_cnn_app.py
3
+ Gradio app to serve the CNN-LSTM fault classification model.
4
+
5
+ Usage:
6
+ - Place a local model file named by LOCAL_MODEL_FILE in the same repo, or
7
+ - Set HUB_REPO and HUB_FILENAME to a public Hugging Face model repo + filename,
8
+ and the app will download it at startup using hf_hub_download.
9
+ """
10
+ import os
11
+ import numpy as np
12
+ import pandas as pd
13
+ import gradio as gr
14
+ from tensorflow.keras.models import load_model
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ # CONFIG: change these if your model filename/repo are different
18
+ LOCAL_MODEL_FILE = "lstm_cnn_model.h5"
19
+ HUB_REPO = "" # e.g., "username/lstm-cnn-model"
20
+ HUB_FILENAME = "" # e.g., "lstm_cnn_model.h5"
21
+
22
+ def get_model_path():
23
+ if os.path.exists(LOCAL_MODEL_FILE):
24
+ return LOCAL_MODEL_FILE
25
+ if HUB_REPO and HUB_FILENAME:
26
+ try:
27
+ print(f"Downloading {HUB_FILENAME} from {HUB_REPO} ...")
28
+ return hf_hub_download(repo_id=HUB_REPO, filename=HUB_FILENAME)
29
+ except Exception as e:
30
+ print("Failed to download from hub:", e)
31
+ return None
32
+
33
+ MODEL_PATH = get_model_path()
34
+ MODEL = None
35
+ if MODEL_PATH:
36
+ try:
37
+ MODEL = load_model(MODEL_PATH)
38
+ print("Loaded model:", MODEL_PATH)
39
+ except Exception as e:
40
+ print("Failed to load model:", e)
41
+ MODEL = None
42
+ else:
43
+ print("No model found. Please upload a model named", LOCAL_MODEL_FILE, "or set HUB_REPO/HUB_FILENAME.")
44
+
45
+ def prepare_input_array(arr, n_timesteps=1, n_features=None):
46
+ arr = np.array(arr)
47
+ if arr.ndim == 1:
48
+ if n_features is None:
49
+ return arr.reshape(1, n_timesteps, -1)
50
+ return arr.reshape(1, n_timesteps, n_features)
51
+ elif arr.ndim == 2:
52
+ return arr
53
+ else:
54
+ return arr
55
+
56
+ def predict_text(text, n_timesteps=1, n_features=None):
57
+ if MODEL is None:
58
+ return "模型未加载,请上传或配置模型。"
59
+ arr = np.fromstring(text, sep=',')
60
+ x = prepare_input_array(arr, n_timesteps=int(n_timesteps), n_features=(int(n_features) if n_features else None))
61
+ probs = MODEL.predict(x)
62
+ label = int(np.argmax(probs, axis=1)[0])
63
+ return f"预测类别: {label} (概率: {float(np.max(probs)):.4f})"
64
+
65
+ def predict_csv(file, n_timesteps=1, n_features=None):
66
+ if MODEL is None:
67
+ return {"error": "模型未加载,请上传或配置模型。"}
68
+ df = pd.read_csv(file.name)
69
+ X = df.values
70
+ if n_features:
71
+ X = X.reshape(X.shape[0], int(n_timesteps), int(n_features))
72
+ preds = MODEL.predict(X)
73
+ labels = preds.argmax(axis=1).tolist()
74
+ return {"labels": labels, "probs": preds.tolist()}
75
+
76
+ with gr.Blocks() as demo:
77
+ gr.Markdown("# CNN-LSTM Fault Classification")
78
+ gr.Markdown("上传 CSV(每行一个样本)或粘贴逗号分隔的一行特征进行预测。")
79
+ with gr.Row():
80
+ file_in = gr.File(label="上传 CSV(每行 = 一个样本)")
81
+ text_in = gr.Textbox(lines=2, placeholder="粘贴逗号分隔的一行特征,例如: 0.1,0.2,0.3,...")
82
+ n_ts = gr.Number(value=1, label="timesteps (整型)")
83
+ n_feat = gr.Number(value=None, label="features (可选,留空尝试自动推断)")
84
+ btn = gr.Button("预测")
85
+ out_text = gr.Textbox(label="单样本预测输出")
86
+ out_json = gr.JSON(label="批量预测结果 (labels & probs)")
87
+
88
+ def run_predict(file, text, n_timesteps, n_features):
89
+ if file is not None:
90
+ return "CSV 预测完成", predict_csv(file, n_timesteps, n_features)
91
+ if text:
92
+ return predict_text(text, n_timesteps, n_features), {}
93
+ return "请提供 CSV 或特征文本", {}
94
+
95
+ btn.click(run_predict, inputs=[file_in, text_in, n_ts, n_feat], outputs=[out_text, out_json])
96
+
97
+ if __name__ == '__main__':
98
+ demo.launch(server_name='0.0.0.0', server_port=7861)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=3.0
2
+ tensorflow>=2.6
3
+ numpy
4
+ pandas
5
+ scikit-learn
6
+ huggingface_hub
7
+ matplotlib
tcn_app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tcn_app.py
3
+ Gradio app to serve the TCN fault classification model.
4
+
5
+ Usage:
6
+ - Place a local model file named by LOCAL_MODEL_FILE in the same repo, or
7
+ - Set HUB_REPO and HUB_FILENAME to a public Hugging Face model repo + filename,
8
+ and the app will download it at startup using hf_hub_download.
9
+
10
+ This file is ready to push to a Hugging Face Space (Gradio).
11
+ """
12
+ import os
13
+ import numpy as np
14
+ import pandas as pd
15
+ import gradio as gr
16
+ from tensorflow.keras.models import load_model
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ # CONFIG: change these if your model filename/repo are different
20
+ LOCAL_MODEL_FILE = "tcn_model.h5"
21
+ HUB_REPO = "" # e.g., "username/tcn-model-repo" (leave empty to disable)
22
+ HUB_FILENAME = "" # e.g., "tcn_model.h5"
23
+
24
+ def get_model_path():
25
+ # prefer local
26
+ if os.path.exists(LOCAL_MODEL_FILE):
27
+ return LOCAL_MODEL_FILE
28
+ # try hub
29
+ if HUB_REPO and HUB_FILENAME:
30
+ try:
31
+ print(f"Downloading {HUB_FILENAME} from {HUB_REPO} ...")
32
+ return hf_hub_download(repo_id=HUB_REPO, filename=HUB_FILENAME)
33
+ except Exception as e:
34
+ print("Failed to download from hub:", e)
35
+ return None
36
+
37
+ MODEL_PATH = get_model_path()
38
+ MODEL = None
39
+ if MODEL_PATH:
40
+ try:
41
+ MODEL = load_model(MODEL_PATH)
42
+ print("Loaded model:", MODEL_PATH)
43
+ except Exception as e:
44
+ print("Failed to load model:", e)
45
+ MODEL = None
46
+ else:
47
+ print("No model found. Please upload a model named", LOCAL_MODEL_FILE, "or set HUB_REPO/HUB_FILENAME.")
48
+
49
+ def prepare_input_array(arr, n_timesteps=1, n_features=None):
50
+ arr = np.array(arr)
51
+ # If input is 1D, reshape to (1, n_timesteps, n_features)
52
+ if arr.ndim == 1:
53
+ if n_features is None:
54
+ # if user didn't supply n_features, assume arr is already shaped as (timesteps*features,)
55
+ return arr.reshape(1, n_timesteps, -1)
56
+ return arr.reshape(1, n_timesteps, n_features)
57
+ elif arr.ndim == 2:
58
+ # Already (timesteps, features) or (samples, features)
59
+ if arr.shape[0] == 1:
60
+ return arr.reshape(1, arr.shape[1], -1)
61
+ return arr
62
+ else:
63
+ return arr
64
+
65
+ def predict_text(text, n_timesteps=1, n_features=None):
66
+ if MODEL is None:
67
+ return "模型未加载,请上传或配置模型。"
68
+ arr = np.fromstring(text, sep=',')
69
+ x = prepare_input_array(arr, n_timesteps=int(n_timesteps), n_features=(int(n_features) if n_features else None))
70
+ probs = MODEL.predict(x)
71
+ label = int(np.argmax(probs, axis=1)[0])
72
+ return f"预测类别: {label} (概率: {float(np.max(probs)):.4f})"
73
+
74
+ def predict_csv(file, n_timesteps=1, n_features=None):
75
+ if MODEL is None:
76
+ return {"error": "模型未加载,请上传或配置模型。"}
77
+ df = pd.read_csv(file.name)
78
+ X = df.values
79
+ if n_features:
80
+ X = X.reshape(X.shape[0], int(n_timesteps), int(n_features))
81
+ preds = MODEL.predict(X)
82
+ labels = preds.argmax(axis=1).tolist()
83
+ return {"labels": labels, "probs": preds.tolist()}
84
+
85
+ with gr.Blocks() as demo:
86
+ gr.Markdown("# TCN Fault Classification")
87
+ gr.Markdown("上传 CSV(每行一个样本)或粘贴逗号分隔的一行特征进行预测。")
88
+ with gr.Row():
89
+ file_in = gr.File(label="上传 CSV(每行 = 一个样本)")
90
+ text_in = gr.Textbox(lines=2, placeholder="粘贴逗号分隔的一行特征,例如: 0.1,0.2,0.3,...")
91
+ n_ts = gr.Number(value=1, label="timesteps (整型)")
92
+ n_feat = gr.Number(value=None, label="features (可选,留空尝试自动推断)")
93
+ btn = gr.Button("预测")
94
+ out_text = gr.Textbox(label="单样本预测输出")
95
+ out_json = gr.JSON(label="批量预测结果 (labels & probs)")
96
+
97
+ def run_predict(file, text, n_timesteps, n_features):
98
+ if file is not None:
99
+ return "CSV 预测完成", predict_csv(file, n_timesteps, n_features)
100
+ if text:
101
+ return predict_text(text, n_timesteps, n_features), {}
102
+ return "请提供 CSV 或特征文本", {}
103
+
104
+ btn.click(run_predict, inputs=[file_in, text_in, n_ts, n_feat], outputs=[out_text, out_json])
105
+
106
+ if __name__ == '__main__':
107
+ demo.launch(server_name='0.0.0.0', server_port=7860)