surahj
commited on
Commit
·
64f3974
0
Parent(s):
Initial commit
Browse files- .gitignore +277 -0
- README.md +314 -0
- pytest.ini +18 -0
- requirements.txt +7 -0
- run_tests.py +118 -0
- src/__init__.py +1 -0
- src/app.py +373 -0
- src/data_generator.py +164 -0
- src/model.py +283 -0
- tests/__init__.py +1 -0
- tests/test_app.py +355 -0
- tests/test_data_generator.py +278 -0
- tests/test_integration.py +308 -0
- tests/test_model.py +359 -0
.gitignore
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
share/python-wheels/
|
| 20 |
+
*.egg-info/
|
| 21 |
+
.installed.cfg
|
| 22 |
+
*.egg
|
| 23 |
+
MANIFEST
|
| 24 |
+
|
| 25 |
+
# Virtual environments
|
| 26 |
+
myenv/
|
| 27 |
+
venv/
|
| 28 |
+
env/
|
| 29 |
+
ENV/
|
| 30 |
+
env.bak/
|
| 31 |
+
venv.bak/
|
| 32 |
+
.venv/
|
| 33 |
+
.env/
|
| 34 |
+
|
| 35 |
+
# PyInstaller
|
| 36 |
+
*.manifest
|
| 37 |
+
*.spec
|
| 38 |
+
|
| 39 |
+
# Installer logs
|
| 40 |
+
pip-log.txt
|
| 41 |
+
pip-delete-this-directory.txt
|
| 42 |
+
|
| 43 |
+
# Unit test / coverage reports
|
| 44 |
+
htmlcov/
|
| 45 |
+
.tox/
|
| 46 |
+
.nox/
|
| 47 |
+
.coverage
|
| 48 |
+
.coverage.*
|
| 49 |
+
.cache
|
| 50 |
+
nosetests.xml
|
| 51 |
+
coverage.xml
|
| 52 |
+
*.cover
|
| 53 |
+
*.py,cover
|
| 54 |
+
.hypothesis/
|
| 55 |
+
.pytest_cache/
|
| 56 |
+
cover/
|
| 57 |
+
|
| 58 |
+
# Translations
|
| 59 |
+
*.mo
|
| 60 |
+
*.pot
|
| 61 |
+
|
| 62 |
+
# Django stuff:
|
| 63 |
+
*.log
|
| 64 |
+
local_settings.py
|
| 65 |
+
db.sqlite3
|
| 66 |
+
db.sqlite3-journal
|
| 67 |
+
|
| 68 |
+
# Flask stuff:
|
| 69 |
+
instance/
|
| 70 |
+
.webassets-cache
|
| 71 |
+
|
| 72 |
+
# Scrapy stuff:
|
| 73 |
+
.scrapy
|
| 74 |
+
|
| 75 |
+
# Sphinx documentation
|
| 76 |
+
docs/_build/
|
| 77 |
+
|
| 78 |
+
# PyBuilder
|
| 79 |
+
.pybuilder/
|
| 80 |
+
target/
|
| 81 |
+
|
| 82 |
+
# Jupyter Notebook
|
| 83 |
+
.ipynb_checkpoints
|
| 84 |
+
|
| 85 |
+
# IPython
|
| 86 |
+
profile_default/
|
| 87 |
+
ipython_config.py
|
| 88 |
+
|
| 89 |
+
# pyenv
|
| 90 |
+
.python-version
|
| 91 |
+
|
| 92 |
+
# pipenv
|
| 93 |
+
Pipfile.lock
|
| 94 |
+
|
| 95 |
+
# poetry
|
| 96 |
+
poetry.lock
|
| 97 |
+
|
| 98 |
+
# pdm
|
| 99 |
+
.pdm.toml
|
| 100 |
+
|
| 101 |
+
# PEP 582
|
| 102 |
+
__pypackages__/
|
| 103 |
+
|
| 104 |
+
# Celery stuff
|
| 105 |
+
celerybeat-schedule
|
| 106 |
+
celerybeat.pid
|
| 107 |
+
|
| 108 |
+
# SageMath parsed files
|
| 109 |
+
*.sage.py
|
| 110 |
+
|
| 111 |
+
# Environments
|
| 112 |
+
.env
|
| 113 |
+
.venv
|
| 114 |
+
env/
|
| 115 |
+
venv/
|
| 116 |
+
ENV/
|
| 117 |
+
env.bak/
|
| 118 |
+
venv.bak/
|
| 119 |
+
|
| 120 |
+
# Spyder project settings
|
| 121 |
+
.spyderproject
|
| 122 |
+
.spyproject
|
| 123 |
+
|
| 124 |
+
# Rope project settings
|
| 125 |
+
.ropeproject
|
| 126 |
+
|
| 127 |
+
# mkdocs documentation
|
| 128 |
+
/site
|
| 129 |
+
|
| 130 |
+
# mypy
|
| 131 |
+
.mypy_cache/
|
| 132 |
+
.dmypy.json
|
| 133 |
+
dmypy.json
|
| 134 |
+
|
| 135 |
+
# Pyre type checker
|
| 136 |
+
.pyre/
|
| 137 |
+
|
| 138 |
+
# pytype static type analyzer
|
| 139 |
+
.pytype/
|
| 140 |
+
|
| 141 |
+
# Cython debug symbols
|
| 142 |
+
cython_debug/
|
| 143 |
+
|
| 144 |
+
# PyCharm
|
| 145 |
+
.idea/
|
| 146 |
+
*.iws
|
| 147 |
+
*.iml
|
| 148 |
+
*.ipr
|
| 149 |
+
|
| 150 |
+
# VS Code
|
| 151 |
+
.vscode/
|
| 152 |
+
*.code-workspace
|
| 153 |
+
|
| 154 |
+
# Sublime Text
|
| 155 |
+
*.sublime-project
|
| 156 |
+
*.sublime-workspace
|
| 157 |
+
|
| 158 |
+
# Vim
|
| 159 |
+
*.swp
|
| 160 |
+
*.swo
|
| 161 |
+
*~
|
| 162 |
+
|
| 163 |
+
# Emacs
|
| 164 |
+
*~
|
| 165 |
+
\#*\#
|
| 166 |
+
/.emacs.desktop
|
| 167 |
+
/.emacs.desktop.lock
|
| 168 |
+
*.elc
|
| 169 |
+
auto-save-list
|
| 170 |
+
tramp
|
| 171 |
+
.\#*
|
| 172 |
+
|
| 173 |
+
# macOS
|
| 174 |
+
.DS_Store
|
| 175 |
+
.AppleDouble
|
| 176 |
+
.LSOverride
|
| 177 |
+
Icon
|
| 178 |
+
._*
|
| 179 |
+
.DocumentRevisions-V100
|
| 180 |
+
.fseventsd
|
| 181 |
+
.Spotlight-V100
|
| 182 |
+
.TemporaryItems
|
| 183 |
+
.Trashes
|
| 184 |
+
.VolumeIcon.icns
|
| 185 |
+
.com.apple.timemachine.donotpresent
|
| 186 |
+
.AppleDB
|
| 187 |
+
.AppleDesktop
|
| 188 |
+
Network Trash Folder
|
| 189 |
+
Temporary Items
|
| 190 |
+
.apdisk
|
| 191 |
+
|
| 192 |
+
# Windows
|
| 193 |
+
Thumbs.db
|
| 194 |
+
Thumbs.db:encryptable
|
| 195 |
+
ehthumbs.db
|
| 196 |
+
ehthumbs_vista.db
|
| 197 |
+
*.tmp
|
| 198 |
+
*.temp
|
| 199 |
+
Desktop.ini
|
| 200 |
+
$RECYCLE.BIN/
|
| 201 |
+
*.cab
|
| 202 |
+
*.msi
|
| 203 |
+
*.msix
|
| 204 |
+
*.msm
|
| 205 |
+
*.msp
|
| 206 |
+
*.lnk
|
| 207 |
+
|
| 208 |
+
# Linux
|
| 209 |
+
*~
|
| 210 |
+
.fuse_hidden*
|
| 211 |
+
.directory
|
| 212 |
+
.Trash-*
|
| 213 |
+
.nfs*
|
| 214 |
+
|
| 215 |
+
# Machine Learning specific
|
| 216 |
+
*.joblib
|
| 217 |
+
*.pkl
|
| 218 |
+
*.pickle
|
| 219 |
+
*.h5
|
| 220 |
+
*.hdf5
|
| 221 |
+
*.model
|
| 222 |
+
*.weights
|
| 223 |
+
*.ckpt
|
| 224 |
+
*.pt
|
| 225 |
+
*.pth
|
| 226 |
+
*.onnx
|
| 227 |
+
*.tflite
|
| 228 |
+
*.pb
|
| 229 |
+
*.savedmodel/
|
| 230 |
+
checkpoints/
|
| 231 |
+
models/
|
| 232 |
+
logs/
|
| 233 |
+
runs/
|
| 234 |
+
wandb/
|
| 235 |
+
mlruns/
|
| 236 |
+
.mlflow/
|
| 237 |
+
|
| 238 |
+
# Data files (uncomment if you don't want to track data)
|
| 239 |
+
# *.csv
|
| 240 |
+
# *.json
|
| 241 |
+
# *.xml
|
| 242 |
+
# *.xlsx
|
| 243 |
+
# *.xls
|
| 244 |
+
# *.parquet
|
| 245 |
+
# *.feather
|
| 246 |
+
# *.hdf
|
| 247 |
+
# *.h5
|
| 248 |
+
|
| 249 |
+
# Temporary files
|
| 250 |
+
*.tmp
|
| 251 |
+
*.temp
|
| 252 |
+
*.bak
|
| 253 |
+
*.backup
|
| 254 |
+
*.old
|
| 255 |
+
*.orig
|
| 256 |
+
|
| 257 |
+
# Logs
|
| 258 |
+
*.log
|
| 259 |
+
logs/
|
| 260 |
+
log/
|
| 261 |
+
|
| 262 |
+
# Configuration files with sensitive data
|
| 263 |
+
config.ini
|
| 264 |
+
secrets.json
|
| 265 |
+
.env.local
|
| 266 |
+
.env.production
|
| 267 |
+
.env.staging
|
| 268 |
+
|
| 269 |
+
# OS generated files
|
| 270 |
+
.DS_Store
|
| 271 |
+
.DS_Store?
|
| 272 |
+
._*
|
| 273 |
+
.Spotlight-V100
|
| 274 |
+
.Trashes
|
| 275 |
+
ehthumbs.db
|
| 276 |
+
Thumbs.db
|
| 277 |
+
|
README.md
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Daily Household Electricity Consumption Predictor
|
| 2 |
+
|
| 3 |
+
A web-based application designed to help Nigerian households estimate their daily electricity usage in Kilowatt-hours (kWh). This project serves as a practical learning vehicle for Machine Learning Operations (MLOps), covering the full lifecycle from data preparation and model training to deployment, monitoring, and continuous improvement.
|
| 4 |
+
|
| 5 |
+
## 🎯 Project Goals
|
| 6 |
+
|
| 7 |
+
### Business Goals
|
| 8 |
+
|
| 9 |
+
- **Empower Households**: Provide users with a simple, accessible tool to understand and predict their daily electricity consumption
|
| 10 |
+
- **Promote Energy Awareness**: Help users identify factors influencing their electricity usage, encouraging more efficient energy habits
|
| 11 |
+
- **Inform Budgeting**: Enable users to better estimate their electricity bills, reducing financial surprises
|
| 12 |
+
- **Foundational MLOps Learning**: Serve as a concrete project to apply and understand core MLOps principles
|
| 13 |
+
|
| 14 |
+
### Machine Learning & Technical Goals
|
| 15 |
+
|
| 16 |
+
- **Accurate Prediction**: Develop a regression model capable of predicting daily kWh consumption with acceptable accuracy
|
| 17 |
+
- **User-Friendly Interface**: Create an intuitive web interface that allows easy input of features and clear display of predictions
|
| 18 |
+
- **Deployable Application**: Build a self-contained application that can be deployed to a public platform
|
| 19 |
+
- **MLOps Readiness**: Design the application with modularity and best practices that facilitate future MLOps implementation
|
| 20 |
+
|
| 21 |
+
## 🏗️ Project Structure
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
lin-re-model/
|
| 25 |
+
├── src/
|
| 26 |
+
│ ├── __init__.py
|
| 27 |
+
│ ├── data_generator.py # Synthetic data generation
|
| 28 |
+
│ ├── model.py # ML model training and prediction
|
| 29 |
+
│ └── app.py # Gradio web interface
|
| 30 |
+
├── tests/
|
| 31 |
+
│ ├── __init__.py
|
| 32 |
+
│ ├── test_data_generator.py # Data generator tests
|
| 33 |
+
│ ├── test_model.py # Model tests
|
| 34 |
+
│ ├── test_app.py # Application tests
|
| 35 |
+
│ └── test_integration.py # Integration tests
|
| 36 |
+
├── requirements.txt # Python dependencies
|
| 37 |
+
├── pytest.ini # Pytest configuration
|
| 38 |
+
├── run_tests.py # Test runner script
|
| 39 |
+
└── README.md # This file
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## 🚀 Quick Start
|
| 43 |
+
|
| 44 |
+
### Prerequisites
|
| 45 |
+
|
| 46 |
+
- Python 3.8 or higher
|
| 47 |
+
- pip (Python package installer)
|
| 48 |
+
|
| 49 |
+
### Installation
|
| 50 |
+
|
| 51 |
+
1. **Clone the repository** (if not already done):
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
git clone <repository-url>
|
| 55 |
+
cd lin-re-model
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
2. **Install dependencies**:
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
pip install -r requirements.txt
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
3. **Run the application**:
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
python src/app.py
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
4. **Open your browser** and navigate to `http://localhost:7860`
|
| 71 |
+
|
| 72 |
+
## 🧪 Testing
|
| 73 |
+
|
| 74 |
+
This project includes comprehensive tests to ensure code quality and functionality. The test suite covers:
|
| 75 |
+
|
| 76 |
+
- **Unit Tests**: Individual component testing
|
| 77 |
+
- **Integration Tests**: End-to-end workflow testing
|
| 78 |
+
- **Data Quality Tests**: Validation of synthetic data generation
|
| 79 |
+
- **Model Performance Tests**: Verification of model accuracy and consistency
|
| 80 |
+
|
| 81 |
+
### Running Tests
|
| 82 |
+
|
| 83 |
+
#### Option 1: Using the test runner script
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
# Run all tests with coverage
|
| 87 |
+
python run_tests.py
|
| 88 |
+
|
| 89 |
+
# Run only unit tests
|
| 90 |
+
python run_tests.py unit
|
| 91 |
+
|
| 92 |
+
# Run only integration tests
|
| 93 |
+
python run_tests.py integration
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
#### Option 2: Using pytest directly
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
# Run all tests
|
| 100 |
+
pytest
|
| 101 |
+
|
| 102 |
+
# Run with verbose output
|
| 103 |
+
pytest -v
|
| 104 |
+
|
| 105 |
+
# Run with coverage report
|
| 106 |
+
pytest --cov=src --cov-report=html
|
| 107 |
+
|
| 108 |
+
# Run specific test file
|
| 109 |
+
pytest tests/test_model.py
|
| 110 |
+
|
| 111 |
+
# Run specific test class
|
| 112 |
+
pytest tests/test_model.py::TestElectricityConsumptionModel
|
| 113 |
+
|
| 114 |
+
# Run specific test method
|
| 115 |
+
pytest tests/test_model.py::TestElectricityConsumptionModel::test_train_model
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### Test Coverage
|
| 119 |
+
|
| 120 |
+
The test suite provides comprehensive coverage including:
|
| 121 |
+
|
| 122 |
+
- **Data Generator Tests**:
|
| 123 |
+
|
| 124 |
+
- Data generation with different parameters
|
| 125 |
+
- Data splitting functionality
|
| 126 |
+
- Data persistence (save/load)
|
| 127 |
+
- Data quality validation
|
| 128 |
+
- Reproducibility checks
|
| 129 |
+
|
| 130 |
+
- **Model Tests**:
|
| 131 |
+
|
| 132 |
+
- Model initialization and training
|
| 133 |
+
- Feature preparation and validation
|
| 134 |
+
- Prediction functionality
|
| 135 |
+
- Model evaluation metrics
|
| 136 |
+
- Model persistence (save/load)
|
| 137 |
+
- Error handling
|
| 138 |
+
|
| 139 |
+
- **Application Tests**:
|
| 140 |
+
|
| 141 |
+
- Web interface functionality
|
| 142 |
+
- User interaction flows
|
| 143 |
+
- Error handling in UI
|
| 144 |
+
- State management
|
| 145 |
+
|
| 146 |
+
- **Integration Tests**:
|
| 147 |
+
- Complete workflow testing
|
| 148 |
+
- End-to-end functionality
|
| 149 |
+
- Performance consistency
|
| 150 |
+
- Data quality across components
|
| 151 |
+
|
| 152 |
+
### Expected Test Results
|
| 153 |
+
|
| 154 |
+
When all tests pass, you should see output similar to:
|
| 155 |
+
|
| 156 |
+
```
|
| 157 |
+
🧪 Running Daily Household Electricity Consumption Predictor Tests
|
| 158 |
+
======================================================================
|
| 159 |
+
============================= test session starts ==============================
|
| 160 |
+
platform linux -- Python 3.8.x, pytest-7.4.0, pluggy-1.0.0
|
| 161 |
+
rootdir: /path/to/lin-re-model
|
| 162 |
+
plugins: cov-4.1.0
|
| 163 |
+
collected 45 tests
|
| 164 |
+
|
| 165 |
+
tests/test_app.py ................... [ 42%]
|
| 166 |
+
tests/test_data_generator.py ................... [ 78%]
|
| 167 |
+
tests/test_integration.py .......... [100%]
|
| 168 |
+
|
| 169 |
+
---------- coverage: platform linux, python 3.8.x-final-0 -----------
|
| 170 |
+
Name Stmts Miss Cover Missing
|
| 171 |
+
------------------------------------------------------------
|
| 172 |
+
src/__init__.py 1 0 100%
|
| 173 |
+
src/app.py 180 5 97% 180-185
|
| 174 |
+
src/data_generator.py 95 2 98% 95-97
|
| 175 |
+
src/model.py 180 8 96% 180-188
|
| 176 |
+
------------------------------------------------------------
|
| 177 |
+
TOTAL 456 15 97%
|
| 178 |
+
|
| 179 |
+
============================== 45 passed in 5.23s ==============================
|
| 180 |
+
|
| 181 |
+
✅ All tests passed!
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
## 📊 Model Features
|
| 185 |
+
|
| 186 |
+
The electricity consumption prediction model uses the following features:
|
| 187 |
+
|
| 188 |
+
1. **Average Daily Temperature** (°C): Numerical input (15-35°C range)
|
| 189 |
+
2. **Day of the Week**: Categorical input (Monday through Sunday)
|
| 190 |
+
3. **Major Event**: Boolean input (Holiday, Power Outage, etc.)
|
| 191 |
+
|
| 192 |
+
### Model Algorithm
|
| 193 |
+
|
| 194 |
+
- **Algorithm**: Linear Regression
|
| 195 |
+
- **Preprocessing**: StandardScaler for numerical features, OneHotEncoder for categorical features
|
| 196 |
+
- **Evaluation Metrics**: MSE, RMSE, MAE, R²
|
| 197 |
+
|
| 198 |
+
## 🎮 Using the Application
|
| 199 |
+
|
| 200 |
+
### Step 1: Generate Data & Train Model
|
| 201 |
+
|
| 202 |
+
1. Navigate to the "Data Generation & Training" tab
|
| 203 |
+
2. Adjust parameters as desired:
|
| 204 |
+
- Number of Data Points (100-5000)
|
| 205 |
+
- Noise Level (0.01-0.5)
|
| 206 |
+
- Training/Validation/Test Set Proportions
|
| 207 |
+
3. Click "Generate Data & Train Model"
|
| 208 |
+
4. Review the training metrics and evaluation results
|
| 209 |
+
|
| 210 |
+
### Step 2: Make Predictions
|
| 211 |
+
|
| 212 |
+
1. Navigate to the "Prediction" tab
|
| 213 |
+
2. Enter your parameters:
|
| 214 |
+
- Average Daily Temperature (15-35°C)
|
| 215 |
+
- Day of the Week
|
| 216 |
+
- Major Event (checkbox)
|
| 217 |
+
3. Click "Predict Consumption"
|
| 218 |
+
4. View your estimated daily electricity consumption
|
| 219 |
+
|
| 220 |
+
### Step 3: Understand the Model
|
| 221 |
+
|
| 222 |
+
1. Navigate to the "Model Information" tab
|
| 223 |
+
2. Click "Show Model Information"
|
| 224 |
+
3. Review feature coefficients and model interpretation
|
| 225 |
+
|
| 226 |
+
## 🔧 Development
|
| 227 |
+
|
| 228 |
+
### Adding New Tests
|
| 229 |
+
|
| 230 |
+
To add new tests:
|
| 231 |
+
|
| 232 |
+
1. **Unit Tests**: Add to appropriate test file in `tests/`
|
| 233 |
+
2. **Integration Tests**: Add to `tests/test_integration.py`
|
| 234 |
+
3. **Follow naming convention**: `test_<functionality>`
|
| 235 |
+
4. **Use descriptive docstrings**: Explain what the test validates
|
| 236 |
+
|
| 237 |
+
### Test Best Practices
|
| 238 |
+
|
| 239 |
+
- **Isolation**: Each test should be independent
|
| 240 |
+
- **Descriptive names**: Test names should clearly indicate what they test
|
| 241 |
+
- **Assertions**: Use specific assertions with meaningful messages
|
| 242 |
+
- **Coverage**: Aim for high test coverage (>95%)
|
| 243 |
+
- **Performance**: Tests should run quickly (<10 seconds total)
|
| 244 |
+
|
| 245 |
+
### Running Tests in Development
|
| 246 |
+
|
| 247 |
+
During development, you can run tests in different ways:
|
| 248 |
+
|
| 249 |
+
```bash
|
| 250 |
+
# Quick test run (no coverage)
|
| 251 |
+
pytest -x # Stop on first failure
|
| 252 |
+
|
| 253 |
+
# Run tests in parallel (if pytest-xdist installed)
|
| 254 |
+
pytest -n auto
|
| 255 |
+
|
| 256 |
+
# Run tests with detailed output
|
| 257 |
+
pytest -v -s
|
| 258 |
+
|
| 259 |
+
# Run tests and watch for changes
|
| 260 |
+
pytest-watch # Requires pytest-watch package
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
## 🚀 Deployment
|
| 264 |
+
|
| 265 |
+
### Local Deployment
|
| 266 |
+
|
| 267 |
+
```bash
|
| 268 |
+
python src/app.py
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
### Hugging Face Spaces Deployment
|
| 272 |
+
|
| 273 |
+
1. Create a new Space on Hugging Face
|
| 274 |
+
2. Upload the project files
|
| 275 |
+
3. Configure the Space to run `python src/app.py`
|
| 276 |
+
4. The application will be available at your Space URL
|
| 277 |
+
|
| 278 |
+
## 📈 Future Enhancements
|
| 279 |
+
|
| 280 |
+
### MLOps Features (Future Phases)
|
| 281 |
+
|
| 282 |
+
- **Data Versioning**: Implement DVC for data version control
|
| 283 |
+
- **Experiment Tracking**: Integrate MLflow or Weights & Biases
|
| 284 |
+
- **Model Registry**: Use MLflow Model Registry for model lifecycle management
|
| 285 |
+
- **Containerization**: Create Dockerfile for reproducible environments
|
| 286 |
+
- **CI/CD**: Set up GitHub Actions for automated testing and deployment
|
| 287 |
+
- **Model Monitoring**: Implement monitoring for data drift and performance degradation
|
| 288 |
+
- **Continuous Training**: Define triggers for automated retraining
|
| 289 |
+
|
| 290 |
+
### Model Improvements
|
| 291 |
+
|
| 292 |
+
- **Feature Engineering**: Add more complex features (historical averages, time of day, etc.)
|
| 293 |
+
- **Advanced Models**: Experiment with Random Forest, Gradient Boosting, etc.
|
| 294 |
+
- **Hyperparameter Tuning**: Implement automated hyperparameter optimization
|
| 295 |
+
- **Ensemble Methods**: Combine multiple models for better predictions
|
| 296 |
+
|
| 297 |
+
## 🤝 Contributing
|
| 298 |
+
|
| 299 |
+
1. Fork the repository
|
| 300 |
+
2. Create a feature branch
|
| 301 |
+
3. Make your changes
|
| 302 |
+
4. Add tests for new functionality
|
| 303 |
+
5. Ensure all tests pass
|
| 304 |
+
6. Submit a pull request
|
| 305 |
+
|
| 306 |
+
## 📄 License
|
| 307 |
+
|
| 308 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
| 309 |
+
|
| 310 |
+
## 🙏 Acknowledgments
|
| 311 |
+
|
| 312 |
+
- Gradio team for the excellent web interface framework
|
| 313 |
+
- Scikit-learn team for the machine learning library
|
| 314 |
+
- The MLOps community for best practices and guidance
|
pytest.ini
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool:pytest]
|
| 2 |
+
testpaths = tests
|
| 3 |
+
python_files = test_*.py
|
| 4 |
+
python_classes = Test*
|
| 5 |
+
python_functions = test_*
|
| 6 |
+
addopts =
|
| 7 |
+
-v
|
| 8 |
+
--tb=short
|
| 9 |
+
--strict-markers
|
| 10 |
+
--disable-warnings
|
| 11 |
+
--cov=src
|
| 12 |
+
--cov-report=term-missing
|
| 13 |
+
--cov-report=html:htmlcov
|
| 14 |
+
--cov-report=xml
|
| 15 |
+
markers =
|
| 16 |
+
unit: Unit tests
|
| 17 |
+
integration: Integration tests
|
| 18 |
+
slow: Slow running tests
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
scikit-learn==1.3.0
|
| 2 |
+
pandas==2.0.3
|
| 3 |
+
numpy==1.24.3
|
| 4 |
+
gradio==3.40.1
|
| 5 |
+
pytest==7.4.0
|
| 6 |
+
pytest-cov==4.1.0
|
| 7 |
+
joblib==1.3.2
|
run_tests.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test runner script for the Daily Household Electricity Consumption Predictor.
|
| 4 |
+
|
| 5 |
+
This script runs all tests and provides a summary of results.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import subprocess
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def run_tests():
|
| 15 |
+
"""Run all tests and return the result."""
|
| 16 |
+
print("🧪 Running Daily Household Electricity Consumption Predictor Tests")
|
| 17 |
+
print("=" * 70)
|
| 18 |
+
|
| 19 |
+
# Change to project root directory
|
| 20 |
+
project_root = Path(__file__).parent
|
| 21 |
+
os.chdir(project_root)
|
| 22 |
+
|
| 23 |
+
# Run pytest with coverage
|
| 24 |
+
cmd = [
|
| 25 |
+
sys.executable,
|
| 26 |
+
"-m",
|
| 27 |
+
"pytest",
|
| 28 |
+
"--verbose",
|
| 29 |
+
"--tb=short",
|
| 30 |
+
"--cov=src",
|
| 31 |
+
"--cov-report=term-missing",
|
| 32 |
+
"--cov-report=html:htmlcov",
|
| 33 |
+
"--cov-report=xml",
|
| 34 |
+
"tests/",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
result = subprocess.run(cmd, capture_output=False, text=True)
|
| 39 |
+
return result.returncode == 0
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"❌ Error running tests: {e}")
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def run_unit_tests():
|
| 46 |
+
"""Run only unit tests."""
|
| 47 |
+
print("🧪 Running Unit Tests")
|
| 48 |
+
print("=" * 40)
|
| 49 |
+
|
| 50 |
+
cmd = [
|
| 51 |
+
sys.executable,
|
| 52 |
+
"-m",
|
| 53 |
+
"pytest",
|
| 54 |
+
"--verbose",
|
| 55 |
+
"--tb=short",
|
| 56 |
+
"-m",
|
| 57 |
+
"unit",
|
| 58 |
+
"tests/",
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
result = subprocess.run(cmd, capture_output=False, text=True)
|
| 63 |
+
return result.returncode == 0
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"❌ Error running unit tests: {e}")
|
| 66 |
+
return False
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def run_integration_tests():
|
| 70 |
+
"""Run only integration tests."""
|
| 71 |
+
print("🧪 Running Integration Tests")
|
| 72 |
+
print("=" * 40)
|
| 73 |
+
|
| 74 |
+
cmd = [
|
| 75 |
+
sys.executable,
|
| 76 |
+
"-m",
|
| 77 |
+
"pytest",
|
| 78 |
+
"--verbose",
|
| 79 |
+
"--tb=short",
|
| 80 |
+
"-m",
|
| 81 |
+
"integration",
|
| 82 |
+
"tests/",
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
result = subprocess.run(cmd, capture_output=False, text=True)
|
| 87 |
+
return result.returncode == 0
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"❌ Error running integration tests: {e}")
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def main():
|
| 94 |
+
"""Main function to run tests based on command line arguments."""
|
| 95 |
+
if len(sys.argv) > 1:
|
| 96 |
+
test_type = sys.argv[1].lower()
|
| 97 |
+
|
| 98 |
+
if test_type == "unit":
|
| 99 |
+
success = run_unit_tests()
|
| 100 |
+
elif test_type == "integration":
|
| 101 |
+
success = run_integration_tests()
|
| 102 |
+
else:
|
| 103 |
+
print(f"❌ Unknown test type: {test_type}")
|
| 104 |
+
print("Available options: unit, integration, all (default)")
|
| 105 |
+
return 1
|
| 106 |
+
else:
|
| 107 |
+
success = run_tests()
|
| 108 |
+
|
| 109 |
+
if success:
|
| 110 |
+
print("\n✅ All tests passed!")
|
| 111 |
+
return 0
|
| 112 |
+
else:
|
| 113 |
+
print("\n❌ Some tests failed!")
|
| 114 |
+
return 1
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
sys.exit(main())
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Daily Household Electricity Consumption Predictor
|
src/app.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio Web Application for Daily Household Electricity Consumption Predictor
|
| 3 |
+
|
| 4 |
+
This module provides a user-friendly web interface for the electricity consumption
|
| 5 |
+
prediction model using Gradio.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import Tuple, Dict, Any
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
# Add src to path for imports
|
| 16 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
| 17 |
+
|
| 18 |
+
from src.data_generator import DataGenerator
|
| 19 |
+
from src.model import ElectricityConsumptionModel
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ElectricityPredictorApp:
|
| 23 |
+
"""Gradio application for electricity consumption prediction."""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
"""Initialize the application with model and data generator."""
|
| 27 |
+
self.data_generator = DataGenerator(seed=42)
|
| 28 |
+
self.model = ElectricityConsumptionModel()
|
| 29 |
+
self.is_model_trained = False
|
| 30 |
+
|
| 31 |
+
def generate_and_train(
|
| 32 |
+
self,
|
| 33 |
+
n_samples: int,
|
| 34 |
+
noise_level: float,
|
| 35 |
+
train_size: float,
|
| 36 |
+
val_size: float,
|
| 37 |
+
test_size: float,
|
| 38 |
+
) -> Tuple[str, str, str]:
|
| 39 |
+
"""
|
| 40 |
+
Generate synthetic data and train the model.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
n_samples: Number of data points to generate
|
| 44 |
+
noise_level: Level of noise in the data
|
| 45 |
+
train_size: Proportion for training set
|
| 46 |
+
val_size: Proportion for validation set
|
| 47 |
+
test_size: Proportion for test set
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Tuple of (data_info, training_metrics, evaluation_metrics)
|
| 51 |
+
"""
|
| 52 |
+
try:
|
| 53 |
+
# Generate data
|
| 54 |
+
data = self.data_generator.generate_data(n_samples, noise_level)
|
| 55 |
+
|
| 56 |
+
# Split data
|
| 57 |
+
train_data, val_data, test_data = self.data_generator.split_data(
|
| 58 |
+
data, train_size, val_size, test_size
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Store data for later use
|
| 62 |
+
self.train_data = train_data
|
| 63 |
+
self.val_data = val_data
|
| 64 |
+
self.test_data = test_data
|
| 65 |
+
|
| 66 |
+
# Train model
|
| 67 |
+
X_train = train_data.drop("consumption_kwh", axis=1)
|
| 68 |
+
y_train = train_data[["consumption_kwh"]]
|
| 69 |
+
|
| 70 |
+
training_metrics = self.model.train(X_train, y_train)
|
| 71 |
+
|
| 72 |
+
# Evaluate model
|
| 73 |
+
X_test = test_data.drop("consumption_kwh", axis=1)
|
| 74 |
+
y_test = test_data[["consumption_kwh"]]
|
| 75 |
+
|
| 76 |
+
evaluation_metrics = self.model.evaluate(X_test, y_test)
|
| 77 |
+
|
| 78 |
+
self.is_model_trained = True
|
| 79 |
+
|
| 80 |
+
# Format output strings
|
| 81 |
+
data_info = f"""
|
| 82 |
+
**Data Generated Successfully!**
|
| 83 |
+
|
| 84 |
+
- Total samples: {len(data)}
|
| 85 |
+
- Training samples: {len(train_data)}
|
| 86 |
+
- Validation samples: {len(val_data)}
|
| 87 |
+
- Test samples: {len(test_data)}
|
| 88 |
+
|
| 89 |
+
**Data Statistics:**
|
| 90 |
+
- Temperature range: {data['temperature'].min():.1f}°C - {data['temperature'].max():.1f}°C
|
| 91 |
+
- Consumption range: {data['consumption_kwh'].min():.1f} - {data['consumption_kwh'].max():.1f} kWh
|
| 92 |
+
- Average consumption: {data['consumption_kwh'].mean():.1f} kWh
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
training_metrics_str = f"""
|
| 96 |
+
**Training Metrics:**
|
| 97 |
+
- Mean Squared Error (MSE): {training_metrics['train_mse']:.4f}
|
| 98 |
+
- Root Mean Squared Error (RMSE): {training_metrics['train_rmse']:.4f}
|
| 99 |
+
- Mean Absolute Error (MAE): {training_metrics['train_mae']:.4f}
|
| 100 |
+
- R-squared (R²): {training_metrics['train_r2']:.4f}
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
evaluation_metrics_str = f"""
|
| 104 |
+
**Test Set Evaluation:**
|
| 105 |
+
- Mean Squared Error (MSE): {evaluation_metrics['test_mse']:.4f}
|
| 106 |
+
- Root Mean Squared Error (RMSE): {evaluation_metrics['test_rmse']:.4f}
|
| 107 |
+
- Mean Absolute Error (MAE): {evaluation_metrics['test_mae']:.4f}
|
| 108 |
+
- R-squared (R²): {evaluation_metrics['test_r2']:.4f}
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
return data_info, training_metrics_str, evaluation_metrics_str
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
error_msg = f"Error during data generation and training: {str(e)}"
|
| 115 |
+
return error_msg, "", ""
|
| 116 |
+
|
| 117 |
+
def predict_consumption(
|
| 118 |
+
self, temperature: float, day_of_week: str, major_event: bool
|
| 119 |
+
) -> str:
|
| 120 |
+
"""
|
| 121 |
+
Make a prediction for electricity consumption.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
temperature: Average daily temperature in Celsius
|
| 125 |
+
day_of_week: Day of the week
|
| 126 |
+
major_event: Whether there's a major event
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Formatted prediction result
|
| 130 |
+
"""
|
| 131 |
+
if not self.is_model_trained:
|
| 132 |
+
return "**Error:** Model must be trained first. Please generate data and train the model."
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
# Convert boolean to int
|
| 136 |
+
major_event_int = 1 if major_event else 0
|
| 137 |
+
|
| 138 |
+
# Make prediction
|
| 139 |
+
prediction = self.model.predict(temperature, day_of_week, major_event_int)
|
| 140 |
+
|
| 141 |
+
# Get model coefficients for explanation
|
| 142 |
+
coefficients = self.model.get_model_coefficients()
|
| 143 |
+
|
| 144 |
+
# Format result
|
| 145 |
+
result = f"""
|
| 146 |
+
**Prediction Result:**
|
| 147 |
+
|
| 148 |
+
**Estimated Daily Electricity Consumption: {prediction:.1f} kWh**
|
| 149 |
+
|
| 150 |
+
**Input Parameters:**
|
| 151 |
+
- Temperature: {temperature}°C
|
| 152 |
+
- Day of Week: {day_of_week}
|
| 153 |
+
- Major Event: {'Yes' if major_event else 'No'}
|
| 154 |
+
|
| 155 |
+
**Model Information:**
|
| 156 |
+
- Model Type: Linear Regression
|
| 157 |
+
- Intercept: {coefficients['intercept']:.4f}
|
| 158 |
+
- Number of Features: {len(coefficients['feature_names'])}
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
return result
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
return f"**Error during prediction:** {str(e)}"
|
| 165 |
+
|
| 166 |
+
def get_model_info(self) -> str:
|
| 167 |
+
"""
|
| 168 |
+
Get detailed information about the trained model.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Formatted model information
|
| 172 |
+
"""
|
| 173 |
+
if not self.is_model_trained:
|
| 174 |
+
return "**Error:** Model must be trained first."
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
coefficients = self.model.get_model_coefficients()
|
| 178 |
+
print(coefficients)
|
| 179 |
+
|
| 180 |
+
# Create feature importance table
|
| 181 |
+
feature_importance = []
|
| 182 |
+
for i, (feature, coef) in enumerate(
|
| 183 |
+
zip(coefficients["feature_names"], coefficients["coefficients"])
|
| 184 |
+
):
|
| 185 |
+
feature_importance.append(f"| {feature} | {coef:.4f} |")
|
| 186 |
+
|
| 187 |
+
feature_table = "\n".join(feature_importance)
|
| 188 |
+
|
| 189 |
+
info = f"""
|
| 190 |
+
**Model Information:**
|
| 191 |
+
|
| 192 |
+
**Model Type:** Linear Regression
|
| 193 |
+
|
| 194 |
+
**Intercept:** {coefficients['intercept']:.4f}
|
| 195 |
+
|
| 196 |
+
**Feature Coefficients:**
|
| 197 |
+
| Feature | Coefficient |
|
| 198 |
+
|---------|-------------|
|
| 199 |
+
{feature_table}
|
| 200 |
+
|
| 201 |
+
**Interpretation:**
|
| 202 |
+
- Positive coefficients increase predicted consumption
|
| 203 |
+
- Negative coefficients decrease predicted consumption
|
| 204 |
+
- Temperature coefficient shows how much consumption changes per degree Celsius
|
| 205 |
+
- Day coefficients show consumption differences compared to Monday (baseline)
|
| 206 |
+
- Major event coefficient shows additional consumption during events
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
return info
|
| 210 |
+
|
| 211 |
+
except Exception as e:
|
| 212 |
+
return f"**Error getting model info:** {str(e)}"
|
| 213 |
+
|
| 214 |
+
def create_interface(self) -> gr.Interface:
|
| 215 |
+
"""
|
| 216 |
+
Create the Gradio interface.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Gradio Interface object
|
| 220 |
+
"""
|
| 221 |
+
with gr.Blocks(
|
| 222 |
+
title="Daily Household Electricity Consumption Predictor"
|
| 223 |
+
) as interface:
|
| 224 |
+
gr.Markdown(
|
| 225 |
+
"""
|
| 226 |
+
# ⚡ Daily Household Electricity Consumption Predictor
|
| 227 |
+
|
| 228 |
+
This application helps Nigerian households estimate their daily electricity consumption
|
| 229 |
+
based on temperature, day of the week, and major events.
|
| 230 |
+
|
| 231 |
+
## How to Use:
|
| 232 |
+
1. **Generate Data & Train Model**: Click the button to generate synthetic data and train the model
|
| 233 |
+
2. **Make Predictions**: Enter your parameters and get consumption estimates
|
| 234 |
+
3. **View Model Info**: See how the model works and feature importance
|
| 235 |
+
"""
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
with gr.Tab("Data Generation & Training"):
|
| 239 |
+
gr.Markdown("### Step 1: Generate Synthetic Data and Train Model")
|
| 240 |
+
|
| 241 |
+
with gr.Row():
|
| 242 |
+
with gr.Column():
|
| 243 |
+
n_samples = gr.Slider(
|
| 244 |
+
minimum=100,
|
| 245 |
+
maximum=5000,
|
| 246 |
+
value=1000,
|
| 247 |
+
step=100,
|
| 248 |
+
label="Number of Data Points",
|
| 249 |
+
)
|
| 250 |
+
noise_level = gr.Slider(
|
| 251 |
+
minimum=0.01,
|
| 252 |
+
maximum=0.5,
|
| 253 |
+
value=0.1,
|
| 254 |
+
step=0.01,
|
| 255 |
+
label="Noise Level",
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
with gr.Column():
|
| 259 |
+
train_size = gr.Slider(
|
| 260 |
+
minimum=0.5,
|
| 261 |
+
maximum=0.9,
|
| 262 |
+
value=0.7,
|
| 263 |
+
step=0.05,
|
| 264 |
+
label="Training Set Proportion",
|
| 265 |
+
)
|
| 266 |
+
val_size = gr.Slider(
|
| 267 |
+
minimum=0.05,
|
| 268 |
+
maximum=0.3,
|
| 269 |
+
value=0.15,
|
| 270 |
+
step=0.05,
|
| 271 |
+
label="Validation Set Proportion",
|
| 272 |
+
)
|
| 273 |
+
test_size = gr.Slider(
|
| 274 |
+
minimum=0.05,
|
| 275 |
+
maximum=0.3,
|
| 276 |
+
value=0.15,
|
| 277 |
+
step=0.05,
|
| 278 |
+
label="Test Set Proportion",
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
train_button = gr.Button(
|
| 282 |
+
"Generate Data & Train Model", variant="primary"
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
with gr.Row():
|
| 286 |
+
data_info = gr.Markdown("**Data information will appear here...**")
|
| 287 |
+
|
| 288 |
+
with gr.Row():
|
| 289 |
+
training_metrics = gr.Markdown(
|
| 290 |
+
"**Training metrics will appear here...**"
|
| 291 |
+
)
|
| 292 |
+
evaluation_metrics = gr.Markdown(
|
| 293 |
+
"**Evaluation metrics will appear here...**"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
train_button.click(
|
| 297 |
+
fn=self.generate_and_train,
|
| 298 |
+
inputs=[n_samples, noise_level, train_size, val_size, test_size],
|
| 299 |
+
outputs=[data_info, training_metrics, evaluation_metrics],
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
with gr.Tab("Prediction"):
|
| 303 |
+
gr.Markdown("### Step 2: Predict Electricity Consumption")
|
| 304 |
+
|
| 305 |
+
with gr.Row():
|
| 306 |
+
with gr.Column():
|
| 307 |
+
temperature = gr.Slider(
|
| 308 |
+
minimum=15,
|
| 309 |
+
maximum=35,
|
| 310 |
+
value=25,
|
| 311 |
+
step=0.5,
|
| 312 |
+
label="Average Daily Temperature (°C)",
|
| 313 |
+
)
|
| 314 |
+
day_of_week = gr.Dropdown(
|
| 315 |
+
choices=[
|
| 316 |
+
"Monday",
|
| 317 |
+
"Tuesday",
|
| 318 |
+
"Wednesday",
|
| 319 |
+
"Thursday",
|
| 320 |
+
"Friday",
|
| 321 |
+
"Saturday",
|
| 322 |
+
"Sunday",
|
| 323 |
+
],
|
| 324 |
+
value="Monday",
|
| 325 |
+
label="Day of the Week",
|
| 326 |
+
)
|
| 327 |
+
major_event = gr.Checkbox(
|
| 328 |
+
label="Major Event (Holiday, Power Outage, etc.)",
|
| 329 |
+
value=False,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
with gr.Column():
|
| 333 |
+
predict_button = gr.Button(
|
| 334 |
+
"Predict Consumption", variant="primary"
|
| 335 |
+
)
|
| 336 |
+
prediction_result = gr.Markdown(
|
| 337 |
+
"**Prediction result will appear here...**"
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
predict_button.click(
|
| 341 |
+
fn=self.predict_consumption,
|
| 342 |
+
inputs=[temperature, day_of_week, major_event],
|
| 343 |
+
outputs=prediction_result,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
with gr.Tab("Model Information"):
|
| 347 |
+
gr.Markdown("### Step 3: Understand the Model")
|
| 348 |
+
|
| 349 |
+
info_button = gr.Button("Show Model Information", variant="secondary")
|
| 350 |
+
model_info = gr.Markdown("**Model information will appear here...**")
|
| 351 |
+
|
| 352 |
+
info_button.click(fn=self.get_model_info, inputs=[], outputs=model_info)
|
| 353 |
+
|
| 354 |
+
gr.Markdown(
|
| 355 |
+
"""
|
| 356 |
+
---
|
| 357 |
+
**Note:** This application uses synthetic data for demonstration purposes.
|
| 358 |
+
In a real-world scenario, you would use actual historical consumption data.
|
| 359 |
+
"""
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
return interface
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def main():
|
| 366 |
+
"""Main function to launch the application."""
|
| 367 |
+
app = ElectricityPredictorApp()
|
| 368 |
+
interface = app.create_interface()
|
| 369 |
+
interface.launch(share=False, server_name="0.0.0.0", server_port=7860)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
if __name__ == "__main__":
|
| 373 |
+
main()
|
src/data_generator.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Generator Module for Daily Household Electricity Consumption Predictor
|
| 3 |
+
|
| 4 |
+
This module generates synthetic data for training and testing the electricity consumption
|
| 5 |
+
prediction model. It creates realistic patterns based on temperature, day of week, and events.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from typing import Tuple, Optional
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DataGenerator:
|
| 15 |
+
"""Generates synthetic electricity consumption data for training and testing."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, seed: Optional[int] = 42):
|
| 18 |
+
"""
|
| 19 |
+
Initialize the data generator.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
seed: Random seed for reproducibility
|
| 23 |
+
"""
|
| 24 |
+
self.seed = seed
|
| 25 |
+
if seed is not None:
|
| 26 |
+
np.random.seed(seed)
|
| 27 |
+
random.seed(seed)
|
| 28 |
+
|
| 29 |
+
def generate_data(
|
| 30 |
+
self, n_samples: int = 1000, noise_level: float = 0.1
|
| 31 |
+
) -> pd.DataFrame:
|
| 32 |
+
"""
|
| 33 |
+
Generate synthetic electricity consumption data.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
n_samples: Number of data points to generate
|
| 37 |
+
noise_level: Level of noise to add to the data (0-1)
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
DataFrame with features and target variable
|
| 41 |
+
"""
|
| 42 |
+
# Generate features
|
| 43 |
+
temperatures = np.random.normal(25, 8, n_samples) # Mean 25°C, std 8°C
|
| 44 |
+
temperatures = np.clip(temperatures, 15, 35) # Clip to realistic range
|
| 45 |
+
|
| 46 |
+
days_of_week = np.random.choice(
|
| 47 |
+
[
|
| 48 |
+
"Monday",
|
| 49 |
+
"Tuesday",
|
| 50 |
+
"Wednesday",
|
| 51 |
+
"Thursday",
|
| 52 |
+
"Friday",
|
| 53 |
+
"Saturday",
|
| 54 |
+
"Sunday",
|
| 55 |
+
],
|
| 56 |
+
n_samples,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
major_events = np.random.choice(
|
| 60 |
+
[0, 1], n_samples, p=[0.9, 0.1]
|
| 61 |
+
) # 10% chance of event
|
| 62 |
+
|
| 63 |
+
# Create base consumption pattern
|
| 64 |
+
base_consumption = 15.0 # Base consumption in kWh
|
| 65 |
+
|
| 66 |
+
# Temperature effect (higher temp = higher consumption due to AC/fans)
|
| 67 |
+
temp_effect = 0.3 * (temperatures - 25)
|
| 68 |
+
|
| 69 |
+
# Day of week effect (weekends typically higher consumption)
|
| 70 |
+
day_effects = {
|
| 71 |
+
"Monday": 0.5,
|
| 72 |
+
"Tuesday": 0.3,
|
| 73 |
+
"Wednesday": 0.2,
|
| 74 |
+
"Thursday": 0.1,
|
| 75 |
+
"Friday": 0.8,
|
| 76 |
+
"Saturday": 1.5,
|
| 77 |
+
"Sunday": 1.2,
|
| 78 |
+
}
|
| 79 |
+
day_effect = np.array([day_effects[day] for day in days_of_week])
|
| 80 |
+
|
| 81 |
+
# Major event effect (events typically increase consumption)
|
| 82 |
+
event_effect = major_events * 2.0
|
| 83 |
+
|
| 84 |
+
# Calculate consumption
|
| 85 |
+
consumption = base_consumption + temp_effect + day_effect + event_effect
|
| 86 |
+
|
| 87 |
+
# Add noise
|
| 88 |
+
noise = np.random.normal(0, noise_level * np.std(consumption), n_samples)
|
| 89 |
+
consumption += noise
|
| 90 |
+
|
| 91 |
+
# Ensure positive values
|
| 92 |
+
consumption = np.maximum(consumption, 5.0)
|
| 93 |
+
|
| 94 |
+
# Create DataFrame
|
| 95 |
+
data = pd.DataFrame(
|
| 96 |
+
{
|
| 97 |
+
"temperature": temperatures,
|
| 98 |
+
"day_of_week": days_of_week,
|
| 99 |
+
"major_event": major_events,
|
| 100 |
+
"consumption_kwh": consumption,
|
| 101 |
+
}
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return data
|
| 105 |
+
|
| 106 |
+
def split_data(
|
| 107 |
+
self,
|
| 108 |
+
data: pd.DataFrame,
|
| 109 |
+
train_size: float = 0.7,
|
| 110 |
+
val_size: float = 0.15,
|
| 111 |
+
test_size: float = 0.15,
|
| 112 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
| 113 |
+
"""
|
| 114 |
+
Split data into training, validation, and test sets.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
data: Input DataFrame
|
| 118 |
+
train_size: Proportion for training set
|
| 119 |
+
val_size: Proportion for validation set
|
| 120 |
+
test_size: Proportion for test set
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Tuple of (train_data, val_data, test_data)
|
| 124 |
+
"""
|
| 125 |
+
assert (
|
| 126 |
+
abs(train_size + val_size + test_size - 1.0) < 1e-6
|
| 127 |
+
), "Split proportions must sum to 1"
|
| 128 |
+
|
| 129 |
+
# Shuffle data
|
| 130 |
+
data_shuffled = data.sample(frac=1, random_state=self.seed).reset_index(
|
| 131 |
+
drop=True
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
n_samples = len(data_shuffled)
|
| 135 |
+
train_end = int(n_samples * train_size)
|
| 136 |
+
val_end = train_end + int(n_samples * val_size)
|
| 137 |
+
|
| 138 |
+
train_data = data_shuffled[:train_end]
|
| 139 |
+
val_data = data_shuffled[train_end:val_end]
|
| 140 |
+
test_data = data_shuffled[val_end:]
|
| 141 |
+
|
| 142 |
+
return train_data, val_data, test_data
|
| 143 |
+
|
| 144 |
+
def save_data(self, data: pd.DataFrame, filepath: str) -> None:
|
| 145 |
+
"""
|
| 146 |
+
Save data to CSV file.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
data: DataFrame to save
|
| 150 |
+
filepath: Path to save the file
|
| 151 |
+
"""
|
| 152 |
+
data.to_csv(filepath, index=False)
|
| 153 |
+
|
| 154 |
+
def load_data(self, filepath: str) -> pd.DataFrame:
|
| 155 |
+
"""
|
| 156 |
+
Load data from CSV file.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
filepath: Path to the file
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Loaded DataFrame
|
| 163 |
+
"""
|
| 164 |
+
return pd.read_csv(filepath)
|
src/model.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Module for Daily Household Electricity Consumption Predictor
|
| 3 |
+
|
| 4 |
+
This module handles data preprocessing, model training, evaluation, and prediction
|
| 5 |
+
for the electricity consumption prediction model.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from sklearn.linear_model import LinearRegression
|
| 11 |
+
from sklearn.preprocessing import OneHotEncoder, StandardScaler
|
| 12 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 13 |
+
from sklearn.pipeline import Pipeline
|
| 14 |
+
from sklearn.compose import ColumnTransformer
|
| 15 |
+
import joblib
|
| 16 |
+
from typing import Tuple, Dict, Any, Optional
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ElectricityConsumptionModel:
|
| 21 |
+
"""Linear regression model for predicting daily electricity consumption."""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
"""Initialize the model with preprocessing pipeline."""
|
| 25 |
+
self.model = None
|
| 26 |
+
self.preprocessor = None
|
| 27 |
+
self.feature_names = None
|
| 28 |
+
self.is_trained = False
|
| 29 |
+
|
| 30 |
+
def _create_preprocessor(self) -> ColumnTransformer:
|
| 31 |
+
"""
|
| 32 |
+
Create preprocessing pipeline for the features.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
ColumnTransformer with preprocessing steps
|
| 36 |
+
"""
|
| 37 |
+
# Numerical features (temperature)
|
| 38 |
+
numerical_features = ["temperature"]
|
| 39 |
+
numerical_transformer = StandardScaler()
|
| 40 |
+
|
| 41 |
+
# Categorical features (day_of_week)
|
| 42 |
+
categorical_features = ["day_of_week"]
|
| 43 |
+
categorical_transformer = OneHotEncoder(drop="first", sparse=False)
|
| 44 |
+
|
| 45 |
+
# Boolean features (major_event) - no transformation needed
|
| 46 |
+
boolean_features = ["major_event"]
|
| 47 |
+
boolean_transformer = "passthrough"
|
| 48 |
+
|
| 49 |
+
# Combine all transformers
|
| 50 |
+
preprocessor = ColumnTransformer(
|
| 51 |
+
transformers=[
|
| 52 |
+
("num", numerical_transformer, numerical_features),
|
| 53 |
+
("cat", categorical_transformer, categorical_features),
|
| 54 |
+
("bool", boolean_transformer, boolean_features),
|
| 55 |
+
],
|
| 56 |
+
remainder="drop",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return preprocessor
|
| 60 |
+
|
| 61 |
+
def _create_pipeline(self) -> Pipeline:
|
| 62 |
+
"""
|
| 63 |
+
Create the complete model pipeline.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Pipeline with preprocessing and model
|
| 67 |
+
"""
|
| 68 |
+
preprocessor = self._create_preprocessor()
|
| 69 |
+
model = LinearRegression()
|
| 70 |
+
|
| 71 |
+
pipeline = Pipeline([("preprocessor", preprocessor), ("regressor", model)])
|
| 72 |
+
|
| 73 |
+
return pipeline
|
| 74 |
+
|
| 75 |
+
def prepare_features(self, data: pd.DataFrame) -> pd.DataFrame:
|
| 76 |
+
"""
|
| 77 |
+
Prepare features for training/prediction.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
data: Input DataFrame with raw features
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
DataFrame with prepared features
|
| 84 |
+
"""
|
| 85 |
+
required_columns = ["temperature", "day_of_week", "major_event"]
|
| 86 |
+
|
| 87 |
+
# Validate input data
|
| 88 |
+
missing_columns = [col for col in required_columns if col not in data.columns]
|
| 89 |
+
if missing_columns:
|
| 90 |
+
raise ValueError(f"Missing required columns: {missing_columns}")
|
| 91 |
+
|
| 92 |
+
# Validate data types and ranges
|
| 93 |
+
if not all(data["temperature"].between(15, 35)):
|
| 94 |
+
raise ValueError("Temperature must be between 15 and 35 degrees Celsius")
|
| 95 |
+
|
| 96 |
+
valid_days = [
|
| 97 |
+
"Monday",
|
| 98 |
+
"Tuesday",
|
| 99 |
+
"Wednesday",
|
| 100 |
+
"Thursday",
|
| 101 |
+
"Friday",
|
| 102 |
+
"Saturday",
|
| 103 |
+
"Sunday",
|
| 104 |
+
]
|
| 105 |
+
if not all(day in valid_days for day in data["day_of_week"].unique()):
|
| 106 |
+
raise ValueError(f"Day of week must be one of: {valid_days}")
|
| 107 |
+
|
| 108 |
+
if not all(data["major_event"].isin([0, 1])):
|
| 109 |
+
raise ValueError("Major event must be 0 or 1")
|
| 110 |
+
|
| 111 |
+
return data[required_columns].copy()
|
| 112 |
+
|
| 113 |
+
def train(self, X_train: pd.DataFrame, y_train: pd.DataFrame) -> Dict[str, float]:
|
| 114 |
+
"""
|
| 115 |
+
Train the model on the provided data.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
X_train: Training features
|
| 119 |
+
y_train: Training targets
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Dictionary with training metrics
|
| 123 |
+
"""
|
| 124 |
+
# Prepare features
|
| 125 |
+
X_prepared = self.prepare_features(X_train)
|
| 126 |
+
|
| 127 |
+
# Create and train pipeline
|
| 128 |
+
self.model = self._create_pipeline()
|
| 129 |
+
self.model.fit(X_prepared, y_train["consumption_kwh"])
|
| 130 |
+
|
| 131 |
+
# Store feature names for later use
|
| 132 |
+
self.feature_names = X_prepared.columns.tolist()
|
| 133 |
+
self.is_trained = True
|
| 134 |
+
|
| 135 |
+
# Calculate training metrics
|
| 136 |
+
y_pred = self.model.predict(X_prepared)
|
| 137 |
+
metrics = {
|
| 138 |
+
"train_mse": mean_squared_error(y_train["consumption_kwh"], y_pred),
|
| 139 |
+
"train_rmse": np.sqrt(
|
| 140 |
+
mean_squared_error(y_train["consumption_kwh"], y_pred)
|
| 141 |
+
),
|
| 142 |
+
"train_mae": mean_absolute_error(y_train["consumption_kwh"], y_pred),
|
| 143 |
+
"train_r2": r2_score(y_train["consumption_kwh"], y_pred),
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
return metrics
|
| 147 |
+
|
| 148 |
+
def evaluate(self, X_test: pd.DataFrame, y_test: pd.DataFrame) -> Dict[str, float]:
|
| 149 |
+
"""
|
| 150 |
+
Evaluate the model on test data.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
X_test: Test features
|
| 154 |
+
y_test: Test targets
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Dictionary with evaluation metrics
|
| 158 |
+
"""
|
| 159 |
+
if not self.is_trained:
|
| 160 |
+
raise ValueError("Model must be trained before evaluation")
|
| 161 |
+
|
| 162 |
+
# Prepare features
|
| 163 |
+
X_prepared = self.prepare_features(X_test)
|
| 164 |
+
|
| 165 |
+
# Make predictions
|
| 166 |
+
y_pred = self.model.predict(X_prepared)
|
| 167 |
+
|
| 168 |
+
# Calculate metrics
|
| 169 |
+
metrics = {
|
| 170 |
+
"test_mse": mean_squared_error(y_test["consumption_kwh"], y_pred),
|
| 171 |
+
"test_rmse": np.sqrt(mean_squared_error(y_test["consumption_kwh"], y_pred)),
|
| 172 |
+
"test_mae": mean_absolute_error(y_test["consumption_kwh"], y_pred),
|
| 173 |
+
"test_r2": r2_score(y_test["consumption_kwh"], y_pred),
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
return metrics
|
| 177 |
+
|
| 178 |
+
def predict(self, temperature: float, day_of_week: str, major_event: int) -> float:
|
| 179 |
+
"""
|
| 180 |
+
Make a single prediction.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
temperature: Average daily temperature in Celsius
|
| 184 |
+
day_of_week: Day of the week
|
| 185 |
+
major_event: Whether there's a major event (0 or 1)
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Predicted electricity consumption in kWh
|
| 189 |
+
"""
|
| 190 |
+
if not self.is_trained:
|
| 191 |
+
raise ValueError("Model must be trained before making predictions")
|
| 192 |
+
|
| 193 |
+
# Create input DataFrame
|
| 194 |
+
input_data = pd.DataFrame(
|
| 195 |
+
{
|
| 196 |
+
"temperature": [temperature],
|
| 197 |
+
"day_of_week": [day_of_week],
|
| 198 |
+
"major_event": [major_event],
|
| 199 |
+
}
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Prepare features
|
| 203 |
+
X_prepared = self.prepare_features(input_data)
|
| 204 |
+
|
| 205 |
+
# Make prediction
|
| 206 |
+
prediction = self.model.predict(X_prepared)[0]
|
| 207 |
+
|
| 208 |
+
return max(0, prediction) # Ensure non-negative prediction
|
| 209 |
+
|
| 210 |
+
def get_model_coefficients(self) -> Dict[str, Any]:
|
| 211 |
+
"""
|
| 212 |
+
Get model coefficients and feature names.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
Dictionary with model coefficients and feature information
|
| 216 |
+
"""
|
| 217 |
+
if not self.is_trained:
|
| 218 |
+
raise ValueError("Model must be trained before accessing coefficients")
|
| 219 |
+
|
| 220 |
+
# Get feature names from preprocessor
|
| 221 |
+
preprocessor = self.model.named_steps["preprocessor"]
|
| 222 |
+
feature_names = []
|
| 223 |
+
|
| 224 |
+
# Numerical features
|
| 225 |
+
feature_names.extend(["temperature"])
|
| 226 |
+
|
| 227 |
+
# Categorical features (one-hot encoded)
|
| 228 |
+
cat_transformer = preprocessor.named_transformers_["cat"]
|
| 229 |
+
day_names = [
|
| 230 |
+
"Tuesday",
|
| 231 |
+
"Wednesday",
|
| 232 |
+
"Thursday",
|
| 233 |
+
"Friday",
|
| 234 |
+
"Saturday",
|
| 235 |
+
"Sunday",
|
| 236 |
+
] # Monday is dropped
|
| 237 |
+
feature_names.extend([f"day_{day.lower()}" for day in day_names])
|
| 238 |
+
|
| 239 |
+
# Boolean features
|
| 240 |
+
feature_names.extend(["major_event"])
|
| 241 |
+
|
| 242 |
+
# Get coefficients
|
| 243 |
+
coefficients = self.model.named_steps["regressor"].coef_
|
| 244 |
+
intercept = self.model.named_steps["regressor"].intercept_
|
| 245 |
+
|
| 246 |
+
return {
|
| 247 |
+
"feature_names": feature_names,
|
| 248 |
+
"coefficients": coefficients.tolist(),
|
| 249 |
+
"intercept": float(intercept),
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
def save_model(self, filepath: str) -> None:
|
| 253 |
+
"""
|
| 254 |
+
Save the trained model to disk.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
filepath: Path to save the model
|
| 258 |
+
"""
|
| 259 |
+
if not self.is_trained:
|
| 260 |
+
raise ValueError("Model must be trained before saving")
|
| 261 |
+
|
| 262 |
+
# Create directory if it doesn't exist
|
| 263 |
+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
| 264 |
+
|
| 265 |
+
# Save model
|
| 266 |
+
joblib.dump(self.model, filepath)
|
| 267 |
+
|
| 268 |
+
def load_model(self, filepath: str) -> None:
|
| 269 |
+
"""
|
| 270 |
+
Load a trained model from disk.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
filepath: Path to the saved model
|
| 274 |
+
"""
|
| 275 |
+
if not os.path.exists(filepath):
|
| 276 |
+
raise FileNotFoundError(f"Model file not found: {filepath}")
|
| 277 |
+
|
| 278 |
+
self.model = joblib.load(filepath)
|
| 279 |
+
self.is_trained = True
|
| 280 |
+
|
| 281 |
+
# Extract feature names from the loaded model
|
| 282 |
+
preprocessor = self.model.named_steps["preprocessor"]
|
| 283 |
+
self.feature_names = ["temperature", "day_of_week", "major_event"]
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Test package for Daily Household Electricity Consumption Predictor
|
tests/test_app.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for the Gradio Application module.
|
| 3 |
+
|
| 4 |
+
This module contains tests for the ElectricityPredictorApp class to ensure
|
| 5 |
+
the web interface functions correctly.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
from unittest.mock import patch, MagicMock
|
| 12 |
+
from src.app import ElectricityPredictorApp
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TestElectricityPredictorApp:
|
| 16 |
+
"""Test cases for ElectricityPredictorApp class."""
|
| 17 |
+
|
| 18 |
+
def setup_method(self):
|
| 19 |
+
"""Set up test app for each test method."""
|
| 20 |
+
self.app = ElectricityPredictorApp()
|
| 21 |
+
|
| 22 |
+
def test_initialization(self):
|
| 23 |
+
"""Test app initialization."""
|
| 24 |
+
assert self.app.data_generator is not None
|
| 25 |
+
assert self.app.model is not None
|
| 26 |
+
assert not self.app.is_model_trained
|
| 27 |
+
|
| 28 |
+
def test_generate_and_train_success(self):
|
| 29 |
+
"""Test successful data generation and training."""
|
| 30 |
+
# Mock the data generator methods
|
| 31 |
+
with patch.object(
|
| 32 |
+
self.app.data_generator, "generate_data"
|
| 33 |
+
) as mock_generate, patch.object(
|
| 34 |
+
self.app.data_generator, "split_data"
|
| 35 |
+
) as mock_split, patch.object(
|
| 36 |
+
self.app.model, "train"
|
| 37 |
+
) as mock_train, patch.object(
|
| 38 |
+
self.app.model, "evaluate"
|
| 39 |
+
) as mock_evaluate:
|
| 40 |
+
|
| 41 |
+
# Create mock data
|
| 42 |
+
mock_data = pd.DataFrame(
|
| 43 |
+
{
|
| 44 |
+
"temperature": [25.0, 30.0],
|
| 45 |
+
"day_of_week": ["Monday", "Tuesday"],
|
| 46 |
+
"major_event": [0, 1],
|
| 47 |
+
"consumption_kwh": [15.0, 18.0],
|
| 48 |
+
}
|
| 49 |
+
)
|
| 50 |
+
mock_generate.return_value = mock_data
|
| 51 |
+
|
| 52 |
+
# Create mock split data
|
| 53 |
+
train_data = mock_data.iloc[:1]
|
| 54 |
+
val_data = mock_data.iloc[1:2]
|
| 55 |
+
test_data = mock_data.iloc[1:2]
|
| 56 |
+
mock_split.return_value = (train_data, val_data, test_data)
|
| 57 |
+
|
| 58 |
+
mock_train.return_value = {
|
| 59 |
+
"train_mse": 2.5,
|
| 60 |
+
"train_rmse": 1.58,
|
| 61 |
+
"train_mae": 1.2,
|
| 62 |
+
"train_r2": 0.85,
|
| 63 |
+
}
|
| 64 |
+
mock_evaluate.return_value = {
|
| 65 |
+
"test_mse": 2.8,
|
| 66 |
+
"test_rmse": 1.67,
|
| 67 |
+
"test_mae": 1.3,
|
| 68 |
+
"test_r2": 0.82,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Call the method
|
| 72 |
+
data_info, training_metrics, evaluation_metrics = (
|
| 73 |
+
self.app.generate_and_train(
|
| 74 |
+
n_samples=1000,
|
| 75 |
+
noise_level=0.1,
|
| 76 |
+
train_size=0.7,
|
| 77 |
+
val_size=0.15,
|
| 78 |
+
test_size=0.15,
|
| 79 |
+
)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Check that methods were called
|
| 83 |
+
mock_generate.assert_called_once_with(1000, 0.1)
|
| 84 |
+
mock_split.assert_called_once_with(mock_data, 0.7, 0.15, 0.15)
|
| 85 |
+
mock_train.assert_called_once()
|
| 86 |
+
mock_evaluate.assert_called_once()
|
| 87 |
+
|
| 88 |
+
# Check that app state was updated
|
| 89 |
+
assert self.app.is_model_trained
|
| 90 |
+
assert hasattr(self.app, "train_data")
|
| 91 |
+
assert hasattr(self.app, "val_data")
|
| 92 |
+
assert hasattr(self.app, "test_data")
|
| 93 |
+
|
| 94 |
+
# Check output strings contain expected information
|
| 95 |
+
assert "Data Generated Successfully!" in data_info
|
| 96 |
+
assert "Training Metrics:" in training_metrics
|
| 97 |
+
assert "Test Set Evaluation:" in evaluation_metrics
|
| 98 |
+
assert "2.5000" in training_metrics # MSE value
|
| 99 |
+
assert "0.8500" in training_metrics # R² value
|
| 100 |
+
|
| 101 |
+
def test_generate_and_train_error(self):
|
| 102 |
+
"""Test error handling in data generation and training."""
|
| 103 |
+
# Mock the data generator to raise an exception
|
| 104 |
+
with patch.object(
|
| 105 |
+
self.app.data_generator,
|
| 106 |
+
"generate_data",
|
| 107 |
+
side_effect=Exception("Test error"),
|
| 108 |
+
):
|
| 109 |
+
data_info, training_metrics, evaluation_metrics = (
|
| 110 |
+
self.app.generate_and_train(
|
| 111 |
+
n_samples=1000,
|
| 112 |
+
noise_level=0.1,
|
| 113 |
+
train_size=0.7,
|
| 114 |
+
val_size=0.15,
|
| 115 |
+
test_size=0.15,
|
| 116 |
+
)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
assert "Error during data generation and training" in data_info
|
| 120 |
+
assert training_metrics == ""
|
| 121 |
+
assert evaluation_metrics == ""
|
| 122 |
+
|
| 123 |
+
def test_predict_consumption_not_trained(self):
|
| 124 |
+
"""Test prediction when model is not trained."""
|
| 125 |
+
result = self.app.predict_consumption(25.0, "Monday", False)
|
| 126 |
+
|
| 127 |
+
assert "Model must be trained first" in result
|
| 128 |
+
|
| 129 |
+
def test_predict_consumption_success(self):
|
| 130 |
+
"""Test successful prediction."""
|
| 131 |
+
# Set up the app as if it's trained
|
| 132 |
+
self.app.is_model_trained = True
|
| 133 |
+
|
| 134 |
+
# Mock the model prediction
|
| 135 |
+
with patch.object(
|
| 136 |
+
self.app.model, "predict", return_value=16.5
|
| 137 |
+
) as mock_predict, patch.object(
|
| 138 |
+
self.app.model, "get_model_coefficients"
|
| 139 |
+
) as mock_coeffs:
|
| 140 |
+
|
| 141 |
+
mock_coeffs.return_value = {
|
| 142 |
+
"feature_names": ["temperature", "major_event"],
|
| 143 |
+
"coefficients": [0.3, 2.0],
|
| 144 |
+
"intercept": 10.0,
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
result = self.app.predict_consumption(25.0, "Monday", True)
|
| 148 |
+
|
| 149 |
+
# Check that prediction was called
|
| 150 |
+
mock_predict.assert_called_once_with(25.0, "Monday", 1)
|
| 151 |
+
|
| 152 |
+
# Check output contains expected information
|
| 153 |
+
assert "Estimated Daily Electricity Consumption: 16.5 kWh" in result
|
| 154 |
+
assert "Temperature: 25.0°C" in result
|
| 155 |
+
assert "Day of Week: Monday" in result
|
| 156 |
+
assert "Major Event: Yes" in result
|
| 157 |
+
assert "Model Type: Linear Regression" in result
|
| 158 |
+
|
| 159 |
+
def test_predict_consumption_error(self):
|
| 160 |
+
"""Test error handling in prediction."""
|
| 161 |
+
# Set up the app as if it's trained
|
| 162 |
+
self.app.is_model_trained = True
|
| 163 |
+
|
| 164 |
+
# Mock the model to raise an exception
|
| 165 |
+
with patch.object(
|
| 166 |
+
self.app.model, "predict", side_effect=Exception("Prediction error")
|
| 167 |
+
):
|
| 168 |
+
result = self.app.predict_consumption(25.0, "Monday", False)
|
| 169 |
+
|
| 170 |
+
assert "Error during prediction" in result
|
| 171 |
+
|
| 172 |
+
def test_get_model_info_not_trained(self):
|
| 173 |
+
"""Test getting model info when model is not trained."""
|
| 174 |
+
result = self.app.get_model_info()
|
| 175 |
+
|
| 176 |
+
assert "Model must be trained first" in result
|
| 177 |
+
|
| 178 |
+
def test_get_model_info_success(self):
|
| 179 |
+
"""Test successful model info retrieval."""
|
| 180 |
+
# Set up the app as if it's trained
|
| 181 |
+
self.app.is_model_trained = True
|
| 182 |
+
|
| 183 |
+
# Mock the model coefficients
|
| 184 |
+
with patch.object(self.app.model, "get_model_coefficients") as mock_coeffs:
|
| 185 |
+
mock_coeffs.return_value = {
|
| 186 |
+
"feature_names": ["temperature", "day_tuesday", "major_event"],
|
| 187 |
+
"coefficients": [0.3, 0.5, 2.0],
|
| 188 |
+
"intercept": 10.0,
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
result = self.app.get_model_info()
|
| 192 |
+
|
| 193 |
+
# Check output contains expected information
|
| 194 |
+
assert "**Model Information:**" in result
|
| 195 |
+
assert "**Model Type:** Linear Regression" in result
|
| 196 |
+
assert "**Intercept:** 10.0000" in result
|
| 197 |
+
assert "**Feature Coefficients:**" in result
|
| 198 |
+
assert "temperature" in result
|
| 199 |
+
assert "major_event" in result
|
| 200 |
+
assert "**Interpretation:**" in result
|
| 201 |
+
|
| 202 |
+
def test_get_model_info_error(self):
|
| 203 |
+
"""Test error handling in model info retrieval."""
|
| 204 |
+
# Set up the app as if it's trained
|
| 205 |
+
self.app.is_model_trained = True
|
| 206 |
+
|
| 207 |
+
# Mock the model to raise an exception
|
| 208 |
+
with patch.object(
|
| 209 |
+
self.app.model,
|
| 210 |
+
"get_model_coefficients",
|
| 211 |
+
side_effect=Exception("Info error"),
|
| 212 |
+
):
|
| 213 |
+
result = self.app.get_model_info()
|
| 214 |
+
|
| 215 |
+
assert "Error getting model info" in result
|
| 216 |
+
|
| 217 |
+
def test_boolean_conversion_in_prediction(self):
|
| 218 |
+
"""Test that boolean values are correctly converted to integers."""
|
| 219 |
+
# Set up the app as if it's trained
|
| 220 |
+
self.app.is_model_trained = True
|
| 221 |
+
|
| 222 |
+
# Mock the model prediction
|
| 223 |
+
with patch.object(self.app.model, "predict") as mock_predict, patch.object(
|
| 224 |
+
self.app.model, "get_model_coefficients"
|
| 225 |
+
) as mock_coeffs:
|
| 226 |
+
|
| 227 |
+
mock_predict.return_value = 15.0
|
| 228 |
+
mock_coeffs.return_value = {
|
| 229 |
+
"feature_names": ["temperature", "major_event"],
|
| 230 |
+
"coefficients": [0.3, 2.0],
|
| 231 |
+
"intercept": 10.0,
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
# Test with True
|
| 235 |
+
self.app.predict_consumption(25.0, "Monday", True)
|
| 236 |
+
mock_predict.assert_called_with(25.0, "Monday", 1)
|
| 237 |
+
|
| 238 |
+
# Test with False
|
| 239 |
+
self.app.predict_consumption(25.0, "Monday", False)
|
| 240 |
+
mock_predict.assert_called_with(25.0, "Monday", 0)
|
| 241 |
+
|
| 242 |
+
def test_data_storage_after_training(self):
|
| 243 |
+
"""Test that data is properly stored after training."""
|
| 244 |
+
# Mock the data generator
|
| 245 |
+
with patch.object(
|
| 246 |
+
self.app.data_generator, "generate_data"
|
| 247 |
+
) as mock_generate, patch.object(
|
| 248 |
+
self.app.data_generator, "split_data"
|
| 249 |
+
) as mock_split, patch.object(
|
| 250 |
+
self.app.model, "train"
|
| 251 |
+
) as mock_train, patch.object(
|
| 252 |
+
self.app.model, "evaluate"
|
| 253 |
+
) as mock_evaluate:
|
| 254 |
+
|
| 255 |
+
# Create mock data
|
| 256 |
+
mock_data = pd.DataFrame(
|
| 257 |
+
{
|
| 258 |
+
"temperature": [25.0, 30.0],
|
| 259 |
+
"day_of_week": ["Monday", "Tuesday"],
|
| 260 |
+
"major_event": [0, 1],
|
| 261 |
+
"consumption_kwh": [15.0, 18.0],
|
| 262 |
+
}
|
| 263 |
+
)
|
| 264 |
+
mock_generate.return_value = mock_data
|
| 265 |
+
|
| 266 |
+
train_data = mock_data.iloc[:1]
|
| 267 |
+
val_data = mock_data.iloc[1:2]
|
| 268 |
+
test_data = mock_data.iloc[1:2]
|
| 269 |
+
mock_split.return_value = (train_data, val_data, test_data)
|
| 270 |
+
|
| 271 |
+
mock_train.return_value = {
|
| 272 |
+
"train_mse": 2.5,
|
| 273 |
+
"train_rmse": 1.58,
|
| 274 |
+
"train_mae": 1.2,
|
| 275 |
+
"train_r2": 0.85,
|
| 276 |
+
}
|
| 277 |
+
mock_evaluate.return_value = {
|
| 278 |
+
"test_mse": 2.8,
|
| 279 |
+
"test_rmse": 1.67,
|
| 280 |
+
"test_mae": 1.3,
|
| 281 |
+
"test_r2": 0.82,
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
# Call the method
|
| 285 |
+
self.app.generate_and_train(1000, 0.1, 0.7, 0.15, 0.15)
|
| 286 |
+
|
| 287 |
+
# Check that data is stored
|
| 288 |
+
assert hasattr(self.app, "train_data")
|
| 289 |
+
assert hasattr(self.app, "val_data")
|
| 290 |
+
assert hasattr(self.app, "test_data")
|
| 291 |
+
assert len(self.app.train_data) == 1
|
| 292 |
+
assert len(self.app.val_data) == 1
|
| 293 |
+
assert len(self.app.test_data) == 1
|
| 294 |
+
|
| 295 |
+
def test_interface_creation(self):
|
| 296 |
+
"""Test that the Gradio interface can be created."""
|
| 297 |
+
# This test verifies that the interface creation doesn't raise exceptions
|
| 298 |
+
try:
|
| 299 |
+
interface = self.app.create_interface()
|
| 300 |
+
assert interface is not None
|
| 301 |
+
except Exception as e:
|
| 302 |
+
pytest.fail(f"Interface creation failed: {e}")
|
| 303 |
+
|
| 304 |
+
def test_prediction_output_format(self):
|
| 305 |
+
"""Test that prediction output is properly formatted."""
|
| 306 |
+
# Set up the app as if it's trained
|
| 307 |
+
self.app.is_model_trained = True
|
| 308 |
+
|
| 309 |
+
# Mock the model
|
| 310 |
+
with patch.object(
|
| 311 |
+
self.app.model, "predict", return_value=16.5
|
| 312 |
+
) as mock_predict, patch.object(
|
| 313 |
+
self.app.model, "get_model_coefficients"
|
| 314 |
+
) as mock_coeffs:
|
| 315 |
+
|
| 316 |
+
mock_coeffs.return_value = {
|
| 317 |
+
"feature_names": ["temperature", "major_event"],
|
| 318 |
+
"coefficients": [0.3, 2.0],
|
| 319 |
+
"intercept": 10.0,
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
result = self.app.predict_consumption(25.0, "Monday", False)
|
| 323 |
+
|
| 324 |
+
# Check formatting
|
| 325 |
+
assert "**Prediction Result:**" in result
|
| 326 |
+
assert "**Input Parameters:**" in result
|
| 327 |
+
assert "**Model Information:**" in result
|
| 328 |
+
assert "Estimated Daily Electricity Consumption: 16.5 kWh" in result
|
| 329 |
+
assert "Temperature: 25.0°C" in result
|
| 330 |
+
assert "Day of Week: Monday" in result
|
| 331 |
+
assert "Major Event: No" in result
|
| 332 |
+
|
| 333 |
+
def test_model_info_output_format(self):
|
| 334 |
+
"""Test that model info output is properly formatted."""
|
| 335 |
+
# Set up the app as if it's trained
|
| 336 |
+
self.app.is_model_trained = True
|
| 337 |
+
|
| 338 |
+
# Mock the model coefficients
|
| 339 |
+
with patch.object(self.app.model, "get_model_coefficients") as mock_coeffs:
|
| 340 |
+
mock_coeffs.return_value = {
|
| 341 |
+
"feature_names": ["temperature", "day_tuesday", "major_event"],
|
| 342 |
+
"coefficients": [0.3, 0.5, 2.0],
|
| 343 |
+
"intercept": 10.0,
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
result = self.app.get_model_info()
|
| 347 |
+
|
| 348 |
+
# Check formatting
|
| 349 |
+
assert "**Model Information:**" in result
|
| 350 |
+
assert "**Model Type:**" in result
|
| 351 |
+
assert "**Intercept:**" in result
|
| 352 |
+
assert "**Feature Coefficients:**" in result
|
| 353 |
+
assert "| Feature | Coefficient |" in result
|
| 354 |
+
assert "**Interpretation:**" in result
|
| 355 |
+
assert "Positive coefficients increase predicted consumption" in result
|
tests/test_data_generator.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for the DataGenerator module.
|
| 3 |
+
|
| 4 |
+
This module contains comprehensive tests for the DataGenerator class to ensure
|
| 5 |
+
proper data generation, splitting, and file operations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
import tempfile
|
| 12 |
+
import os
|
| 13 |
+
from src.data_generator import DataGenerator
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestDataGenerator:
|
| 17 |
+
"""Test cases for DataGenerator class."""
|
| 18 |
+
|
| 19 |
+
def test_initialization(self):
|
| 20 |
+
"""Test DataGenerator initialization with and without seed."""
|
| 21 |
+
# Test with seed
|
| 22 |
+
generator = DataGenerator(seed=42)
|
| 23 |
+
assert generator.seed == 42
|
| 24 |
+
|
| 25 |
+
# Test without seed
|
| 26 |
+
generator_no_seed = DataGenerator(seed=None)
|
| 27 |
+
assert generator_no_seed.seed is None
|
| 28 |
+
|
| 29 |
+
def test_generate_data_basic(self):
|
| 30 |
+
"""Test basic data generation with default parameters."""
|
| 31 |
+
generator = DataGenerator(seed=42)
|
| 32 |
+
data = generator.generate_data()
|
| 33 |
+
|
| 34 |
+
# Check DataFrame structure
|
| 35 |
+
assert isinstance(data, pd.DataFrame)
|
| 36 |
+
assert len(data) == 1000 # Default n_samples
|
| 37 |
+
assert list(data.columns) == [
|
| 38 |
+
"temperature",
|
| 39 |
+
"day_of_week",
|
| 40 |
+
"major_event",
|
| 41 |
+
"consumption_kwh",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
# Check data types
|
| 45 |
+
assert data["temperature"].dtype in [np.float64, np.float32]
|
| 46 |
+
assert data["day_of_week"].dtype == "object"
|
| 47 |
+
assert data["major_event"].dtype in [np.int64, np.int32]
|
| 48 |
+
assert data["consumption_kwh"].dtype in [np.float64, np.float32]
|
| 49 |
+
|
| 50 |
+
def test_generate_data_custom_parameters(self):
|
| 51 |
+
"""Test data generation with custom parameters."""
|
| 52 |
+
generator = DataGenerator(seed=42)
|
| 53 |
+
data = generator.generate_data(n_samples=500, noise_level=0.2)
|
| 54 |
+
|
| 55 |
+
assert len(data) == 500
|
| 56 |
+
|
| 57 |
+
# Check temperature range
|
| 58 |
+
assert data["temperature"].min() >= 15
|
| 59 |
+
assert data["temperature"].max() <= 35
|
| 60 |
+
|
| 61 |
+
# Check day of week values
|
| 62 |
+
valid_days = [
|
| 63 |
+
"Monday",
|
| 64 |
+
"Tuesday",
|
| 65 |
+
"Wednesday",
|
| 66 |
+
"Thursday",
|
| 67 |
+
"Friday",
|
| 68 |
+
"Saturday",
|
| 69 |
+
"Sunday",
|
| 70 |
+
]
|
| 71 |
+
assert all(day in valid_days for day in data["day_of_week"].unique())
|
| 72 |
+
|
| 73 |
+
# Check major event values
|
| 74 |
+
assert all(event in [0, 1] for event in data["major_event"].unique())
|
| 75 |
+
|
| 76 |
+
# Check consumption is positive
|
| 77 |
+
assert all(data["consumption_kwh"] > 0)
|
| 78 |
+
|
| 79 |
+
def test_generate_data_reproducibility(self):
|
| 80 |
+
"""Test that data generation is reproducible with the same seed."""
|
| 81 |
+
# Reset numpy random seed to ensure reproducibility
|
| 82 |
+
np.random.seed(42)
|
| 83 |
+
|
| 84 |
+
generator1 = DataGenerator(seed=42)
|
| 85 |
+
data1 = generator1.generate_data(n_samples=100)
|
| 86 |
+
|
| 87 |
+
# Reset numpy random seed again
|
| 88 |
+
np.random.seed(42)
|
| 89 |
+
|
| 90 |
+
generator2 = DataGenerator(seed=42)
|
| 91 |
+
data2 = generator2.generate_data(n_samples=100)
|
| 92 |
+
|
| 93 |
+
pd.testing.assert_frame_equal(data1, data2)
|
| 94 |
+
|
| 95 |
+
def test_generate_data_different_seeds(self):
|
| 96 |
+
"""Test that different seeds produce different data."""
|
| 97 |
+
generator1 = DataGenerator(seed=42)
|
| 98 |
+
generator2 = DataGenerator(seed=123)
|
| 99 |
+
|
| 100 |
+
data1 = generator1.generate_data(n_samples=100)
|
| 101 |
+
data2 = generator2.generate_data(n_samples=100)
|
| 102 |
+
|
| 103 |
+
# Data should be different
|
| 104 |
+
assert not data1.equals(data2)
|
| 105 |
+
|
| 106 |
+
def test_split_data_basic(self):
|
| 107 |
+
"""Test basic data splitting functionality."""
|
| 108 |
+
generator = DataGenerator(seed=42)
|
| 109 |
+
data = generator.generate_data(n_samples=1000)
|
| 110 |
+
|
| 111 |
+
train_data, val_data, test_data = generator.split_data(data)
|
| 112 |
+
|
| 113 |
+
# Check split proportions
|
| 114 |
+
assert len(train_data) == 700 # 70% of 1000
|
| 115 |
+
assert len(val_data) == 150 # 15% of 1000
|
| 116 |
+
assert len(test_data) == 150 # 15% of 1000
|
| 117 |
+
|
| 118 |
+
# Check total samples
|
| 119 |
+
assert len(train_data) + len(val_data) + len(test_data) == len(data)
|
| 120 |
+
|
| 121 |
+
# Check all data is used
|
| 122 |
+
all_data = pd.concat([train_data, val_data, test_data])
|
| 123 |
+
assert len(all_data) == len(data)
|
| 124 |
+
|
| 125 |
+
def test_split_data_custom_proportions(self):
|
| 126 |
+
"""Test data splitting with custom proportions."""
|
| 127 |
+
generator = DataGenerator(seed=42)
|
| 128 |
+
data = generator.generate_data(n_samples=1000)
|
| 129 |
+
|
| 130 |
+
train_data, val_data, test_data = generator.split_data(
|
| 131 |
+
data, train_size=0.6, val_size=0.2, test_size=0.2
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
assert len(train_data) == 600
|
| 135 |
+
assert len(val_data) == 200
|
| 136 |
+
assert len(test_data) == 200
|
| 137 |
+
|
| 138 |
+
def test_split_data_validation(self):
|
| 139 |
+
"""Test that split proportions validation works."""
|
| 140 |
+
generator = DataGenerator(seed=42)
|
| 141 |
+
data = generator.generate_data(n_samples=100)
|
| 142 |
+
|
| 143 |
+
# Test invalid proportions
|
| 144 |
+
with pytest.raises(AssertionError):
|
| 145 |
+
generator.split_data(data, train_size=0.5, val_size=0.3, test_size=0.3)
|
| 146 |
+
|
| 147 |
+
with pytest.raises(AssertionError):
|
| 148 |
+
generator.split_data(data, train_size=0.4, val_size=0.3, test_size=0.2)
|
| 149 |
+
|
| 150 |
+
def test_split_data_reproducibility(self):
|
| 151 |
+
"""Test that data splitting is reproducible."""
|
| 152 |
+
generator = DataGenerator(seed=42)
|
| 153 |
+
data = generator.generate_data(n_samples=1000)
|
| 154 |
+
|
| 155 |
+
# First split
|
| 156 |
+
train1, val1, test1 = generator.split_data(data)
|
| 157 |
+
|
| 158 |
+
# Second split with same data
|
| 159 |
+
train2, val2, test2 = generator.split_data(data)
|
| 160 |
+
|
| 161 |
+
# Results should be identical
|
| 162 |
+
pd.testing.assert_frame_equal(train1, train2)
|
| 163 |
+
pd.testing.assert_frame_equal(val1, val2)
|
| 164 |
+
pd.testing.assert_frame_equal(test1, test2)
|
| 165 |
+
|
| 166 |
+
def test_save_and_load_data(self):
|
| 167 |
+
"""Test saving and loading data to/from CSV."""
|
| 168 |
+
generator = DataGenerator(seed=42)
|
| 169 |
+
data = generator.generate_data(n_samples=100)
|
| 170 |
+
|
| 171 |
+
with tempfile.NamedTemporaryFile(
|
| 172 |
+
mode="w", suffix=".csv", delete=False
|
| 173 |
+
) as tmp_file:
|
| 174 |
+
filepath = tmp_file.name
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
# Save data
|
| 178 |
+
generator.save_data(data, filepath)
|
| 179 |
+
|
| 180 |
+
# Check file exists
|
| 181 |
+
assert os.path.exists(filepath)
|
| 182 |
+
|
| 183 |
+
# Load data
|
| 184 |
+
loaded_data = generator.load_data(filepath)
|
| 185 |
+
|
| 186 |
+
# Check data is identical
|
| 187 |
+
pd.testing.assert_frame_equal(data, loaded_data)
|
| 188 |
+
|
| 189 |
+
finally:
|
| 190 |
+
# Clean up
|
| 191 |
+
if os.path.exists(filepath):
|
| 192 |
+
os.unlink(filepath)
|
| 193 |
+
|
| 194 |
+
def test_data_statistics(self):
|
| 195 |
+
"""Test that generated data has reasonable statistics."""
|
| 196 |
+
generator = DataGenerator(seed=42)
|
| 197 |
+
data = generator.generate_data(n_samples=1000)
|
| 198 |
+
|
| 199 |
+
# Temperature statistics
|
| 200 |
+
assert 15 <= data["temperature"].mean() <= 35
|
| 201 |
+
assert data["temperature"].std() > 0
|
| 202 |
+
|
| 203 |
+
# Consumption statistics
|
| 204 |
+
assert data["consumption_kwh"].mean() > 0
|
| 205 |
+
assert data["consumption_kwh"].std() > 0
|
| 206 |
+
|
| 207 |
+
# Day of week distribution
|
| 208 |
+
day_counts = data["day_of_week"].value_counts()
|
| 209 |
+
assert len(day_counts) == 7
|
| 210 |
+
# All days should have some data
|
| 211 |
+
assert all(count > 0 for count in day_counts.values)
|
| 212 |
+
|
| 213 |
+
# Major event distribution (should be mostly 0s)
|
| 214 |
+
event_counts = data["major_event"].value_counts()
|
| 215 |
+
assert 0 in event_counts.index
|
| 216 |
+
assert 1 in event_counts.index
|
| 217 |
+
# Should be more 0s than 1s
|
| 218 |
+
assert event_counts[0] > event_counts[1]
|
| 219 |
+
|
| 220 |
+
def test_noise_level_effect(self):
|
| 221 |
+
"""Test that noise level affects data variability."""
|
| 222 |
+
generator = DataGenerator(seed=42)
|
| 223 |
+
|
| 224 |
+
# Generate data with low noise
|
| 225 |
+
data_low_noise = generator.generate_data(n_samples=1000, noise_level=0.01)
|
| 226 |
+
|
| 227 |
+
# Generate data with high noise
|
| 228 |
+
data_high_noise = generator.generate_data(n_samples=1000, noise_level=0.5)
|
| 229 |
+
|
| 230 |
+
# High noise should have higher standard deviation
|
| 231 |
+
assert (
|
| 232 |
+
data_high_noise["consumption_kwh"].std()
|
| 233 |
+
> data_low_noise["consumption_kwh"].std()
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def test_temperature_consumption_correlation(self):
|
| 237 |
+
"""Test that temperature and consumption have positive correlation."""
|
| 238 |
+
generator = DataGenerator(seed=42)
|
| 239 |
+
data = generator.generate_data(n_samples=1000)
|
| 240 |
+
|
| 241 |
+
correlation = data["temperature"].corr(data["consumption_kwh"])
|
| 242 |
+
assert correlation > 0 # Should be positive correlation
|
| 243 |
+
|
| 244 |
+
def test_day_of_week_effect(self):
|
| 245 |
+
"""Test that different days have different consumption patterns."""
|
| 246 |
+
generator = DataGenerator(seed=42)
|
| 247 |
+
data = generator.generate_data(n_samples=1000)
|
| 248 |
+
|
| 249 |
+
# Group by day and check consumption means
|
| 250 |
+
day_consumption = data.groupby("day_of_week")["consumption_kwh"].mean()
|
| 251 |
+
|
| 252 |
+
# Should have some variation between days
|
| 253 |
+
assert day_consumption.std() > 0
|
| 254 |
+
|
| 255 |
+
# Weekend days (Saturday, Sunday) should generally have higher consumption
|
| 256 |
+
weekend_avg = (day_consumption["Saturday"] + day_consumption["Sunday"]) / 2
|
| 257 |
+
weekday_avg = (
|
| 258 |
+
day_consumption["Monday"]
|
| 259 |
+
+ day_consumption["Tuesday"]
|
| 260 |
+
+ day_consumption["Wednesday"]
|
| 261 |
+
+ day_consumption["Thursday"]
|
| 262 |
+
+ day_consumption["Friday"]
|
| 263 |
+
) / 5
|
| 264 |
+
|
| 265 |
+
# This might not always be true due to randomness, but should be generally true
|
| 266 |
+
# We'll just check that there's variation
|
| 267 |
+
assert abs(weekend_avg - weekday_avg) > 0.1
|
| 268 |
+
|
| 269 |
+
def test_major_event_effect(self):
|
| 270 |
+
"""Test that major events increase consumption."""
|
| 271 |
+
generator = DataGenerator(seed=42)
|
| 272 |
+
data = generator.generate_data(n_samples=1000)
|
| 273 |
+
|
| 274 |
+
# Group by major event and check consumption means
|
| 275 |
+
event_consumption = data.groupby("major_event")["consumption_kwh"].mean()
|
| 276 |
+
|
| 277 |
+
# Consumption should be higher when there's a major event
|
| 278 |
+
assert event_consumption[1] > event_consumption[0]
|
tests/test_integration.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integration tests for the Daily Household Electricity Consumption Predictor.
|
| 3 |
+
|
| 4 |
+
This module contains integration tests that test the complete workflow
|
| 5 |
+
from data generation through model training to prediction.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
import tempfile
|
| 12 |
+
import os
|
| 13 |
+
from src.data_generator import DataGenerator
|
| 14 |
+
from src.model import ElectricityConsumptionModel
|
| 15 |
+
from src.app import ElectricityPredictorApp
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestIntegration:
|
| 19 |
+
"""Integration tests for the complete system."""
|
| 20 |
+
|
| 21 |
+
def setup_method(self):
|
| 22 |
+
"""Set up test environment for each test method."""
|
| 23 |
+
self.generator = DataGenerator(seed=42)
|
| 24 |
+
self.model = ElectricityConsumptionModel()
|
| 25 |
+
self.app = ElectricityPredictorApp()
|
| 26 |
+
|
| 27 |
+
def test_complete_workflow(self):
|
| 28 |
+
"""Test the complete workflow from data generation to prediction."""
|
| 29 |
+
# Step 1: Generate data
|
| 30 |
+
data = self.generator.generate_data(n_samples=1000, noise_level=0.1)
|
| 31 |
+
assert len(data) == 1000
|
| 32 |
+
assert all(
|
| 33 |
+
col in data.columns
|
| 34 |
+
for col in ["temperature", "day_of_week", "major_event", "consumption_kwh"]
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Step 2: Split data
|
| 38 |
+
train_data, val_data, test_data = self.generator.split_data(data)
|
| 39 |
+
assert len(train_data) + len(val_data) + len(test_data) == len(data)
|
| 40 |
+
|
| 41 |
+
# Step 3: Train model
|
| 42 |
+
X_train = train_data.drop("consumption_kwh", axis=1)
|
| 43 |
+
y_train = train_data[["consumption_kwh"]]
|
| 44 |
+
train_metrics = self.model.train(X_train, y_train)
|
| 45 |
+
|
| 46 |
+
assert self.model.is_trained
|
| 47 |
+
assert "train_r2" in train_metrics
|
| 48 |
+
assert train_metrics["train_r2"] > 0.3 # Reasonable performance
|
| 49 |
+
|
| 50 |
+
# Step 4: Evaluate model
|
| 51 |
+
X_test = test_data.drop("consumption_kwh", axis=1)
|
| 52 |
+
y_test = test_data[["consumption_kwh"]]
|
| 53 |
+
test_metrics = self.model.evaluate(X_test, y_test)
|
| 54 |
+
|
| 55 |
+
assert "test_r2" in test_metrics
|
| 56 |
+
assert test_metrics["test_r2"] > 0.3 # Reasonable performance
|
| 57 |
+
|
| 58 |
+
# Step 5: Make predictions
|
| 59 |
+
prediction1 = self.model.predict(25.0, "Monday", 0)
|
| 60 |
+
prediction2 = self.model.predict(30.0, "Saturday", 1)
|
| 61 |
+
|
| 62 |
+
assert prediction1 > 0
|
| 63 |
+
assert prediction2 > 0
|
| 64 |
+
assert (
|
| 65 |
+
prediction2 > prediction1
|
| 66 |
+
) # Higher temp + weekend + event should increase consumption
|
| 67 |
+
|
| 68 |
+
def test_app_integration(self):
|
| 69 |
+
"""Test the complete app workflow."""
|
| 70 |
+
# Test data generation and training through the app
|
| 71 |
+
data_info, training_metrics, evaluation_metrics = self.app.generate_and_train(
|
| 72 |
+
n_samples=500,
|
| 73 |
+
noise_level=0.1,
|
| 74 |
+
train_size=0.7,
|
| 75 |
+
val_size=0.15,
|
| 76 |
+
test_size=0.15,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
assert self.app.is_model_trained
|
| 80 |
+
assert "Data Generated Successfully!" in data_info
|
| 81 |
+
assert "Training Metrics:" in training_metrics
|
| 82 |
+
assert "Test Set Evaluation:" in evaluation_metrics
|
| 83 |
+
|
| 84 |
+
# Test prediction through the app
|
| 85 |
+
prediction_result = self.app.predict_consumption(25.0, "Monday", False)
|
| 86 |
+
assert "Estimated Daily Electricity Consumption:" in prediction_result
|
| 87 |
+
assert "Temperature: 25.0°C" in prediction_result
|
| 88 |
+
|
| 89 |
+
# Test model info through the app
|
| 90 |
+
model_info = self.app.get_model_info()
|
| 91 |
+
assert "Model Information:" in model_info
|
| 92 |
+
assert "Feature Coefficients:" in model_info
|
| 93 |
+
|
| 94 |
+
def test_model_persistence(self):
|
| 95 |
+
"""Test model saving and loading."""
|
| 96 |
+
# Generate data and train model
|
| 97 |
+
data = self.generator.generate_data(n_samples=500)
|
| 98 |
+
train_data, _, _ = self.generator.split_data(data)
|
| 99 |
+
|
| 100 |
+
X_train = train_data.drop("consumption_kwh", axis=1)
|
| 101 |
+
y_train = train_data[["consumption_kwh"]]
|
| 102 |
+
self.model.train(X_train, y_train)
|
| 103 |
+
|
| 104 |
+
# Save model
|
| 105 |
+
with tempfile.NamedTemporaryFile(suffix=".joblib", delete=False) as tmp_file:
|
| 106 |
+
model_path = tmp_file.name
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
self.model.save_model(model_path)
|
| 110 |
+
assert os.path.exists(model_path)
|
| 111 |
+
|
| 112 |
+
# Load model in new instance
|
| 113 |
+
new_model = ElectricityConsumptionModel()
|
| 114 |
+
new_model.load_model(model_path)
|
| 115 |
+
|
| 116 |
+
assert new_model.is_trained
|
| 117 |
+
|
| 118 |
+
# Test predictions are identical
|
| 119 |
+
pred1 = self.model.predict(25.0, "Monday", 0)
|
| 120 |
+
pred2 = new_model.predict(25.0, "Monday", 0)
|
| 121 |
+
|
| 122 |
+
assert abs(pred1 - pred2) < 1e-10
|
| 123 |
+
|
| 124 |
+
finally:
|
| 125 |
+
if os.path.exists(model_path):
|
| 126 |
+
os.unlink(model_path)
|
| 127 |
+
|
| 128 |
+
def test_data_persistence(self):
|
| 129 |
+
"""Test data saving and loading."""
|
| 130 |
+
# Generate data
|
| 131 |
+
data = self.generator.generate_data(n_samples=100)
|
| 132 |
+
|
| 133 |
+
# Save data
|
| 134 |
+
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file:
|
| 135 |
+
data_path = tmp_file.name
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
self.generator.save_data(data, data_path)
|
| 139 |
+
assert os.path.exists(data_path)
|
| 140 |
+
|
| 141 |
+
# Load data
|
| 142 |
+
loaded_data = self.generator.load_data(data_path)
|
| 143 |
+
|
| 144 |
+
# Check data is identical
|
| 145 |
+
pd.testing.assert_frame_equal(data, loaded_data)
|
| 146 |
+
|
| 147 |
+
finally:
|
| 148 |
+
if os.path.exists(data_path):
|
| 149 |
+
os.unlink(data_path)
|
| 150 |
+
|
| 151 |
+
def test_model_performance_consistency(self):
|
| 152 |
+
"""Test that model performance is consistent across runs."""
|
| 153 |
+
# Generate data
|
| 154 |
+
data = self.generator.generate_data(n_samples=1000, noise_level=0.1)
|
| 155 |
+
train_data, _, test_data = self.generator.split_data(data)
|
| 156 |
+
|
| 157 |
+
# Train model multiple times with same data
|
| 158 |
+
X_train = train_data.drop("consumption_kwh", axis=1)
|
| 159 |
+
y_train = train_data[["consumption_kwh"]]
|
| 160 |
+
X_test = test_data.drop("consumption_kwh", axis=1)
|
| 161 |
+
y_test = test_data[["consumption_kwh"]]
|
| 162 |
+
|
| 163 |
+
r2_scores = []
|
| 164 |
+
for _ in range(3):
|
| 165 |
+
model = ElectricityConsumptionModel()
|
| 166 |
+
model.train(X_train, y_train)
|
| 167 |
+
metrics = model.evaluate(X_test, y_test)
|
| 168 |
+
r2_scores.append(metrics["test_r2"])
|
| 169 |
+
|
| 170 |
+
# R² scores should be very similar (within 0.01)
|
| 171 |
+
assert max(r2_scores) - min(r2_scores) < 0.01
|
| 172 |
+
|
| 173 |
+
def test_feature_importance_consistency(self):
|
| 174 |
+
"""Test that feature importance is consistent with domain knowledge."""
|
| 175 |
+
# Generate data and train model
|
| 176 |
+
data = self.generator.generate_data(n_samples=1000)
|
| 177 |
+
train_data, _, _ = self.generator.split_data(data)
|
| 178 |
+
|
| 179 |
+
X_train = train_data.drop("consumption_kwh", axis=1)
|
| 180 |
+
y_train = train_data[["consumption_kwh"]]
|
| 181 |
+
self.model.train(X_train, y_train)
|
| 182 |
+
|
| 183 |
+
# Get coefficients
|
| 184 |
+
coefficients = self.model.get_model_coefficients()
|
| 185 |
+
|
| 186 |
+
# Find temperature coefficient
|
| 187 |
+
temp_idx = coefficients["feature_names"].index("temperature")
|
| 188 |
+
temp_coef = coefficients["coefficients"][temp_idx]
|
| 189 |
+
|
| 190 |
+
# Find major event coefficient
|
| 191 |
+
event_idx = coefficients["feature_names"].index("major_event")
|
| 192 |
+
event_coef = coefficients["coefficients"][event_idx]
|
| 193 |
+
|
| 194 |
+
# Temperature should have positive effect (higher temp = higher consumption)
|
| 195 |
+
assert temp_coef > 0
|
| 196 |
+
|
| 197 |
+
# Major event should have positive effect (events increase consumption)
|
| 198 |
+
assert event_coef > 0
|
| 199 |
+
|
| 200 |
+
def test_prediction_bounds(self):
|
| 201 |
+
"""Test that predictions are within reasonable bounds."""
|
| 202 |
+
# Generate data and train model
|
| 203 |
+
data = self.generator.generate_data(n_samples=1000)
|
| 204 |
+
train_data, _, _ = self.generator.split_data(data)
|
| 205 |
+
|
| 206 |
+
X_train = train_data.drop("consumption_kwh", axis=1)
|
| 207 |
+
y_train = train_data[["consumption_kwh"]]
|
| 208 |
+
self.model.train(X_train, y_train)
|
| 209 |
+
|
| 210 |
+
# Test predictions across different inputs
|
| 211 |
+
predictions = []
|
| 212 |
+
|
| 213 |
+
for temp in [15, 20, 25, 30, 35]:
|
| 214 |
+
for day in [
|
| 215 |
+
"Monday",
|
| 216 |
+
"Tuesday",
|
| 217 |
+
"Wednesday",
|
| 218 |
+
"Thursday",
|
| 219 |
+
"Friday",
|
| 220 |
+
"Saturday",
|
| 221 |
+
"Sunday",
|
| 222 |
+
]:
|
| 223 |
+
for event in [0, 1]:
|
| 224 |
+
pred = self.model.predict(temp, day, event)
|
| 225 |
+
predictions.append(pred)
|
| 226 |
+
|
| 227 |
+
# All predictions should be positive
|
| 228 |
+
assert all(p > 0 for p in predictions)
|
| 229 |
+
|
| 230 |
+
# Predictions should be within reasonable range (5-50 kWh)
|
| 231 |
+
assert all(5 <= p <= 50 for p in predictions)
|
| 232 |
+
|
| 233 |
+
def test_data_quality_checks(self):
|
| 234 |
+
"""Test that generated data meets quality requirements."""
|
| 235 |
+
# Generate data
|
| 236 |
+
data = self.generator.generate_data(n_samples=1000)
|
| 237 |
+
|
| 238 |
+
# Check for missing values
|
| 239 |
+
assert not data.isnull().any().any()
|
| 240 |
+
|
| 241 |
+
# Check data types
|
| 242 |
+
assert data["temperature"].dtype in [np.float64, np.float32]
|
| 243 |
+
assert data["day_of_week"].dtype == "object"
|
| 244 |
+
assert data["major_event"].dtype in [np.int64, np.int32]
|
| 245 |
+
assert data["consumption_kwh"].dtype in [np.float64, np.float32]
|
| 246 |
+
|
| 247 |
+
# Check value ranges
|
| 248 |
+
assert data["temperature"].min() >= 15
|
| 249 |
+
assert data["temperature"].max() <= 35
|
| 250 |
+
assert all(data["major_event"].isin([0, 1]))
|
| 251 |
+
assert all(data["consumption_kwh"] > 0)
|
| 252 |
+
|
| 253 |
+
# Check day of week values
|
| 254 |
+
valid_days = [
|
| 255 |
+
"Monday",
|
| 256 |
+
"Tuesday",
|
| 257 |
+
"Wednesday",
|
| 258 |
+
"Thursday",
|
| 259 |
+
"Friday",
|
| 260 |
+
"Saturday",
|
| 261 |
+
"Sunday",
|
| 262 |
+
]
|
| 263 |
+
assert all(day in valid_days for day in data["day_of_week"].unique())
|
| 264 |
+
|
| 265 |
+
# Check correlations make sense
|
| 266 |
+
temp_consumption_corr = data["temperature"].corr(data["consumption_kwh"])
|
| 267 |
+
assert temp_consumption_corr > 0 # Positive correlation
|
| 268 |
+
|
| 269 |
+
def test_error_handling(self):
|
| 270 |
+
"""Test error handling in the complete workflow."""
|
| 271 |
+
# Test with invalid temperature
|
| 272 |
+
with pytest.raises(ValueError):
|
| 273 |
+
self.model.predict(10.0, "Monday", 0) # Temperature too low
|
| 274 |
+
|
| 275 |
+
with pytest.raises(ValueError):
|
| 276 |
+
self.model.predict(40.0, "Monday", 0) # Temperature too high
|
| 277 |
+
|
| 278 |
+
# Test with invalid day
|
| 279 |
+
with pytest.raises(ValueError):
|
| 280 |
+
self.model.predict(25.0, "InvalidDay", 0)
|
| 281 |
+
|
| 282 |
+
# Test with invalid major event
|
| 283 |
+
with pytest.raises(ValueError):
|
| 284 |
+
self.model.predict(25.0, "Monday", 2) # Invalid value
|
| 285 |
+
|
| 286 |
+
# Test prediction without training
|
| 287 |
+
untrained_model = ElectricityConsumptionModel()
|
| 288 |
+
with pytest.raises(ValueError):
|
| 289 |
+
untrained_model.predict(25.0, "Monday", 0)
|
| 290 |
+
|
| 291 |
+
def test_app_state_management(self):
|
| 292 |
+
"""Test that app state is properly managed."""
|
| 293 |
+
# Initially not trained
|
| 294 |
+
assert not self.app.is_model_trained
|
| 295 |
+
|
| 296 |
+
# After training
|
| 297 |
+
self.app.generate_and_train(500, 0.1, 0.7, 0.15, 0.15)
|
| 298 |
+
assert self.app.is_model_trained
|
| 299 |
+
|
| 300 |
+
# Check that data is stored
|
| 301 |
+
assert hasattr(self.app, "train_data")
|
| 302 |
+
assert hasattr(self.app, "val_data")
|
| 303 |
+
assert hasattr(self.app, "test_data")
|
| 304 |
+
|
| 305 |
+
# Check data sizes
|
| 306 |
+
assert len(self.app.train_data) > 0
|
| 307 |
+
assert len(self.app.val_data) > 0
|
| 308 |
+
assert len(self.app.test_data) > 0
|
tests/test_model.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for the Model module.
|
| 3 |
+
|
| 4 |
+
This module contains comprehensive tests for the ElectricityConsumptionModel class
|
| 5 |
+
to ensure proper model training, evaluation, prediction, and persistence.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
import tempfile
|
| 12 |
+
import os
|
| 13 |
+
from src.model import ElectricityConsumptionModel
|
| 14 |
+
from src.data_generator import DataGenerator
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestElectricityConsumptionModel:
|
| 18 |
+
"""Test cases for ElectricityConsumptionModel class."""
|
| 19 |
+
|
| 20 |
+
def setup_method(self):
|
| 21 |
+
"""Set up test data for each test method."""
|
| 22 |
+
self.generator = DataGenerator(seed=42)
|
| 23 |
+
self.data = self.generator.generate_data(n_samples=1000)
|
| 24 |
+
self.train_data, self.val_data, self.test_data = self.generator.split_data(
|
| 25 |
+
self.data
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
self.model = ElectricityConsumptionModel()
|
| 29 |
+
|
| 30 |
+
def test_initialization(self):
|
| 31 |
+
"""Test model initialization."""
|
| 32 |
+
model = ElectricityConsumptionModel()
|
| 33 |
+
|
| 34 |
+
assert model.model is None
|
| 35 |
+
assert model.preprocessor is None
|
| 36 |
+
assert model.feature_names is None
|
| 37 |
+
assert not model.is_trained
|
| 38 |
+
|
| 39 |
+
def test_prepare_features_valid_data(self):
|
| 40 |
+
"""Test feature preparation with valid data."""
|
| 41 |
+
# Test with valid data
|
| 42 |
+
valid_data = pd.DataFrame(
|
| 43 |
+
{
|
| 44 |
+
"temperature": [25.0, 30.0],
|
| 45 |
+
"day_of_week": ["Monday", "Saturday"],
|
| 46 |
+
"major_event": [0, 1],
|
| 47 |
+
}
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
prepared_data = self.model.prepare_features(valid_data)
|
| 51 |
+
|
| 52 |
+
assert isinstance(prepared_data, pd.DataFrame)
|
| 53 |
+
assert list(prepared_data.columns) == [
|
| 54 |
+
"temperature",
|
| 55 |
+
"day_of_week",
|
| 56 |
+
"major_event",
|
| 57 |
+
]
|
| 58 |
+
assert len(prepared_data) == 2
|
| 59 |
+
|
| 60 |
+
def test_prepare_features_missing_columns(self):
|
| 61 |
+
"""Test feature preparation with missing columns."""
|
| 62 |
+
invalid_data = pd.DataFrame(
|
| 63 |
+
{
|
| 64 |
+
"temperature": [25.0],
|
| 65 |
+
"day_of_week": ["Monday"],
|
| 66 |
+
# Missing major_event column
|
| 67 |
+
}
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
with pytest.raises(ValueError, match="Missing required columns"):
|
| 71 |
+
self.model.prepare_features(invalid_data)
|
| 72 |
+
|
| 73 |
+
def test_prepare_features_invalid_temperature(self):
|
| 74 |
+
"""Test feature preparation with invalid temperature values."""
|
| 75 |
+
invalid_data = pd.DataFrame(
|
| 76 |
+
{
|
| 77 |
+
"temperature": [10.0, 40.0], # Outside valid range
|
| 78 |
+
"day_of_week": ["Monday", "Tuesday"],
|
| 79 |
+
"major_event": [0, 0],
|
| 80 |
+
}
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
with pytest.raises(ValueError, match="Temperature must be between 15 and 35"):
|
| 84 |
+
self.model.prepare_features(invalid_data)
|
| 85 |
+
|
| 86 |
+
def test_prepare_features_invalid_day_of_week(self):
|
| 87 |
+
"""Test feature preparation with invalid day of week values."""
|
| 88 |
+
invalid_data = pd.DataFrame(
|
| 89 |
+
{"temperature": [25.0], "day_of_week": ["InvalidDay"], "major_event": [0]}
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
with pytest.raises(ValueError, match="Day of week must be one of"):
|
| 93 |
+
self.model.prepare_features(invalid_data)
|
| 94 |
+
|
| 95 |
+
def test_prepare_features_invalid_major_event(self):
|
| 96 |
+
"""Test feature preparation with invalid major event values."""
|
| 97 |
+
invalid_data = pd.DataFrame(
|
| 98 |
+
{
|
| 99 |
+
"temperature": [25.0],
|
| 100 |
+
"day_of_week": ["Monday"],
|
| 101 |
+
"major_event": [2], # Invalid value
|
| 102 |
+
}
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
with pytest.raises(ValueError, match="Major event must be 0 or 1"):
|
| 106 |
+
self.model.prepare_features(invalid_data)
|
| 107 |
+
|
| 108 |
+
def test_train_model(self):
|
| 109 |
+
"""Test model training."""
|
| 110 |
+
X_train = self.train_data.drop("consumption_kwh", axis=1)
|
| 111 |
+
y_train = self.train_data[["consumption_kwh"]]
|
| 112 |
+
|
| 113 |
+
metrics = self.model.train(X_train, y_train)
|
| 114 |
+
|
| 115 |
+
# Check that model is trained
|
| 116 |
+
assert self.model.is_trained
|
| 117 |
+
assert self.model.model is not None
|
| 118 |
+
assert self.model.feature_names is not None
|
| 119 |
+
|
| 120 |
+
# Check metrics structure
|
| 121 |
+
expected_metrics = ["train_mse", "train_rmse", "train_mae", "train_r2"]
|
| 122 |
+
assert all(metric in metrics for metric in expected_metrics)
|
| 123 |
+
|
| 124 |
+
# Check metric values are reasonable
|
| 125 |
+
assert metrics["train_mse"] > 0
|
| 126 |
+
assert metrics["train_rmse"] > 0
|
| 127 |
+
assert metrics["train_mae"] > 0
|
| 128 |
+
assert 0 <= metrics["train_r2"] <= 1
|
| 129 |
+
|
| 130 |
+
def test_evaluate_model_not_trained(self):
|
| 131 |
+
"""Test evaluation when model is not trained."""
|
| 132 |
+
X_test = self.test_data.drop("consumption_kwh", axis=1)
|
| 133 |
+
y_test = self.test_data[["consumption_kwh"]]
|
| 134 |
+
|
| 135 |
+
with pytest.raises(ValueError, match="Model must be trained before evaluation"):
|
| 136 |
+
self.model.evaluate(X_test, y_test)
|
| 137 |
+
|
| 138 |
+
def test_evaluate_model(self):
|
| 139 |
+
"""Test model evaluation."""
|
| 140 |
+
# Train model first
|
| 141 |
+
X_train = self.train_data.drop("consumption_kwh", axis=1)
|
| 142 |
+
y_train = self.train_data[["consumption_kwh"]]
|
| 143 |
+
self.model.train(X_train, y_train)
|
| 144 |
+
|
| 145 |
+
# Evaluate model
|
| 146 |
+
X_test = self.test_data.drop("consumption_kwh", axis=1)
|
| 147 |
+
y_test = self.test_data[["consumption_kwh"]]
|
| 148 |
+
|
| 149 |
+
metrics = self.model.evaluate(X_test, y_test)
|
| 150 |
+
|
| 151 |
+
# Check metrics structure
|
| 152 |
+
expected_metrics = ["test_mse", "test_rmse", "test_mae", "test_r2"]
|
| 153 |
+
assert all(metric in metrics for metric in expected_metrics)
|
| 154 |
+
|
| 155 |
+
# Check metric values are reasonable
|
| 156 |
+
assert metrics["test_mse"] > 0
|
| 157 |
+
assert metrics["test_rmse"] > 0
|
| 158 |
+
assert metrics["test_mae"] > 0
|
| 159 |
+
assert 0 <= metrics["test_r2"] <= 1
|
| 160 |
+
|
| 161 |
+
def test_predict_not_trained(self):
|
| 162 |
+
"""Test prediction when model is not trained."""
|
| 163 |
+
with pytest.raises(
|
| 164 |
+
ValueError, match="Model must be trained before making predictions"
|
| 165 |
+
):
|
| 166 |
+
self.model.predict(25.0, "Monday", 0)
|
| 167 |
+
|
| 168 |
+
def test_predict_valid_inputs(self):
|
| 169 |
+
"""Test prediction with valid inputs."""
|
| 170 |
+
# Train model first
|
| 171 |
+
X_train = self.train_data.drop("consumption_kwh", axis=1)
|
| 172 |
+
y_train = self.train_data[["consumption_kwh"]]
|
| 173 |
+
self.model.train(X_train, y_train)
|
| 174 |
+
|
| 175 |
+
# Test prediction
|
| 176 |
+
prediction = self.model.predict(25.0, "Monday", 0)
|
| 177 |
+
|
| 178 |
+
assert isinstance(prediction, float)
|
| 179 |
+
assert prediction >= 0 # Should be non-negative
|
| 180 |
+
|
| 181 |
+
def test_predict_different_inputs(self):
|
| 182 |
+
"""Test prediction with different input combinations."""
|
| 183 |
+
# Train model first
|
| 184 |
+
X_train = self.train_data.drop("consumption_kwh", axis=1)
|
| 185 |
+
y_train = self.train_data[["consumption_kwh"]]
|
| 186 |
+
self.model.train(X_train, y_train)
|
| 187 |
+
|
| 188 |
+
# Test different temperature values
|
| 189 |
+
pred1 = self.model.predict(20.0, "Monday", 0)
|
| 190 |
+
pred2 = self.model.predict(30.0, "Monday", 0)
|
| 191 |
+
|
| 192 |
+
# Higher temperature should generally lead to higher consumption
|
| 193 |
+
assert pred2 > pred1
|
| 194 |
+
|
| 195 |
+
# Test different days
|
| 196 |
+
pred3 = self.model.predict(25.0, "Saturday", 0)
|
| 197 |
+
pred4 = self.model.predict(25.0, "Monday", 0)
|
| 198 |
+
|
| 199 |
+
# Should be different (though not necessarily higher/lower due to randomness)
|
| 200 |
+
assert pred3 != pred4
|
| 201 |
+
|
| 202 |
+
# Test with and without major event
|
| 203 |
+
pred5 = self.model.predict(25.0, "Monday", 1)
|
| 204 |
+
pred6 = self.model.predict(25.0, "Monday", 0)
|
| 205 |
+
|
| 206 |
+
# Major event should increase consumption
|
| 207 |
+
assert pred5 > pred6
|
| 208 |
+
|
| 209 |
+
def test_get_model_coefficients_not_trained(self):
|
| 210 |
+
"""Test getting coefficients when model is not trained."""
|
| 211 |
+
with pytest.raises(
|
| 212 |
+
ValueError, match="Model must be trained before accessing coefficients"
|
| 213 |
+
):
|
| 214 |
+
self.model.get_model_coefficients()
|
| 215 |
+
|
| 216 |
+
def test_get_model_coefficients(self):
|
| 217 |
+
"""Test getting model coefficients."""
|
| 218 |
+
# Train model first
|
| 219 |
+
X_train = self.train_data.drop("consumption_kwh", axis=1)
|
| 220 |
+
y_train = self.train_data[["consumption_kwh"]]
|
| 221 |
+
self.model.train(X_train, y_train)
|
| 222 |
+
|
| 223 |
+
coefficients = self.model.get_model_coefficients()
|
| 224 |
+
|
| 225 |
+
# Check structure
|
| 226 |
+
assert "feature_names" in coefficients
|
| 227 |
+
assert "coefficients" in coefficients
|
| 228 |
+
assert "intercept" in coefficients
|
| 229 |
+
|
| 230 |
+
# Check types
|
| 231 |
+
assert isinstance(coefficients["feature_names"], list)
|
| 232 |
+
assert isinstance(coefficients["coefficients"], list)
|
| 233 |
+
assert isinstance(coefficients["intercept"], float)
|
| 234 |
+
|
| 235 |
+
# Check lengths
|
| 236 |
+
assert len(coefficients["feature_names"]) == len(coefficients["coefficients"])
|
| 237 |
+
assert len(coefficients["feature_names"]) > 0
|
| 238 |
+
|
| 239 |
+
def test_save_model_not_trained(self):
|
| 240 |
+
"""Test saving model when not trained."""
|
| 241 |
+
with tempfile.NamedTemporaryFile(suffix=".joblib", delete=False) as tmp_file:
|
| 242 |
+
filepath = tmp_file.name
|
| 243 |
+
|
| 244 |
+
try:
|
| 245 |
+
with pytest.raises(ValueError, match="Model must be trained before saving"):
|
| 246 |
+
self.model.save_model(filepath)
|
| 247 |
+
finally:
|
| 248 |
+
if os.path.exists(filepath):
|
| 249 |
+
os.unlink(filepath)
|
| 250 |
+
|
| 251 |
+
def test_save_and_load_model(self):
|
| 252 |
+
"""Test saving and loading model."""
|
| 253 |
+
# Train model first
|
| 254 |
+
X_train = self.train_data.drop("consumption_kwh", axis=1)
|
| 255 |
+
y_train = self.train_data[["consumption_kwh"]]
|
| 256 |
+
self.model.train(X_train, y_train)
|
| 257 |
+
|
| 258 |
+
with tempfile.NamedTemporaryFile(suffix=".joblib", delete=False) as tmp_file:
|
| 259 |
+
filepath = tmp_file.name
|
| 260 |
+
|
| 261 |
+
try:
|
| 262 |
+
# Save model
|
| 263 |
+
self.model.save_model(filepath)
|
| 264 |
+
assert os.path.exists(filepath)
|
| 265 |
+
|
| 266 |
+
# Create new model and load
|
| 267 |
+
new_model = ElectricityConsumptionModel()
|
| 268 |
+
new_model.load_model(filepath)
|
| 269 |
+
|
| 270 |
+
# Check that model is trained
|
| 271 |
+
assert new_model.is_trained
|
| 272 |
+
assert new_model.model is not None
|
| 273 |
+
|
| 274 |
+
# Test prediction with loaded model
|
| 275 |
+
original_pred = self.model.predict(25.0, "Monday", 0)
|
| 276 |
+
loaded_pred = new_model.predict(25.0, "Monday", 0)
|
| 277 |
+
|
| 278 |
+
# Predictions should be identical
|
| 279 |
+
assert abs(original_pred - loaded_pred) < 1e-10
|
| 280 |
+
|
| 281 |
+
finally:
|
| 282 |
+
if os.path.exists(filepath):
|
| 283 |
+
os.unlink(filepath)
|
| 284 |
+
|
| 285 |
+
def test_load_model_file_not_found(self):
|
| 286 |
+
"""Test loading model from non-existent file."""
|
| 287 |
+
with pytest.raises(FileNotFoundError):
|
| 288 |
+
self.model.load_model("non_existent_file.joblib")
|
| 289 |
+
|
| 290 |
+
def test_model_performance_reasonable(self):
|
| 291 |
+
"""Test that model performance is reasonable."""
|
| 292 |
+
# Train model
|
| 293 |
+
X_train = self.train_data.drop("consumption_kwh", axis=1)
|
| 294 |
+
y_train = self.train_data[["consumption_kwh"]]
|
| 295 |
+
train_metrics = self.model.train(X_train, y_train)
|
| 296 |
+
|
| 297 |
+
# Evaluate model
|
| 298 |
+
X_test = self.test_data.drop("consumption_kwh", axis=1)
|
| 299 |
+
y_test = self.test_data[["consumption_kwh"]]
|
| 300 |
+
test_metrics = self.model.evaluate(X_test, y_test)
|
| 301 |
+
|
| 302 |
+
# R-squared should be reasonable (not too low, not perfect)
|
| 303 |
+
assert 0.3 <= train_metrics["train_r2"] <= 0.995
|
| 304 |
+
assert 0.3 <= test_metrics["test_r2"] <= 0.995
|
| 305 |
+
|
| 306 |
+
# Test R-squared should not be much worse than train R-squared
|
| 307 |
+
assert test_metrics["test_r2"] >= train_metrics["train_r2"] - 0.2
|
| 308 |
+
|
| 309 |
+
def test_model_consistency(self):
|
| 310 |
+
"""Test that model predictions are consistent."""
|
| 311 |
+
# Train model
|
| 312 |
+
X_train = self.train_data.drop("consumption_kwh", axis=1)
|
| 313 |
+
y_train = self.train_data[["consumption_kwh"]]
|
| 314 |
+
self.model.train(X_train, y_train)
|
| 315 |
+
|
| 316 |
+
# Make same prediction multiple times
|
| 317 |
+
pred1 = self.model.predict(25.0, "Monday", 0)
|
| 318 |
+
pred2 = self.model.predict(25.0, "Monday", 0)
|
| 319 |
+
pred3 = self.model.predict(25.0, "Monday", 0)
|
| 320 |
+
|
| 321 |
+
# All predictions should be identical
|
| 322 |
+
assert abs(pred1 - pred2) < 1e-10
|
| 323 |
+
assert abs(pred2 - pred3) < 1e-10
|
| 324 |
+
|
| 325 |
+
def test_model_feature_importance(self):
|
| 326 |
+
"""Test that model captures feature importance correctly."""
|
| 327 |
+
# Train model
|
| 328 |
+
X_train = self.train_data.drop("consumption_kwh", axis=1)
|
| 329 |
+
y_train = self.train_data[["consumption_kwh"]]
|
| 330 |
+
self.model.train(X_train, y_train)
|
| 331 |
+
|
| 332 |
+
coefficients = self.model.get_model_coefficients()
|
| 333 |
+
|
| 334 |
+
# Temperature coefficient should be positive (higher temp = higher consumption)
|
| 335 |
+
temp_idx = coefficients["feature_names"].index("temperature")
|
| 336 |
+
assert coefficients["coefficients"][temp_idx] > 0
|
| 337 |
+
|
| 338 |
+
# Major event coefficient should be positive (events increase consumption)
|
| 339 |
+
event_idx = coefficients["feature_names"].index("major_event")
|
| 340 |
+
assert coefficients["coefficients"][event_idx] > 0
|
| 341 |
+
|
| 342 |
+
def test_model_with_extreme_values(self):
|
| 343 |
+
"""Test model behavior with extreme input values."""
|
| 344 |
+
# Train model
|
| 345 |
+
X_train = self.train_data.drop("consumption_kwh", axis=1)
|
| 346 |
+
y_train = self.train_data[["consumption_kwh"]]
|
| 347 |
+
self.model.train(X_train, y_train)
|
| 348 |
+
|
| 349 |
+
# Test with minimum temperature
|
| 350 |
+
min_pred = self.model.predict(15.0, "Monday", 0)
|
| 351 |
+
assert min_pred >= 0
|
| 352 |
+
|
| 353 |
+
# Test with maximum temperature
|
| 354 |
+
max_pred = self.model.predict(35.0, "Monday", 0)
|
| 355 |
+
assert max_pred >= 0
|
| 356 |
+
|
| 357 |
+
# Test with major event
|
| 358 |
+
event_pred = self.model.predict(25.0, "Monday", 1)
|
| 359 |
+
assert event_pred >= 0
|