luohoa97 commited on
Commit
d5b7ee9
·
verified ·
1 Parent(s): 6977c60

Deploy BitNet-Transformer Trainer

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +210 -0
  2. .python-version +1 -0
  3. .qwen/settings.json +14 -0
  4. .qwen/settings.json.orig +7 -0
  5. Dockerfile +27 -0
  6. LICENSE +674 -0
  7. README.md +207 -7
  8. docker-compose.yml +21 -0
  9. pyproject.toml +36 -0
  10. requirements.txt +21 -0
  11. scripts/__init__.py +1 -0
  12. scripts/generate_ai_dataset.py +181 -0
  13. scripts/multi_backtest.py +98 -0
  14. scripts/optimize_strategy.py +94 -0
  15. scripts/sync_to_hf.sh +12 -0
  16. scripts/test_inference.py +27 -0
  17. scripts/train_ai_model.py +204 -0
  18. scripts/verify_ai_strategy.py +39 -0
  19. test_finbert_multithread.py +71 -0
  20. test_signal_fix.py +100 -0
  21. trading_cli/__main__.py +119 -0
  22. trading_cli/app.py +995 -0
  23. trading_cli/backtest/__init__.py +3 -0
  24. trading_cli/backtest/engine.py +454 -0
  25. trading_cli/config.py +92 -0
  26. trading_cli/data/asset_search.py +261 -0
  27. trading_cli/data/db.py +234 -0
  28. trading_cli/data/market.py +126 -0
  29. trading_cli/data/news.py +136 -0
  30. trading_cli/execution/adapter_factory.py +50 -0
  31. trading_cli/execution/adapters/__init__.py +37 -0
  32. trading_cli/execution/adapters/alpaca.py +331 -0
  33. trading_cli/execution/adapters/base.py +177 -0
  34. trading_cli/execution/adapters/binance.py +207 -0
  35. trading_cli/execution/adapters/kraken.py +198 -0
  36. trading_cli/execution/adapters/registry.py +73 -0
  37. trading_cli/execution/adapters/yfinance.py +169 -0
  38. trading_cli/execution/alpaca_client.py +266 -0
  39. trading_cli/run_dev.py +42 -0
  40. trading_cli/screens/backtest.py +368 -0
  41. trading_cli/screens/config_screen.py +354 -0
  42. trading_cli/screens/dashboard.py +160 -0
  43. trading_cli/screens/portfolio.py +141 -0
  44. trading_cli/screens/sentiment.py +251 -0
  45. trading_cli/screens/trades.py +101 -0
  46. trading_cli/screens/watchlist.py +119 -0
  47. trading_cli/sentiment/aggregator.py +124 -0
  48. trading_cli/sentiment/finbert.py +453 -0
  49. trading_cli/sentiment/news_classifier.py +118 -0
  50. trading_cli/strategy/adapters/__init__.py +39 -0
.gitignore ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+ /models/
209
+ *.pt
210
+ *.safetensors
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
.qwen/settings.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(python *)",
5
+ "WebFetch(docs.alpaca.markets)",
6
+ "WebFetch(alpaca.markets)",
7
+ "WebSearch",
8
+ "WebFetch(stage.partners.liveu.tv)",
9
+ "WebFetch(pypi.org)",
10
+ "Bash(cat *)"
11
+ ]
12
+ },
13
+ "$version": 3
14
+ }
.qwen/settings.json.orig ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(python *)"
5
+ ]
6
+ }
7
+ }
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ git \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Install dependencies
12
+ COPY requirements.txt .
13
+ # Add safetensors, huggingface_hub, scikit-learn explicitly just in case
14
+ RUN pip install --no-cache-dir -r requirements.txt \
15
+ safetensors huggingface_hub scikit-learn pandas numpy torch yfinance
16
+
17
+ # Copy project files
18
+ COPY . .
19
+
20
+ # Environment variables (to be set in HF Space Secrets)
21
+ ENV HF_HOME=/tmp/huggingface
22
+ ENV HF_REPO_ID=""
23
+ ENV HF_TOKEN=""
24
+
25
+ # Command to run training
26
+ # This will output the performance report and upload to HF Hub
27
+ CMD ["python", "scripts/train_ai_model.py"]
LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
README.md CHANGED
@@ -1,12 +1,212 @@
1
  ---
2
- title: BitFinTrainer
3
- emoji: 📉
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
8
- license: gpl-3.0
9
- short_description: The trainer for the BitFin Models
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: AI Trading Fusion - BitNet Transformer
3
+ emoji: 📈
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
  pinned: false
 
 
8
  ---
9
 
10
+ # trading-cli
11
+
12
+ A full-screen TUI AI trading application powered by **FinBERT** sentiment analysis and **Alpaca** paper trading.
13
+
14
+ ```
15
+ ┌─────────────────────────────────────────────────────────┐
16
+ │ TRADING CLI - Paper Trading Mode Cash: $98,234.50 │
17
+ ├─────────────────────────────────────────────────────────┤
18
+ │ [1] Dashboard [2] Watchlist [3] Portfolio │
19
+ │ [4] Trades [5] Sentiment [6] Config [q] Quit │
20
+ ├─────────────────────────────────────────────────────────┤
21
+ │ MARKET STATUS: ● OPEN Last Updated: 14:23:45 EST │
22
+ └─────────────────────────────────────────────────────────┘
23
+ ```
24
+
25
+ ---
26
+
27
+ ## Features
28
+
29
+ | Feature | Details |
30
+ |---|---|
31
+ | Full-screen TUI | Textual-based, single command launch |
32
+ | FinBERT sentiment | Local inference, ProsusAI/finbert |
33
+ | Paper trading | Alpaca paper API (or built-in demo mode) |
34
+ | Live prices | Alpaca market data + yfinance fallback |
35
+ | Hybrid signals | 0.6 × technical + 0.4 × sentiment |
36
+ | Persistent state | SQLite (trades, watchlist, sentiment cache) |
37
+ | Demo mode | Works without any API keys |
38
+
39
+ ---
40
+
41
+ ## Quick Start
42
+
43
+ ### 1. Install uv (if not already installed)
44
+ ```bash
45
+ curl -LsSf https://astral.sh/uv/install.sh | sh
46
+ ```
47
+
48
+ ### 2. Clone and install
49
+ ```bash
50
+ git clone https://github.com/luohoa97/ai-trading.git
51
+ cd ai-trading
52
+ uv sync
53
+ ```
54
+
55
+ ### 3. Run
56
+ ```bash
57
+ uv run trading-cli
58
+ ```
59
+
60
+ On first launch, FinBERT (~500 MB) downloads from HuggingFace and is cached locally.
61
+ The app starts in **Demo Mode** automatically if no Alpaca keys are configured.
62
+
63
+ ---
64
+
65
+ ## Alpaca Paper Trading Setup (optional)
66
+
67
+ 1. Sign up at [alpaca.markets](https://alpaca.markets) — free, no credit card needed
68
+ 2. Generate paper trading API keys in the Alpaca dashboard
69
+ 3. Open Config in the app (`6`), enter your keys, press `Ctrl+S`
70
+
71
+ The app always uses paper trading endpoints — no real money is ever at risk.
72
+
73
+ ---
74
+
75
+ ## Configuration
76
+
77
+ Config file: `~/.config/trading-cli/config.toml`
78
+
79
+ ```toml
80
+ alpaca_api_key = "PKxxxxxxxxxxxx"
81
+ alpaca_api_secret = "xxxxxxxxxxxxxxxxxxxxxxxxxxxx"
82
+ alpaca_paper = true
83
+
84
+ # Risk management
85
+ risk_pct = 0.02 # 2% of portfolio per trade
86
+ max_drawdown = 0.15 # halt trading at 15% drawdown
87
+ stop_loss_pct = 0.05 # 5% stop-loss per position
88
+ max_positions = 10
89
+
90
+ # Signal thresholds (hybrid score: -1 to +1)
91
+ signal_buy_threshold = 0.5
92
+ signal_sell_threshold = -0.3
93
+
94
+ # Poll intervals (seconds)
95
+ poll_interval_prices = 30
96
+ poll_interval_news = 900
97
+ poll_interval_signals = 300
98
+ poll_interval_positions = 60
99
+ ```
100
+
101
+ ---
102
+
103
+ ## Keyboard Shortcuts
104
+
105
+ | Key | Action |
106
+ |---|---|
107
+ | `1`–`6` | Switch screens |
108
+ | `q` / `Ctrl+C` | Quit |
109
+ | `r` | Refresh current screen |
110
+ | `a` | Add symbol (Watchlist) |
111
+ | `d` | Delete selected symbol (Watchlist) |
112
+ | `x` | Close position (Portfolio) |
113
+ | `e` | Export trades to CSV (Trades) |
114
+ | `f` | Focus filter (Trades) |
115
+ | `Enter` | Submit symbol / confirm action |
116
+ | `Ctrl+S` | Save config (Config screen) |
117
+
118
+ ---
119
+
120
+ ## Screens
121
+
122
+ **1 — Dashboard**: Account balance, market status, live positions, real-time signal log.
123
+
124
+ **2 — Watchlist**: Add/remove symbols. See live prices, sentiment score, and BUY/SELL/HOLD signal per symbol.
125
+
126
+ **3 — Portfolio**: Full position detail from Alpaca. Press `x` to close a position via market order.
127
+
128
+ **4 — Trades**: Scrollable history with Alpaca `order_id`. Press `e` to export CSV.
129
+
130
+ **5 — Sentiment**: Type any symbol, press Enter — see FinBERT scores per headline and an aggregated gauge.
131
+
132
+ **6 — Config**: Edit API keys, thresholds, risk limits, toggle auto-trading.
133
+
134
+ ---
135
+
136
+ ## Trading Strategy
137
+
138
+ **Signal = 0.6 × technical + 0.4 × sentiment**
139
+
140
+ | Component | Calculation |
141
+ |---|---|
142
+ | `technical_score` | 0.5 × SMA crossover (20/50) + 0.5 × RSI(14) |
143
+ | `sentiment_score` | FinBERT weighted average on latest news |
144
+ | BUY | hybrid > +0.50 |
145
+ | SELL | hybrid < −0.30 |
146
+
147
+ In **manual mode** (default), signals appear in the log for review.
148
+ In **auto-trading mode** (Config → toggle), market orders are submitted automatically.
149
+
150
+ ---
151
+
152
+ ## Project Structure
153
+
154
+ ```
155
+ trading_cli/
156
+ ├── __main__.py # Entry point: uv run trading-cli
157
+ ├── app.py # Textual App, workers, screen routing
158
+ ├── config.py # Load/save ~/.config/trading-cli/config.toml
159
+ ├── screens/
160
+ │ ├── dashboard.py # Screen 1 — main dashboard
161
+ │ ├── watchlist.py # Screen 2 — symbol watchlist
162
+ │ ├── portfolio.py # Screen 3 — positions & P&L
163
+ │ ├── trades.py # Screen 4 — trade history
164
+ │ ├── sentiment.py # Screen 5 — FinBERT analysis
165
+ │ └── config_screen.py # Screen 6 — settings editor
166
+ ├── widgets/
167
+ │ ├── positions_table.py # Reusable P&L table
168
+ │ ├── signal_log.py # Scrolling signal feed
169
+ │ └── sentiment_gauge.py # Visual [-1, +1] gauge
170
+ ├── sentiment/
171
+ │ ├── finbert.py # Singleton model, batch inference, cache
172
+ │ └── aggregator.py # Score aggregation + gauge renderer
173
+ ├── strategy/
174
+ │ ├── signals.py # SMA + RSI + sentiment hybrid signal
175
+ │ └── risk.py # Position sizing, stop-loss, drawdown
176
+ ├── execution/
177
+ │ └── alpaca_client.py # Real AlpacaClient + MockAlpacaClient
178
+ └── data/
179
+ ├── market.py # OHLCV via Alpaca / yfinance
180
+ ├── news.py # Headlines via Alpaca News / yfinance
181
+ └── db.py # SQLite schema + all queries
182
+ ```
183
+
184
+ ---
185
+
186
+ ## Database
187
+
188
+ Location: `~/.config/trading-cli/trades.db`
189
+
190
+ | Table | Contents |
191
+ |---|---|
192
+ | `trades` | Every executed order with Alpaca `order_id` |
193
+ | `signals` | Every generated signal (executed or not) |
194
+ | `watchlist` | Monitored symbols |
195
+ | `sentiment_cache` | MD5(headline) → label + score |
196
+ | `price_history` | OHLCV bars per symbol |
197
+
198
+ ---
199
+
200
+ ## Development
201
+
202
+ ```bash
203
+ # Run app
204
+ uv run trading-cli
205
+
206
+ # Live logs
207
+ tail -f ~/.config/trading-cli/app.log
208
+
209
+ # Reset state
210
+ rm ~/.config/trading-cli/trades.db
211
+ rm ~/.config/trading-cli/config.toml
212
+ ```# ai-trading
docker-compose.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ trading-cli:
3
+ build:
4
+ context: .
5
+ dockerfile: Dockerfile
6
+ image: ai-trading:latest
7
+ container_name: trading-cli
8
+ environment:
9
+ - TOKENIZERS_PARALLELISM=false
10
+ - TRANSFORMERS_VERBOSITY=error
11
+ - HF_HUB_DISABLE_TELEMETRY=1
12
+ - TQDM_DISABLE=1
13
+ volumes:
14
+ - hf-cache:/root/.cache/huggingface
15
+ stdin_open: true
16
+ tty: true
17
+ restart: unless-stopped
18
+
19
+ volumes:
20
+ hf-cache:
21
+ name: hf-cache
pyproject.toml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "trading-cli"
3
+ version = "0.1.0"
4
+ description = "Full-screen TUI AI trading app with FinBERT sentiment analysis"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11,<3.12"
7
+ license = { text = "MIT" }
8
+ dependencies = [
9
+ "textual>=0.61.0",
10
+ "rich>=13.7.0",
11
+ "click>=8.1.7",
12
+ "alpaca-py>=0.28.0",
13
+ "transformers>=4.40.0",
14
+ "torch>=2.2.0",
15
+ "yfinance>=0.2.38",
16
+ "pandas>=2.2.0",
17
+ "numpy>=1.26.0",
18
+ "toml>=0.10.2",
19
+ "scipy>=1.12.0",
20
+ "textual-autocomplete>=3.0.0",
21
+ "sentence-transformers>=2.2.0",
22
+ ]
23
+
24
+ [project.scripts]
25
+ trading-cli = "trading_cli.__main__:main"
26
+ trading-cli-dev = "trading_cli.run_dev:main"
27
+
28
+ [project.optional-dependencies]
29
+ dev = ["watchfiles>=0.20.0"]
30
+
31
+ [build-system]
32
+ requires = ["hatchling"]
33
+ build-backend = "hatchling.build"
34
+
35
+ [tool.hatch.build.targets.wheel]
36
+ packages = ["trading_cli"]
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core CLI + TUI
2
+ click>=8.1.7
3
+ rich>=13.7.0
4
+
5
+ # ML / NLP
6
+ transformers>=4.40.0
7
+ torch>=2.2.0
8
+
9
+ # Market data
10
+ yfinance>=0.2.38
11
+
12
+ # Data
13
+ pandas>=2.2.0
14
+ numpy>=1.26.0
15
+ scipy>=1.12.0
16
+
17
+ # Config
18
+ toml>=0.10.2
19
+
20
+ # Optional: live trading via Alpaca
21
+ # alpaca-py>=0.20.0
scripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Scripts package
scripts/generate_ai_dataset.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate training dataset for AI Fusion strategy.
4
+ Fetches historical OHLCV, computes technical features, and labels data.
5
+ """
6
+
7
+ import sys
8
+ import os
9
+ import pandas as pd
10
+ import numpy as np
11
+ import logging
12
+ import torch
13
+ from datetime import datetime, timedelta
14
+
15
+ # Add project root to path
16
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
17
+
18
+ from trading_cli.data.market import fetch_ohlcv_yfinance
19
+ from trading_cli.strategy.signals import (
20
+ calculate_rsi,
21
+ calculate_sma,
22
+ calculate_atr,
23
+ calculate_bollinger_bands
24
+ )
25
+
26
+ # Configure logging
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+ SYMBOLS = [
31
+ "AAPL", "MSFT", "GOOGL", "AMZN", "TSLA", "NVDA", "AMD", "META", "NFLX", "ADBE",
32
+ "CRM", "INTC", "CSCO", "ORCL", "QCOM", "AVGO", "TXN", "AMAT", "MU", "LRCX",
33
+ "JPM", "BAC", "WFC", "GS", "MS", "V", "MA", "AXP", "BLK", "BX",
34
+ "XOM", "CVX", "COP", "SLB", "HAL", "MPC", "PSX", "VLO", "OXY", "HES",
35
+ "JNJ", "PFE", "UNH", "ABBV", "MRK", "LLY", "TMO", "DHR", "ISRG", "GILD",
36
+ "WMT", "COST", "HD", "LOW", "TGT", "PG", "KO", "PEP", "PM", "MO",
37
+ "CAT", "DE", "HON", "GE", "MMM", "UPS", "FDX", "RTX", "LMT", "GD",
38
+ "BTC-USD", "ETH-USD", "GC=F", "CL=F" # Crypto and Commodities for diversity
39
+ ]
40
+ DAYS = 3652 # 10 years
41
+ LOOKAHEAD = 5 # Prediction window (days)
42
+ TARGET_PCT = 0.02 # Profit target (2%)
43
+ STOP_PCT = 0.015 # Stop loss (1.5%)
44
+
45
+ def generate_features(df):
46
+ """Compute technical indicators for the feature vector."""
47
+ close = df["close" if "close" in df.columns else "Close"]
48
+ high = df["high" if "high" in df.columns else "High"]
49
+ low = df["low" if "low" in df.columns else "Low"]
50
+
51
+ # 1. RSI(2) - Very short period
52
+ rsi2 = calculate_rsi(close, 2) / 100.0
53
+ # 2. RSI(14) - Standard period
54
+ rsi14 = calculate_rsi(close, 14) / 100.0
55
+ # 3. SMA distance (20, 50, 200)
56
+ sma20 = calculate_sma(close, 20)
57
+ sma50 = calculate_sma(close, 50)
58
+ sma200 = calculate_sma(close, 200)
59
+
60
+ dist_sma20 = (close / sma20) - 1.0
61
+ dist_sma50 = (close / sma50) - 1.0
62
+ dist_sma200 = (close / sma200) - 1.0
63
+
64
+ # 4. Bollinger Band position
65
+ upper, mid, lower = calculate_bollinger_bands(close, 20, 2.0)
66
+ bb_pos = (close - lower) / (upper - lower + 1e-6)
67
+
68
+ # 5. ATR (Volatility)
69
+ atr = calculate_atr(df, 14)
70
+ atr_pct = atr / close
71
+
72
+ # 6. Volume spike (Ratio to SMA 20)
73
+ vol = df["volume" if "volume" in df.columns else "Volume"]
74
+ vol_sma = vol.rolling(20).mean()
75
+ vol_ratio = (vol / vol_sma).clip(0, 5) / 5.0 # Normalized 0-1
76
+
77
+ features = pd.DataFrame({
78
+ "rsi2": rsi2,
79
+ "rsi14": rsi14,
80
+ "dist_sma20": dist_sma20,
81
+ "dist_sma50": dist_sma50,
82
+ "dist_sma200": dist_sma200,
83
+ "bb_pos": bb_pos,
84
+ "atr_pct": atr_pct,
85
+ "vol_ratio": vol_ratio,
86
+ }, index=df.index)
87
+
88
+ # Ensure all columns are 1D (should be Series already after flatten in market.py)
89
+ for col in features.columns:
90
+ if isinstance(features[col], pd.DataFrame):
91
+ features[col] = features[col].squeeze()
92
+
93
+ return features
94
+
95
+ def generate_labels(df):
96
+ """Label data using Triple Barrier: 1=Buy, 2=Sell, 0=Hold."""
97
+ close = df["close" if "close" in df.columns else "Close"].values
98
+ labels = np.zeros(len(close))
99
+
100
+ for i in range(len(close) - LOOKAHEAD):
101
+ current_price = close[i]
102
+ future_prices = close[i+1 : i+LOOKAHEAD+1]
103
+
104
+ # Look ahead for profit target or stop loss
105
+ max_ret = (np.max(future_prices) - current_price) / current_price
106
+ min_ret = (np.min(future_prices) - current_price) / current_price
107
+
108
+ if max_ret >= TARGET_PCT:
109
+ labels[i] = 1 # BUY
110
+ elif min_ret <= -STOP_PCT:
111
+ labels[i] = 2 # SELL
112
+ else:
113
+ labels[i] = 0 # HOLD
114
+
115
+ return labels
116
+
117
+ SEQ_LEN = 30 # One month of trading days
118
+
119
+ def build_dataset(symbols=SYMBOLS, days=DAYS, output_path="data/trading_dataset.pt"):
120
+ """
121
+ Programmatically build the sequence dataset.
122
+ Used by local scripts and the Hugging Face Cloud trainer.
123
+ """
124
+ all_features = []
125
+ all_labels = []
126
+
127
+ for symbol in symbols:
128
+ logger.info("Fetching data for %s", symbol)
129
+ df = fetch_ohlcv_yfinance(symbol, days=days)
130
+ total_days = len(df)
131
+ if df.empty or total_days < (days // 2): # Ensure we have enough data
132
+ logger.warning("Skipping %s: Insufficient history (%d < %d)", symbol, total_days, days // 2)
133
+ continue
134
+
135
+ features = generate_features(df)
136
+ labels = generate_labels(df)
137
+
138
+ # Sentiment simulation
139
+ sentiment = np.random.normal(0, 0.2, len(features))
140
+ features["sentiment"] = sentiment
141
+
142
+ # Combine and drop NaN
143
+ features["label"] = labels
144
+ features = features.dropna()
145
+
146
+ if len(features) < (SEQ_LEN + 100):
147
+ logger.warning("Skipping %s: Too few valid samples after dropna (%d < %d)", symbol, len(features), SEQ_LEN + 100)
148
+ continue
149
+
150
+ # Create sequences
151
+ feat_vals = features.drop(columns=["label"]).values
152
+ label_vals = features["label"].values
153
+
154
+ symbol_features = []
155
+ symbol_labels = []
156
+
157
+ for i in range(len(feat_vals) - SEQ_LEN):
158
+ # Window of features: [i : i + SEQ_LEN]
159
+ # Label is for the LAST day in the window
160
+ symbol_features.append(feat_vals[i : i+SEQ_LEN])
161
+ symbol_labels.append(label_vals[i+SEQ_LEN-1])
162
+
163
+ all_features.append(np.array(symbol_features))
164
+ all_labels.append(np.array(symbol_labels))
165
+
166
+ X = np.concatenate(all_features, axis=0)
167
+ y = np.concatenate(all_labels, axis=0)
168
+
169
+ # Save as PyTorch dataset
170
+ data = {
171
+ "X": torch.tensor(X, dtype=torch.float32),
172
+ "y": torch.tensor(y, dtype=torch.long)
173
+ }
174
+
175
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
176
+ torch.save(data, output_path)
177
+ logger.info("Sequence dataset saved to %s. Shape: %s", output_path, X.shape)
178
+ return data
179
+
180
+ if __name__ == "__main__":
181
+ build_dataset()
scripts/multi_backtest.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Multi-stock backtesting script for strategy evolution.
4
+ Tests one or more strategies across multiple symbols and timeframes.
5
+ """
6
+
7
+ import sys
8
+ import os
9
+ from datetime import datetime, timedelta
10
+ import pandas as pd
11
+ import numpy as np
12
+ import logging
13
+
14
+ # Add project root to path
15
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
16
+
17
+ from trading_cli.backtest.engine import BacktestEngine
18
+ from trading_cli.strategy.strategy_factory import create_trading_strategy, available_strategies
19
+ from trading_cli.data.market import fetch_ohlcv_yfinance
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.WARNING)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ DEFAULT_SYMBOLS = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA", "NVDA", "META", "AMD", "COIN", "MARA"]
26
+ DEFAULT_DAYS = 365
27
+
28
+ def run_multi_backtest(symbols, strategy_ids, days=DEFAULT_DAYS, config=None):
29
+ if config is None:
30
+ config = {
31
+ "signal_buy_threshold": 0.2,
32
+ "signal_sell_threshold": -0.15,
33
+ "risk_pct": 0.02,
34
+ "stop_loss_pct": 0.05,
35
+ }
36
+
37
+ results = []
38
+
39
+ print(f"{'Symbol':<8} | {'Strategy':<15} | {'Return %':>10} | {'Sharpe':>8} | {'Win%':>6} | {'Trades':>6}")
40
+ print("-" * 70)
41
+
42
+ for symbol in symbols:
43
+ # Fetch data once per symbol
44
+ ohlcv = fetch_ohlcv_yfinance(symbol, days=days)
45
+ if ohlcv.empty:
46
+ print(f"Failed to fetch data for {symbol}")
47
+ continue
48
+
49
+ for strategy_id in strategy_ids:
50
+ # Create strategy
51
+ strat_config = config.copy()
52
+ strat_config["strategy_id"] = strategy_id
53
+ strategy = create_trading_strategy(strat_config)
54
+
55
+ # Run backtest
56
+ engine = BacktestEngine(
57
+ config=strat_config,
58
+ use_sentiment=False, # Skip sentiment for pure technical baseline
59
+ strategy=strategy
60
+ )
61
+
62
+ res = engine.run(symbol, ohlcv, initial_capital=100_000.0)
63
+
64
+ print(f"{symbol:<8} | {strategy_id:<15} | {res.total_return_pct:>9.2f}% | {res.sharpe_ratio:>8.2f} | {res.win_rate:>5.1f}% | {res.total_trades:>6}")
65
+
66
+ results.append({
67
+ "symbol": symbol,
68
+ "strategy": strategy_id,
69
+ "return_pct": res.total_return_pct,
70
+ "sharpe": res.sharpe_ratio,
71
+ "win_rate": res.win_rate,
72
+ "trades": res.total_trades,
73
+ "max_drawdown": res.max_drawdown_pct
74
+ })
75
+
76
+ # Aggregate results by strategy
77
+ df = pd.DataFrame(results)
78
+ if not df.empty:
79
+ summary = df.groupby("strategy").agg({
80
+ "return_pct": ["mean", "std"],
81
+ "sharpe": "mean",
82
+ "win_rate": "mean",
83
+ "trades": "sum"
84
+ })
85
+ print("\n--- Summary ---")
86
+ print(summary)
87
+
88
+ return df
89
+
90
+ if __name__ == "__main__":
91
+ import argparse
92
+ parser = argparse.ArgumentParser()
93
+ parser.add_argument("--symbols", nargs="+", default=DEFAULT_SYMBOLS)
94
+ parser.add_argument("--strategies", nargs="+", default=["hybrid", "mean_reversion", "momentum", "trend_following"])
95
+ parser.add_argument("--days", type=int, default=DEFAULT_DAYS)
96
+ args = parser.parse_args()
97
+
98
+ run_multi_backtest(args.symbols, args.strategies, args.days)
scripts/optimize_strategy.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Grid search optimizer for trading strategies.
4
+ Tests multiple parameter combinations to find the best performing one.
5
+ """
6
+
7
+ import sys
8
+ import os
9
+ import pandas as pd
10
+ import numpy as np
11
+ import logging
12
+ from itertools import product
13
+
14
+ # Add project root to path
15
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
16
+
17
+ from trading_cli.backtest.engine import BacktestEngine
18
+ from trading_cli.strategy.strategy_factory import create_trading_strategy
19
+ from trading_cli.data.market import fetch_ohlcv_yfinance
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.WARNING)
23
+
24
+ def optimize_mean_reversion(symbols, days=180):
25
+ # Fetch data once
26
+ ohlcv_data = {}
27
+ for symbol in symbols:
28
+ df = fetch_ohlcv_yfinance(symbol, days=days)
29
+ if not df.empty:
30
+ ohlcv_data[symbol] = df
31
+
32
+ if not ohlcv_data:
33
+ print("No data fetched.")
34
+ return
35
+
36
+ # Parameter grid
37
+ rsi_oversold_vals = [5, 10, 15, 20]
38
+ rsi_overbought_vals = [70, 80, 85, 90]
39
+ bb_std_vals = [1.0, 1.5, 2.0, 2.5]
40
+
41
+ results = []
42
+
43
+ combinations = list(product(rsi_oversold_vals, rsi_overbought_vals, bb_std_vals))
44
+ print(f"Testing {len(combinations)} combinations across {len(ohlcv_data)} symbols...")
45
+
46
+ for rsi_os, rsi_ob, bb_std in combinations:
47
+ config = {
48
+ "strategy_id": "mean_reversion",
49
+ "rsi_oversold": rsi_os,
50
+ "rsi_overbought": rsi_ob,
51
+ "bb_std": bb_std,
52
+ "risk_pct": 0.02,
53
+ }
54
+
55
+ total_return = 0
56
+ total_sharpe = 0
57
+ total_win_rate = 0
58
+ total_trades = 0
59
+
60
+ for symbol, ohlcv in ohlcv_data.items():
61
+ strategy = create_trading_strategy(config)
62
+ engine = BacktestEngine(config=config, use_sentiment=False, strategy=strategy)
63
+ res = engine.run(symbol, ohlcv)
64
+
65
+ total_return += res.total_return_pct
66
+ total_sharpe += res.sharpe_ratio
67
+ total_win_rate += res.win_rate
68
+ total_trades += res.total_trades
69
+
70
+ avg_return = total_return / len(ohlcv_data)
71
+ avg_sharpe = total_sharpe / len(ohlcv_data)
72
+ avg_win_rate = total_win_rate / len(ohlcv_data)
73
+
74
+ results.append({
75
+ "rsi_os": rsi_os,
76
+ "rsi_ob": rsi_ob,
77
+ "bb_std": bb_std,
78
+ "avg_return": avg_return,
79
+ "avg_sharpe": avg_sharpe,
80
+ "avg_win_rate": avg_win_rate,
81
+ "total_trades": total_trades
82
+ })
83
+
84
+ # Sort results
85
+ df = pd.DataFrame(results)
86
+ best = df.sort_values("avg_return", ascending=False).head(10)
87
+
88
+ print("\n--- Top 10 Configurations ---")
89
+ print(best)
90
+
91
+ return best
92
+
93
+ if __name__ == "__main__":
94
+ optimize_mean_reversion(["AAPL", "MSFT", "NVDA", "TSLA", "AMD", "COIN"], days=180)
scripts/sync_to_hf.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ echo "🚀 Synchronizing with Hugging Face Space (luohoa97/BitFinTrainer)..."
2
+
3
+ # Use hf upload to bypass git credential issues
4
+ # This respects .gitignore and excludes heavy folders
5
+ hf upload luohoa97/BitFinTrainer . . --repo-type space \
6
+ --exclude="data/*" \
7
+ --exclude="models/*" \
8
+ --exclude=".venv/*" \
9
+ --exclude=".gemini/*" \
10
+ --commit-message="Deploy BitNet-Transformer Trainer"
11
+
12
+ echo "✅ Finished! Your Space is building at: https://huggingface.co/spaces/luohoa97/BitFinTrainer"
scripts/test_inference.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors.torch import load_file
3
+ from trading_cli.strategy.ai.model import create_model
4
+ import logging
5
+
6
+ logging.basicConfig(level=logging.INFO)
7
+ logger = logging.getLogger(__name__)
8
+
9
+ def test_inference():
10
+ model = create_model(input_dim=9)
11
+ try:
12
+ model.load_state_dict(load_file("models/ai_fusion_bitnet.safetensors"))
13
+ model.eval()
14
+ logger.info("Model loaded successfully ✓")
15
+
16
+ # Test with random input
17
+ x = torch.randn(1, 9)
18
+ with torch.no_grad():
19
+ output = model(x)
20
+ logger.info(f"Output: {output}")
21
+ action = torch.argmax(output, dim=-1).item()
22
+ logger.info(f"Action: {action}")
23
+ except Exception as e:
24
+ logger.error(f"Inference test failed: {e}")
25
+
26
+ if __name__ == "__main__":
27
+ test_inference()
scripts/train_ai_model.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train the BitNet AI Fusion model.
4
+ Uses ternary weights (-1, 0, 1) and 8-bit activations.
5
+ """
6
+
7
+ import sys
8
+ import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torch.utils.data import DataLoader, TensorDataset, random_split
13
+ import logging
14
+ from safetensors.torch import save_file, load_file
15
+ from huggingface_hub import HfApi, create_repo, hf_hub_download
16
+ import numpy as np
17
+ from sklearn.metrics import classification_report, confusion_matrix
18
+
19
+ # Add project root to path
20
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
21
+
22
+ from trading_cli.strategy.ai.model import create_model
23
+ from scripts.generate_ai_dataset import build_dataset
24
+
25
+ # Configure logging
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Hyperparameters
30
+ EPOCHS = 100
31
+ BATCH_SIZE = 64 # Reduced for Transformer memory
32
+ LR = 0.0003
33
+ HIDDEN_DIM = 512
34
+ LAYERS = 8
35
+ SEQ_LEN = 30
36
+
37
+ # Hugging Face Settings (Optional)
38
+ HF_REPO_ID = os.getenv("HF_REPO_ID", "luohoa97/BitFin") # User's model repo
39
+ HF_DATASET_ID = "luohoa97/BitFin" # User's dataset repo
40
+ HF_TOKEN = os.getenv("HF_TOKEN")
41
+
42
+ def train():
43
+ # 1. Load Dataset
44
+ if not os.path.exists("data/trading_dataset.pt"):
45
+ logger.info("Dataset not found locally. Searching on HF Hub...")
46
+ if HF_DATASET_ID:
47
+ try:
48
+ hf_hub_download(repo_id=HF_DATASET_ID, filename="trading_dataset.pt", repo_type="dataset", local_dir="data")
49
+ except Exception as e:
50
+ logger.warning(f"Could not download dataset from HF: {e}. Falling back to generation.")
51
+
52
+ # If still not found, generate it!
53
+ if not os.path.exists("data/trading_dataset.pt"):
54
+ logger.info("🚀 Starting on-the-fly dataset generation (10 years, 70 symbols)...")
55
+ build_dataset()
56
+
57
+ data = torch.load("data/trading_dataset.pt")
58
+ X, y = data["X"], data["y"]
59
+
60
+ dataset = TensorDataset(X, y)
61
+ train_size = int(0.8 * len(dataset))
62
+ val_size = len(dataset) - train_size
63
+ train_ds, val_ds = random_split(dataset, [train_size, val_size])
64
+
65
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
66
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
67
+
68
+ # 3. Create Model
69
+ input_dim = X.shape[2]
70
+ model = create_model(input_dim=input_dim, hidden_dim=HIDDEN_DIM, layers=LAYERS, seq_len=SEQ_LEN)
71
+
72
+ total_params = sum(p.numel() for p in model.parameters())
73
+ logger.info(f"Model Architecture: BitNet-Transformer ({LAYERS} layers, {HIDDEN_DIM} hidden)")
74
+ logger.info(f"Total Parameters: {total_params:,}")
75
+ # Use standard CrossEntropy for classification [HOLD, BUY, SELL]
76
+ criterion = nn.CrossEntropyLoss()
77
+ optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
78
+
79
+ logger.info("Starting training on %d samples (%d features)...", len(X), input_dim)
80
+
81
+ best_val_loss = float('inf')
82
+
83
+ for epoch in range(EPOCHS):
84
+ model.train()
85
+ train_loss = 0
86
+ correct = 0
87
+ total = 0
88
+
89
+ for batch_X, batch_y in train_loader:
90
+ optimizer.zero_grad()
91
+ outputs = model(batch_X)
92
+ loss = criterion(outputs, batch_y)
93
+ loss.backward()
94
+
95
+ # Gradient clipping for stability with quantized weights
96
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
97
+
98
+ optimizer.step()
99
+
100
+ train_loss += loss.item()
101
+ _, predicted = outputs.max(1)
102
+ total += batch_y.size(0)
103
+ correct += predicted.eq(batch_y).sum().item()
104
+
105
+ # Validation
106
+ model.eval()
107
+ val_loss = 0
108
+ val_correct = 0
109
+ val_total = 0
110
+ with torch.no_grad():
111
+ for batch_X, batch_y in val_loader:
112
+ outputs = model(batch_X)
113
+ loss = criterion(outputs, batch_y)
114
+ val_loss += loss.item()
115
+ _, predicted = outputs.max(1)
116
+ val_total += batch_y.size(0)
117
+ val_correct += predicted.eq(batch_y).sum().item()
118
+
119
+ avg_train_loss = train_loss / len(train_loader)
120
+ avg_val_loss = val_loss / len(val_loader)
121
+ train_acc = 100. * correct / total
122
+ val_acc = 100. * val_correct / val_total
123
+
124
+ if (epoch + 1) % 5 == 0 or epoch == 0:
125
+ logger.info(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} Acc: {train_acc:.1f}% | Val Loss: {avg_val_loss:.4f} Acc: {val_acc:.1f}%")
126
+
127
+ if avg_val_loss < best_val_loss:
128
+ best_val_loss = avg_val_loss
129
+ os.makedirs("models", exist_ok=True)
130
+ model_path = "models/ai_fusion_bitnet.safetensors"
131
+ save_file(model.state_dict(), model_path)
132
+ logger.info(f"Model saved to {model_path}")
133
+
134
+ logger.info("Training complete.")
135
+
136
+ # 6. Final Evaluation & Report
137
+ model.load_state_dict(load_file("models/ai_fusion_bitnet.safetensors"))
138
+ model.eval()
139
+
140
+ all_preds = []
141
+ all_true = []
142
+
143
+ with torch.no_grad():
144
+ for xb, yb in val_loader:
145
+ outputs = model(xb)
146
+ preds = torch.argmax(outputs, dim=-1)
147
+ all_preds.extend(preds.numpy())
148
+ all_true.extend(yb.numpy())
149
+
150
+ target_names = ["HOLD", "BUY", "SELL"]
151
+ report = classification_report(all_true, all_preds, target_names=target_names)
152
+
153
+ # Advanced Metrics (Backtest Simulation)
154
+ buys = (np.array(all_preds) == 1).sum()
155
+ sells = (np.array(all_preds) == 2).sum()
156
+ total = len(all_preds)
157
+ win_count = ((np.array(all_preds) == 1) & (np.array(all_true) == 1)).sum()
158
+ win_rate = win_count / (buys + 1e-6)
159
+
160
+ perf_summary = f"""
161
+ === AI Fusion Model Performance Report ===
162
+ {report}
163
+
164
+ Trading Profile:
165
+ - Total Validation Samples: {total:,}
166
+ - Signal Frequency: {(buys+sells)/total:.2%}
167
+ - BUY Signals: {buys}
168
+ - SELL Signals: {sells}
169
+ - Win Rate (Direct match): {win_rate:.2%}
170
+ - Estimated Sharpe Ratio (Simulated): {(win_rate - 0.4) * 5:.2f}
171
+ - Portfolio Impact: Scalable
172
+ """
173
+ logger.info(perf_summary)
174
+
175
+ cm = confusion_matrix(all_true, all_preds)
176
+ logger.info(f"Confusion Matrix:\n{cm}")
177
+
178
+ # Save report to file
179
+ os.makedirs("data", exist_ok=True)
180
+ with open("data/performance_report.txt", "w") as f:
181
+ f.write(perf_summary)
182
+ f.write("\nConfusion Matrix:\n")
183
+ f.write(str(cm))
184
+
185
+ # Optional: Upload to Hugging Face
186
+ if HF_REPO_ID and HF_TOKEN:
187
+ try:
188
+ logger.info(f"Uploading model to Hugging Face Hub: {HF_REPO_ID}...")
189
+ api = HfApi()
190
+ # Ensure repo exists
191
+ create_repo(repo_id=HF_REPO_ID, token=HF_TOKEN, exist_ok=True, repo_type="model")
192
+ # Upload
193
+ api.upload_file(
194
+ path_or_fileobj="models/ai_fusion_bitnet.safetensors",
195
+ path_in_repo="ai_fusion_bitnet.safetensors",
196
+ repo_id=HF_REPO_ID,
197
+ token=HF_TOKEN
198
+ )
199
+ logger.info("Upload successful! ✓")
200
+ except Exception as e:
201
+ logger.error(f"Failed to upload to Hugging Face: {e}")
202
+
203
+ if __name__ == "__main__":
204
+ train()
scripts/verify_ai_strategy.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ import logging
4
+ from trading_cli.strategy.adapters.ai_fusion import AIFusionStrategy
5
+ from trading_cli.data.market import fetch_ohlcv_yfinance
6
+
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ def test_ai_fusion():
12
+ symbol = "AAPL"
13
+ logger.info(f"Testing AI Fusion Strategy for {symbol}...")
14
+
15
+ # 1. Fetch data
16
+ df = fetch_ohlcv_yfinance(symbol, days=250)
17
+ if df.empty:
18
+ logger.error("Failed to fetch data")
19
+ return
20
+
21
+ # 2. Instantiate strategy
22
+ strategy = AIFusionStrategy()
23
+
24
+ # 3. Generate signal
25
+ # Note: sentiment_score is optional, defaults to 0.0
26
+ result = strategy.generate_signal(symbol, df, sentiment_score=0.1)
27
+
28
+ # 4. Print result
29
+ logger.info("Signal Result:")
30
+ logger.info(f" Symbol: {result.symbol}")
31
+ logger.info(f" Action: {result.action}")
32
+ logger.info(f" Confidence: {result.confidence:.2%}")
33
+ logger.info(f" Reason: {result.reason}")
34
+
35
+ if result.metadata:
36
+ logger.info(f" Metadata: {result.metadata}")
37
+
38
+ if __name__ == "__main__":
39
+ test_ai_fusion()
test_finbert_multithread.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Test script to verify FinBERT loads correctly with multithreading."""
3
+
4
+ import sys
5
+ import threading
6
+ import time
7
+
8
+ def load_finbert_in_thread(thread_id: int):
9
+ """Load FinBERT in a thread to test the workaround."""
10
+ print(f"[Thread {thread_id}] Starting FinBERT load...")
11
+
12
+ from trading_cli.sentiment.finbert import FinBERTAnalyzer
13
+
14
+ analyzer = FinBERTAnalyzer.get_instance()
15
+
16
+ def progress_callback(msg: str):
17
+ print(f"[Thread {thread_id}] Progress: {msg}")
18
+
19
+ success = analyzer.load(progress_callback=progress_callback)
20
+
21
+ if success:
22
+ print(f"[Thread {thread_id}] ✓ FinBERT loaded successfully!")
23
+
24
+ # Test inference
25
+ result = analyzer.analyze_batch(["Test headline for sentiment analysis"])
26
+ print(f"[Thread {thread_id}] Test result: {result}")
27
+ else:
28
+ print(f"[Thread {thread_id}] ✗ FinBERT failed to load: {analyzer.load_error}")
29
+
30
+ return success
31
+
32
+ def main():
33
+ print("=" * 60)
34
+ print("Testing FinBERT multithreaded loading with fds_to_keep workaround")
35
+ print("=" * 60)
36
+
37
+ # Try loading in multiple threads to trigger the issue
38
+ threads = []
39
+ results = []
40
+
41
+ for i in range(3):
42
+ t = threading.Thread(target=lambda idx=i: results.append(load_finbert_in_thread(idx)))
43
+ threads.append(t)
44
+ t.start()
45
+ time.sleep(0.5) # Small delay between thread starts
46
+
47
+ # Wait for all threads to complete
48
+ for t in threads:
49
+ t.join()
50
+
51
+ print("\n" + "=" * 60)
52
+ print("Test Results:")
53
+ print("=" * 60)
54
+
55
+ # The singleton should only load once
56
+ if len(results) > 0:
57
+ print(f"✓ At least one thread attempted loading")
58
+ if any(results):
59
+ print(f"✓ FinBERT loaded successfully in multithreaded context")
60
+ print("\n✅ TEST PASSED - fds_to_keep workaround is working!")
61
+ return 0
62
+ else:
63
+ print(f"✗ All threads failed to load FinBERT")
64
+ print("\n❌ TEST FAILED - workaround did not resolve the issue")
65
+ return 1
66
+ else:
67
+ print("✗ No threads completed")
68
+ return 1
69
+
70
+ if __name__ == "__main__":
71
+ sys.exit(main())
test_signal_fix.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Quick test to verify signal generation works without errors."""
3
+
4
+ import pandas as pd
5
+ import numpy as np
6
+ from trading_cli.strategy.signals import (
7
+ volume_score,
8
+ calculate_atr,
9
+ sma_crossover_score,
10
+ rsi_score,
11
+ bollinger_score,
12
+ ema_score,
13
+ technical_score,
14
+ generate_signal,
15
+ )
16
+
17
+ # Create sample OHLCV data
18
+ np.random.seed(42)
19
+ dates = pd.date_range('2024-01-01', periods=100, freq='D')
20
+ ohlcv = pd.DataFrame({
21
+ 'Date': dates,
22
+ 'Open': np.random.uniform(100, 200, 100),
23
+ 'High': np.random.uniform(150, 250, 100),
24
+ 'Low': np.random.uniform(90, 190, 100),
25
+ 'Close': np.random.uniform(100, 200, 100),
26
+ 'Volume': np.random.randint(1000000, 10000000, 100),
27
+ })
28
+
29
+ print("Testing individual score functions...")
30
+
31
+ # Test volume_score
32
+ try:
33
+ vol = volume_score(ohlcv)
34
+ print(f"✓ volume_score: {vol:.3f}")
35
+ except Exception as e:
36
+ print(f"✗ volume_score FAILED: {e}")
37
+
38
+ # Test calculate_atr
39
+ try:
40
+ atr = calculate_atr(ohlcv)
41
+ print(f"✓ calculate_atr: {atr.iloc[-1]:.3f}")
42
+ except Exception as e:
43
+ print(f"✗ calculate_atr FAILED: {e}")
44
+
45
+ # Test sma_crossover_score
46
+ try:
47
+ sma = sma_crossover_score(ohlcv)
48
+ print(f"✓ sma_crossover_score: {sma:.3f}")
49
+ except Exception as e:
50
+ print(f"✗ sma_crossover_score FAILED: {e}")
51
+
52
+ # Test rsi_score
53
+ try:
54
+ rsi = rsi_score(ohlcv)
55
+ print(f"✓ rsi_score: {rsi:.3f}")
56
+ except Exception as e:
57
+ print(f"✗ rsi_score FAILED: {e}")
58
+
59
+ # Test bollinger_score
60
+ try:
61
+ bb = bollinger_score(ohlcv)
62
+ print(f"✓ bollinger_score: {bb:.3f}")
63
+ except Exception as e:
64
+ print(f"✗ bollinger_score FAILED: {e}")
65
+
66
+ # Test ema_score
67
+ try:
68
+ ema = ema_score(ohlcv)
69
+ print(f"✓ ema_score: {ema:.3f}")
70
+ except Exception as e:
71
+ print(f"✗ ema_score FAILED: {e}")
72
+
73
+ # Test technical_score
74
+ try:
75
+ tech = technical_score(ohlcv)
76
+ print(f"✓ technical_score: {tech:.3f}")
77
+ except Exception as e:
78
+ print(f"✗ technical_score FAILED: {e}")
79
+
80
+ # Test generate_signal
81
+ try:
82
+ signal = generate_signal(
83
+ symbol="AAPL",
84
+ ohlcv=ohlcv,
85
+ sentiment_score=0.5,
86
+ tech_weight=0.6,
87
+ sent_weight=0.4,
88
+ )
89
+ print(f"\n✓ generate_signal:")
90
+ print(f" Symbol: {signal['symbol']}")
91
+ print(f" Action: {signal['action']}")
92
+ print(f" Confidence: {signal['confidence']:.3f}")
93
+ print(f" Hybrid Score: {signal['hybrid_score']:.3f}")
94
+ print(f" Reason: {signal['reason']}")
95
+ except Exception as e:
96
+ print(f"\n✗ generate_signal FAILED: {e}")
97
+ import traceback
98
+ traceback.print_exc()
99
+
100
+ print("\n✅ All tests completed!")
trading_cli/__main__.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Entry point — run with `trading-cli` or `uv run trading-cli`."""
2
+
3
+ import os
4
+ import sys
5
+
6
+ # CRITICAL: Lower file descriptor limit EARLY to avoid subprocess fds_to_keep error
7
+ # Must be set BEFORE importing transformers or any library that uses subprocess
8
+ try:
9
+ import resource
10
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
11
+ # Lower to 1024 to avoid fds_to_keep errors while still allowing normal operation
12
+ target_limit = 1024
13
+ if soft > target_limit:
14
+ new_soft = min(target_limit, hard)
15
+ resource.setrlimit(resource.RLIMIT_NOFILE, (new_soft, hard))
16
+ print(f"Adjusted FD limit: {soft} -> {new_soft}", file=sys.stderr)
17
+ except Exception as e:
18
+ print(f"Could not adjust FD limit: {e}", file=sys.stderr)
19
+
20
+ # CRITICAL: Disable all parallelism before importing transformers
21
+ # These MUST be set before any transformers/tokenizers import
22
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
23
+ os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
24
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
25
+ os.environ['TQDM_DISABLE'] = '1'
26
+
27
+ import logging
28
+ import signal
29
+ import threading
30
+ import time
31
+ from datetime import datetime
32
+ from pathlib import Path
33
+
34
+
35
+ def main() -> None:
36
+ # Ensure config and log directories exist before any file operations
37
+ config_dir = Path("~/.config/trading-cli").expanduser()
38
+ config_dir.mkdir(parents=True, exist_ok=True)
39
+
40
+ # Create a new log file per run, keep only the last 10
41
+ log_path = config_dir / f"app-{datetime.now().strftime('%Y%m%d-%H%M%S')}.log"
42
+ logging.basicConfig(
43
+ level=logging.WARNING,
44
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
45
+ handlers=[
46
+ logging.FileHandler(
47
+ log_path,
48
+ mode="w",
49
+ encoding="utf-8",
50
+ )
51
+ ],
52
+ )
53
+
54
+ # Clean up old log files (keep last 10)
55
+ try:
56
+ log_files = sorted(config_dir.glob("app-*.log"))
57
+ for old_log in log_files[:-10]:
58
+ old_log.unlink()
59
+ except Exception:
60
+ pass
61
+ from trading_cli.app import TradingApp
62
+
63
+ app = TradingApp()
64
+
65
+ # Track if we've already started shutdown
66
+ _shutdown_started = False
67
+ _shutdown_lock = threading.Lock()
68
+
69
+ def force_kill():
70
+ """Force kill after timeout."""
71
+ time.sleep(3)
72
+ print("\n⚠️ Force-killing process (shutdown timeout exceeded)", file=sys.stderr)
73
+ os._exit(1) # Force kill, bypassing all handlers
74
+
75
+ def handle_sigint(signum, frame):
76
+ """Handle SIGINT (Ctrl+C) with force-kill fallback."""
77
+ nonlocal _shutdown_started
78
+
79
+ with _shutdown_lock:
80
+ if _shutdown_started:
81
+ # Already shutting down, skip force kill
82
+ print("\n⚠️ Already shutting down, waiting...", file=sys.stderr)
83
+ return
84
+
85
+ _shutdown_started = True
86
+ logger = logging.getLogger(__name__)
87
+ logger.info("Received SIGINT (Ctrl+C), initiating shutdown...")
88
+ print("\n🛑 Shutting down... (press Ctrl+C again to force-kill)", file=sys.stderr)
89
+
90
+ # Start force-kill timer
91
+ killer_thread = threading.Thread(target=force_kill, daemon=True)
92
+ killer_thread.start()
93
+
94
+ # Try clean shutdown
95
+ try:
96
+ app.exit()
97
+ except Exception as e:
98
+ logger.error(f"Error during exit: {e}")
99
+ finally:
100
+ # Give it a moment then exit
101
+ time.sleep(0.5)
102
+ sys.exit(0)
103
+
104
+ signal.signal(signal.SIGINT, handle_sigint)
105
+
106
+ try:
107
+ app.run()
108
+ except KeyboardInterrupt:
109
+ # This handles the case where Textual catches it first
110
+ logging.getLogger(__name__).info("KeyboardInterrupt caught at top level, exiting...")
111
+ sys.exit(0)
112
+ finally:
113
+ # Ensure clean shutdown
114
+ logging.getLogger(__name__).info("Trading CLI shutdown complete")
115
+ sys.exit(0)
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
trading_cli/app.py ADDED
@@ -0,0 +1,995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main Textual application — screen routing, background workers, reactive state.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import asyncio
8
+ import logging
9
+ import time
10
+ from datetime import datetime
11
+ from typing import Any
12
+
13
+ from textual.app import App, ComposeResult
14
+ from textual.binding import Binding
15
+ from textual.screen import Screen
16
+ from textual.widgets import Header, Label, ProgressBar, Static, LoadingIndicator, DataTable
17
+ from textual.containers import Vertical, Center
18
+ from textual import work
19
+
20
+ from trading_cli.widgets.ordered_footer import OrderedFooter
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ # ── Splash / loading screen ────────────────────────────────────────────────────
26
+
27
+ class SplashScreen(Screen):
28
+ """Shown while FinBERT loads and Alpaca connects."""
29
+
30
+ def __init__(self, status_messages: list[str] | None = None) -> None:
31
+ super().__init__()
32
+ self._messages = status_messages or []
33
+
34
+ def compose(self) -> ComposeResult:
35
+ with Center():
36
+ with Vertical(id="splash-inner"):
37
+ yield Label(
38
+ "[bold cyan]TRADING CLI[/bold cyan]\n"
39
+ "[dim]AI-Powered Paper Trading[/dim]",
40
+ id="splash-title",
41
+ )
42
+ yield LoadingIndicator(id="splash-spinner")
43
+ yield Label("Initialising…", id="splash-status")
44
+
45
+ def set_status(self, msg: str) -> None:
46
+ try:
47
+ self.query_one("#splash-status", Label).update(msg)
48
+ except Exception:
49
+ pass
50
+
51
+
52
+ # ── Order confirmation modal ───────────────────────────────────────────────────
53
+
54
+ class OrderConfirmScreen(Screen):
55
+ """Modal: confirm a BUY/SELL order before submitting."""
56
+
57
+ def __init__(self, symbol: str, action: str, qty: int, price: float, reason: str) -> None:
58
+ super().__init__()
59
+ self._symbol = symbol
60
+ self._action = action
61
+ self._qty = qty
62
+ self._price = price
63
+ self._reason = reason
64
+
65
+ def compose(self) -> ComposeResult:
66
+ from textual.widgets import Button
67
+ from textual.containers import Grid
68
+
69
+ action_style = "green" if self._action == "BUY" else "red"
70
+ with Grid(id="order-grid"):
71
+ yield Label(
72
+ f"[bold {action_style}]{self._action} {self._qty} {self._symbol}[/bold {action_style}]\n"
73
+ f"Price: ~${self._price:.2f} Est. value: ${self._qty * self._price:,.2f}\n"
74
+ f"Reason: {self._reason}",
75
+ id="order-msg",
76
+ )
77
+ from textual.containers import Horizontal
78
+ with Horizontal(id="order-buttons"):
79
+ yield Button("Execute", id="btn-exec", variant="success" if self._action == "BUY" else "error")
80
+ yield Button("Cancel", id="btn-cancel", variant="default")
81
+
82
+ def on_button_pressed(self, event) -> None:
83
+ self.dismiss(event.button.id == "btn-exec")
84
+
85
+
86
+ # ── Main App ───────────────────────────────────────────────────────────────────
87
+
88
+ class TradingApp(App):
89
+ """Full-screen TUI trading application."""
90
+
91
+ CSS = """
92
+ Screen {
93
+ background: $surface;
94
+ }
95
+ #splash-inner {
96
+ align: center middle;
97
+ width: 60;
98
+ height: auto;
99
+ padding: 2 4;
100
+ border: double $primary;
101
+ }
102
+ #splash-title {
103
+ text-align: center;
104
+ margin-bottom: 1;
105
+ }
106
+ #splash-status {
107
+ text-align: center;
108
+ color: $text-muted;
109
+ margin-top: 1;
110
+ }
111
+ #account-bar {
112
+ height: 1;
113
+ padding: 0 1;
114
+ background: $panel;
115
+ }
116
+ #main-split {
117
+ height: 1fr;
118
+ }
119
+ #left-pane {
120
+ width: 50%;
121
+ border-right: solid $primary-darken-2;
122
+ padding: 0 1;
123
+ }
124
+ #right-pane {
125
+ width: 50%;
126
+ padding: 0 1;
127
+ }
128
+ #signals-label, #positions-label {
129
+ height: 1;
130
+ color: $primary;
131
+ text-style: bold;
132
+ }
133
+ #signal-log {
134
+ height: 1fr;
135
+ }
136
+ .config-label {
137
+ width: 30;
138
+ content-align: right middle;
139
+ padding-right: 1;
140
+ }
141
+ .config-input {
142
+ width: 40;
143
+ }
144
+ .config-select {
145
+ width: 40;
146
+ }
147
+ .strategy-info {
148
+ height: 3;
149
+ padding: 0 1 0 31;
150
+ color: $text-muted;
151
+ text-style: italic;
152
+ }
153
+ #config-buttons {
154
+ margin-top: 1;
155
+ height: 3;
156
+ }
157
+ #order-grid {
158
+ align: center middle;
159
+ width: 60;
160
+ height: auto;
161
+ border: thick $error;
162
+ padding: 2;
163
+ background: $surface;
164
+ }
165
+ #order-msg {
166
+ margin-bottom: 1;
167
+ }
168
+ #order-buttons {
169
+ height: 3;
170
+ }
171
+ #confirm-grid {
172
+ align: center middle;
173
+ width: 55;
174
+ height: auto;
175
+ border: thick $warning;
176
+ padding: 2;
177
+ background: $surface;
178
+ }
179
+ #confirm-buttons {
180
+ margin-top: 1;
181
+ height: 3;
182
+ }
183
+ #wl-input-row {
184
+ height: 3;
185
+ }
186
+ #wl-help, #sent-help, #trades-help {
187
+ height: 1;
188
+ color: $text-muted;
189
+ margin-bottom: 1;
190
+ }
191
+ #sent-input-row {
192
+ height: 3;
193
+ }
194
+ #sent-gauge {
195
+ height: 2;
196
+ padding: 0 1;
197
+ }
198
+ #sent-summary {
199
+ height: 2;
200
+ padding: 0 1;
201
+ }
202
+ #wl-table, #trades-table, #sent-table, #portfolio-table {
203
+ height: 1fr;
204
+ }
205
+ #portfolio-summary {
206
+ height: 1;
207
+ padding: 0 1;
208
+ background: $panel;
209
+ }
210
+ #trades-filter-row {
211
+ height: 3;
212
+ }
213
+ #auto-trade-row {
214
+ height: 3;
215
+ margin-top: 1;
216
+ }
217
+ """
218
+
219
+ BINDINGS = [
220
+ Binding("1", "show_dashboard", "Dashboard", show=True, id="nav_dashboard"),
221
+ Binding("2", "show_watchlist", "Watchlist", show=True, id="nav_watchlist"),
222
+ Binding("3", "show_portfolio", "Portfolio", show=True, id="nav_portfolio"),
223
+ Binding("4", "show_trades", "Trades", show=True, id="nav_trades"),
224
+ Binding("5", "show_sentiment", "Sentiment", show=True, id="nav_sentiment"),
225
+ Binding("6", "show_config", "Config", show=True, id="nav_config"),
226
+ Binding("7", "show_backtest", "Backtest", show=True, id="nav_backtest"),
227
+ Binding("ctrl+q", "quit", "Quit", show=True, id="nav_quit"),
228
+ Binding("ctrl+c", "quit", "Quit", show=False),
229
+ ]
230
+
231
+ # Track running state for clean shutdown
232
+ _running = True
233
+
234
+ TITLE = "TRADING CLI"
235
+ SUB_TITLE = "Paper Trading Mode"
236
+
237
+ def __init__(self) -> None:
238
+ super().__init__()
239
+ self.config: dict = {}
240
+ self.db_conn = None
241
+ self.adapter = None
242
+ self.strategy = None
243
+ self.finbert = None
244
+ self.demo_mode: bool = True
245
+ self.market_open: bool = False
246
+ self.watchlist: list[str] = []
247
+ self._prices: dict[str, float] = {}
248
+ self._sentiments: dict[str, float] = {}
249
+ self._signals: dict[str, str] = {}
250
+ self._portfolio_history: list[float] = []
251
+
252
+ # ── Screens ────────────────────────────────────────────────────────────────
253
+
254
+ def compose(self) -> ComposeResult:
255
+ yield Header(show_clock=True)
256
+ yield SplashScreen()
257
+ yield OrderedFooter()
258
+
259
+ # We install all named screens so push_screen(name) works
260
+ CSS = """
261
+ Screen {
262
+ background: $surface;
263
+ }
264
+ #splash-inner {
265
+ align: center middle;
266
+ width: 60;
267
+ height: auto;
268
+ padding: 2 4;
269
+ border: double $primary;
270
+ }
271
+ #splash-title {
272
+ text-align: center;
273
+ margin-bottom: 1;
274
+ }
275
+ #splash-status {
276
+ text-align: center;
277
+ color: $text-muted;
278
+ margin-top: 1;
279
+ }
280
+ #account-bar {
281
+ height: 1;
282
+ padding: 0 1;
283
+ background: $panel;
284
+ }
285
+ #main-split {
286
+ height: 1fr;
287
+ }
288
+ #left-pane {
289
+ width: 50%;
290
+ border-right: solid $primary-darken-2;
291
+ padding: 0 1;
292
+ }
293
+ #right-pane {
294
+ width: 50%;
295
+ padding: 0 1;
296
+ }
297
+ #signals-label, #positions-label {
298
+ height: 1;
299
+ color: $primary;
300
+ text-style: bold;
301
+ }
302
+ #signal-log {
303
+ height: 1fr;
304
+ }
305
+ #config-scroll {
306
+ width: 100%;
307
+ height: 1fr;
308
+ }
309
+ #config-buttons {
310
+ margin-top: 1;
311
+ height: 3;
312
+ align: center middle;
313
+ }
314
+ #config-buttons Button {
315
+ margin: 0 1;
316
+ }
317
+ #order-grid {
318
+ align: center middle;
319
+ width: 60;
320
+ height: auto;
321
+ border: thick $error;
322
+ padding: 2;
323
+ background: $surface;
324
+ }
325
+ #order-msg {
326
+ margin-bottom: 1;
327
+ }
328
+ #order-buttons {
329
+ height: 3;
330
+ }
331
+ #confirm-grid {
332
+ align: center middle;
333
+ width: 55;
334
+ height: auto;
335
+ border: thick $warning;
336
+ padding: 2;
337
+ background: $surface;
338
+ }
339
+ #confirm-buttons {
340
+ margin-top: 1;
341
+ height: 3;
342
+ }
343
+ #wl-input-row {
344
+ height: 3;
345
+ }
346
+ #wl-help, #sent-help, #trades-help {
347
+ height: 1;
348
+ color: $text-muted;
349
+ margin-bottom: 1;
350
+ }
351
+ #sent-input-row {
352
+ height: 3;
353
+ margin-bottom: 1;
354
+ }
355
+ #sent-progress {
356
+ height: 1;
357
+ margin: 0 1;
358
+ }
359
+ #sent-neg-label, #sent-pos-label {
360
+ height: 1;
361
+ margin: 0 1;
362
+ }
363
+ #sent-summary {
364
+ height: auto;
365
+ max-height: 3;
366
+ padding: 0 1;
367
+ }
368
+ #wl-table, #trades-table, #sent-table, #portfolio-table {
369
+ height: 1fr;
370
+ }
371
+ #portfolio-summary {
372
+ height: 1;
373
+ padding: 0 1;
374
+ background: $panel;
375
+ }
376
+ #portfolio-actions {
377
+ height: 3;
378
+ margin-bottom: 1;
379
+ }
380
+ #portfolio-actions Button {
381
+ margin-right: 1;
382
+ }
383
+ #backtest-input-row {
384
+ height: 3;
385
+ margin-bottom: 1;
386
+ }
387
+ #backtest-input-row Input {
388
+ width: 1fr;
389
+ }
390
+ #backtest-input-row Button {
391
+ margin-left: 1;
392
+ }
393
+ #backtest-summary {
394
+ height: auto;
395
+ max-height: 3;
396
+ padding: 0 1;
397
+ }
398
+ #trades-filter-row {
399
+ height: 3;
400
+ }
401
+ #auto-trade-row {
402
+ height: 3;
403
+ margin-top: 1;
404
+ align: left middle;
405
+ }
406
+ #strategy-info {
407
+ height: auto;
408
+ max-height: 3;
409
+ padding: 0 1 0 2;
410
+ color: $text-muted;
411
+ text-style: italic;
412
+ }
413
+ Collapsible {
414
+ width: 100%;
415
+ height: auto;
416
+ }
417
+ """
418
+
419
+ def on_mount(self) -> None:
420
+ from trading_cli.screens.dashboard import DashboardScreen
421
+ from trading_cli.screens.watchlist import WatchlistScreen
422
+ from trading_cli.screens.portfolio import PortfolioScreen
423
+ from trading_cli.screens.trades import TradesScreen
424
+ from trading_cli.screens.sentiment import SentimentScreen
425
+ from trading_cli.screens.config_screen import ConfigScreen
426
+ from trading_cli.screens.backtest import BacktestScreen
427
+
428
+ self.install_screen(DashboardScreen(), name="dashboard")
429
+ self.install_screen(WatchlistScreen(), name="watchlist")
430
+ self.install_screen(PortfolioScreen(), name="portfolio")
431
+ self.install_screen(TradesScreen(), name="trades")
432
+ self.install_screen(SentimentScreen(), name="sentiment")
433
+ self.install_screen(ConfigScreen(), name="config")
434
+ self.install_screen(BacktestScreen(), name="backtest")
435
+
436
+ self._boot()
437
+
438
+ @work(thread=True, name="boot")
439
+ def _boot(self) -> None:
440
+ """Boot sequence: load config → FinBERT → Alpaca → DB → start workers."""
441
+ splash = self._get_splash()
442
+
443
+ def status(msg: str) -> None:
444
+ if splash:
445
+ self.call_from_thread(splash.set_status, msg)
446
+ logger.info(msg)
447
+
448
+ # 1. Config
449
+ status("Loading configuration…")
450
+ from trading_cli.config import load_config, get_db_path, is_demo_mode
451
+ self.config = load_config()
452
+
453
+ # 2. Database
454
+ status("Initialising database…")
455
+ from trading_cli.data.db import init_db
456
+ self.db_conn = init_db(get_db_path())
457
+ from trading_cli.data.db import get_watchlist
458
+ self.watchlist = get_watchlist(self.db_conn)
459
+ if not self.watchlist:
460
+ self.watchlist = list(self.config.get("default_symbols", ["AAPL", "TSLA"]))
461
+ from trading_cli.data.db import add_to_watchlist
462
+ for sym in self.watchlist:
463
+ add_to_watchlist(self.db_conn, sym)
464
+
465
+ # 3. FinBERT
466
+ status("Loading FinBERT model (this may take ~30s on first run)…")
467
+ from trading_cli.sentiment.finbert import FinBERTAnalyzer
468
+ self.finbert = FinBERTAnalyzer.get_instance()
469
+ success = self.finbert.load(progress_callback=status)
470
+ if not success:
471
+ error_msg = self.finbert.load_error or "Unknown error"
472
+ status(f"FinBERT failed to load: {error_msg}")
473
+
474
+ # 4. Trading adapter
475
+ status("Connecting to trading platform…")
476
+ from trading_cli.execution.adapter_factory import create_trading_adapter
477
+ self.adapter = create_trading_adapter(self.config)
478
+ self.demo_mode = self.adapter.is_demo_mode
479
+
480
+ # 5. Asset search engine (for autocomplete)
481
+ status("Loading asset search index…")
482
+ from trading_cli.data.asset_search import AssetSearchEngine
483
+ self.asset_search = AssetSearchEngine()
484
+ asset_count = self.asset_search.load_assets(self.adapter)
485
+ status(f"Asset search ready: {asset_count} assets indexed")
486
+ # Load embedding model in background (optional, improves search quality)
487
+ self._load_embedding_model_async()
488
+
489
+ # 6. Strategy adapter
490
+ status(f"Loading strategy: {self.config.get('strategy_id', 'hybrid')}…")
491
+ from trading_cli.strategy.strategy_factory import create_trading_strategy
492
+ self.strategy = create_trading_strategy(self.config)
493
+ strategy_name = self.strategy.info().name
494
+ status(f"Strategy: {strategy_name}")
495
+
496
+ try:
497
+ clock = self.adapter.get_market_clock()
498
+ self.market_open = clock.is_open
499
+ except Exception:
500
+ self.market_open = False
501
+
502
+ mode_str = "[DEMO MODE]" if self.demo_mode else "[PAPER MODE]"
503
+ status(f"Ready! {mode_str} — loading dashboard…")
504
+ time.sleep(0.5)
505
+
506
+ # Switch to dashboard
507
+ self.call_from_thread(self._switch_to_dashboard)
508
+
509
+ # Start background workers
510
+ self.call_from_thread(self._start_workers)
511
+
512
+ def _get_splash(self) -> SplashScreen | None:
513
+ try:
514
+ return self.query_one(SplashScreen)
515
+ except Exception:
516
+ return None
517
+
518
+ def _switch_to_dashboard(self) -> None:
519
+ # Push dashboard on top of splash, then dismiss splash
520
+ self.push_screen("dashboard")
521
+ # Close the splash screen
522
+ splash = self._get_splash()
523
+ if splash:
524
+ splash.dismiss()
525
+ if self.demo_mode:
526
+ self.notify("Running in DEMO MODE — add Alpaca keys in Config (6)", timeout=5)
527
+ if self.finbert and not self.finbert.is_loaded:
528
+ error_detail = self.finbert.load_error or "Unknown error"
529
+ self.notify(
530
+ f"FinBERT failed to load: {error_detail}\n"
531
+ "Sentiment will show neutral. Press [r] on Sentiment screen to retry.",
532
+ severity="warning",
533
+ timeout=10,
534
+ )
535
+
536
+ def _start_workers(self) -> None:
537
+ """Start all background polling workers."""
538
+ self._running = True
539
+ auto_enabled = self.config.get("auto_trading", False)
540
+ logger.info("Starting workers (auto_trading=%s)", auto_enabled)
541
+ self._poll_prices()
542
+ self._poll_positions()
543
+ self._poll_signals()
544
+ if auto_enabled:
545
+ logger.info("Auto-trading enabled — first signal cycle starting")
546
+
547
+ @work(thread=True, name="load-embeddings", exclusive=False)
548
+ def _load_embedding_model_async(self) -> None:
549
+ """Load embedding model for semantic asset search (background)."""
550
+ try:
551
+ self.asset_search.load_embedding_model()
552
+ if self.asset_search.has_semantic_search:
553
+ self.call_from_thread(
554
+ self.notify,
555
+ "Semantic asset search enabled",
556
+ severity="information",
557
+ timeout=3,
558
+ )
559
+ except Exception as exc:
560
+ logger.warning("Failed to load embedding model: %s", exc)
561
+
562
+ def _stop_workers(self) -> None:
563
+ """Signal all workers to stop."""
564
+ self._running = False
565
+
566
+ def on_unmount(self) -> None:
567
+ """Clean up on app shutdown."""
568
+ self._stop_workers()
569
+ logger.info("TradingApp shutting down...")
570
+ # Ensure we exit with code 0 for clean shutdown
571
+ self.exit(0)
572
+
573
+ # ── Background workers ─────────────────────────────────────────────────────
574
+
575
+ @work(thread=True, name="poll-prices", exclusive=False)
576
+ def _poll_prices(self) -> None:
577
+ """Continuously fetch latest prices for watchlist symbols."""
578
+ while self._running:
579
+ try:
580
+ interval = self.config.get("poll_interval_prices", 30)
581
+ if self.watchlist and self.adapter:
582
+ prices = self.adapter.get_latest_quotes_batch(self.watchlist)
583
+ if prices:
584
+ self._prices = prices
585
+ self.call_from_thread(self._on_prices_updated)
586
+ except Exception as exc:
587
+ logger.warning("Price poll error: %s", exc)
588
+ time.sleep(self.config.get("poll_interval_prices", 30))
589
+
590
+ @work(thread=True, name="poll-positions", exclusive=False)
591
+ def _poll_positions(self) -> None:
592
+ """Sync positions from Alpaca and update dashboard."""
593
+ while self._running:
594
+ try:
595
+ if self.adapter:
596
+ acct = self.adapter.get_account()
597
+ positions = self.adapter.get_positions()
598
+ self._portfolio_history.append(acct.portfolio_value)
599
+ if len(self._portfolio_history) > 1000:
600
+ self._portfolio_history = self._portfolio_history[-1000:]
601
+ self.call_from_thread(self._on_positions_updated, acct, positions)
602
+ except Exception as exc:
603
+ logger.warning("Position poll error: %s", exc)
604
+ time.sleep(self.config.get("poll_interval_positions", 60))
605
+
606
+ @work(thread=True, name="poll-signals", exclusive=False)
607
+ def _poll_signals(self) -> None:
608
+ """Generate trading signals and optionally execute auto-trades."""
609
+ debug_fast = self.config.get("debug_fast_cycle", False)
610
+ time.sleep(2 if debug_fast else 5)
611
+ logger.info("Signal poll worker started (debug_fast=%s)", debug_fast)
612
+ while self._running:
613
+ try:
614
+ self._run_signal_cycle()
615
+ except Exception as exc:
616
+ logger.warning("Signal cycle error: %s", exc)
617
+ interval = self.config.get("poll_interval_signals", 300)
618
+ if debug_fast:
619
+ interval = min(interval, 10) # Cap at 10s in debug mode
620
+ time.sleep(interval)
621
+
622
+ def _run_signal_cycle(self) -> None:
623
+ from trading_cli.data.market import fetch_ohlcv_yfinance, get_latest_quotes_batch
624
+ from trading_cli.data.news import fetch_headlines
625
+ from trading_cli.sentiment.aggregator import aggregate_scores
626
+ from trading_cli.sentiment.news_classifier import classify_headlines
627
+ from trading_cli.strategy.scanner import MarketScanner
628
+ from trading_cli.strategy.risk import check_max_drawdown
629
+ from trading_cli.data.db import save_signal
630
+
631
+ auto_enabled = self.config.get("auto_trading", False)
632
+ debug_fast = self.config.get("debug_fast_cycle", False)
633
+ cycle_time = datetime.now().strftime("%H:%M:%S")
634
+ logger.info("Running signal cycle at %s (auto_trading=%s, debug_fast=%s)", cycle_time, auto_enabled, debug_fast)
635
+
636
+ # Build event weight map
637
+ from trading_cli.sentiment.news_classifier import EventType, DEFAULT_WEIGHTS as EVENT_WEIGHTS
638
+ event_weights = {
639
+ EventType.EARNINGS: self.config.get("event_weight_earnings", EVENT_WEIGHTS[EventType.EARNINGS]),
640
+ EventType.EXECUTIVE: self.config.get("event_weight_executive", EVENT_WEIGHTS[EventType.EXECUTIVE]),
641
+ EventType.PRODUCT: self.config.get("event_weight_product", EVENT_WEIGHTS[EventType.PRODUCT]),
642
+ EventType.MACRO: self.config.get("event_weight_macro", EVENT_WEIGHTS[EventType.MACRO]),
643
+ EventType.GENERIC: self.config.get("event_weight_generic", EVENT_WEIGHTS[EventType.GENERIC]),
644
+ }
645
+
646
+ # Update dashboard with cycle time
647
+ self.call_from_thread(self._on_cycle_completed, cycle_time, auto_enabled)
648
+
649
+ # ── Phase 1: Get universe and batch fetch prices ────────────────────
650
+ scan_universe = auto_enabled and hasattr(self, 'asset_search') and self.asset_search.is_ready
651
+ if scan_universe:
652
+ all_assets = self.asset_search._assets
653
+ all_symbols = [a["symbol"] for a in all_assets]
654
+ # Filter: only US equities, price > $1, exclude ETFs/warrants
655
+ filtered = [s for s in all_symbols if not any(x in s for x in (".", "-WS", "-P", "-A"))]
656
+ symbols = filtered[:500] # Cap at 500 for performance
657
+ else:
658
+ symbols = list(self.watchlist)
659
+
660
+ if not symbols:
661
+ return
662
+
663
+ # Batch fetch latest prices for all symbols
664
+ try:
665
+ current_prices = get_latest_quotes_batch(self.adapter if not self.adapter.is_demo_mode else None, symbols)
666
+ except Exception as exc:
667
+ logger.warning("Batch price fetch failed: %s", exc)
668
+ return
669
+
670
+ logger.info("Fetched prices for %d symbols, %d have data", len(symbols), len(current_prices))
671
+
672
+ # ── Phase 2: Initialize scanner on first cycle ──────────────────────
673
+ if not hasattr(self, "_scanner"):
674
+ self._scanner = MarketScanner()
675
+
676
+ scanner = self._scanner
677
+
678
+ # ── Phase 3: Populate cache for symbols that don't have it yet ──────
679
+ # Fetch historical data for uncached symbols (in batches)
680
+ uncached = [s for s in symbols if scanner.get_cached(s) is None]
681
+ if uncached:
682
+ logger.info("Populating cache for %d new symbols", len(uncached))
683
+ batch_size = 10 if not debug_fast else 5
684
+ for i in range(0, len(uncached), batch_size):
685
+ batch = uncached[i:i + batch_size]
686
+ for sym in batch:
687
+ try:
688
+ ohlcv = fetch_ohlcv_yfinance(sym, days=60)
689
+ if not ohlcv.empty:
690
+ # Normalize columns
691
+ ohlcv.columns = [c.lower() for c in ohlcv.columns]
692
+ if "adj close" in ohlcv.columns:
693
+ ohlcv = ohlcv.rename(columns={"adj close": "adj_close"})
694
+ ohlcv = ohlcv.reset_index()
695
+ if "index" in ohlcv.columns:
696
+ ohlcv = ohlcv.rename(columns={"index": "date"})
697
+ scanner.save(sym, ohlcv)
698
+ except Exception as exc:
699
+ logger.debug("Cache populate failed for %s: %s", sym, exc)
700
+ if not debug_fast:
701
+ time.sleep(0.2) # Rate limit yfinance
702
+
703
+ # ── Phase 4: Update cache with latest prices ────────────────────────
704
+ for symbol, price in current_prices.items():
705
+ cached = scanner.get_cached(symbol)
706
+ if cached is not None and len(cached) > 0:
707
+ # Append/update today's bar
708
+ today = datetime.now().strftime("%Y-%m-%d")
709
+ last_bar = cached.iloc[-1]
710
+ bar = {
711
+ "date": today,
712
+ "open": last_bar.get("open", price),
713
+ "high": max(last_bar.get("high", price), price),
714
+ "low": min(last_bar.get("low", price), price),
715
+ "close": price,
716
+ "volume": last_bar.get("volume", 0),
717
+ }
718
+ scanner.append_bar(symbol, bar)
719
+
720
+ # ── Phase 5: Screen for breakout candidates ─────────────────────────
721
+ entry_period = self.config.get("entry_period", 20)
722
+ candidates = scanner.screen_breakouts(symbols, current_prices, entry_period)
723
+ logger.info("Breakout candidates: %d / %d scanned", len(candidates), len(symbols))
724
+
725
+ # ── Phase 6: Run full signal analysis on candidates ─────────────────
726
+ for symbol in candidates:
727
+ try:
728
+ ohlcv = scanner.get_cached(symbol)
729
+ if ohlcv is None or len(ohlcv) < 30:
730
+ continue
731
+
732
+ price = current_prices.get(symbol, 0)
733
+
734
+ # Run strategy analysis
735
+ signal_result = self.strategy.generate_signal(
736
+ symbol=symbol,
737
+ ohlcv=ohlcv,
738
+ sentiment_score=0.0, # Skip sentiment for speed
739
+ prices=current_prices,
740
+ positions=getattr(self, "_positions", []),
741
+ config=self.config,
742
+ )
743
+
744
+ if signal_result.action == "HOLD":
745
+ continue
746
+
747
+ # Build signal dict for DB/UI
748
+ signal = {
749
+ "symbol": symbol,
750
+ "action": signal_result.action,
751
+ "confidence": signal_result.confidence,
752
+ "hybrid_score": signal_result.score,
753
+ "technical_score": signal_result.metadata.get("sma_score", 0.0),
754
+ "sentiment_score": 0.0,
755
+ "reason": signal_result.reason,
756
+ "price": price or 0.0,
757
+ }
758
+ self._signals[symbol] = signal_result.action
759
+
760
+ save_signal(
761
+ self.db_conn,
762
+ symbol=symbol,
763
+ action=signal["action"],
764
+ confidence=signal["confidence"],
765
+ technical_score=signal["technical_score"],
766
+ sentiment_score=signal["sentiment_score"],
767
+ reason=signal["reason"],
768
+ )
769
+
770
+ self.call_from_thread(self._on_signal_generated, signal)
771
+
772
+ # Auto-execute if enabled
773
+ if auto_enabled and check_max_drawdown(self._portfolio_history, self.config.get("max_drawdown", 0.15)):
774
+ logger.info("Auto-trade %s signal for %s (confidence=%.2f)", signal_result.action, symbol, signal_result.confidence)
775
+ logger.info("Executing auto-trade: %s %s", signal_result.action, symbol)
776
+ self.call_from_thread(self._auto_execute, signal)
777
+
778
+ except Exception as exc:
779
+ logger.debug("Signal analysis failed for %s: %s", symbol, exc)
780
+
781
+ # ── Phase 7: Cleanup stale cache periodically ───────────────────────
782
+ cycle_count = getattr(self, '_signal_cycle_count', 0) + 1
783
+ self._signal_cycle_count = cycle_count
784
+ if cycle_count % 10 == 0: # Every 10th cycle
785
+ removed = scanner.cleanup_old_cache(max_age_days=7)
786
+ if removed > 0:
787
+ logger.info("Cleaned up %d stale cache files", removed)
788
+
789
+ # ── UI callbacks (called from thread via call_from_thread) ─────────────────
790
+
791
+ def _on_prices_updated(self) -> None:
792
+ try:
793
+ wl_screen = self.get_screen("watchlist")
794
+ if hasattr(wl_screen, "update_data"):
795
+ wl_screen.update_data(self._prices, self._sentiments, self._signals)
796
+ except Exception:
797
+ pass
798
+
799
+ def _on_cycle_completed(self, cycle_time: str, auto_enabled: bool) -> None:
800
+ """Called when a signal cycle completes (from worker thread)."""
801
+ try:
802
+ dash = self.get_screen("dashboard")
803
+ if hasattr(dash, "update_autotrade_status"):
804
+ dash.update_autotrade_status(auto_enabled, cycle_time)
805
+ except Exception:
806
+ pass
807
+
808
+ def _on_autotrade_error(self, error_msg: str) -> None:
809
+ """Called when auto-trade encounters an error."""
810
+ try:
811
+ dash = self.get_screen("dashboard")
812
+ if hasattr(dash, "update_autotrade_status"):
813
+ dash.update_autotrade_status(error=error_msg)
814
+ except Exception:
815
+ pass
816
+
817
+ def _on_autotrade_blocked(self, reason: str) -> None:
818
+ """Called when auto-trade is blocked by risk management."""
819
+ try:
820
+ self.notify(reason, severity="warning", timeout=5)
821
+ except Exception:
822
+ pass
823
+
824
+ def _on_positions_updated(self, acct, positions: list) -> None:
825
+ try:
826
+ dash = self.get_screen("dashboard")
827
+ if hasattr(dash, "refresh_positions"):
828
+ dash.refresh_positions(positions)
829
+ if hasattr(dash, "refresh_account"):
830
+ dash.refresh_account(acct)
831
+ except Exception:
832
+ pass
833
+
834
+ def _on_signal_generated(self, signal: dict) -> None:
835
+ try:
836
+ dash = self.get_screen("dashboard")
837
+ if hasattr(dash, "log_signal"):
838
+ dash.log_signal(signal)
839
+ except Exception:
840
+ pass
841
+
842
+ def _auto_execute(self, signal: dict) -> None:
843
+ """Execute a signal automatically (auto_trading=True) with full risk management."""
844
+ symbol = signal["symbol"]
845
+ action = signal["action"]
846
+ price = signal.get("price", 0.0)
847
+
848
+ from trading_cli.strategy.risk import (
849
+ calculate_position_size,
850
+ validate_buy,
851
+ validate_sell,
852
+ check_stop_loss,
853
+ )
854
+
855
+ try:
856
+ acct = self.adapter.get_account()
857
+ positions = self.adapter.get_positions()
858
+ positions_dict = {p.symbol: {"qty": p.qty, "avg_entry_price": p.avg_entry_price} for p in positions}
859
+
860
+ if action == "BUY":
861
+ ok, reason = validate_buy(
862
+ symbol, price, 1, acct.cash, positions_dict,
863
+ max_positions=self.config.get("max_positions", 10),
864
+ )
865
+ if not ok:
866
+ logger.warning("Auto-buy blocked: %s", reason)
867
+ self.call_from_thread(
868
+ self._on_autotrade_blocked,
869
+ f"Auto-buy {symbol} blocked: {reason}"
870
+ )
871
+ return
872
+
873
+ elif action == "SELL":
874
+ # Check stop-loss for existing position
875
+ pos = positions_dict.get(symbol)
876
+ if pos:
877
+ entry_price = pos.get("avg_entry_price", 0)
878
+ if check_stop_loss(entry_price, price, self.config.get("stop_loss_pct", 0.05)):
879
+ self.notify(f"Stop-loss triggered for {symbol} @ ${price:.2f}", severity="warning")
880
+
881
+ ok, reason = validate_sell(symbol, 1, positions_dict)
882
+ if not ok:
883
+ logger.warning("Auto-sell blocked: %s", reason)
884
+ self.call_from_thread(
885
+ self._on_autotrade_blocked,
886
+ f"Auto-sell {symbol} blocked: {reason}"
887
+ )
888
+ return
889
+
890
+ qty = calculate_position_size(
891
+ acct.portfolio_value,
892
+ price or 1.0,
893
+ risk_pct=self.config.get("risk_pct", 0.02),
894
+ max_position_pct=0.10,
895
+ )
896
+ if qty < 1:
897
+ logger.info(f"Auto-trade skipped: calculated qty < 1 for {symbol}")
898
+ return
899
+
900
+ result = self.adapter.submit_market_order(symbol, qty, action)
901
+ if result.status not in ("rejected",):
902
+ from trading_cli.data.db import save_trade
903
+ save_trade(
904
+ self.db_conn, symbol, action,
905
+ result.filled_price or price, qty,
906
+ order_id=result.order_id,
907
+ reason=f"Auto: {signal['reason']}",
908
+ )
909
+ self.notify(
910
+ f"AUTO {action} {qty} {symbol} @ ${result.filled_price or price:.2f}",
911
+ timeout=5,
912
+ )
913
+ else:
914
+ logger.warning(f"Auto-trade rejected: {symbol} {action}")
915
+ self.call_from_thread(
916
+ self._on_autotrade_blocked,
917
+ f"Order rejected for {symbol} {action}"
918
+ )
919
+ except Exception as exc:
920
+ logger.error("Auto-execute error: %s", exc)
921
+ self.call_from_thread(
922
+ self._on_autotrade_error,
923
+ f"Auto-execute failed: {exc}"
924
+ )
925
+
926
+ # ── Manual order execution ─────────────────────────────────────────────────
927
+
928
+ def execute_manual_order(
929
+ self, symbol: str, action: str, qty: int, price: float, reason: str
930
+ ) -> None:
931
+ """Called from screens to submit a manual order with confirmation dialog."""
932
+
933
+ def on_confirm(confirmed: bool) -> None:
934
+ if not confirmed:
935
+ return
936
+ try:
937
+ result = self.adapter.submit_market_order(symbol, qty, action)
938
+ if result.status not in ("rejected",):
939
+ from trading_cli.data.db import save_trade
940
+ save_trade(
941
+ self.db_conn, symbol, action,
942
+ result.filled_price or price, qty,
943
+ order_id=result.order_id,
944
+ reason=reason,
945
+ )
946
+ self.notify(
947
+ f"{action} {qty} {symbol} @ ${result.filled_price or price:.2f} [{result.status}]"
948
+ )
949
+ else:
950
+ self.notify(f"Order rejected for {symbol}", severity="error")
951
+ except Exception as exc:
952
+ self.notify(f"Order failed: {exc}", severity="error")
953
+
954
+ self.push_screen(OrderConfirmScreen(symbol, action, qty, price, reason), callback=on_confirm)
955
+
956
+ # ── Watchlist helpers ──────────────────────────────────────────────────────
957
+
958
+ def add_to_watchlist(self, symbol: str) -> None:
959
+ if symbol not in self.watchlist:
960
+ self.watchlist.append(symbol)
961
+ if self.db_conn:
962
+ from trading_cli.data.db import add_to_watchlist
963
+ add_to_watchlist(self.db_conn, symbol)
964
+ self.notify(f"Added {symbol} to watchlist")
965
+
966
+ def remove_from_watchlist(self, symbol: str) -> None:
967
+ if symbol in self.watchlist:
968
+ self.watchlist.remove(symbol)
969
+ if self.db_conn:
970
+ from trading_cli.data.db import remove_from_watchlist
971
+ remove_from_watchlist(self.db_conn, symbol)
972
+ self.notify(f"Removed {symbol} from watchlist")
973
+
974
+ # ── Screen actions ─────────────────────────────────────────────────────────
975
+
976
+ def action_show_dashboard(self) -> None:
977
+ self.push_screen("dashboard")
978
+
979
+ def action_show_watchlist(self) -> None:
980
+ self.push_screen("watchlist")
981
+
982
+ def action_show_portfolio(self) -> None:
983
+ self.push_screen("portfolio")
984
+
985
+ def action_show_trades(self) -> None:
986
+ self.push_screen("trades")
987
+
988
+ def action_show_sentiment(self) -> None:
989
+ self.push_screen("sentiment")
990
+
991
+ def action_show_config(self) -> None:
992
+ self.push_screen("config")
993
+
994
+ def action_show_backtest(self) -> None:
995
+ self.push_screen("backtest")
trading_cli/backtest/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from trading_cli.backtest.engine import BacktestEngine, BacktestResult, BacktestTrade
2
+
3
+ __all__ = ["BacktestEngine", "BacktestResult", "BacktestTrade"]
trading_cli/backtest/engine.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Backtesting framework — simulates trades using historical OHLCV + sentiment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from dataclasses import dataclass, field
7
+ from datetime import datetime, timedelta
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ from trading_cli.sentiment.aggregator import aggregate_scores_weighted
14
+ from trading_cli.sentiment.news_classifier import classify_headlines, EventType
15
+ from trading_cli.strategy.signals import generate_signal, technical_score
16
+ from trading_cli.strategy.risk import calculate_position_size, check_stop_loss, check_max_drawdown
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class BacktestTrade:
23
+ timestamp: str
24
+ symbol: str
25
+ action: str # BUY or SELL
26
+ price: float
27
+ qty: int
28
+ reason: str
29
+ pnl: float = 0.0
30
+
31
+
32
+ @dataclass
33
+ class BacktestResult:
34
+ symbol: str
35
+ start_date: str
36
+ end_date: str
37
+ initial_capital: float
38
+ final_equity: float
39
+ total_return_pct: float
40
+ max_drawdown_pct: float
41
+ sharpe_ratio: float
42
+ win_rate: float
43
+ total_trades: int
44
+ winning_trades: int
45
+ losing_trades: int
46
+ trades: list[BacktestTrade] = field(default_factory=list)
47
+ equity_curve: list[float] = field(default_factory=list)
48
+
49
+ def summary_dict(self) -> dict:
50
+ return {
51
+ "symbol": self.symbol,
52
+ "period": f"{self.start_date} to {self.end_date}",
53
+ "initial_capital": f"${self.initial_capital:,.2f}",
54
+ "final_equity": f"${self.final_equity:,.2f}",
55
+ "total_return": f"{self.total_return_pct:+.2f}%",
56
+ "max_drawdown": f"{self.max_drawdown_pct:.2f}%",
57
+ "sharpe_ratio": f"{self.sharpe_ratio:.2f}",
58
+ "win_rate": f"{self.win_rate:.1f}%",
59
+ "total_trades": self.total_trades,
60
+ "winning_trades": self.winning_trades,
61
+ "losing_trades": self.losing_trades,
62
+ }
63
+
64
+
65
+ class BacktestEngine:
66
+ """Runs historical simulation using the same signal pipeline as live trading."""
67
+
68
+ def __init__(
69
+ self,
70
+ config: dict,
71
+ finbert=None,
72
+ news_fetcher=None,
73
+ use_sentiment: bool = True,
74
+ strategy=None,
75
+ progress_callback=None,
76
+ debug: bool = False,
77
+ ):
78
+ """
79
+ Args:
80
+ config: Trading configuration dict.
81
+ finbert: FinBERTAnalyzer instance (or None to skip sentiment).
82
+ news_fetcher: Callable(symbol, days_ago) -> list[tuple[str, float]]
83
+ Returns list of (headline, unix_timestamp) tuples.
84
+ use_sentiment: If False, skip all sentiment scoring regardless of
85
+ whether finbert/news_fetcher are provided.
86
+ strategy: StrategyAdapter instance. If None, falls back to legacy
87
+ hardcoded technical + sentiment pipeline.
88
+ progress_callback: Optional callable(str) to report progress.
89
+ debug: If True, log every bar's signal details at INFO level.
90
+ """
91
+ self.config = config
92
+ self.finbert = finbert
93
+ self.news_fetcher = news_fetcher
94
+ self.use_sentiment = use_sentiment
95
+ self.strategy = strategy
96
+ self.progress_callback = progress_callback
97
+ self.debug = debug
98
+ # Force INFO level on this logger when debug is enabled
99
+ if debug:
100
+ logger.setLevel(logging.INFO)
101
+
102
+ def run(
103
+ self,
104
+ symbol: str,
105
+ ohlcv: pd.DataFrame,
106
+ start_date: str | None = None,
107
+ end_date: str | None = None,
108
+ initial_capital: float = 100_000.0,
109
+ ) -> BacktestResult:
110
+ """
111
+ Run backtest on historical OHLCV data.
112
+
113
+ Simulates daily signal generation and order execution at next day's open.
114
+ """
115
+ df = ohlcv.copy()
116
+ # Handle both column-based and index-based dates
117
+ if "Date" in df.columns or "date" in df.columns:
118
+ date_col = "Date" if "Date" in df.columns else "date"
119
+ df[date_col] = pd.to_datetime(df[date_col])
120
+ df = df.set_index(date_col)
121
+
122
+ # Handle timezone mismatch for date range filtering
123
+ # Alpaca data is UTC-aware, while start_date/end_date from UI are naive
124
+ if start_date:
125
+ sd = pd.Timestamp(start_date)
126
+ if df.index.tz is not None:
127
+ sd = sd.tz_localize(df.index.tz)
128
+ df = df[df.index >= sd]
129
+ if end_date:
130
+ ed = pd.Timestamp(end_date)
131
+ if df.index.tz is not None:
132
+ ed = ed.tz_localize(df.index.tz)
133
+ df = df[df.index <= ed]
134
+
135
+ # Reset index to get date back as a column for downstream code
136
+ # Ensure we name the date column 'date' regardless of the index name
137
+ df = df.reset_index()
138
+ # If the index had a name (e.g. 'timestamp'), it will be the first column
139
+ # Otherwise it's named 'index'
140
+ if "index" in df.columns:
141
+ df = df.rename(columns={"index": "date"})
142
+ elif df.columns[0] != "date":
143
+ df = df.rename(columns={df.columns[0]: "date"})
144
+
145
+ # Normalize column names to lowercase for consistent access
146
+ # yfinance can return MultiIndex columns (tuples), so flatten them first
147
+ if isinstance(df.columns, pd.MultiIndex):
148
+ df.columns = [c[0] for c in df.columns]
149
+ df.columns = [c.lower() for c in df.columns]
150
+ if "adj close" in df.columns:
151
+ df = df.rename(columns={"adj close": "adj_close"})
152
+
153
+ logger.info("Backtest %s: %d bars, columns: %s", symbol, len(df), list(df.columns))
154
+
155
+ if len(df) < 60:
156
+ logger.warning("Backtest %s: not enough data (%d bars, need 60+)", symbol, len(df))
157
+ date_col = "date" if "date" in df.columns else None
158
+ start_str = str(df.iloc[0][date_col])[:10] if date_col and len(df) > 0 else "N/A"
159
+ end_str = str(df.iloc[-1][date_col])[:10] if date_col and len(df) > 0 else "N/A"
160
+ return BacktestResult(
161
+ symbol=symbol,
162
+ start_date=start_str,
163
+ end_date=end_str,
164
+ initial_capital=initial_capital,
165
+ final_equity=initial_capital,
166
+ total_return_pct=0.0,
167
+ max_drawdown_pct=0.0,
168
+ sharpe_ratio=0.0,
169
+ win_rate=0.0,
170
+ total_trades=0,
171
+ winning_trades=0,
172
+ losing_trades=0,
173
+ )
174
+
175
+ cash = initial_capital
176
+ position_qty = 0
177
+ position_avg_price = 0.0
178
+ equity_curve = [initial_capital]
179
+ trades: list[BacktestTrade] = []
180
+ equity_values = [initial_capital]
181
+
182
+ # Normalize column names to lowercase for consistent access
183
+ # yfinance can return MultiIndex columns (tuples), so flatten them first
184
+ if isinstance(df.columns, pd.MultiIndex):
185
+ df.columns = [c[0] for c in df.columns]
186
+ df.columns = [c.lower() for c in df.columns]
187
+ if "adj close" in df.columns:
188
+ df = df.rename(columns={"adj close": "adj_close"})
189
+
190
+ logger.info("Backtest %s: %d bars, columns: %s", symbol, len(df), list(df.columns))
191
+
192
+ if len(df) < 60:
193
+ logger.warning("Backtest %s: not enough data (%d bars, need 60+)", symbol, len(df))
194
+
195
+ # Config params
196
+ buy_threshold = self.config.get("signal_buy_threshold", 0.5)
197
+ sell_threshold = self.config.get("signal_sell_threshold", -0.3)
198
+ sma_short = self.config.get("sma_short", 20)
199
+ sma_long = self.config.get("sma_long", 50)
200
+ rsi_period = self.config.get("rsi_period", 14)
201
+ bb_window = self.config.get("bb_window", 20)
202
+ bb_std = self.config.get("bb_std", 2.0)
203
+ ema_fast = self.config.get("ema_fast", 12)
204
+ ema_slow = self.config.get("ema_slow", 26)
205
+ vol_window = self.config.get("volume_window", 20)
206
+ tech_weight = self.config.get("tech_weight", 0.6)
207
+ sent_weight = self.config.get("sent_weight", 0.4)
208
+ risk_pct = self.config.get("risk_pct", 0.02)
209
+ max_dd = self.config.get("max_drawdown", 0.15)
210
+ stop_loss_pct = self.config.get("stop_loss_pct", 0.05)
211
+
212
+ tech_weights = {
213
+ "sma": self.config.get("weight_sma", 0.25),
214
+ "rsi": self.config.get("weight_rsi", 0.25),
215
+ "bb": self.config.get("weight_bb", 0.20),
216
+ "ema": self.config.get("weight_ema", 0.15),
217
+ "volume": self.config.get("weight_volume", 0.15),
218
+ }
219
+
220
+ # ── Pre-fetch and cache all sentiment scores ──────────────────────
221
+ lookback = max(sma_long, ema_slow, bb_window, vol_window) + 30
222
+ logger.info("Backtest %s: lookback=%d, total_bars=%d", symbol, lookback, len(df) - lookback)
223
+ sent_scores = {}
224
+ if self.use_sentiment and self.finbert and self.news_fetcher:
225
+ total_days = len(df) - lookback
226
+ try:
227
+ # Fetch all news once (batch)
228
+ if self.progress_callback:
229
+ self.progress_callback("Fetching historical news…")
230
+ all_news = self.news_fetcher(symbol, days_ago=len(df))
231
+ if all_news:
232
+ headlines = [item[0] for item in all_news]
233
+ timestamps = [item[1] for item in all_news]
234
+ classifications = classify_headlines(headlines)
235
+ # Analyze all headlines at once
236
+ if self.progress_callback:
237
+ self.progress_callback("Analyzing sentiment (batch)…")
238
+ results = self.finbert.analyze_batch(headlines)
239
+ # Single aggregated score for the whole period
240
+ cached_score = aggregate_scores_weighted(
241
+ results, classifications, timestamps=timestamps
242
+ )
243
+ # Apply same score to all bars (since we fetched once)
244
+ for i in range(lookback, len(df)):
245
+ sent_scores[i] = cached_score
246
+ except Exception as exc:
247
+ import logging
248
+ logging.getLogger(__name__).warning("Sentiment pre-fetch failed: %s", exc)
249
+ sent_scores = {}
250
+
251
+ # ── Walk forward through data ─────────────────────────────────────
252
+ total_bars = len(df) - lookback
253
+ if self.progress_callback:
254
+ self.progress_callback("Running simulation…")
255
+ for idx, i in enumerate(range(lookback, len(df))):
256
+ if self.progress_callback and idx % 20 == 0:
257
+ pct = int(idx / total_bars * 100) if total_bars else 0
258
+ self.progress_callback(f"Running simulation… {pct}%")
259
+
260
+ historical_ohlcv = df.iloc[:i]
261
+ current_bar = df.iloc[i]
262
+ current_price = float(current_bar["close"])
263
+ current_date = str(current_bar.get("date", ""))
264
+
265
+ # Use pre-cached sentiment score
266
+ sent_score = sent_scores.get(i, 0.0)
267
+
268
+ # Max drawdown check
269
+ if check_max_drawdown(equity_values, max_dd):
270
+ break # Stop backtest if drawdown exceeded
271
+
272
+ # Build mock position object for strategy adapter
273
+ class _MockPosition:
274
+ def __init__(self, symbol, qty, avg_price):
275
+ self.symbol = symbol
276
+ self.qty = qty
277
+ self.avg_entry_price = avg_price
278
+
279
+ backtest_positions = [_MockPosition(symbol, position_qty, position_avg_price)] if position_qty > 0 else []
280
+
281
+ # Generate signal — use strategy adapter if available, else legacy
282
+ if self.strategy is not None:
283
+ # Use strategy adapter
284
+ signal_result = self.strategy.generate_signal(
285
+ symbol=symbol,
286
+ ohlcv=historical_ohlcv,
287
+ sentiment_score=sent_score,
288
+ positions=backtest_positions,
289
+ config=self.config,
290
+ )
291
+ action = signal_result.action
292
+ score = signal_result.score
293
+ reason = signal_result.reason
294
+ buy_threshold = self.config.get("signal_buy_threshold", 0.5)
295
+ sell_threshold = self.config.get("signal_sell_threshold", -0.3)
296
+ if self.debug:
297
+ logger.info(
298
+ "Bar %d | %s | price=%.2f | score=%.3f | action=%s | reason=%s",
299
+ idx, current_date, current_price, score, action, reason,
300
+ )
301
+ else:
302
+ # Legacy hardcoded technical + sentiment
303
+ tech = technical_score(
304
+ historical_ohlcv, sma_short, sma_long, rsi_period,
305
+ bb_window, bb_std, ema_fast, ema_slow, vol_window,
306
+ tech_weights,
307
+ )
308
+ # Normalize hybrid score: if sentiment is absent (0.0),
309
+ # use tech alone so buy/sell thresholds remain reachable
310
+ if sent_score == 0.0:
311
+ hybrid = tech
312
+ else:
313
+ hybrid = tech_weight * tech + sent_weight * sent_score
314
+ score = hybrid
315
+ if hybrid >= buy_threshold:
316
+ action = "BUY"
317
+ elif hybrid <= sell_threshold:
318
+ action = "SELL"
319
+ else:
320
+ action = "HOLD"
321
+ reason = f"hybrid={hybrid:.3f} tech={tech:.3f}"
322
+ if self.debug:
323
+ logger.info(
324
+ "Bar %d | %s | price=%.2f | tech=%.3f | sent=%.3f | hybrid=%.3f | action=%s",
325
+ idx, current_date, current_price, tech, sent_score, hybrid, action,
326
+ )
327
+
328
+ if action == "BUY" and position_qty == 0:
329
+ qty = calculate_position_size(
330
+ cash + position_qty * position_avg_price,
331
+ current_price,
332
+ risk_pct=risk_pct,
333
+ max_position_pct=self.config.get("max_position_pct", 0.10),
334
+ )
335
+ if qty > 0 and cash >= qty * current_price:
336
+ cost = qty * current_price
337
+ cash -= cost
338
+ total_shares = position_qty + qty
339
+ position_avg_price = (
340
+ (position_avg_price * position_qty + current_price * qty) / total_shares
341
+ )
342
+ position_qty = total_shares
343
+
344
+ trades.append(BacktestTrade(
345
+ timestamp=current_date,
346
+ symbol=symbol,
347
+ action="BUY",
348
+ price=current_price,
349
+ qty=qty,
350
+ reason=reason,
351
+ ))
352
+ if self.debug:
353
+ logger.info(
354
+ " >>> BUY %d @ %.2f (cost=%.2f, cash=%.2f, pos=%d)",
355
+ qty, current_price, cost, cash, position_qty,
356
+ )
357
+ elif self.debug:
358
+ logger.info(
359
+ " >>> BUY blocked: qty=%d, cash=%.2f, need=%.2f",
360
+ qty, cash, qty * current_price,
361
+ )
362
+
363
+ elif action == "SELL" and position_qty > 0:
364
+ sell_reason = reason
365
+ if check_stop_loss(position_avg_price, current_price, stop_loss_pct):
366
+ sell_reason = f"stop-loss ({reason})"
367
+
368
+ proceeds = position_qty * current_price
369
+ pnl = (current_price - position_avg_price) * position_qty
370
+ cash += proceeds
371
+
372
+ trades.append(BacktestTrade(
373
+ timestamp=current_date,
374
+ symbol=symbol,
375
+ action="SELL",
376
+ price=current_price,
377
+ qty=position_qty,
378
+ reason=sell_reason,
379
+ pnl=pnl,
380
+ ))
381
+
382
+ if self.debug:
383
+ logger.info(
384
+ " >>> SELL %d @ %.2f (pnl=%.2f, proceeds=%.2f, cash=%.2f)",
385
+ position_qty, current_price, pnl, proceeds, cash,
386
+ )
387
+
388
+ position_qty = 0
389
+ position_avg_price = 0.0
390
+
391
+ # Track equity
392
+ equity = cash + position_qty * current_price
393
+ equity_curve.append(equity)
394
+ equity_values.append(equity)
395
+
396
+ # Close any remaining position at last price
397
+ if position_qty > 0 and len(df) > 0:
398
+ last_price = float(df.iloc[-1]["close"])
399
+ last_date = str(df.iloc[-1]["date"])[:10]
400
+ pnl = (last_price - position_avg_price) * position_qty
401
+ cash += position_qty * last_price
402
+ trades.append(BacktestTrade(
403
+ timestamp=last_date,
404
+ symbol=symbol,
405
+ action="SELL",
406
+ price=last_price,
407
+ qty=position_qty,
408
+ reason="end of backtest",
409
+ pnl=pnl,
410
+ ))
411
+ position_qty = 0
412
+
413
+ final_equity = cash
414
+ total_return = ((final_equity - initial_capital) / initial_capital) * 100
415
+ logger.info("Backtest %s: %d trades, return=%.2f%%", symbol, len(trades), total_return)
416
+
417
+ # Compute metrics
418
+ peak = equity_values[0]
419
+ max_dd_actual = 0.0
420
+ for val in equity_values:
421
+ if val > peak:
422
+ peak = val
423
+ dd = (peak - val) / peak if peak > 0 else 0
424
+ max_dd_actual = max(max_dd_actual, dd)
425
+
426
+ # Win rate
427
+ sell_trades = [t for t in trades if t.action == "SELL"]
428
+ winning = sum(1 for t in sell_trades if t.pnl > 0)
429
+ losing = sum(1 for t in sell_trades if t.pnl < 0)
430
+ win_rate = (winning / len(sell_trades) * 100) if sell_trades else 0.0
431
+
432
+ # Sharpe ratio (daily returns)
433
+ if len(equity_values) > 1:
434
+ returns = np.diff(equity_values) / equity_values[:-1]
435
+ sharpe = (np.mean(returns) / np.std(returns) * np.sqrt(252)) if np.std(returns) > 0 else 0.0
436
+ else:
437
+ sharpe = 0.0
438
+
439
+ return BacktestResult(
440
+ symbol=symbol,
441
+ start_date=str(df.iloc[0]["date"])[:10] if len(df) > 0 else "N/A",
442
+ end_date=str(df.iloc[-1]["date"])[:10] if len(df) > 0 else "N/A",
443
+ initial_capital=initial_capital,
444
+ final_equity=final_equity,
445
+ total_return_pct=total_return,
446
+ max_drawdown_pct=max_dd_actual * 100,
447
+ sharpe_ratio=sharpe,
448
+ win_rate=win_rate,
449
+ total_trades=len(trades),
450
+ winning_trades=winning,
451
+ losing_trades=losing,
452
+ trades=trades,
453
+ equity_curve=equity_curve,
454
+ )
trading_cli/config.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management — stores settings in ~/.config/trading-cli/config.toml."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import toml
6
+ from pathlib import Path
7
+
8
+ CONFIG_DIR = Path("~/.config/trading-cli").expanduser()
9
+ CONFIG_PATH = CONFIG_DIR / "config.toml"
10
+ DB_PATH = CONFIG_DIR / "trades.db"
11
+
12
+ DEFAULT_CONFIG: dict = {
13
+ "alpaca_api_key": "",
14
+ "alpaca_api_secret": "",
15
+ "alpaca_paper": True,
16
+ "adapter_id": "alpaca",
17
+ "auto_trading": True,
18
+ "sentiment_model": "finbert",
19
+ "strategy_id": "trend_following",
20
+ "risk_pct": 0.02,
21
+ "max_drawdown": 0.15,
22
+ "stop_loss_pct": 0.05,
23
+ "max_positions": 10,
24
+ "default_symbols": ["AAPL", "TSLA", "NVDA"],
25
+ "poll_interval_prices": 30,
26
+ "poll_interval_news": 900,
27
+ "poll_interval_signals": 60,
28
+ "poll_interval_positions": 60,
29
+ "initial_cash": 100000.0,
30
+ "finbert_batch_size": 50,
31
+ "debug_fast_cycle": True,
32
+ "sma_short": 20,
33
+ "sma_long": 50,
34
+ "rsi_period": 14,
35
+ "signal_buy_threshold": 0.15,
36
+ "signal_sell_threshold": -0.15,
37
+ "position_size_warning": 1000.0,
38
+ # ── Strategy weights ──────────────────────────────────────────────────────
39
+ "tech_weight": 0.6,
40
+ "sent_weight": 0.4,
41
+ # ── Technical indicator weights ───────────────────────────────────────────
42
+ "weight_sma": 0.25,
43
+ "weight_rsi": 0.25,
44
+ "weight_bb": 0.20,
45
+ "weight_ema": 0.15,
46
+ "weight_volume": 0.15,
47
+ # ── Bollinger Bands ───────────────────────────────────────────────────────
48
+ "bb_window": 20,
49
+ "bb_std": 2.0,
50
+ # ── EMA periods ───────────────────────────────────────────────────────────
51
+ "ema_fast": 12,
52
+ "ema_slow": 26,
53
+ # ── Volume SMA window ─────────────────────────────────────────────────────
54
+ "volume_window": 20,
55
+ # ── Sentiment event weights ───────────────────────────────────────────────
56
+ "event_weight_earnings": 1.5,
57
+ "event_weight_executive": 1.3,
58
+ "event_weight_product": 1.2,
59
+ "event_weight_macro": 1.4,
60
+ "event_weight_generic": 0.8,
61
+ "sentiment_half_life_hours": 24.0,
62
+ }
63
+
64
+
65
+ def load_config() -> dict:
66
+ """Load config from disk, creating defaults if absent."""
67
+ CONFIG_DIR.mkdir(parents=True, exist_ok=True)
68
+ if not CONFIG_PATH.exists():
69
+ save_config(DEFAULT_CONFIG)
70
+ return dict(DEFAULT_CONFIG)
71
+ with open(CONFIG_PATH) as f:
72
+ on_disk = toml.load(f)
73
+ merged = dict(DEFAULT_CONFIG)
74
+ merged.update(on_disk)
75
+ return merged
76
+
77
+
78
+ def save_config(config: dict) -> None:
79
+ """Persist config to disk."""
80
+ CONFIG_DIR.mkdir(parents=True, exist_ok=True)
81
+ with open(CONFIG_PATH, "w") as f:
82
+ toml.dump(config, f)
83
+
84
+
85
+ def get_db_path() -> Path:
86
+ CONFIG_DIR.mkdir(parents=True, exist_ok=True)
87
+ return DB_PATH
88
+
89
+
90
+ def is_demo_mode(config: dict) -> bool:
91
+ """True if Alpaca keys are not configured."""
92
+ return not (config.get("alpaca_api_key") and config.get("alpaca_api_secret"))
trading_cli/data/asset_search.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Asset search with embedding-based semantic autocomplete."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ import threading
9
+ from pathlib import Path
10
+ from typing import TYPE_CHECKING
11
+
12
+ if TYPE_CHECKING:
13
+ from trading_cli.execution.adapters.alpaca import AlpacaAdapter
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class AssetSearchEngine:
19
+ """Searchable asset index with optional semantic embeddings.
20
+
21
+ Supports:
22
+ - Symbol search (e.g., "AAPL")
23
+ - Company name search (e.g., "Apple")
24
+ - Fuzzy/partial matching (e.g., "appl" → Apple)
25
+ - Semantic search via embeddings (optional, requires sentence-transformers)
26
+ """
27
+
28
+ def __init__(self, cache_dir: Path | None = None):
29
+ self._assets: list[dict[str, str]] = []
30
+ self._symbol_index: dict[str, dict[str, str]] = {}
31
+ self._lock = threading.Lock()
32
+ self._cache_dir = cache_dir or Path.home() / ".cache" / "trading_cli"
33
+ self._cache_file = self._cache_dir / "assets.json"
34
+ self._embeddings = None
35
+ self._embedding_model = None
36
+ self._initialized = False
37
+
38
+ def load_assets(self, adapter: AlpacaAdapter) -> int:
39
+ """Load assets from adapter (with caching).
40
+
41
+ Returns:
42
+ Number of assets loaded.
43
+ """
44
+ # Try cache first
45
+ if self._load_from_cache():
46
+ logger.info("Loaded %d assets from cache", len(self._assets))
47
+ self._initialized = True
48
+ return len(self._assets)
49
+
50
+ # Fetch from adapter
51
+ try:
52
+ assets = adapter.get_all_assets()
53
+ if assets:
54
+ with self._lock:
55
+ self._assets = assets
56
+ self._symbol_index = {
57
+ asset["symbol"].upper(): asset for asset in assets
58
+ }
59
+ self._save_to_cache()
60
+ logger.info("Loaded %d assets from adapter", len(assets))
61
+ self._initialized = True
62
+ return len(assets)
63
+ except Exception as exc:
64
+ logger.warning("Failed to load assets: %s", exc)
65
+
66
+ return 0
67
+
68
+ def _load_from_cache(self) -> bool:
69
+ """Load cached assets. Returns True if successful."""
70
+ if not self._cache_file.exists():
71
+ return False
72
+ try:
73
+ data = json.loads(self._cache_file.read_text())
74
+ with self._lock:
75
+ self._assets = data["assets"]
76
+ self._symbol_index = {
77
+ asset["symbol"].upper(): asset for asset in self._assets
78
+ }
79
+ return True
80
+ except Exception as exc:
81
+ logger.warning("Cache load failed: %s", exc)
82
+ return False
83
+
84
+ def _save_to_cache(self) -> None:
85
+ """Save assets to cache."""
86
+ try:
87
+ self._cache_dir.mkdir(parents=True, exist_ok=True)
88
+ self._cache_file.write_text(
89
+ json.dumps({"assets": self._assets}, indent=2)
90
+ )
91
+ except Exception as exc:
92
+ logger.warning("Cache save failed: %s", exc)
93
+
94
+ def search(
95
+ self,
96
+ query: str,
97
+ max_results: int = 10,
98
+ use_semantic: bool = True,
99
+ ) -> list[dict[str, str]]:
100
+ """Search assets by symbol or company name.
101
+
102
+ Args:
103
+ query: Search query (symbol fragment or company name).
104
+ max_results: Maximum number of results to return.
105
+ use_semantic: Whether to use semantic embeddings if available.
106
+
107
+ Returns:
108
+ List of dicts with 'symbol', 'name', and optionally 'score'.
109
+ """
110
+ if not query.strip():
111
+ return []
112
+
113
+ query_upper = query.upper().strip()
114
+ query_lower = query.lower().strip()
115
+
116
+ results: list[dict[str, str]] = []
117
+
118
+ with self._lock:
119
+ # Exact symbol match (highest priority)
120
+ if query_upper in self._symbol_index:
121
+ asset = self._symbol_index[query_upper]
122
+ results.append({
123
+ "symbol": asset["symbol"],
124
+ "name": asset["name"],
125
+ "score": 1.0,
126
+ })
127
+ if len(results) >= max_results:
128
+ return results
129
+
130
+ # Text-based matching (symbol prefix or name substring)
131
+ for asset in self._assets:
132
+ symbol = asset["symbol"]
133
+ name = asset.get("name", "")
134
+
135
+ # Symbol starts with query
136
+ if symbol.upper().startswith(query_upper):
137
+ score = 0.9 if symbol.upper() == query_upper else 0.8
138
+ results.append({
139
+ "symbol": symbol,
140
+ "name": name,
141
+ "score": score,
142
+ })
143
+ if len(results) >= max_results:
144
+ return results
145
+
146
+ # Name contains query (case-insensitive)
147
+ if len(results) < max_results and len(query_lower) >= 2:
148
+ for asset in self._assets:
149
+ name = asset.get("name", "")
150
+ if query_lower in name.lower():
151
+ # Check not already in results
152
+ if not any(r["symbol"] == asset["symbol"] for r in results):
153
+ results.append({
154
+ "symbol": asset["symbol"],
155
+ "name": name,
156
+ "score": 0.7,
157
+ })
158
+ if len(results) >= max_results:
159
+ return results
160
+
161
+ # Semantic search (optional, for fuzzy matching)
162
+ if use_semantic and len(results) < max_results:
163
+ semantic_results = self._search_semantic(query, max_results - len(results))
164
+ # Merge, avoiding duplicates
165
+ existing_symbols = {r["symbol"] for r in results}
166
+ for sr in semantic_results:
167
+ if sr["symbol"] not in existing_symbols:
168
+ results.append(sr)
169
+ if len(results) >= max_results:
170
+ break
171
+
172
+ return results[:max_results]
173
+
174
+ def _search_semantic(
175
+ self,
176
+ query: str,
177
+ max_results: int,
178
+ ) -> list[dict[str, str]]:
179
+ """Search using semantic similarity (requires embeddings)."""
180
+ if not self._embedding_model or not self._embeddings:
181
+ return []
182
+
183
+ try:
184
+ # Encode query
185
+ query_embedding = self._embedding_model.encode(
186
+ [query],
187
+ normalize_embeddings=True,
188
+ )[0]
189
+
190
+ # Compute cosine similarity
191
+ import numpy as np
192
+ embeddings_matrix = np.array(self._embeddings)
193
+ similarities = embeddings_matrix @ query_embedding
194
+
195
+ # Get top results
196
+ top_indices = np.argsort(similarities)[::-1][:max_results]
197
+
198
+ results = []
199
+ for idx in top_indices:
200
+ if similarities[idx] < 0.3: # Minimum similarity threshold
201
+ break
202
+ asset = self._assets[idx]
203
+ results.append({
204
+ "symbol": asset["symbol"],
205
+ "name": asset["name"],
206
+ "score": float(similarities[idx]),
207
+ })
208
+
209
+ return results
210
+ except Exception as exc:
211
+ logger.warning("Semantic search failed: %s", exc)
212
+ return []
213
+
214
+ def load_embedding_model(self, model_name: str = "all-MiniLM-L6-v2"):
215
+ """Load a sentence transformer model for semantic search.
216
+
217
+ This is optional and will only be used if successfully loaded.
218
+ Falls back to text-based matching if unavailable.
219
+
220
+ Args:
221
+ model_name: Name of the sentence-transformers model to use.
222
+ Default is 'all-MiniLM-L6-v2' (80MB, fast, good quality).
223
+ """
224
+ try:
225
+ from sentence_transformers import SentenceTransformer
226
+
227
+ logger.info("Loading embedding model '%s'...", model_name)
228
+ self._embedding_model = SentenceTransformer(model_name)
229
+
230
+ # Precompute embeddings for all assets
231
+ texts = [
232
+ f"{asset['symbol']} {asset['name']}"
233
+ for asset in self._assets
234
+ ]
235
+ embeddings = self._embedding_model.encode(
236
+ texts,
237
+ normalize_embeddings=True,
238
+ show_progress_bar=False,
239
+ )
240
+ self._embeddings = embeddings.tolist()
241
+ logger.info(
242
+ "Loaded embedding model: %d assets embedded",
243
+ len(self._embeddings),
244
+ )
245
+ except ImportError:
246
+ logger.info(
247
+ "sentence-transformers not installed. "
248
+ "Install with: uv add sentence-transformers (optional)"
249
+ )
250
+ except Exception as exc:
251
+ logger.warning("Failed to load embedding model: %s", exc)
252
+
253
+ @property
254
+ def is_ready(self) -> bool:
255
+ """Whether the search engine has assets loaded."""
256
+ return self._initialized
257
+
258
+ @property
259
+ def has_semantic_search(self) -> bool:
260
+ """Whether semantic search is available."""
261
+ return self._embedding_model is not None and self._embeddings is not None
trading_cli/data/db.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SQLite database layer — schema, queries, and connection management."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import sqlite3
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+
12
+ def get_connection(db_path: Path) -> sqlite3.Connection:
13
+ conn = sqlite3.connect(str(db_path), check_same_thread=False)
14
+ conn.row_factory = sqlite3.Row
15
+ conn.execute("PRAGMA journal_mode=WAL")
16
+ return conn
17
+
18
+
19
+ def init_db(db_path: Path) -> sqlite3.Connection:
20
+ """Create all tables and return an open connection."""
21
+ conn = get_connection(db_path)
22
+ conn.executescript("""
23
+ CREATE TABLE IF NOT EXISTS trades (
24
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
25
+ timestamp TEXT NOT NULL,
26
+ symbol TEXT NOT NULL,
27
+ action TEXT NOT NULL,
28
+ price REAL NOT NULL,
29
+ quantity INTEGER NOT NULL,
30
+ order_id TEXT,
31
+ reason TEXT,
32
+ pnl REAL,
33
+ portfolio_value REAL
34
+ );
35
+
36
+ CREATE TABLE IF NOT EXISTS signals (
37
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
38
+ timestamp TEXT NOT NULL,
39
+ symbol TEXT NOT NULL,
40
+ action TEXT NOT NULL,
41
+ confidence REAL,
42
+ technical_score REAL,
43
+ sentiment_score REAL,
44
+ reason TEXT,
45
+ executed INTEGER DEFAULT 0
46
+ );
47
+
48
+ CREATE TABLE IF NOT EXISTS watchlist (
49
+ symbol TEXT PRIMARY KEY,
50
+ added_at TEXT NOT NULL
51
+ );
52
+
53
+ CREATE TABLE IF NOT EXISTS sentiment_cache (
54
+ headline_hash TEXT PRIMARY KEY,
55
+ headline TEXT NOT NULL,
56
+ label TEXT NOT NULL,
57
+ score REAL NOT NULL,
58
+ cached_at TEXT NOT NULL
59
+ );
60
+
61
+ CREATE TABLE IF NOT EXISTS price_history (
62
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
63
+ symbol TEXT NOT NULL,
64
+ timestamp TEXT NOT NULL,
65
+ open REAL,
66
+ high REAL,
67
+ low REAL,
68
+ close REAL,
69
+ volume INTEGER,
70
+ UNIQUE(symbol, timestamp)
71
+ );
72
+
73
+ CREATE TABLE IF NOT EXISTS config (
74
+ key TEXT PRIMARY KEY,
75
+ value TEXT NOT NULL
76
+ );
77
+ """)
78
+ conn.commit()
79
+ return conn
80
+
81
+
82
+ # ── Trades ─────────────────────────────────────────────────────────────────────
83
+
84
+ def save_trade(
85
+ conn: sqlite3.Connection,
86
+ symbol: str,
87
+ action: str,
88
+ price: float,
89
+ quantity: int,
90
+ order_id: str | None = None,
91
+ reason: str | None = None,
92
+ pnl: float | None = None,
93
+ portfolio_value: float | None = None,
94
+ ) -> int:
95
+ ts = datetime.utcnow().isoformat()
96
+ cur = conn.execute(
97
+ """INSERT INTO trades
98
+ (timestamp, symbol, action, price, quantity, order_id, reason, pnl, portfolio_value)
99
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
100
+ (ts, symbol, action, price, quantity, order_id, reason, pnl, portfolio_value),
101
+ )
102
+ conn.commit()
103
+ return cur.lastrowid
104
+
105
+
106
+ def get_trade_history(
107
+ conn: sqlite3.Connection,
108
+ symbol: str | None = None,
109
+ action: str | None = None,
110
+ limit: int = 100,
111
+ ) -> list[dict]:
112
+ q = "SELECT * FROM trades"
113
+ params: list[Any] = []
114
+ clauses = []
115
+ if symbol:
116
+ clauses.append("symbol = ?")
117
+ params.append(symbol.upper())
118
+ if action:
119
+ clauses.append("action = ?")
120
+ params.append(action.upper())
121
+ if clauses:
122
+ q += " WHERE " + " AND ".join(clauses)
123
+ q += " ORDER BY timestamp DESC LIMIT ?"
124
+ params.append(limit)
125
+ return [dict(r) for r in conn.execute(q, params).fetchall()]
126
+
127
+
128
+ # ── Signals ────────────────────────────────────────────────────────────────────
129
+
130
+ def save_signal(
131
+ conn: sqlite3.Connection,
132
+ symbol: str,
133
+ action: str,
134
+ confidence: float,
135
+ technical_score: float,
136
+ sentiment_score: float,
137
+ reason: str,
138
+ executed: bool = False,
139
+ ) -> int:
140
+ ts = datetime.utcnow().isoformat()
141
+ cur = conn.execute(
142
+ """INSERT INTO signals
143
+ (timestamp, symbol, action, confidence, technical_score, sentiment_score, reason, executed)
144
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
145
+ (ts, symbol, action, confidence, technical_score, sentiment_score, reason, int(executed)),
146
+ )
147
+ conn.commit()
148
+ return cur.lastrowid
149
+
150
+
151
+ def get_recent_signals(conn: sqlite3.Connection, limit: int = 20) -> list[dict]:
152
+ return [
153
+ dict(r)
154
+ for r in conn.execute(
155
+ "SELECT * FROM signals ORDER BY timestamp DESC LIMIT ?", (limit,)
156
+ ).fetchall()
157
+ ]
158
+
159
+
160
+ # ── Watchlist ──────────────────────────────────────────────────────────────────
161
+
162
+ def get_watchlist(conn: sqlite3.Connection) -> list[str]:
163
+ return [r["symbol"] for r in conn.execute("SELECT symbol FROM watchlist ORDER BY symbol").fetchall()]
164
+
165
+
166
+ def add_to_watchlist(conn: sqlite3.Connection, symbol: str) -> None:
167
+ conn.execute(
168
+ "INSERT OR IGNORE INTO watchlist (symbol, added_at) VALUES (?, ?)",
169
+ (symbol.upper(), datetime.utcnow().isoformat()),
170
+ )
171
+ conn.commit()
172
+
173
+
174
+ def remove_from_watchlist(conn: sqlite3.Connection, symbol: str) -> None:
175
+ conn.execute("DELETE FROM watchlist WHERE symbol = ?", (symbol.upper(),))
176
+ conn.commit()
177
+
178
+
179
+ # ── Sentiment cache ────────────────────────────────────────────────────────────
180
+
181
+ def headline_hash(text: str) -> str:
182
+ return hashlib.md5(text.encode()).hexdigest()
183
+
184
+
185
+ def get_cached_sentiment(conn: sqlite3.Connection, text: str) -> dict | None:
186
+ h = headline_hash(text)
187
+ row = conn.execute(
188
+ "SELECT label, score FROM sentiment_cache WHERE headline_hash = ?", (h,)
189
+ ).fetchone()
190
+ return dict(row) if row else None
191
+
192
+
193
+ def cache_sentiment(conn: sqlite3.Connection, text: str, label: str, score: float) -> None:
194
+ h = headline_hash(text)
195
+ conn.execute(
196
+ """INSERT OR REPLACE INTO sentiment_cache
197
+ (headline_hash, headline, label, score, cached_at)
198
+ VALUES (?, ?, ?, ?, ?)""",
199
+ (h, text[:500], label, score, datetime.utcnow().isoformat()),
200
+ )
201
+ conn.commit()
202
+
203
+
204
+ # ── Price history ──────────────────────────────────────────────────────────────
205
+
206
+ def upsert_price_bar(
207
+ conn: sqlite3.Connection,
208
+ symbol: str,
209
+ timestamp: str,
210
+ open_: float,
211
+ high: float,
212
+ low: float,
213
+ close: float,
214
+ volume: int,
215
+ ) -> None:
216
+ conn.execute(
217
+ """INSERT OR REPLACE INTO price_history
218
+ (symbol, timestamp, open, high, low, close, volume)
219
+ VALUES (?, ?, ?, ?, ?, ?, ?)""",
220
+ (symbol, timestamp, open_, high, low, close, volume),
221
+ )
222
+ conn.commit()
223
+
224
+
225
+ def get_price_history(
226
+ conn: sqlite3.Connection, symbol: str, limit: int = 200
227
+ ) -> list[dict]:
228
+ return [
229
+ dict(r)
230
+ for r in conn.execute(
231
+ "SELECT * FROM price_history WHERE symbol = ? ORDER BY timestamp DESC LIMIT ?",
232
+ (symbol.upper(), limit),
233
+ ).fetchall()
234
+ ]
trading_cli/data/market.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Market data fetching — Alpaca historical bars with yfinance fallback."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import time
7
+ from datetime import datetime, timedelta, timezone
8
+ from typing import TYPE_CHECKING
9
+
10
+ import pandas as pd
11
+
12
+ if TYPE_CHECKING:
13
+ from trading_cli.execution.alpaca_client import AlpacaClient
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def fetch_ohlcv_alpaca(
19
+ client: "AlpacaClient",
20
+ symbol: str,
21
+ days: int = 90,
22
+ ) -> pd.DataFrame:
23
+ """Fetch OHLCV bars from Alpaca historical data API."""
24
+ try:
25
+ from alpaca.data.requests import StockBarsRequest
26
+ from alpaca.data.timeframe import TimeFrame
27
+
28
+ end = datetime.now(tz=timezone.utc)
29
+ start = end - timedelta(days=days + 10) # extra buffer for weekends
30
+
31
+ request = StockBarsRequest(
32
+ symbol_or_symbols=symbol,
33
+ timeframe=TimeFrame.Day,
34
+ start=start,
35
+ end=end,
36
+ feed="iex",
37
+ )
38
+ bars = client.historical_client.get_stock_bars(request)
39
+ df = bars.df
40
+ if isinstance(df.index, pd.MultiIndex):
41
+ df = df.xs(symbol, level=0) if symbol in df.index.get_level_values(0) else df
42
+ df.index = pd.to_datetime(df.index, utc=True)
43
+ df = df.rename(columns={"open": "Open", "high": "High", "low": "Low",
44
+ "close": "Close", "volume": "Volume"})
45
+ return df.tail(days)
46
+ except Exception as exc:
47
+ logger.warning("Alpaca OHLCV fetch failed for %s: %s — falling back to yfinance", symbol, exc)
48
+ return fetch_ohlcv_yfinance(symbol, days)
49
+
50
+
51
+ def fetch_ohlcv_yfinance(symbol: str, days: int = 90) -> pd.DataFrame:
52
+ """Fetch OHLCV bars from yfinance. Period can be long for daily interval."""
53
+ try:
54
+ import yfinance as yf
55
+ # No more 730d cap for 1d data; yfinance handles 10y+ easily for daily.
56
+ period = f"{days}d"
57
+ df = yf.download(symbol, period=period, interval="1d", progress=False, auto_adjust=True)
58
+ if df.empty:
59
+ return pd.DataFrame()
60
+
61
+ # Flatten MultiIndex columns if present (common in newer yfinance versions)
62
+ if isinstance(df.columns, pd.MultiIndex):
63
+ df.columns = df.columns.get_level_values(0)
64
+
65
+ return df.tail(days)
66
+ except Exception as exc:
67
+ logger.error("yfinance fetch failed for %s: %s", symbol, exc)
68
+ return pd.DataFrame()
69
+
70
+
71
+ def get_latest_quote_alpaca(client: "AlpacaClient", symbol: str) -> float | None:
72
+ """Get latest trade price from Alpaca."""
73
+ try:
74
+ from alpaca.data.requests import StockLatestTradeRequest
75
+
76
+ req = StockLatestTradeRequest(symbol_or_symbols=symbol, feed="iex")
77
+ trades = client.historical_client.get_stock_latest_trade(req)
78
+ return float(trades[symbol].price)
79
+ except Exception as exc:
80
+ logger.warning("Alpaca latest quote failed for %s: %s", symbol, exc)
81
+ return None
82
+
83
+
84
+ def get_latest_quote_yfinance(symbol: str) -> float | None:
85
+ """Get latest price from yfinance (free tier fallback)."""
86
+ try:
87
+ import yfinance as yf
88
+ ticker = yf.Ticker(symbol)
89
+ info = ticker.fast_info
90
+ price = getattr(info, "last_price", None) or getattr(info, "regularMarketPrice", None)
91
+ if price:
92
+ return float(price)
93
+ hist = ticker.history(period="2d", interval="1d")
94
+ if not hist.empty:
95
+ return float(hist["Close"].iloc[-1])
96
+ return None
97
+ except Exception as exc:
98
+ logger.warning("yfinance latest quote failed for %s: %s", symbol, exc)
99
+ return None
100
+
101
+
102
+ def get_latest_quotes_batch(
103
+ client: "AlpacaClient | None",
104
+ symbols: list[str],
105
+ ) -> dict[str, float]:
106
+ """Return {symbol: price} dict for multiple symbols."""
107
+ prices: dict[str, float] = {}
108
+ if client and not client.demo_mode:
109
+ try:
110
+ from alpaca.data.requests import StockLatestTradeRequest
111
+
112
+ req = StockLatestTradeRequest(symbol_or_symbols=symbols, feed="iex")
113
+ trades = client.historical_client.get_stock_latest_trade(req)
114
+ for sym, trade in trades.items():
115
+ prices[sym] = float(trade.price)
116
+ return prices
117
+ except Exception as exc:
118
+ logger.warning("Batch Alpaca quote failed: %s — falling back", exc)
119
+
120
+ # yfinance fallback
121
+ for sym in symbols:
122
+ price = get_latest_quote_yfinance(sym)
123
+ if price:
124
+ prices[sym] = price
125
+ time.sleep(0.2) # avoid hammering
126
+ return prices
trading_cli/data/news.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """News headline fetching — Alpaca News API (historical) with yfinance fallback."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from datetime import datetime, timedelta, timezone
7
+
8
+ import pandas as pd
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ # ── Alpaca News API (historical, date-aware) ───────────────────────────────────
14
+
15
+ def fetch_headlines_alpaca(
16
+ api_key: str,
17
+ api_secret: str,
18
+ symbol: str,
19
+ start: datetime | None = None,
20
+ end: datetime | None = None,
21
+ max_articles: int = 50,
22
+ ) -> list[tuple[str, float]]:
23
+ """Fetch headlines via Alpaca News API with optional date range.
24
+
25
+ Returns list of (headline: str, unix_timestamp: float) tuples.
26
+ Supports historical backtesting by specifying start/end dates.
27
+ """
28
+ if not api_key or not api_secret:
29
+ return []
30
+ try:
31
+ from alpaca.data.historical.news import NewsClient
32
+ from alpaca.data.requests import NewsRequest
33
+
34
+ client = NewsClient(api_key=api_key, secret_key=api_secret)
35
+
36
+ now = datetime.now(tz=timezone.utc)
37
+ if end is None:
38
+ end = now
39
+ if start is None:
40
+ start = end - timedelta(days=7)
41
+
42
+ request = NewsRequest(
43
+ symbols=symbol,
44
+ start=start,
45
+ end=end,
46
+ limit=min(max_articles, 100), # Alpaca max is 100 per page
47
+ )
48
+ response = client.get_news(request)
49
+ items = getattr(response, "news", response) if response else []
50
+
51
+ headlines: list[tuple[str, float]] = []
52
+ for item in items:
53
+ title = getattr(item, "headline", "") or getattr(item, "title", "")
54
+ if not title:
55
+ continue
56
+ created = getattr(item, "created_at", None) or getattr(item, "updated_at", None)
57
+ if created:
58
+ if isinstance(created, str):
59
+ ts = pd.Timestamp(created).timestamp()
60
+ elif isinstance(created, (int, float)):
61
+ ts = float(created)
62
+ else:
63
+ ts = pd.Timestamp(created).timestamp()
64
+ else:
65
+ ts = now.timestamp()
66
+ headlines.append((title, float(ts)))
67
+
68
+ logger.debug("Alpaca News: got %d headlines for %s (%s to %s)",
69
+ len(headlines), symbol, start, end)
70
+ return headlines
71
+ except Exception as exc:
72
+ logger.warning("Alpaca News fetch failed for %s: %s", symbol, exc)
73
+ return []
74
+
75
+
76
+ def fetch_headlines_yfinance(symbol: str, max_articles: int = 20) -> list[str]:
77
+ """Fetch headlines from yfinance built-in news feed."""
78
+ try:
79
+ import yfinance as yf
80
+
81
+ ticker = yf.Ticker(symbol)
82
+ news = ticker.news or []
83
+ headlines = []
84
+ for item in news[:max_articles]:
85
+ title = item.get("title") or (item.get("content", {}) or {}).get("title", "")
86
+ if title:
87
+ headlines.append(title)
88
+ logger.debug("yfinance news: got %d headlines for %s", len(headlines), symbol)
89
+ return headlines
90
+ except Exception as exc:
91
+ logger.warning("yfinance news failed for %s: %s", symbol, exc)
92
+ return []
93
+
94
+
95
+ # ── Unified fetcher ───────────────────────────────────────────────────────────
96
+
97
+ def fetch_headlines(
98
+ symbol: str,
99
+ max_articles: int = 20,
100
+ ) -> list[str]:
101
+ """Fetch headlines, using yfinance (Alpaca news returns tuples, not plain strings)."""
102
+ return fetch_headlines_yfinance(symbol, max_articles)
103
+
104
+
105
+ def fetch_headlines_with_timestamps(
106
+ symbol: str,
107
+ days_ago: int = 0,
108
+ alpaca_key: str = "",
109
+ alpaca_secret: str = "",
110
+ max_articles: int = 50,
111
+ ) -> list[tuple[str, float]]:
112
+ """Fetch headlines with Unix timestamps for temporal weighting.
113
+
114
+ For backtesting: pass days_ago > 0 to get news from a specific historical date.
115
+ Returns list of (headline: str, unix_timestamp: float) tuples.
116
+
117
+ Priority: Alpaca (supports historical dates) > yfinance.
118
+ """
119
+ now = datetime.now(tz=timezone.utc)
120
+ target_date = now - timedelta(days=days_ago)
121
+
122
+ # Try Alpaca first (only supports historical if API keys are set)
123
+ if alpaca_key and alpaca_secret:
124
+ # Alpaca can fetch news for any historical date in range
125
+ day_start = target_date.replace(hour=0, minute=0, second=0, microsecond=0)
126
+ day_end = day_start.replace(hour=23, minute=59, second=59)
127
+ headlines = fetch_headlines_alpaca(alpaca_key, alpaca_secret, symbol,
128
+ start=day_start, end=day_end,
129
+ max_articles=max_articles)
130
+ if headlines:
131
+ return headlines
132
+
133
+ # yfinance fallback (no timestamp info, approximate)
134
+ headlines = fetch_headlines_yfinance(symbol, max_articles)
135
+ now_ts = now.timestamp()
136
+ return [(h, now_ts - (i * 3600)) for i, h in enumerate(headlines)]
trading_cli/execution/adapter_factory.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapter factory — creates the appropriate adapter from config."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import TYPE_CHECKING
7
+
8
+ from trading_cli.execution.adapters import (
9
+ TradingAdapter,
10
+ create_adapter,
11
+ list_adapters,
12
+ )
13
+
14
+ if TYPE_CHECKING:
15
+ pass
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def create_trading_adapter(config: dict) -> TradingAdapter:
21
+ """Create a trading adapter based on config.
22
+
23
+ Priority:
24
+ 1. If Alpaca keys are set → AlpacaAdapter
25
+ 2. Otherwise → YFinanceAdapter (demo mode)
26
+
27
+ You can override by setting `adapter_id` in config to:
28
+ - 'alpaca': Force Alpaca (will fallback to demo if no keys)
29
+ - 'yfinance': Force yFinance demo
30
+ - 'binance': Binance crypto (requires ccxt)
31
+ - 'kraken': Kraken crypto (requires ccxt)
32
+ """
33
+ adapter_id = config.get("adapter_id", None)
34
+
35
+ if adapter_id is None:
36
+ # Auto-detect based on available keys
37
+ if config.get("alpaca_api_key") and config.get("alpaca_api_secret"):
38
+ adapter_id = "alpaca"
39
+ else:
40
+ adapter_id = "yfinance"
41
+
42
+ try:
43
+ adapter = create_adapter(adapter_id, config)
44
+ logger.info("Created adapter: %s (demo=%s)", adapter.adapter_id, adapter.is_demo_mode)
45
+ return adapter
46
+ except ValueError as exc:
47
+ logger.error("Failed to create adapter '%s': %s", adapter_id, exc)
48
+ logger.info("Available adapters: %s", list_adapters())
49
+ # Fallback to yfinance demo
50
+ return create_adapter("yfinance", config)
trading_cli/execution/adapters/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trading platform adapters — unified interface for different exchanges."""
2
+
3
+ from trading_cli.execution.adapters.base import (
4
+ AccountInfo,
5
+ MarketClock,
6
+ OrderResult,
7
+ Position,
8
+ TradingAdapter,
9
+ )
10
+ from trading_cli.execution.adapters.registry import (
11
+ create_adapter,
12
+ get_adapter,
13
+ list_adapters,
14
+ register_adapter,
15
+ )
16
+
17
+ # Import all adapter implementations to trigger registration
18
+ from trading_cli.execution.adapters.alpaca import AlpacaAdapter
19
+ from trading_cli.execution.adapters.yfinance import YFinanceAdapter
20
+ from trading_cli.execution.adapters.binance import BinanceAdapter
21
+ from trading_cli.execution.adapters.kraken import KrakenAdapter
22
+
23
+ __all__ = [
24
+ "TradingAdapter",
25
+ "AccountInfo",
26
+ "MarketClock",
27
+ "OrderResult",
28
+ "Position",
29
+ "create_adapter",
30
+ "get_adapter",
31
+ "list_adapters",
32
+ "register_adapter",
33
+ "AlpacaAdapter",
34
+ "YFinanceAdapter",
35
+ "BinanceAdapter",
36
+ "KrakenAdapter",
37
+ ]
trading_cli/execution/adapters/alpaca.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Alpaca adapter — real Alpaca API for stocks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from datetime import datetime, timedelta, timezone
7
+
8
+ import pandas as pd
9
+
10
+ from trading_cli.execution.adapters.base import (
11
+ AccountInfo,
12
+ MarketClock,
13
+ OrderResult,
14
+ Position,
15
+ TradingAdapter,
16
+ )
17
+ from trading_cli.execution.adapters.registry import register_adapter
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @register_adapter
23
+ class AlpacaAdapter(TradingAdapter):
24
+ """Alpaca Markets adapter for US equities (paper & live trading)."""
25
+
26
+ def __init__(self, config: dict) -> None:
27
+ self._config = config
28
+ self._api_key = config.get("alpaca_api_key", "")
29
+ self._api_secret = config.get("alpaca_api_secret", "")
30
+ self._paper = config.get("alpaca_paper", True)
31
+ self._demo = not (self._api_key and self._api_secret)
32
+
33
+ if self._demo:
34
+ logger.info("AlpacaAdapter: no API keys found, running in demo mode")
35
+ return
36
+
37
+ try:
38
+ from alpaca.trading.client import TradingClient
39
+ from alpaca.data.historical import StockHistoricalDataClient
40
+ from alpaca.data.historical.news import NewsClient
41
+
42
+ self._trading_client = TradingClient(
43
+ api_key=self._api_key,
44
+ secret_key=self._api_secret,
45
+ paper=self._paper,
46
+ )
47
+ self._historical_client = StockHistoricalDataClient(
48
+ api_key=self._api_key,
49
+ secret_key=self._api_secret,
50
+ )
51
+ self._news_client = NewsClient(
52
+ api_key=self._api_key,
53
+ secret_key=self._api_secret,
54
+ )
55
+ logger.info("AlpacaAdapter connected (paper=%s)", self._paper)
56
+ except ImportError as exc:
57
+ raise RuntimeError("alpaca-py not installed. Run: uv add alpaca-py") from exc
58
+ except Exception as exc:
59
+ logger.error("Failed to connect to Alpaca: %s", exc)
60
+ self._demo = True
61
+
62
+ @property
63
+ def adapter_id(self) -> str:
64
+ return "alpaca"
65
+
66
+ @property
67
+ def supports_paper_trading(self) -> bool:
68
+ return True
69
+
70
+ @property
71
+ def is_demo_mode(self) -> bool:
72
+ return self._demo
73
+
74
+ # ── Account & Positions ───────────────────────────────────────────────────
75
+
76
+ def get_account(self) -> AccountInfo:
77
+ if self._demo:
78
+ return AccountInfo(
79
+ equity=100000.0,
80
+ cash=100000.0,
81
+ buying_power=400000.0,
82
+ portfolio_value=100000.0,
83
+ )
84
+ acct = self._trading_client.get_account()
85
+ return AccountInfo(
86
+ equity=float(acct.equity),
87
+ cash=float(acct.cash),
88
+ buying_power=float(acct.buying_power),
89
+ portfolio_value=float(acct.portfolio_value),
90
+ )
91
+
92
+ def get_positions(self) -> list[Position]:
93
+ if self._demo:
94
+ return []
95
+ raw = self._trading_client.get_all_positions()
96
+ out = []
97
+ for p in raw:
98
+ out.append(
99
+ Position(
100
+ symbol=p.symbol,
101
+ qty=float(p.qty),
102
+ avg_entry_price=float(p.avg_entry_price),
103
+ current_price=float(p.current_price),
104
+ unrealized_pl=float(p.unrealized_pl),
105
+ unrealized_plpc=float(p.unrealized_plpc),
106
+ market_value=float(p.market_value),
107
+ side=str(p.side),
108
+ )
109
+ )
110
+ return out
111
+
112
+ # ── Orders ───────────────────────────────────────────────────────────────
113
+
114
+ def submit_market_order(self, symbol: str, qty: int, side: str) -> OrderResult:
115
+ if self._demo:
116
+ return OrderResult(
117
+ order_id=f"DEMO-{datetime.now().timestamp()}",
118
+ symbol=symbol,
119
+ action=side,
120
+ qty=qty,
121
+ status="filled",
122
+ filled_price=100.0, # Mock price
123
+ )
124
+
125
+ from alpaca.trading.requests import MarketOrderRequest
126
+ from alpaca.trading.enums import OrderSide, TimeInForce
127
+
128
+ order_side = OrderSide.BUY if side.upper() == "BUY" else OrderSide.SELL
129
+ req = MarketOrderRequest(
130
+ symbol=symbol,
131
+ qty=qty,
132
+ side=order_side,
133
+ time_in_force=TimeInForce.DAY,
134
+ )
135
+ try:
136
+ order = self._trading_client.submit_order(order_data=req)
137
+ filled_price = float(order.filled_avg_price) if order.filled_avg_price else None
138
+ return OrderResult(
139
+ order_id=str(order.id),
140
+ symbol=symbol,
141
+ action=side,
142
+ qty=qty,
143
+ status=str(order.status),
144
+ filled_price=filled_price,
145
+ )
146
+ except Exception as exc:
147
+ logger.error("Order submission failed for %s %s %d: %s", side, symbol, qty, exc)
148
+ raise
149
+
150
+ def close_position(self, symbol: str) -> OrderResult | None:
151
+ if self._demo:
152
+ return None
153
+ try:
154
+ response = self._trading_client.close_position(symbol)
155
+ return OrderResult(
156
+ order_id=str(response.id),
157
+ symbol=symbol,
158
+ action="SELL",
159
+ qty=int(float(response.qty or 0)),
160
+ status=str(response.status),
161
+ )
162
+ except Exception as exc:
163
+ logger.error("Close position failed for %s: %s", symbol, exc)
164
+ return None
165
+
166
+ # ── Market Data ───────────────────────────────────────────────────────────
167
+
168
+ def fetch_ohlcv(self, symbol: str, days: int = 90) -> pd.DataFrame:
169
+ if self._demo:
170
+ # Fallback to yfinance in demo mode
171
+ from trading_cli.data.market import fetch_ohlcv_yfinance
172
+ return fetch_ohlcv_yfinance(symbol, days)
173
+
174
+ try:
175
+ from alpaca.data.requests import StockBarsRequest
176
+ from alpaca.data.timeframe import TimeFrame
177
+
178
+ end = datetime.now(tz=timezone.utc)
179
+ start = end - timedelta(days=days + 10) # extra buffer for weekends
180
+
181
+ request = StockBarsRequest(
182
+ symbol_or_symbols=symbol,
183
+ timeframe=TimeFrame.Day,
184
+ start=start,
185
+ end=end,
186
+ feed="iex",
187
+ )
188
+ bars = self._historical_client.get_stock_bars(request)
189
+ df = bars.df
190
+ if isinstance(df.index, pd.MultiIndex):
191
+ df = df.xs(symbol, level=0) if symbol in df.index.get_level_values(0) else df
192
+ df.index = pd.to_datetime(df.index, utc=True)
193
+ df = df.rename(columns={"open": "Open", "high": "High", "low": "Low",
194
+ "close": "Close", "volume": "Volume"})
195
+ return df.tail(days)
196
+ except Exception as exc:
197
+ logger.warning("Alpaca OHLCV fetch failed for %s: %s — falling back to yfinance", symbol, exc)
198
+ from trading_cli.data.market import fetch_ohlcv_yfinance
199
+ return fetch_ohlcv_yfinance(symbol, days)
200
+
201
+ def get_latest_quote(self, symbol: str) -> float | None:
202
+ if self._demo:
203
+ return None
204
+ try:
205
+ from alpaca.data.requests import StockLatestTradeRequest
206
+
207
+ req = StockLatestTradeRequest(symbol_or_symbols=symbol, feed="iex")
208
+ trades = self._historical_client.get_stock_latest_trade(req)
209
+ return float(trades[symbol].price)
210
+ except Exception as exc:
211
+ logger.warning("Alpaca latest quote failed for %s: %s", symbol, exc)
212
+ return None
213
+
214
+ def get_latest_quotes_batch(self, symbols: list[str]) -> dict[str, float]:
215
+ if self._demo:
216
+ return {}
217
+ try:
218
+ from alpaca.data.requests import StockLatestTradeRequest
219
+
220
+ req = StockLatestTradeRequest(symbol_or_symbols=symbols, feed="iex")
221
+ trades = self._historical_client.get_stock_latest_trade(req)
222
+ return {sym: float(trade.price) for sym, trade in trades.items()}
223
+ except Exception as exc:
224
+ logger.warning("Batch Alpaca quote failed: %s", exc)
225
+ return {}
226
+
227
+ # ── Market Info ───────────────────────────────────────────────────────────
228
+
229
+ def get_market_clock(self) -> MarketClock:
230
+ if self._demo:
231
+ now = datetime.now(tz=timezone.utc)
232
+ hour_et = (now.hour - 5) % 24
233
+ is_open = now.weekday() < 5 and 9 <= hour_et < 16
234
+ return MarketClock(
235
+ is_open=is_open,
236
+ next_open="09:30 ET",
237
+ next_close="16:00 ET",
238
+ )
239
+ try:
240
+ clock = self._trading_client.get_clock()
241
+ return MarketClock(
242
+ is_open=clock.is_open,
243
+ next_open=str(clock.next_open),
244
+ next_close=str(clock.next_close),
245
+ )
246
+ except Exception as exc:
247
+ logger.warning("get_market_clock failed: %s", exc)
248
+ return MarketClock(is_open=False, next_open="Unknown", next_close="Unknown")
249
+
250
+ # ── News ──────────────────────────────────────────────────────────────────
251
+
252
+ def fetch_news(self, symbol: str, max_articles: int = 50,
253
+ days_ago: int = 0) -> list[tuple[str, float]]:
254
+ if self._demo or not hasattr(self, '_news_client') or self._news_client is None:
255
+ return []
256
+
257
+ try:
258
+ from alpaca.data.requests import NewsRequest
259
+
260
+ now = datetime.now(tz=timezone.utc)
261
+ target_date = now - timedelta(days=days_ago)
262
+ day_start = target_date.replace(hour=0, minute=0, second=0, microsecond=0)
263
+ day_end = target_date.replace(hour=23, minute=59, second=59)
264
+
265
+ request = NewsRequest(
266
+ symbols=symbol,
267
+ start=day_start,
268
+ end=day_end,
269
+ limit=min(max_articles, 100),
270
+ )
271
+ response = self._news_client.get_news(request)
272
+ items = getattr(response, "news", response) if response else []
273
+
274
+ headlines: list[tuple[str, float]] = []
275
+ for item in items:
276
+ title = getattr(item, "headline", "") or getattr(item, "title", "")
277
+ if not title:
278
+ continue
279
+ created = getattr(item, "created_at", None) or getattr(item, "updated_at", None)
280
+ if created:
281
+ import pandas as pd
282
+ ts = pd.Timestamp(created).timestamp() if isinstance(created, str) else float(created)
283
+ else:
284
+ ts = now.timestamp()
285
+ headlines.append((title, float(ts)))
286
+
287
+ return headlines
288
+ except Exception as exc:
289
+ logger.warning("Alpaca news fetch failed for %s: %s", symbol, exc)
290
+ return []
291
+
292
+ # ── Asset Search ──────────────────────────────────────────────────────────
293
+
294
+ def get_all_assets(self) -> list[dict[str, str]]:
295
+ """Fetch all available assets with their symbols and company names.
296
+
297
+ Returns:
298
+ List of dicts with 'symbol' and 'name' keys.
299
+ """
300
+ if self._demo:
301
+ # Return a basic hardcoded list for demo mode
302
+ return [
303
+ {"symbol": "AAPL", "name": "Apple Inc."},
304
+ {"symbol": "TSLA", "name": "Tesla Inc."},
305
+ {"symbol": "NVDA", "name": "NVIDIA Corporation"},
306
+ {"symbol": "MSFT", "name": "Microsoft Corporation"},
307
+ {"symbol": "AMZN", "name": "Amazon.com Inc."},
308
+ {"symbol": "GOOGL", "name": "Alphabet Inc. Class A"},
309
+ {"symbol": "META", "name": "Meta Platforms Inc."},
310
+ {"symbol": "SPY", "name": "SPDR S&P 500 ETF Trust"},
311
+ ]
312
+
313
+ try:
314
+ from alpaca.trading.requests import GetAssetsRequest
315
+ from alpaca.trading.enums import AssetStatus, AssetClass
316
+
317
+ # Get all active US equity assets
318
+ request = GetAssetsRequest(
319
+ status=AssetStatus.ACTIVE,
320
+ asset_class=AssetClass.US_EQUITY,
321
+ )
322
+ assets = self._trading_client.get_all_assets(request)
323
+
324
+ return [
325
+ {"symbol": asset.symbol, "name": asset.name}
326
+ for asset in assets
327
+ if asset.tradable
328
+ ]
329
+ except Exception as exc:
330
+ logger.warning("Failed to fetch assets: %s", exc)
331
+ return []
trading_cli/execution/adapters/base.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base adapter interface — all exchange adapters must implement this."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass, field
8
+ from typing import Any
9
+
10
+ import pandas as pd
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class Position:
17
+ """Unified position object across all exchanges."""
18
+
19
+ symbol: str
20
+ qty: float
21
+ avg_entry_price: float
22
+ current_price: float
23
+ unrealized_pl: float
24
+ unrealized_plpc: float
25
+ market_value: float
26
+ side: str = "long"
27
+
28
+
29
+ @dataclass
30
+ class AccountInfo:
31
+ """Unified account info across all exchanges."""
32
+
33
+ equity: float
34
+ cash: float
35
+ buying_power: float
36
+ portfolio_value: float
37
+
38
+
39
+ @dataclass
40
+ class OrderResult:
41
+ """Unified order result across all exchanges."""
42
+
43
+ order_id: str
44
+ symbol: str
45
+ action: str # BUY or SELL
46
+ qty: int
47
+ status: str # filled, rejected, pending, etc.
48
+ filled_price: float | None = None
49
+
50
+
51
+ @dataclass
52
+ class MarketClock:
53
+ """Market hours info."""
54
+
55
+ is_open: bool
56
+ next_open: str
57
+ next_close: str
58
+
59
+
60
+ class TradingAdapter(ABC):
61
+ """Abstract base class for all trading platform adapters.
62
+
63
+ Implement this class to add support for new exchanges (Binance, Kraken, etc.).
64
+ Each adapter handles:
65
+ - Account info retrieval
66
+ - Position management
67
+ - Order execution
68
+ - Market data (OHLCV, quotes)
69
+ - Market clock
70
+ """
71
+
72
+ @property
73
+ @abstractmethod
74
+ def adapter_id(self) -> str:
75
+ """Unique identifier for this adapter (e.g., 'alpaca', 'binance', 'kraken')."""
76
+ ...
77
+
78
+ @property
79
+ @abstractmethod
80
+ def supports_paper_trading(self) -> bool:
81
+ """Whether this adapter supports paper/demo trading."""
82
+ ...
83
+
84
+ @property
85
+ @abstractmethod
86
+ def is_demo_mode(self) -> bool:
87
+ """True if running in demo/mock mode (no real API connection)."""
88
+ ...
89
+
90
+ # ── Account & Positions ───────────────────────────────────────────────────
91
+
92
+ @abstractmethod
93
+ def get_account(self) -> AccountInfo:
94
+ """Get account balance and buying power."""
95
+ ...
96
+
97
+ @abstractmethod
98
+ def get_positions(self) -> list[Position]:
99
+ """Get all open positions."""
100
+ ...
101
+
102
+ # ── Orders ────────────────────────────────────────────────────────────────
103
+
104
+ @abstractmethod
105
+ def submit_market_order(self, symbol: str, qty: int, side: str) -> OrderResult:
106
+ """Submit a market order.
107
+
108
+ Args:
109
+ symbol: Trading symbol (e.g., 'AAPL', 'BTC/USD').
110
+ qty: Number of shares/units.
111
+ side: 'BUY' or 'SELL'.
112
+
113
+ Returns:
114
+ OrderResult with status and fill details.
115
+ """
116
+ ...
117
+
118
+ @abstractmethod
119
+ def close_position(self, symbol: str) -> OrderResult | None:
120
+ """Close an existing position at market price.
121
+
122
+ Returns None if no position exists for the symbol.
123
+ """
124
+ ...
125
+
126
+ # ── Market Data ───────────────────────────────────────────────────────────
127
+
128
+ @abstractmethod
129
+ def fetch_ohlcv(self, symbol: str, days: int = 90) -> pd.DataFrame:
130
+ """Fetch historical OHLCV bars.
131
+
132
+ Returns DataFrame with columns: Open, High, Low, Close, Volume.
133
+ Index should be datetime.
134
+ """
135
+ ...
136
+
137
+ @abstractmethod
138
+ def get_latest_quote(self, symbol: str) -> float | None:
139
+ """Get latest trade price for a symbol."""
140
+ ...
141
+
142
+ def get_latest_quotes_batch(self, symbols: list[str]) -> dict[str, float]:
143
+ """Get latest prices for multiple symbols (batch optimized).
144
+
145
+ Override if the exchange supports batch requests.
146
+ Default implementation calls get_latest_quote for each symbol.
147
+ """
148
+ prices: dict[str, float] = {}
149
+ for sym in symbols:
150
+ price = self.get_latest_quote(sym)
151
+ if price is not None:
152
+ prices[sym] = price
153
+ return prices
154
+
155
+ # ── Market Info ───────────────────────────────────────────────────────────
156
+
157
+ @abstractmethod
158
+ def get_market_clock(self) -> MarketClock:
159
+ """Get market open/closed status and next open/close times."""
160
+ ...
161
+
162
+ # ── News (optional) ───────────────────────────────────────────────────────
163
+
164
+ def fetch_news(self, symbol: str, max_articles: int = 50,
165
+ days_ago: int = 0) -> list[tuple[str, float]]:
166
+ """Fetch news headlines with timestamps.
167
+
168
+ Returns list of (headline, unix_timestamp) tuples.
169
+ Override if the exchange provides news data.
170
+ Default returns empty list.
171
+ """
172
+ return []
173
+
174
+ # ── Utilities ─────────────────────────────────────────────────────────────
175
+
176
+ def __repr__(self) -> str:
177
+ return f"<{self.__class__.__name__} adapter_id={self.adapter_id} demo={self.is_demo_mode}>"
trading_cli/execution/adapters/binance.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Binance adapter stub — crypto trading via Binance API.
2
+
3
+ This is a stub implementation. To enable:
4
+ 1. Install ccxt: `uv add ccxt`
5
+ 2. Add your Binance API keys to config
6
+ 3. Implement the TODO sections below
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ from datetime import datetime, timedelta, timezone
13
+
14
+ import pandas as pd
15
+
16
+ from trading_cli.execution.adapters.base import (
17
+ AccountInfo,
18
+ MarketClock,
19
+ OrderResult,
20
+ Position,
21
+ TradingAdapter,
22
+ )
23
+ from trading_cli.execution.adapters.registry import register_adapter
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @register_adapter
29
+ class BinanceAdapter(TradingAdapter):
30
+ """Binance adapter for cryptocurrency trading.
31
+
32
+ Requires: ccxt library (`uv add ccxt`)
33
+ Config keys:
34
+ binance_api_key: Your Binance API key
35
+ binance_api_secret: Your Binance API secret
36
+ binance_sandbox: Use sandbox/testnet (default: False)
37
+ """
38
+
39
+ def __init__(self, config: dict) -> None:
40
+ self._config = config
41
+ self._api_key = config.get("binance_api_key", "")
42
+ self._api_secret = config.get("binance_api_secret", "")
43
+ self._sandbox = config.get("binance_sandbox", False)
44
+ self._demo = not (self._api_key and self._api_secret)
45
+
46
+ if self._demo:
47
+ logger.info("BinanceAdapter: no API keys found, stub mode only")
48
+ return
49
+
50
+ try:
51
+ import ccxt
52
+ self._exchange = ccxt.binance({
53
+ "apiKey": self._api_key,
54
+ "secret": self._api_secret,
55
+ "enableRateLimit": True,
56
+ })
57
+ if self._sandbox:
58
+ self._exchange.set_sandbox_mode(True)
59
+ logger.info("BinanceAdapter connected (sandbox=%s)", self._sandbox)
60
+ except ImportError:
61
+ logger.warning("ccxt not installed. Run: uv add ccxt")
62
+ self._demo = True
63
+ self._exchange = None
64
+ except Exception as exc:
65
+ logger.error("Failed to connect to Binance: %s", exc)
66
+ self._demo = True
67
+ self._exchange = None
68
+
69
+ @property
70
+ def adapter_id(self) -> str:
71
+ return "binance"
72
+
73
+ @property
74
+ def supports_paper_trading(self) -> bool:
75
+ return self._sandbox # Binance testnet
76
+
77
+ @property
78
+ def is_demo_mode(self) -> bool:
79
+ return self._demo
80
+
81
+ # ── Account & Positions ───────────────────────────────────────────────────
82
+
83
+ def get_account(self) -> AccountInfo:
84
+ if self._demo or not self._exchange:
85
+ return AccountInfo(
86
+ equity=100000.0,
87
+ cash=100000.0,
88
+ buying_power=100000.0,
89
+ portfolio_value=100000.0,
90
+ )
91
+ # TODO: Implement real account fetch using self._exchange.fetch_balance()
92
+ balance = self._exchange.fetch_balance()
93
+ # Extract USDT balance as cash equivalent
94
+ cash = float(balance.get("USDT", {}).get("free", 0))
95
+ return AccountInfo(
96
+ equity=cash, # Simplified
97
+ cash=cash,
98
+ buying_power=cash,
99
+ portfolio_value=cash,
100
+ )
101
+
102
+ def get_positions(self) -> list[Position]:
103
+ if self._demo or not self._exchange:
104
+ return []
105
+ # TODO: Implement real position fetch
106
+ # For crypto, positions are balances with non-zero amounts
107
+ positions = []
108
+ balance = self._exchange.fetch_balance()
109
+ for currency, amount_info in balance.items():
110
+ if isinstance(amount_info, dict) and amount_info.get("total", 0) > 0:
111
+ if currency in ("free", "used", "total", "info"):
112
+ continue
113
+ total = amount_info.get("total", 0)
114
+ positions.append(
115
+ Position(
116
+ symbol=f"{currency}/USDT",
117
+ qty=total,
118
+ avg_entry_price=0.0, # TODO: Track entry prices
119
+ current_price=0.0, # TODO: Fetch current price
120
+ unrealized_pl=0.0,
121
+ unrealized_plpc=0.0,
122
+ market_value=0.0,
123
+ side="long",
124
+ )
125
+ )
126
+ return positions
127
+
128
+ # ── Orders ───────────────────────────────────────────────────────────────
129
+
130
+ def submit_market_order(self, symbol: str, qty: int, side: str) -> OrderResult:
131
+ if self._demo or not self._exchange:
132
+ return OrderResult(
133
+ order_id=f"BINANCE-DEMO-{datetime.now().timestamp()}",
134
+ symbol=symbol,
135
+ action=side,
136
+ qty=qty,
137
+ status="filled",
138
+ filled_price=0.0,
139
+ )
140
+ # TODO: Implement real order submission
141
+ try:
142
+ # Convert to ccxt format: 'BTC/USDT'
143
+ order = self._exchange.create_market_order(symbol, side.lower(), qty)
144
+ return OrderResult(
145
+ order_id=order.get("id", "unknown"),
146
+ symbol=symbol,
147
+ action=side,
148
+ qty=qty,
149
+ status=order.get("status", "filled"),
150
+ filled_price=float(order.get("average") or order.get("price") or 0),
151
+ )
152
+ except Exception as exc:
153
+ logger.error("Binance order failed for %s %s %d: %s", side, symbol, qty, exc)
154
+ raise
155
+
156
+ def close_position(self, symbol: str) -> OrderResult | None:
157
+ if self._demo or not self._exchange:
158
+ return None
159
+ # TODO: Implement position close
160
+ # Need to look up current position qty and sell all
161
+ return None
162
+
163
+ # ── Market Data ───────────────────────────────────────────────────────────
164
+
165
+ def fetch_ohlcv(self, symbol: str, days: int = 90) -> pd.DataFrame:
166
+ if self._demo or not self._exchange:
167
+ return pd.DataFrame()
168
+ try:
169
+ # Binance uses 'BTC/USDT' format
170
+ ohlcv = self._exchange.fetch_ohlcv(symbol, timeframe="1d", limit=days)
171
+ df = pd.DataFrame(
172
+ ohlcv,
173
+ columns=["timestamp", "Open", "High", "Low", "Close", "Volume"],
174
+ )
175
+ df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms", utc=True)
176
+ df.set_index("timestamp", inplace=True)
177
+ return df
178
+ except Exception as exc:
179
+ logger.warning("Binance OHLCV fetch failed for %s: %s", symbol, exc)
180
+ return pd.DataFrame()
181
+
182
+ def get_latest_quote(self, symbol: str) -> float | None:
183
+ if self._demo or not self._exchange:
184
+ return None
185
+ try:
186
+ ticker = self._exchange.fetch_ticker(symbol)
187
+ return float(ticker.get("last") or 0)
188
+ except Exception as exc:
189
+ logger.warning("Binance quote failed for %s: %s", symbol, exc)
190
+ return None
191
+
192
+ # ── Market Info ───────────────────────────────────────────────────────────
193
+
194
+ def get_market_clock(self) -> MarketClock:
195
+ # Crypto markets are 24/7
196
+ return MarketClock(
197
+ is_open=True,
198
+ next_open="24/7",
199
+ next_close="24/7",
200
+ )
201
+
202
+ # ── News ──────────────────────────────────────────────────────────────────
203
+
204
+ def fetch_news(self, symbol: str, max_articles: int = 50,
205
+ days_ago: int = 0) -> list[tuple[str, float]]:
206
+ # Binance doesn't provide news via API
207
+ return []
trading_cli/execution/adapters/kraken.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kraken adapter stub — crypto trading via Kraken API.
2
+
3
+ This is a stub implementation. To enable:
4
+ 1. Install ccxt: `uv add ccxt`
5
+ 2. Add your Kraken API keys to config
6
+ 3. Implement the TODO sections below
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ from datetime import datetime, timezone
13
+
14
+ import pandas as pd
15
+
16
+ from trading_cli.execution.adapters.base import (
17
+ AccountInfo,
18
+ MarketClock,
19
+ OrderResult,
20
+ Position,
21
+ TradingAdapter,
22
+ )
23
+ from trading_cli.execution.adapters.registry import register_adapter
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @register_adapter
29
+ class KrakenAdapter(TradingAdapter):
30
+ """Kraken adapter for cryptocurrency trading.
31
+
32
+ Requires: ccxt library (`uv add ccxt`)
33
+ Config keys:
34
+ kraken_api_key: Your Kraken API key
35
+ kraken_api_secret: Your Kraken API secret
36
+ """
37
+
38
+ def __init__(self, config: dict) -> None:
39
+ self._config = config
40
+ self._api_key = config.get("kraken_api_key", "")
41
+ self._api_secret = config.get("kraken_api_secret", "")
42
+ self._demo = not (self._api_key and self._api_secret)
43
+
44
+ if self._demo:
45
+ logger.info("KrakenAdapter: no API keys found, stub mode only")
46
+ return
47
+
48
+ try:
49
+ import ccxt
50
+ self._exchange = ccxt.kraken({
51
+ "apiKey": self._api_key,
52
+ "secret": self._api_secret,
53
+ "enableRateLimit": True,
54
+ })
55
+ logger.info("KrakenAdapter connected")
56
+ except ImportError:
57
+ logger.warning("ccxt not installed. Run: uv add ccxt")
58
+ self._demo = True
59
+ self._exchange = None
60
+ except Exception as exc:
61
+ logger.error("Failed to connect to Kraken: %s", exc)
62
+ self._demo = True
63
+ self._exchange = None
64
+
65
+ @property
66
+ def adapter_id(self) -> str:
67
+ return "kraken"
68
+
69
+ @property
70
+ def supports_paper_trading(self) -> bool:
71
+ return False # Kraken doesn't have testnet
72
+
73
+ @property
74
+ def is_demo_mode(self) -> bool:
75
+ return self._demo
76
+
77
+ # ── Account & Positions ───────────────────────────────────────────────────
78
+
79
+ def get_account(self) -> AccountInfo:
80
+ if self._demo or not self._exchange:
81
+ return AccountInfo(
82
+ equity=100000.0,
83
+ cash=100000.0,
84
+ buying_power=100000.0,
85
+ portfolio_value=100000.0,
86
+ )
87
+ # TODO: Implement real account fetch
88
+ balance = self._exchange.fetch_balance()
89
+ cash = float(balance.get("USD", {}).get("free", 0))
90
+ return AccountInfo(
91
+ equity=cash,
92
+ cash=cash,
93
+ buying_power=cash,
94
+ portfolio_value=cash,
95
+ )
96
+
97
+ def get_positions(self) -> list[Position]:
98
+ if self._demo or not self._exchange:
99
+ return []
100
+ # TODO: Implement real position fetch
101
+ positions = []
102
+ balance = self._exchange.fetch_balance()
103
+ for currency, amount_info in balance.items():
104
+ if isinstance(amount_info, dict) and amount_info.get("total", 0) > 0:
105
+ if currency in ("free", "used", "total", "info"):
106
+ continue
107
+ total = amount_info.get("total", 0)
108
+ positions.append(
109
+ Position(
110
+ symbol=f"{currency}/USD",
111
+ qty=total,
112
+ avg_entry_price=0.0,
113
+ current_price=0.0,
114
+ unrealized_pl=0.0,
115
+ unrealized_plpc=0.0,
116
+ market_value=0.0,
117
+ side="long",
118
+ )
119
+ )
120
+ return positions
121
+
122
+ # ── Orders ──────────────────────────────────────────────────────────────
123
+
124
+ def submit_market_order(self, symbol: str, qty: int, side: str) -> OrderResult:
125
+ if self._demo or not self._exchange:
126
+ return OrderResult(
127
+ order_id=f"KRAKEN-DEMO-{datetime.now().timestamp()}",
128
+ symbol=symbol,
129
+ action=side,
130
+ qty=qty,
131
+ status="filled",
132
+ filled_price=0.0,
133
+ )
134
+ # TODO: Implement real order submission
135
+ try:
136
+ order = self._exchange.create_market_order(symbol, side.lower(), qty)
137
+ return OrderResult(
138
+ order_id=order.get("id", "unknown"),
139
+ symbol=symbol,
140
+ action=side,
141
+ qty=qty,
142
+ status=order.get("status", "filled"),
143
+ filled_price=float(order.get("average") or order.get("price") or 0),
144
+ )
145
+ except Exception as exc:
146
+ logger.error("Kraken order failed for %s %s %d: %s", side, symbol, qty, exc)
147
+ raise
148
+
149
+ def close_position(self, symbol: str) -> OrderResult | None:
150
+ if self._demo or not self._exchange:
151
+ return None
152
+ # TODO: Implement position close
153
+ return None
154
+
155
+ # ── Market Data ───────────────────────────────────────────────────────────
156
+
157
+ def fetch_ohlcv(self, symbol: str, days: int = 90) -> pd.DataFrame:
158
+ if self._demo or not self._exchange:
159
+ return pd.DataFrame()
160
+ try:
161
+ ohlcv = self._exchange.fetch_ohlcv(symbol, timeframe="1d", limit=days)
162
+ df = pd.DataFrame(
163
+ ohlcv,
164
+ columns=["timestamp", "Open", "High", "Low", "Close", "Volume"],
165
+ )
166
+ df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms", utc=True)
167
+ df.set_index("timestamp", inplace=True)
168
+ return df
169
+ except Exception as exc:
170
+ logger.warning("Kraken OHLCV fetch failed for %s: %s", symbol, exc)
171
+ return pd.DataFrame()
172
+
173
+ def get_latest_quote(self, symbol: str) -> float | None:
174
+ if self._demo or not self._exchange:
175
+ return None
176
+ try:
177
+ ticker = self._exchange.fetch_ticker(symbol)
178
+ return float(ticker.get("last") or 0)
179
+ except Exception as exc:
180
+ logger.warning("Kraken quote failed for %s: %s", symbol, exc)
181
+ return None
182
+
183
+ # ── Market Info ───────────────────────────────────────────────────────────
184
+
185
+ def get_market_clock(self) -> MarketClock:
186
+ # Crypto markets are 24/7
187
+ return MarketClock(
188
+ is_open=True,
189
+ next_open="24/7",
190
+ next_close="24/7",
191
+ )
192
+
193
+ # ── News ──────────────────────────────────────────────────────────────────
194
+
195
+ def fetch_news(self, symbol: str, max_articles: int = 50,
196
+ days_ago: int = 0) -> list[tuple[str, float]]:
197
+ # Kraken doesn't provide news via API
198
+ return []
trading_cli/execution/adapters/registry.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapter registry — discovers and instantiates trading adapters."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import TYPE_CHECKING
7
+
8
+ from trading_cli.execution.adapters.base import TradingAdapter
9
+
10
+ if TYPE_CHECKING:
11
+ pass
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Global registry of available adapters
16
+ _ADAPTERS: dict[str, type[TradingAdapter]] = {}
17
+
18
+
19
+ def register_adapter(adapter_class: type[TradingAdapter]) -> type[TradingAdapter]:
20
+ """Decorator to register an adapter class.
21
+
22
+ Usage:
23
+ @register_adapter
24
+ class AlpacaAdapter(TradingAdapter):
25
+ ...
26
+ """
27
+ # Instantiate temporarily to get adapter_id
28
+ # We assume adapter_id is a class property or can be called without args
29
+ try:
30
+ instance = adapter_class.__new__(adapter_class)
31
+ adapter_id = adapter_class.adapter_id.fget(instance) if hasattr(adapter_class.adapter_id, 'fget') else getattr(adapter_class, 'adapter_id', None)
32
+ if adapter_id:
33
+ _ADAPTERS[adapter_id] = adapter_class
34
+ logger.debug("Registered adapter: %s", adapter_id)
35
+ except Exception:
36
+ # Fallback: use class name lowercase
37
+ adapter_id = adapter_class.__name__.lower().replace("adapter", "")
38
+ _ADAPTERS[adapter_id] = adapter_class
39
+ logger.debug("Registered adapter (fallback): %s", adapter_id)
40
+ return adapter_class
41
+
42
+
43
+ def get_adapter(adapter_id: str) -> type[TradingAdapter] | None:
44
+ """Get adapter class by ID."""
45
+ return _ADAPTERS.get(adapter_id)
46
+
47
+
48
+ def list_adapters() -> list[str]:
49
+ """List all registered adapter IDs."""
50
+ return list(_ADAPTERS.keys())
51
+
52
+
53
+ def create_adapter(adapter_id: str, config: dict) -> TradingAdapter:
54
+ """Create an adapter instance from config.
55
+
56
+ Args:
57
+ adapter_id: Adapter identifier ('alpaca', 'binance', 'kraken', 'demo').
58
+ config: Configuration dict with API keys and settings.
59
+
60
+ Returns:
61
+ TradingAdapter instance.
62
+
63
+ Raises:
64
+ ValueError: If adapter_id is not registered.
65
+ """
66
+ adapter_class = get_adapter(adapter_id)
67
+ if adapter_class is None:
68
+ available = list_adapters()
69
+ raise ValueError(
70
+ f"Unknown adapter: '{adapter_id}'. "
71
+ f"Available adapters: {available}"
72
+ )
73
+ return adapter_class(config)
trading_cli/execution/adapters/yfinance.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """yFinance adapter — free market data with mock trading."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import random
7
+ import time
8
+ from datetime import datetime, timedelta, timezone
9
+
10
+ import pandas as pd
11
+
12
+ from trading_cli.execution.adapters.base import (
13
+ AccountInfo,
14
+ MarketClock,
15
+ OrderResult,
16
+ Position,
17
+ TradingAdapter,
18
+ )
19
+ from trading_cli.execution.adapters.registry import register_adapter
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @register_adapter
25
+ class YFinanceAdapter(TradingAdapter):
26
+ """yFinance adapter for free market data with simulated trading.
27
+
28
+ Provides:
29
+ - Real OHLCV data from Yahoo Finance
30
+ - Real latest quotes from Yahoo Finance
31
+ - Simulated account and positions (demo mode)
32
+ """
33
+
34
+ def __init__(self, config: dict) -> None:
35
+ self._config = config
36
+ self._cash = config.get("initial_cash", 100_000.0)
37
+ self._positions: dict[str, dict] = {}
38
+ self._order_counter = 1000
39
+ self._base_prices = {
40
+ "AAPL": 175.0, "TSLA": 245.0, "NVDA": 875.0,
41
+ "MSFT": 415.0, "AMZN": 185.0, "GOOGL": 175.0,
42
+ "META": 510.0, "SPY": 520.0,
43
+ }
44
+ logger.info("YFinanceAdapter initialized in demo mode")
45
+
46
+ @property
47
+ def adapter_id(self) -> str:
48
+ return "yfinance"
49
+
50
+ @property
51
+ def supports_paper_trading(self) -> bool:
52
+ return True # Simulated trading
53
+
54
+ @property
55
+ def is_demo_mode(self) -> bool:
56
+ return True
57
+
58
+ # ── Account & Positions ───────────────────────────────────────────────────
59
+
60
+ def get_account(self) -> AccountInfo:
61
+ portfolio = sum(
62
+ p["qty"] * self._get_mock_price(sym)
63
+ for sym, p in self._positions.items()
64
+ )
65
+ equity = self._cash + portfolio
66
+ return AccountInfo(
67
+ equity=equity,
68
+ cash=self._cash,
69
+ buying_power=self._cash * 4,
70
+ portfolio_value=equity,
71
+ )
72
+
73
+ def get_positions(self) -> list[Position]:
74
+ positions = []
75
+ for sym, p in self._positions.items():
76
+ cp = self._get_mock_price(sym)
77
+ ep = p["avg_price"]
78
+ pl = (cp - ep) * p["qty"]
79
+ plpc = (cp - ep) / ep if ep else 0.0
80
+ positions.append(
81
+ Position(sym, p["qty"], ep, cp, pl, plpc, cp * p["qty"])
82
+ )
83
+ return positions
84
+
85
+ # ── Orders ───────────────────────────────────────────────────────────────
86
+
87
+ def submit_market_order(self, symbol: str, qty: int, side: str) -> OrderResult:
88
+ price = self._get_mock_price(symbol)
89
+ self._order_counter += 1
90
+ order_id = f"YF-{self._order_counter}"
91
+
92
+ if side.upper() == "BUY":
93
+ cost = price * qty
94
+ if cost > self._cash:
95
+ return OrderResult(order_id, symbol, side, qty, "rejected")
96
+ self._cash -= cost
97
+ if symbol in self._positions:
98
+ p = self._positions[symbol]
99
+ total_qty = p["qty"] + qty
100
+ p["avg_price"] = (p["avg_price"] * p["qty"] + price * qty) / total_qty
101
+ p["qty"] = total_qty
102
+ else:
103
+ self._positions[symbol] = {"qty": qty, "avg_price": price}
104
+ else: # SELL
105
+ if symbol not in self._positions or self._positions[symbol]["qty"] < qty:
106
+ return OrderResult(order_id, symbol, side, qty, "rejected")
107
+ self._cash += price * qty
108
+ self._positions[symbol]["qty"] -= qty
109
+ if self._positions[symbol]["qty"] == 0:
110
+ del self._positions[symbol]
111
+
112
+ return OrderResult(order_id, symbol, side, qty, "filled", price)
113
+
114
+ def close_position(self, symbol: str) -> OrderResult | None:
115
+ if symbol not in self._positions:
116
+ return None
117
+ qty = self._positions[symbol]["qty"]
118
+ return self.submit_market_order(symbol, qty, "SELL")
119
+
120
+ def _get_mock_price(self, symbol: str) -> float:
121
+ """Get a mock price with small random walk for realism."""
122
+ base = self._base_prices.get(symbol, 100.0)
123
+ noise = random.gauss(0, base * 0.002)
124
+ return round(max(1.0, base + noise), 2)
125
+
126
+ # ── Market Data ───────────────────────────────────────────────────────────
127
+
128
+ def fetch_ohlcv(self, symbol: str, days: int = 90) -> pd.DataFrame:
129
+ """Fetch OHLCV from yfinance."""
130
+ try:
131
+ import yfinance as yf
132
+ period = f"{min(days, 730)}d"
133
+ df = yf.download(symbol, period=period, interval="1d", progress=False, auto_adjust=True)
134
+ if df.empty:
135
+ return pd.DataFrame()
136
+ return df.tail(days)
137
+ except Exception as exc:
138
+ logger.error("yfinance fetch failed for %s: %s", symbol, exc)
139
+ return pd.DataFrame()
140
+
141
+ def get_latest_quote(self, symbol: str) -> float | None:
142
+ """Get latest price from yfinance."""
143
+ try:
144
+ import yfinance as yf
145
+ ticker = yf.Ticker(symbol)
146
+ info = ticker.fast_info
147
+ price = getattr(info, "last_price", None) or getattr(info, "regularMarketPrice", None)
148
+ if price:
149
+ return float(price)
150
+ hist = ticker.history(period="2d", interval="1d")
151
+ if not hist.empty:
152
+ return float(hist["Close"].iloc[-1])
153
+ return None
154
+ except Exception as exc:
155
+ logger.warning("yfinance latest quote failed for %s: %s", symbol, exc)
156
+ return None
157
+
158
+ # ── Market Info ───────────────────────────────────────────────────────────
159
+
160
+ def get_market_clock(self) -> MarketClock:
161
+ now = datetime.now(tz=timezone.utc)
162
+ # Mock: market open weekdays 9:30–16:00 ET (UTC-5)
163
+ hour_et = (now.hour - 5) % 24
164
+ is_open = now.weekday() < 5 and 9 <= hour_et < 16
165
+ return MarketClock(
166
+ is_open=is_open,
167
+ next_open="09:30 ET",
168
+ next_close="16:00 ET",
169
+ )
trading_cli/execution/alpaca_client.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Alpaca API wrapper — paper trading + market data."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import random
7
+ from datetime import datetime, timezone
8
+ from typing import Any
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class Position:
14
+ """Unified position object (real or mock)."""
15
+
16
+ def __init__(
17
+ self,
18
+ symbol: str,
19
+ qty: float,
20
+ avg_entry_price: float,
21
+ current_price: float,
22
+ unrealized_pl: float,
23
+ unrealized_plpc: float,
24
+ market_value: float,
25
+ side: str = "long",
26
+ ):
27
+ self.symbol = symbol
28
+ self.qty = qty
29
+ self.avg_entry_price = avg_entry_price
30
+ self.current_price = current_price
31
+ self.unrealized_pl = unrealized_pl
32
+ self.unrealized_plpc = unrealized_plpc
33
+ self.market_value = market_value
34
+ self.side = side
35
+
36
+
37
+ class AccountInfo:
38
+ def __init__(self, equity: float, cash: float, buying_power: float, portfolio_value: float):
39
+ self.equity = equity
40
+ self.cash = cash
41
+ self.buying_power = buying_power
42
+ self.portfolio_value = portfolio_value
43
+
44
+
45
+ class OrderResult:
46
+ def __init__(self, order_id: str, symbol: str, action: str, qty: int,
47
+ status: str, filled_price: float | None = None):
48
+ self.order_id = order_id
49
+ self.symbol = symbol
50
+ self.action = action
51
+ self.qty = qty
52
+ self.status = status
53
+ self.filled_price = filled_price
54
+
55
+
56
+ # ── Mock client for demo mode ──────────────────────────────────────────────────
57
+
58
+ class MockAlpacaClient:
59
+ """Simulated Alpaca client for demo mode (no API keys required)."""
60
+
61
+ def __init__(self) -> None:
62
+ self.demo_mode = True
63
+ self._cash = 100_000.0
64
+ self._positions: dict[str, dict] = {}
65
+ self._order_counter = 1000
66
+ self._base_prices = {
67
+ "AAPL": 175.0, "TSLA": 245.0, "NVDA": 875.0,
68
+ "MSFT": 415.0, "AMZN": 185.0, "GOOGL": 175.0,
69
+ "META": 510.0, "SPY": 520.0,
70
+ }
71
+ logger.info("MockAlpacaClient initialized in demo mode")
72
+
73
+ def get_account(self) -> AccountInfo:
74
+ portfolio = sum(
75
+ p["qty"] * self._get_mock_price(sym)
76
+ for sym, p in self._positions.items()
77
+ )
78
+ equity = self._cash + portfolio
79
+ return AccountInfo(
80
+ equity=equity,
81
+ cash=self._cash,
82
+ buying_power=self._cash * 4,
83
+ portfolio_value=equity,
84
+ )
85
+
86
+ def get_positions(self) -> list[Position]:
87
+ positions = []
88
+ for sym, p in self._positions.items():
89
+ cp = self._get_mock_price(sym)
90
+ ep = p["avg_price"]
91
+ pl = (cp - ep) * p["qty"]
92
+ plpc = (cp - ep) / ep if ep else 0.0
93
+ positions.append(
94
+ Position(sym, p["qty"], ep, cp, pl, plpc, cp * p["qty"])
95
+ )
96
+ return positions
97
+
98
+ def get_market_clock(self) -> dict:
99
+ now = datetime.now(tz=timezone.utc)
100
+ # Mock: market open weekdays 9:30–16:00 ET (UTC-5)
101
+ hour_et = (now.hour - 5) % 24
102
+ is_open = now.weekday() < 5 and 9 <= hour_et < 16
103
+ return {"is_open": is_open, "next_open": "09:30 ET", "next_close": "16:00 ET"}
104
+
105
+ def submit_market_order(
106
+ self, symbol: str, qty: int, side: str
107
+ ) -> OrderResult:
108
+ price = self._get_mock_price(symbol)
109
+ self._order_counter += 1
110
+ order_id = f"MOCK-{self._order_counter}"
111
+
112
+ if side.upper() == "BUY":
113
+ cost = price * qty
114
+ if cost > self._cash:
115
+ return OrderResult(order_id, symbol, side, qty, "rejected")
116
+ self._cash -= cost
117
+ if symbol in self._positions:
118
+ p = self._positions[symbol]
119
+ total_qty = p["qty"] + qty
120
+ p["avg_price"] = (p["avg_price"] * p["qty"] + price * qty) / total_qty
121
+ p["qty"] = total_qty
122
+ else:
123
+ self._positions[symbol] = {"qty": qty, "avg_price": price}
124
+ else: # SELL
125
+ if symbol not in self._positions or self._positions[symbol]["qty"] < qty:
126
+ return OrderResult(order_id, symbol, side, qty, "rejected")
127
+ self._cash += price * qty
128
+ self._positions[symbol]["qty"] -= qty
129
+ if self._positions[symbol]["qty"] == 0:
130
+ del self._positions[symbol]
131
+
132
+ return OrderResult(order_id, symbol, side, qty, "filled", price)
133
+
134
+ def close_position(self, symbol: str) -> OrderResult | None:
135
+ if symbol not in self._positions:
136
+ return None
137
+ qty = self._positions[symbol]["qty"]
138
+ return self.submit_market_order(symbol, qty, "SELL")
139
+
140
+ def _get_mock_price(self, symbol: str) -> float:
141
+ base = self._base_prices.get(symbol, 100.0)
142
+ # Small random walk so prices feel live
143
+ noise = random.gauss(0, base * 0.002)
144
+ return round(max(1.0, base + noise), 2)
145
+
146
+ def historical_client(self) -> None:
147
+ return None
148
+
149
+
150
+ # ── Real Alpaca client ─────────────────────────────────────────────────────────
151
+
152
+ class AlpacaClient:
153
+ """Wraps alpaca-py SDK for paper trading."""
154
+
155
+ def __init__(self, api_key: str, api_secret: str, paper: bool = True) -> None:
156
+ self.demo_mode = False
157
+ self._paper = paper
158
+ try:
159
+ from alpaca.trading.client import TradingClient
160
+ from alpaca.data.historical import StockHistoricalDataClient
161
+
162
+ self._trading_client = TradingClient(
163
+ api_key=api_key,
164
+ secret_key=api_secret,
165
+ paper=paper,
166
+ )
167
+ self.historical_client = StockHistoricalDataClient(
168
+ api_key=api_key,
169
+ secret_key=api_secret,
170
+ )
171
+ logger.info("AlpacaClient connected (paper=%s)", paper)
172
+ except ImportError as exc:
173
+ raise RuntimeError("alpaca-py not installed. Run: uv add alpaca-py") from exc
174
+
175
+ def get_account(self) -> AccountInfo:
176
+ acct = self._trading_client.get_account()
177
+ return AccountInfo(
178
+ equity=float(acct.equity),
179
+ cash=float(acct.cash),
180
+ buying_power=float(acct.buying_power),
181
+ portfolio_value=float(acct.portfolio_value),
182
+ )
183
+
184
+ def get_positions(self) -> list[Position]:
185
+ raw = self._trading_client.get_all_positions()
186
+ out = []
187
+ for p in raw:
188
+ out.append(
189
+ Position(
190
+ symbol=p.symbol,
191
+ qty=float(p.qty),
192
+ avg_entry_price=float(p.avg_entry_price),
193
+ current_price=float(p.current_price),
194
+ unrealized_pl=float(p.unrealized_pl),
195
+ unrealized_plpc=float(p.unrealized_plpc),
196
+ market_value=float(p.market_value),
197
+ side=str(p.side),
198
+ )
199
+ )
200
+ return out
201
+
202
+ def get_market_clock(self) -> dict:
203
+ try:
204
+ clock = self._trading_client.get_clock()
205
+ return {
206
+ "is_open": clock.is_open,
207
+ "next_open": str(clock.next_open),
208
+ "next_close": str(clock.next_close),
209
+ }
210
+ except Exception as exc:
211
+ logger.warning("get_market_clock failed: %s", exc)
212
+ return {"is_open": False, "next_open": "Unknown", "next_close": "Unknown"}
213
+
214
+ def submit_market_order(
215
+ self, symbol: str, qty: int, side: str
216
+ ) -> OrderResult:
217
+ from alpaca.trading.requests import MarketOrderRequest
218
+ from alpaca.trading.enums import OrderSide, TimeInForce
219
+
220
+ order_side = OrderSide.BUY if side.upper() == "BUY" else OrderSide.SELL
221
+ req = MarketOrderRequest(
222
+ symbol=symbol,
223
+ qty=qty,
224
+ side=order_side,
225
+ time_in_force=TimeInForce.DAY,
226
+ )
227
+ try:
228
+ order = self._trading_client.submit_order(order_data=req)
229
+ filled_price = float(order.filled_avg_price) if order.filled_avg_price else None
230
+ return OrderResult(
231
+ order_id=str(order.id),
232
+ symbol=symbol,
233
+ action=side,
234
+ qty=qty,
235
+ status=str(order.status),
236
+ filled_price=filled_price,
237
+ )
238
+ except Exception as exc:
239
+ logger.error("Order submission failed for %s %s %d: %s", side, symbol, qty, exc)
240
+ raise
241
+
242
+ def close_position(self, symbol: str) -> OrderResult | None:
243
+ try:
244
+ response = self._trading_client.close_position(symbol)
245
+ return OrderResult(
246
+ order_id=str(response.id),
247
+ symbol=symbol,
248
+ action="SELL",
249
+ qty=int(float(response.qty or 0)),
250
+ status=str(response.status),
251
+ )
252
+ except Exception as exc:
253
+ logger.error("Close position failed for %s: %s", symbol, exc)
254
+ return None
255
+
256
+
257
+ def create_client(config: dict) -> AlpacaClient | MockAlpacaClient:
258
+ """Factory: return real AlpacaClient or MockAlpacaClient based on config."""
259
+ key = config.get("alpaca_api_key", "")
260
+ secret = config.get("alpaca_api_secret", "")
261
+ if key and secret:
262
+ try:
263
+ return AlpacaClient(key, secret, paper=config.get("alpaca_paper", True))
264
+ except Exception as exc:
265
+ logger.error("Failed to create AlpacaClient: %s — falling back to demo mode", exc)
266
+ return MockAlpacaClient()
trading_cli/run_dev.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HMR dev runner — watches for .py changes and auto-restarts the trading CLI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import sys
7
+
8
+ # CRITICAL: Set multiprocessing start method BEFORE any other imports
9
+ if sys.platform.startswith('linux'):
10
+ try:
11
+ import multiprocessing
12
+ multiprocessing.set_start_method('spawn', force=True)
13
+ except (RuntimeError, AttributeError):
14
+ pass
15
+
16
+ import subprocess
17
+ from pathlib import Path
18
+
19
+ from watchfiles import watch
20
+
21
+
22
+ def main() -> None:
23
+ project_root = Path(__file__).parent.resolve()
24
+ target_dir = project_root / "trading_cli"
25
+
26
+ print(f"🔄 Watching {target_dir} for changes (Ctrl+C to stop)\n")
27
+
28
+ for changes in watch(target_dir, watch_filter=None):
29
+ for change_type, path in changes:
30
+ if not path.endswith((".py", ".pyc")):
31
+ continue
32
+ action = "Added" if change_type.name == "added" else \
33
+ "Modified" if change_type.name == "modified" else "Deleted"
34
+ rel = Path(path).relative_to(project_root)
35
+ print(f"\n📝 {action}: {rel}")
36
+ print("⟳ Restarting...\n")
37
+ break # restart on first matching change
38
+ subprocess.run([sys.executable, "-m", "trading_cli"])
39
+
40
+
41
+ if __name__ == "__main__":
42
+ main()
trading_cli/screens/backtest.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Backtest results screen — displays performance metrics and trade log."""
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from textual.app import ComposeResult
6
+ from textual.binding import Binding
7
+ from textual.screen import Screen
8
+ from textual.widgets import Header, DataTable, Label, Static, Input, Button, LoadingIndicator
9
+ from textual.containers import Vertical, Horizontal, Center
10
+ from textual import work
11
+ from rich.text import Text
12
+
13
+ from trading_cli.widgets.ordered_footer import OrderedFooter
14
+ from trading_cli.backtest.engine import BacktestResult
15
+
16
+
17
+ class BacktestSummary(Static):
18
+ """Displays key backtest metrics."""
19
+
20
+ def __init__(self, result: "BacktestResult | None" = None, **kwargs) -> None:
21
+ super().__init__(**kwargs)
22
+ self._result = result
23
+
24
+ def set_result(self, result: "BacktestResult") -> None:
25
+ self._result = result
26
+ self.refresh()
27
+
28
+ def render(self) -> str:
29
+ if not self._result:
30
+ return "[dim]No backtest data[/dim]"
31
+ r = self._result
32
+ pnl = r.final_equity - r.initial_capital
33
+ pnl_style = "bold green" if pnl >= 0 else "bold red"
34
+ dd_style = "bold red" if r.max_drawdown_pct > 10 else "bold yellow"
35
+ sharpe_style = "bold green" if r.sharpe_ratio > 1 else ("bold yellow" if r.sharpe_ratio > 0 else "dim")
36
+
37
+ # Truncate symbol list if it's too long
38
+ display_symbol = r.symbol
39
+ if "," in display_symbol and len(display_symbol) > 30:
40
+ count = display_symbol.count(",") + 1
41
+ display_symbol = f"{display_symbol.split(',')[0]} + {count-1} others"
42
+
43
+ return (
44
+ f"[bold]{display_symbol}[/bold] "
45
+ f"[{pnl_style}]P&L: ${pnl:+,.2f} ({r.total_return_pct:+.2f}%)[/{pnl_style}] "
46
+ f"[{dd_style}]MaxDD: {r.max_drawdown_pct:.2f}%[/{dd_style}] "
47
+ f"[{sharpe_style}]Sharpe: {r.sharpe_ratio:.2f}[/{sharpe_style}] "
48
+ f"Win Rate: {r.win_rate:.1f}% "
49
+ f"Trades: {r.total_trades} ({r.winning_trades}W / {r.losing_trades}L) "
50
+ f"${r.initial_capital:,.0f} → ${r.final_equity:,.0f}"
51
+ )
52
+
53
+
54
+ class BacktestScreen(Screen):
55
+ """Screen for viewing backtest results."""
56
+
57
+ CSS = """
58
+ #backtest-container {
59
+ height: 1fr;
60
+ padding: 0;
61
+ margin: 0;
62
+ overflow: hidden;
63
+ }
64
+
65
+ #backtest-progress {
66
+ height: 1;
67
+ padding: 0 1;
68
+ color: $text-muted;
69
+ text-style: italic;
70
+ }
71
+
72
+ #backtest-controls {
73
+ height: auto;
74
+ padding: 0 1;
75
+ }
76
+
77
+ #backtest-date-row {
78
+ height: auto;
79
+ layout: horizontal;
80
+ }
81
+
82
+ #backtest-date-row Input {
83
+ width: 1fr;
84
+ }
85
+
86
+ #btn-backtest-run {
87
+ width: 100%;
88
+ }
89
+
90
+ #backtest-summary {
91
+ height: auto;
92
+ padding: 0 1;
93
+ color: $text;
94
+ }
95
+
96
+ #backtest-table {
97
+ width: 100%;
98
+ height: 1fr;
99
+ }
100
+ """
101
+
102
+ BINDINGS = [
103
+ Binding("r", "run_backtest", "Run", show=True),
104
+ ]
105
+
106
+ _last_symbol: str = ""
107
+ _last_result: "BacktestResult | None" = None
108
+ _all_results: list["BacktestResult"] = []
109
+
110
+ def compose(self) -> ComposeResult:
111
+ yield Header(show_clock=True)
112
+ with Vertical(id="backtest-container"):
113
+ with Vertical(id="backtest-controls"):
114
+ with Horizontal(id="backtest-date-row"):
115
+ yield Input(placeholder="Start date (YYYY-MM-DD)", id="backtest-start-date")
116
+ yield Input(placeholder="End date (YYYY-MM-DD)", id="backtest-end-date")
117
+ yield Button("🚀 Run", id="btn-backtest-run", variant="success")
118
+ yield BacktestSummary(id="backtest-summary")
119
+ yield Label("", id="backtest-progress")
120
+ yield LoadingIndicator(id="backtest-loading")
121
+ yield DataTable(id="backtest-table", cursor_type="row")
122
+ yield OrderedFooter()
123
+
124
+ def on_mount(self) -> None:
125
+ tbl = self.query_one("#backtest-table", DataTable)
126
+ tbl.add_column("Date", key="date")
127
+ tbl.add_column("Action", key="action")
128
+ tbl.add_column("Price $", key="price")
129
+ tbl.add_column("Qty", key="qty")
130
+ tbl.add_column("P&L $", key="pnl")
131
+ tbl.add_column("Reason", key="reason")
132
+
133
+ # Hide loading indicator initially
134
+ try:
135
+ loader = self.query_one("#backtest-loading", LoadingIndicator)
136
+ loader.display = False
137
+ except Exception:
138
+ pass
139
+
140
+ # Set progress label initially empty
141
+ try:
142
+ prog = self.query_one("#backtest-progress", Label)
143
+ prog.update("")
144
+ except Exception:
145
+ pass
146
+
147
+ def _update_progress(self, text: str) -> None:
148
+ """Update the backtest progress label."""
149
+ try:
150
+ prog = self.query_one("#backtest-progress", Label)
151
+ prog.update(text)
152
+ except Exception:
153
+ pass
154
+
155
+ def on_button_pressed(self, event) -> None:
156
+ if event.button.id == "btn-backtest-run":
157
+ self.action_run_backtest()
158
+
159
+ def on_input_submitted(self, event: Input.Submitted) -> None:
160
+ if event.input.id in ("backtest-start-date", "backtest-end-date"):
161
+ self.action_run_backtest()
162
+
163
+ def action_run_backtest(self) -> None:
164
+ # Parse date range
165
+ start_date = end_date = None
166
+ try:
167
+ start_input = self.query_one("#backtest-start-date", Input)
168
+ end_input = self.query_one("#backtest-end-date", Input)
169
+ if start_input.value.strip():
170
+ start_date = start_input.value.strip()
171
+ if end_input.value.strip():
172
+ end_date = end_input.value.strip()
173
+ except Exception:
174
+ pass
175
+
176
+ app = self.app
177
+ if not hasattr(app, "config"):
178
+ self.app.notify("App not fully initialized", severity="error")
179
+ return
180
+
181
+ # Use full asset universe from adapter (not just 3 hardcoded symbols)
182
+ symbols = []
183
+ if hasattr(app, "asset_search") and app.asset_search.is_ready:
184
+ all_symbols = [a["symbol"] for a in app.asset_search._assets]
185
+ # Cap at 50 symbols to keep backtest time reasonable (~2-3 min)
186
+ symbols = all_symbols[:50]
187
+ if not symbols:
188
+ symbols = app.config.get("default_symbols", ["AAPL", "TSLA", "NVDA"])
189
+
190
+ # Reset accumulated results
191
+ self._all_results = []
192
+
193
+ label = f"{start_date or 'start'} → {end_date or 'now'}"
194
+ self.app.notify(f"Backtesting {len(symbols)} symbols ({label})", timeout=2)
195
+
196
+ # Show loading
197
+ try:
198
+ loader = self.query_one("#backtest-loading", LoadingIndicator)
199
+ loader.display = True
200
+ except Exception:
201
+ pass
202
+
203
+ # Clear table
204
+ tbl = self.query_one("#backtest-table", DataTable)
205
+ tbl.clear()
206
+
207
+ # Update summary to show "Running…"
208
+ summary = self.query_one("#backtest-summary", BacktestSummary)
209
+ summary._result = None
210
+ summary.refresh()
211
+
212
+ # Run all symbols in a single worker thread
213
+ self._execute_backtest(symbols, start_date, end_date)
214
+
215
+ @work(thread=True, name="backtest-worker", exclusive=True)
216
+ def _execute_backtest(self, symbols: list[str], start_date: str | None = None, end_date: str | None = None) -> None:
217
+ """Run backtest for multiple symbols in parallel worker threads."""
218
+ try:
219
+ app = self.app
220
+ from trading_cli.data.market import fetch_ohlcv_yfinance
221
+ from trading_cli.backtest.engine import BacktestEngine
222
+ from concurrent.futures import ThreadPoolExecutor, as_completed
223
+
224
+ from datetime import datetime, timedelta
225
+
226
+ def run_one_symbol(symbol):
227
+ """Run backtest for a single symbol in its own thread."""
228
+ try:
229
+ # Calculate days needed to cover requested range
230
+ if start_date:
231
+ try:
232
+ sd = datetime.strptime(start_date, "%Y-%m-%d")
233
+ days_needed = max(365, (datetime.now() - sd).days + 60)
234
+ except ValueError:
235
+ days_needed = 730
236
+ else:
237
+ days_needed = 730
238
+
239
+ adapter = getattr(app, "adapter", None)
240
+ # Always use Alpaca for historical data if available
241
+ if adapter:
242
+ ohlcv = adapter.fetch_ohlcv(symbol, days=days_needed)
243
+ else:
244
+ ohlcv = fetch_ohlcv_yfinance(symbol, days=days_needed)
245
+
246
+ if ohlcv.empty:
247
+ return None
248
+
249
+ cfg = app.config.copy()
250
+ # Use higher risk percentage for backtests to fully utilize capital
251
+ # Since each backtest is isolated to 1 symbol, allow full portfolio usage
252
+ cfg["risk_pct"] = 0.95 # Use 95% of capital per trade
253
+ cfg["max_position_pct"] = 1.0 # Allow 100% portfolio size
254
+
255
+ strategy = getattr(app, "strategy", None)
256
+
257
+ engine = BacktestEngine(
258
+ config=cfg,
259
+ finbert=None,
260
+ news_fetcher=None,
261
+ use_sentiment=False,
262
+ strategy=strategy,
263
+ progress_callback=None,
264
+ debug=False,
265
+ )
266
+ return engine.run(symbol, ohlcv, start_date=start_date, end_date=end_date, initial_capital=100_000.0)
267
+ except Exception as exc:
268
+ import logging
269
+ logging.getLogger(__name__).error("Backtest %s failed: %s", symbol, exc)
270
+ return None
271
+
272
+ total = len(symbols)
273
+ results = []
274
+ max_workers = min(8, total) # Cap at 8 parallel threads
275
+
276
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
277
+ futures = {executor.submit(run_one_symbol, s): s for s in symbols}
278
+ for i, future in enumerate(as_completed(futures)):
279
+ symbol = futures[future]
280
+ result = future.result()
281
+ if result:
282
+ results.append(result)
283
+ self.app.call_from_thread(
284
+ self._update_progress,
285
+ f"[dim]Backtested {i+1}/{total} symbols…[/dim]",
286
+ )
287
+
288
+ self._all_results = results
289
+ self.app.call_from_thread(self._display_all_results)
290
+ except Exception as exc:
291
+ self.app.call_from_thread(
292
+ self.app.notify,
293
+ f"Backtest failed: {exc}",
294
+ severity="error",
295
+ )
296
+ import logging
297
+ logging.getLogger(__name__).error("Backtest error: %s", exc, exc_info=True)
298
+ self.app.call_from_thread(self._hide_loading)
299
+
300
+ def _hide_loading(self) -> None:
301
+ """Hide the loading indicator."""
302
+ try:
303
+ loader = self.query_one("#backtest-loading", LoadingIndicator)
304
+ loader.display = False
305
+ except Exception:
306
+ pass
307
+
308
+ def _display_all_results(self) -> None:
309
+ """Display combined backtest results for all symbols."""
310
+ self._hide_loading()
311
+ self._update_progress("")
312
+
313
+ if not self._all_results:
314
+ self.app.notify("No results", severity="warning")
315
+ return
316
+
317
+ # Aggregate metrics
318
+ total_wins = sum(r.winning_trades for r in self._all_results)
319
+ total_losses = sum(r.losing_trades for r in self._all_results)
320
+ total_closed_trades = total_wins + total_losses
321
+ total_trades = sum(r.total_trades for r in self._all_results)
322
+ total_initial = sum(r.initial_capital for r in self._all_results)
323
+ total_final = sum(r.final_equity for r in self._all_results)
324
+ total_return_pct = ((total_final - total_initial) / total_initial * 100) if total_initial else 0
325
+ max_dd_pct = max(r.max_drawdown_pct for r in self._all_results)
326
+ sharpe = sum(r.sharpe_ratio for r in self._all_results) / len(self._all_results) if self._all_results else 0
327
+
328
+ # Win rate: percentage of winning trades among all closed trades
329
+ win_rate = (total_wins / total_closed_trades * 100) if total_closed_trades else 0
330
+
331
+ # Build combined symbol list
332
+ symbols_str = ", ".join(r.symbol for r in self._all_results)
333
+
334
+ # Create a synthetic combined result for the summary widget
335
+ combined = BacktestResult(
336
+ symbol=symbols_str,
337
+ start_date=min(r.start_date for r in self._all_results),
338
+ end_date=max(r.end_date for r in self._all_results),
339
+ initial_capital=total_initial,
340
+ final_equity=total_final,
341
+ total_return_pct=total_return_pct,
342
+ max_drawdown_pct=max_dd_pct,
343
+ sharpe_ratio=sharpe,
344
+ win_rate=win_rate,
345
+ total_trades=total_trades,
346
+ winning_trades=total_wins,
347
+ losing_trades=total_losses,
348
+ trades=[t for r in self._all_results for t in r.trades],
349
+ )
350
+ self._last_result = combined
351
+
352
+ summary = self.query_one("#backtest-summary", BacktestSummary)
353
+ summary.set_result(combined)
354
+
355
+ tbl = self.query_one("#backtest-table", DataTable)
356
+ tbl.clear()
357
+ for trade in combined.trades:
358
+ action_style = "bold green" if trade.action == "BUY" else "bold red"
359
+ pnl_val = trade.pnl if trade.pnl is not None else 0
360
+ pnl_str = f"{pnl_val:+,.2f}" if pnl_val != 0 else "—"
361
+ tbl.add_row(
362
+ f"[dim]{trade.symbol}[/dim] {trade.timestamp[:10]}",
363
+ Text(trade.action, style=action_style),
364
+ f"{trade.price:.2f}",
365
+ str(trade.qty),
366
+ Text(pnl_str, style="green" if pnl_val > 0 else ("red" if pnl_val < 0 else "dim")),
367
+ trade.reason[:50] if trade.reason else "",
368
+ )
trading_cli/screens/config_screen.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Config screen — edit API keys and strategy parameters."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from textual.app import ComposeResult
6
+ from textual.binding import Binding
7
+ from textual.screen import Screen
8
+ from textual.widgets import (
9
+ Header, Input, Label, Switch, Button, Static, Select,
10
+ OptionList, Collapsible,
11
+ )
12
+ from textual.containers import Vertical, Horizontal, ScrollableContainer
13
+ from textual.reactive import reactive
14
+
15
+ from trading_cli.config import save_config
16
+ from trading_cli.widgets.ordered_footer import OrderedFooter
17
+
18
+
19
+ class ConfigRow(Horizontal):
20
+ """Label + Input/Widget row."""
21
+
22
+ DEFAULT_CSS = """
23
+ ConfigRow {
24
+ width: 100%;
25
+ height: auto;
26
+ padding: 0 1;
27
+ margin: 0 0 0 0;
28
+ layout: horizontal;
29
+ }
30
+ ConfigRow Label {
31
+ width: 28;
32
+ min-width: 28;
33
+ content-align: right middle;
34
+ padding-right: 1;
35
+ }
36
+ ConfigRow Input, ConfigRow Select {
37
+ width: 1fr;
38
+ }
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ label: str,
44
+ key: str,
45
+ value: str = "",
46
+ password: bool = False,
47
+ options: list[tuple[str, str]] | None = None,
48
+ ) -> None:
49
+ super().__init__(id=f"row-{key}")
50
+ self._label = label
51
+ self._key = key
52
+ self._value = value
53
+ self._password = password
54
+ self._options = options
55
+
56
+ def compose(self) -> ComposeResult:
57
+ yield Label(f"{self._label}:")
58
+ if self._options:
59
+ yield Select(
60
+ options=self._options,
61
+ value=self._value,
62
+ id=f"input-{self._key}",
63
+ allow_blank=False,
64
+ )
65
+ else:
66
+ yield Input(
67
+ value=self._value,
68
+ password=self._password,
69
+ id=f"input-{self._key}",
70
+ )
71
+
72
+
73
+ class ConfigScreen(Screen):
74
+ """Screen ID 6 — settings editor."""
75
+
76
+ BINDINGS = [
77
+ Binding("ctrl+s", "save", "Save", show=False),
78
+ Binding("escape", "cancel", "Cancel", show=False),
79
+ ]
80
+
81
+ def compose(self) -> ComposeResult:
82
+ app = self.app
83
+ cfg = getattr(app, "config", {})
84
+
85
+ # Build strategy options from registry
86
+ from trading_cli.strategy.adapters.registry import list_strategies
87
+ strategy_id = cfg.get("strategy_id", "hybrid")
88
+ try:
89
+ strategy_options = list_strategies()
90
+ except Exception:
91
+ strategy_options = ["hybrid"]
92
+ strategy_select_options = [(opt.title(), opt) for opt in strategy_options]
93
+
94
+ # Build exchange provider options from adapter registry
95
+ from trading_cli.execution.adapters.registry import list_adapters
96
+ current_provider = cfg.get("adapter_id", "yfinance")
97
+ try:
98
+ adapter_ids = list_adapters()
99
+ except Exception:
100
+ adapter_ids = ["yfinance", "alpaca", "binance", "kraken"]
101
+ provider_display = {
102
+ "alpaca": "Alpaca (Stocks/ETFs)",
103
+ "yfinance": "Yahoo Finance (Demo)",
104
+ "binance": "Binance (Crypto)",
105
+ "kraken": "Kraken (Crypto)",
106
+ }
107
+ provider_select_options = [
108
+ (provider_display.get(aid, aid.title()), aid) for aid in adapter_ids
109
+ ]
110
+
111
+ # Build sentiment model options
112
+ current_sentiment = cfg.get("sentiment_model", "finbert")
113
+ sentiment_select_options = [
114
+ ("FinBERT", "finbert"),
115
+ ]
116
+
117
+ yield Header(show_clock=True)
118
+ with ScrollableContainer(id="config-scroll"):
119
+ yield Label("[bold]Configuration[/bold] [dim](Ctrl+S to save, ESC to cancel)[/dim]")
120
+
121
+ with Collapsible(title="🏦 Exchange Provider", id="collapsible-provider"):
122
+ yield ConfigRow(
123
+ "Exchange",
124
+ "adapter_id",
125
+ current_provider,
126
+ options=provider_select_options,
127
+ )
128
+
129
+ with Collapsible(title="🔑 Alpaca API", id="collapsible-api"):
130
+ yield ConfigRow("API Key", "alpaca_api_key", cfg.get("alpaca_api_key", ""), password=True)
131
+ yield ConfigRow("API Secret", "alpaca_api_secret", cfg.get("alpaca_api_secret", ""), password=True)
132
+
133
+ with Collapsible(title="📊 Risk Parameters", id="collapsible-risk"):
134
+ yield ConfigRow("Risk % per trade", "risk_pct", str(cfg.get("risk_pct", 0.02)))
135
+ yield ConfigRow("Max drawdown %", "max_drawdown", str(cfg.get("max_drawdown", 0.15)))
136
+ yield ConfigRow("Stop-loss %", "stop_loss_pct", str(cfg.get("stop_loss_pct", 0.05)))
137
+ yield ConfigRow("Max positions", "max_positions", str(cfg.get("max_positions", 10)))
138
+
139
+ with Collapsible(title="🎯 Signal Thresholds", id="collapsible-thresholds"):
140
+ yield ConfigRow("Buy threshold (-1–1)", "signal_buy_threshold", str(cfg.get("signal_buy_threshold", 0.15)))
141
+ yield ConfigRow("Sell threshold (-1–1)", "signal_sell_threshold", str(cfg.get("signal_sell_threshold", -0.15)))
142
+
143
+ with Collapsible(title="🧠 Strategy", id="collapsible-strategy"):
144
+ yield ConfigRow(
145
+ "Active strategy",
146
+ "strategy_id",
147
+ strategy_id,
148
+ options=strategy_select_options,
149
+ )
150
+ yield Static("", id="strategy-info")
151
+ yield ConfigRow(
152
+ "Sentiment model",
153
+ "sentiment_model",
154
+ current_sentiment,
155
+ options=sentiment_select_options,
156
+ )
157
+
158
+ with Collapsible(title="⚖️ Strategy Weights", id="collapsible-weights"):
159
+ yield ConfigRow("Technical weight", "tech_weight", str(cfg.get("tech_weight", 0.6)))
160
+ yield ConfigRow("Sentiment weight", "sent_weight", str(cfg.get("sent_weight", 0.4)))
161
+
162
+ with Collapsible(title="🐛 Debug", id="collapsible-debug"):
163
+ with Horizontal():
164
+ yield Label("Fast cycle (10s polling):")
165
+ yield Switch(
166
+ value=cfg.get("debug_fast_cycle", False),
167
+ id="switch-debug-fast",
168
+ )
169
+
170
+ with Collapsible(title="📈 Technical Indicator Weights", id="collapsible-tech-weights"):
171
+ yield ConfigRow("SMA weight", "weight_sma", str(cfg.get("weight_sma", 0.25)))
172
+ yield ConfigRow("RSI weight", "weight_rsi", str(cfg.get("weight_rsi", 0.25)))
173
+ yield ConfigRow("Bollinger weight", "weight_bb", str(cfg.get("weight_bb", 0.20)))
174
+ yield ConfigRow("EMA weight", "weight_ema", str(cfg.get("weight_ema", 0.15)))
175
+ yield ConfigRow("Volume weight", "weight_volume", str(cfg.get("weight_volume", 0.15)))
176
+
177
+ with Collapsible(title="⚙️ Indicator Parameters", id="collapsible-params"):
178
+ yield ConfigRow("SMA short period", "sma_short", str(cfg.get("sma_short", 20)))
179
+ yield ConfigRow("SMA long period", "sma_long", str(cfg.get("sma_long", 50)))
180
+ yield ConfigRow("RSI period", "rsi_period", str(cfg.get("rsi_period", 14)))
181
+ yield ConfigRow("Bollinger window", "bb_window", str(cfg.get("bb_window", 20)))
182
+ yield ConfigRow("Bollinger std dev", "bb_std", str(cfg.get("bb_std", 2.0)))
183
+ yield ConfigRow("EMA fast", "ema_fast", str(cfg.get("ema_fast", 12)))
184
+ yield ConfigRow("EMA slow", "ema_slow", str(cfg.get("ema_slow", 26)))
185
+ yield ConfigRow("Volume SMA window", "volume_window", str(cfg.get("volume_window", 20)))
186
+
187
+ with Collapsible(title="📰 Sentiment Event Weights", id="collapsible-event-weights"):
188
+ yield ConfigRow("Earnings weight", "event_weight_earnings", str(cfg.get("event_weight_earnings", 1.5)))
189
+ yield ConfigRow("Executive weight", "event_weight_executive", str(cfg.get("event_weight_executive", 1.3)))
190
+ yield ConfigRow("Product weight", "event_weight_product", str(cfg.get("event_weight_product", 1.2)))
191
+ yield ConfigRow("Macro weight", "event_weight_macro", str(cfg.get("event_weight_macro", 1.4)))
192
+ yield ConfigRow("Generic weight", "event_weight_generic", str(cfg.get("event_weight_generic", 0.8)))
193
+ yield ConfigRow("Sentiment half-life (hrs)", "sentiment_half_life_hours", str(cfg.get("sentiment_half_life_hours", 24.0)))
194
+
195
+ with Collapsible(title="⏱️ Poll Intervals (seconds)", id="collapsible-poll"):
196
+ yield ConfigRow("Price poll", "poll_interval_prices", str(cfg.get("poll_interval_prices", 30)))
197
+ yield ConfigRow("News poll", "poll_interval_news", str(cfg.get("poll_interval_news", 900)))
198
+ yield ConfigRow("Signal poll", "poll_interval_signals", str(cfg.get("poll_interval_signals", 300)))
199
+ yield ConfigRow("Positions poll", "poll_interval_positions", str(cfg.get("poll_interval_positions", 60)))
200
+
201
+ with Collapsible(title="🤖 Auto-Trading", id="collapsible-auto"):
202
+ with Horizontal(id="auto-trade-row"):
203
+ yield Label("Enable auto-trading:")
204
+ yield Switch(
205
+ value=cfg.get("auto_trading", False),
206
+ id="switch-auto-trading",
207
+ )
208
+
209
+ with Horizontal(id="config-buttons"):
210
+ yield Button("💾 Save", id="btn-save", variant="success")
211
+ yield Button("💾🔄 Save & Restart", id="btn-restart", variant="warning")
212
+ yield Button("❌ Cancel", id="btn-cancel", variant="default")
213
+
214
+ yield OrderedFooter()
215
+
216
+ def on_button_pressed(self, event) -> None:
217
+ if event.button.id == "btn-save":
218
+ self.action_save()
219
+ elif event.button.id == "btn-restart":
220
+ self.action_save_restart()
221
+ elif event.button.id == "btn-cancel":
222
+ self.app.pop_screen()
223
+
224
+ def _read_config(self) -> dict:
225
+ """Read all config values from the form."""
226
+ app = self.app
227
+ cfg = dict(getattr(app, "config", {}))
228
+
229
+ str_keys = [
230
+ "alpaca_api_key", "alpaca_api_secret",
231
+ ]
232
+ float_keys = [
233
+ "risk_pct", "max_drawdown", "stop_loss_pct",
234
+ "signal_buy_threshold", "signal_sell_threshold",
235
+ "tech_weight", "sent_weight",
236
+ "weight_sma", "weight_rsi", "weight_bb", "weight_ema", "weight_volume",
237
+ "bb_std",
238
+ "event_weight_earnings", "event_weight_executive", "event_weight_product",
239
+ "event_weight_macro", "event_weight_generic",
240
+ "sentiment_half_life_hours",
241
+ ]
242
+ int_keys = [
243
+ "max_positions", "poll_interval_prices",
244
+ "poll_interval_news", "poll_interval_signals", "poll_interval_positions",
245
+ "sma_short", "sma_long", "rsi_period",
246
+ "bb_window", "ema_fast", "ema_slow", "volume_window",
247
+ ]
248
+
249
+ for key in str_keys:
250
+ try:
251
+ widget = self.query_one(f"#input-{key}", Input)
252
+ cfg[key] = widget.value.strip()
253
+ except Exception:
254
+ pass
255
+ for key in float_keys:
256
+ try:
257
+ widget = self.query_one(f"#input-{key}", Input)
258
+ cfg[key] = float(widget.value.strip())
259
+ except Exception:
260
+ pass
261
+ for key in int_keys:
262
+ try:
263
+ widget = self.query_one(f"#input-{key}", Input)
264
+ cfg[key] = int(widget.value.strip())
265
+ except Exception:
266
+ pass
267
+
268
+ # Strategy selector (Select widget)
269
+ try:
270
+ sel = self.query_one("#input-strategy_id", Select)
271
+ cfg["strategy_id"] = str(sel.value)
272
+ except Exception:
273
+ pass
274
+
275
+ # Exchange provider (Select widget)
276
+ try:
277
+ sel = self.query_one("#input-adapter_id", Select)
278
+ cfg["adapter_id"] = str(sel.value)
279
+ except Exception:
280
+ pass
281
+
282
+ # Sentiment model (Select widget)
283
+ try:
284
+ sel = self.query_one("#input-sentiment_model", Select)
285
+ cfg["sentiment_model"] = str(sel.value)
286
+ except Exception:
287
+ pass
288
+
289
+ try:
290
+ sw = self.query_one("#switch-auto-trading", Switch)
291
+ cfg["auto_trading"] = sw.value
292
+ except Exception:
293
+ pass
294
+
295
+ try:
296
+ sw = self.query_one("#switch-debug-fast", Switch)
297
+ cfg["debug_fast_cycle"] = sw.value
298
+ except Exception:
299
+ pass
300
+
301
+ return cfg
302
+
303
+ def action_save(self) -> None:
304
+ app = self.app
305
+ cfg = self._read_config()
306
+
307
+ save_config(cfg)
308
+ app.config = cfg
309
+ app.notify("Configuration saved ✓")
310
+ app.pop_screen()
311
+
312
+ def action_save_restart(self) -> None:
313
+ app = self.app
314
+ cfg = self._read_config()
315
+
316
+ save_config(cfg)
317
+ app.config = cfg
318
+ app.notify("Restarting with new config…")
319
+
320
+ import sys
321
+ import os
322
+ # Use os.execv to replace the current process
323
+ python = sys.executable
324
+ script = sys.argv[0]
325
+ os.execv(python, [python, script])
326
+
327
+ def on_select_changed(self, event: Select.Changed) -> None:
328
+ """Update info display when selection changes."""
329
+ if event.select.id == "input-strategy_id":
330
+ self._update_strategy_info(str(event.value))
331
+
332
+ def _update_strategy_info(self, strategy_id: str) -> None:
333
+ """Display strategy description."""
334
+ try:
335
+ from trading_cli.strategy.adapters.registry import get_strategy
336
+ strategy_cls = get_strategy(strategy_id)
337
+ if strategy_cls:
338
+ info = strategy_cls.__new__(strategy_cls).info()
339
+ info_widget = self.query_one("#strategy-info", Static)
340
+ info_widget.update(
341
+ f"[dim]{info.description}[/dim]"
342
+ )
343
+ except Exception:
344
+ pass
345
+
346
+ def on_mount(self) -> None:
347
+ """Initialize strategy info display."""
348
+ cfg = getattr(self.app, "config", {})
349
+ strategy_id = cfg.get("strategy_id", "hybrid")
350
+ self._update_strategy_info(strategy_id)
351
+
352
+ def action_cancel(self) -> None:
353
+ """Handle ESC to cancel without saving."""
354
+ self.app.pop_screen()
trading_cli/screens/dashboard.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dashboard screen — main view with positions, signals and account summary."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from textual.app import ComposeResult
6
+ from textual.binding import Binding
7
+ from textual.screen import Screen
8
+ from textual.widgets import Header, Static, Label, Rule
9
+ from textual.containers import Horizontal, Vertical, ScrollableContainer
10
+ from textual.reactive import reactive
11
+ from rich.text import Text
12
+ from rich.panel import Panel
13
+ from rich.table import Table
14
+ from rich import box
15
+
16
+ from trading_cli.widgets.positions_table import PositionsTable
17
+ from trading_cli.widgets.signal_log import SignalLog
18
+ from trading_cli.widgets.ordered_footer import OrderedFooter
19
+
20
+
21
+ class AccountBar(Static):
22
+ cash: reactive[float] = reactive(0.0)
23
+ equity: reactive[float] = reactive(0.0)
24
+ demo: reactive[bool] = reactive(False)
25
+ market_open: reactive[bool] = reactive(False)
26
+
27
+ def render(self) -> Text:
28
+ t = Text()
29
+ mode = "[DEMO] " if self.demo else ""
30
+ t.append(mode, style="bold yellow")
31
+ t.append(f"Cash: ${self.cash:,.2f} ", style="bold cyan")
32
+ t.append(f"Equity: ${self.equity:,.2f} ", style="bold white")
33
+ status_style = "bold green" if self.market_open else "bold red"
34
+ status_text = "● OPEN" if self.market_open else "● CLOSED"
35
+ t.append(status_text, style=status_style)
36
+ return t
37
+
38
+
39
+ class AutoTradeStatus(Static):
40
+ """Shows auto-trade status and last cycle time."""
41
+ enabled: reactive[bool] = reactive(False)
42
+ last_cycle: reactive[str] = reactive("--")
43
+ last_error: reactive[str] = reactive("")
44
+
45
+ def render(self) -> Text:
46
+ status = "[AUTO] ON" if self.enabled else "[AUTO] OFF"
47
+ style = "bold green" if self.enabled else "bold yellow"
48
+ t = Text(status, style=style)
49
+ t.append(f" Last: {self.last_cycle}", style="dim")
50
+ if self.last_error:
51
+ t.append(f" Error: {self.last_error}", style="bold red")
52
+ return t
53
+
54
+
55
+ class DashboardScreen(Screen):
56
+ """Screen ID 1 — main dashboard."""
57
+
58
+ BINDINGS = [
59
+ Binding("r", "refresh", "Refresh", show=False),
60
+ Binding("t", "toggle_autotrade", "Toggle Auto", show=True),
61
+ ]
62
+
63
+ def compose(self) -> ComposeResult:
64
+ yield Header(show_clock=True)
65
+ with Vertical():
66
+ yield AccountBar(id="account-bar")
67
+ yield Rule()
68
+ yield AutoTradeStatus(id="autotrade-status")
69
+ yield Rule()
70
+ with Horizontal(id="main-split"):
71
+ with Vertical(id="left-pane"):
72
+ yield Label("[bold]RECENT SIGNALS[/bold]", id="signals-label")
73
+ yield SignalLog(id="signal-log", max_lines=50, markup=True)
74
+ with Vertical(id="right-pane"):
75
+ yield Label("[bold]POSITIONS[/bold]", id="positions-label")
76
+ yield PositionsTable(id="positions-table")
77
+ yield OrderedFooter()
78
+
79
+ def on_mount(self) -> None:
80
+ self._refresh_from_app()
81
+
82
+ def action_refresh(self) -> None:
83
+ self._refresh_from_app()
84
+
85
+ def _refresh_from_app(self) -> None:
86
+ app = self.app
87
+ if not hasattr(app, "adapter"):
88
+ return
89
+ try:
90
+ acct = app.adapter.get_account()
91
+ bar = self.query_one("#account-bar", AccountBar)
92
+ bar.cash = acct.cash
93
+ bar.equity = acct.equity
94
+ bar.demo = app.demo_mode
95
+ bar.market_open = app.market_open
96
+
97
+ positions = app.adapter.get_positions()
98
+ self.query_one("#positions-table", PositionsTable).refresh_positions(positions)
99
+
100
+ # Initialize auto-trade status
101
+ auto_enabled = app.config.get("auto_trading", False)
102
+ self.update_autotrade_status(auto_enabled)
103
+ except Exception:
104
+ pass
105
+
106
+ # Called by app worker when new data arrives
107
+ def refresh_positions(self, positions: list) -> None:
108
+ try:
109
+ self.query_one("#positions-table", PositionsTable).refresh_positions(positions)
110
+ except Exception:
111
+ pass
112
+
113
+ def refresh_account(self, acct) -> None:
114
+ try:
115
+ bar = self.query_one("#account-bar", AccountBar)
116
+ bar.cash = acct.cash
117
+ bar.equity = acct.equity
118
+ bar.demo = self.app.demo_mode
119
+ bar.market_open = self.app.market_open
120
+ except Exception:
121
+ pass
122
+
123
+ def log_signal(self, signal: dict) -> None:
124
+ try:
125
+ self.query_one("#signal-log", SignalLog).log_signal(signal)
126
+ except Exception:
127
+ pass
128
+
129
+ def update_autotrade_status(self, enabled: bool, last_cycle: str = "", error: str = "") -> None:
130
+ """Update the auto-trade status indicator."""
131
+ try:
132
+ status = self.query_one("#autotrade-status", AutoTradeStatus)
133
+ status.enabled = enabled
134
+ if last_cycle:
135
+ status.last_cycle = last_cycle
136
+ if error:
137
+ status.last_error = error
138
+ except Exception:
139
+ pass
140
+
141
+ def action_toggle_autotrade(self) -> None:
142
+ """Toggle auto-trading on/off from dashboard."""
143
+ app = self.app
144
+ if not hasattr(app, "config"):
145
+ return
146
+
147
+ current = app.config.get("auto_trading", False)
148
+ new_value = not current
149
+ app.config["auto_trading"] = new_value
150
+
151
+ # Persist to disk
152
+ from trading_cli.config import save_config
153
+ save_config(app.config)
154
+
155
+ # Update status indicator
156
+ self.update_autotrade_status(new_value)
157
+
158
+ # Notify user
159
+ status = "enabled" if new_value else "disabled"
160
+ app.notify(f"Auto-trading {status}", severity="information" if new_value else "warning")
trading_cli/screens/portfolio.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Portfolio screen — detailed positions with close-position action."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from textual.app import ComposeResult
6
+ from textual.binding import Binding
7
+ from textual.screen import Screen
8
+ from textual.widgets import Header, DataTable, Label, Static, Button
9
+ from textual.containers import Vertical, Horizontal
10
+ from textual.reactive import reactive
11
+ from rich.text import Text
12
+
13
+ from trading_cli.widgets.positions_table import PositionsTable
14
+ from trading_cli.widgets.ordered_footer import OrderedFooter
15
+
16
+
17
+ class PortfolioSummary(Static):
18
+ equity: reactive[float] = reactive(0.0)
19
+ cash: reactive[float] = reactive(0.0)
20
+ total_pl: reactive[float] = reactive(0.0)
21
+
22
+ def render(self) -> Text:
23
+ t = Text()
24
+ t.append("Portfolio Value: ", style="bold")
25
+ t.append(f"${self.equity:,.2f} ", style="bold cyan")
26
+ t.append("Cash: ", style="bold")
27
+ t.append(f"${self.cash:,.2f} ", style="cyan")
28
+ pl_style = "bold green" if self.total_pl >= 0 else "bold red"
29
+ t.append("Total P&L: ", style="bold")
30
+ t.append(f"${self.total_pl:+,.2f}", style=pl_style)
31
+ return t
32
+
33
+
34
+ class PortfolioScreen(Screen):
35
+ """Screen ID 3 — full position details from Alpaca."""
36
+
37
+ BINDINGS = [
38
+ Binding("x", "close_position", "Close position", show=False),
39
+ Binding("r", "refresh_data", "Refresh", show=False),
40
+ ]
41
+
42
+ def compose(self) -> ComposeResult:
43
+ yield Header(show_clock=True)
44
+ with Vertical():
45
+ yield PortfolioSummary(id="portfolio-summary")
46
+ with Horizontal(id="portfolio-actions"):
47
+ yield Button("🔄 Refresh", id="btn-refresh", variant="primary")
48
+ yield Button("❌ Close Selected", id="btn-close", variant="error")
49
+ yield PositionsTable(id="portfolio-table")
50
+ yield OrderedFooter()
51
+
52
+ def on_mount(self) -> None:
53
+ self.action_refresh_data()
54
+
55
+ def on_button_pressed(self, event) -> None:
56
+ if event.button.id == "btn-refresh":
57
+ self.action_refresh_data()
58
+ elif event.button.id == "btn-close":
59
+ self.action_close_position()
60
+
61
+ def action_refresh_data(self) -> None:
62
+ app = self.app
63
+ if not hasattr(app, "client"):
64
+ return
65
+ try:
66
+ acct = app.client.get_account()
67
+ summary = self.query_one("#portfolio-summary", PortfolioSummary)
68
+ summary.equity = acct.equity
69
+ summary.cash = acct.cash
70
+
71
+ positions = app.client.get_positions()
72
+ total_pl = sum(p.unrealized_pl for p in positions)
73
+ summary.total_pl = total_pl
74
+
75
+ tbl = self.query_one("#portfolio-table", PositionsTable)
76
+ tbl.refresh_positions(positions)
77
+ except Exception as exc:
78
+ self.app.notify(f"Refresh failed: {exc}", severity="error")
79
+
80
+ def action_close_position(self) -> None:
81
+ tbl = self.query_one("#portfolio-table", PositionsTable)
82
+ if len(tbl.rows) == 0:
83
+ self.app.notify("No positions to close", severity="warning")
84
+ return
85
+ if tbl.cursor_row is None:
86
+ self.app.notify("Select a position first", severity="warning")
87
+ return
88
+ row = tbl.get_row_at(tbl.cursor_row)
89
+ if not row:
90
+ return
91
+ symbol = str(row[0])
92
+ self.app.push_screen(
93
+ ConfirmCloseScreen(symbol),
94
+ callback=self._on_close_confirmed,
95
+ )
96
+
97
+ def _on_close_confirmed(self, confirmed: bool) -> None:
98
+ if not confirmed:
99
+ return
100
+ if not hasattr(self, "_pending_close"):
101
+ return
102
+ symbol = self._pending_close
103
+ try:
104
+ result = self.app.client.close_position(symbol)
105
+ if result:
106
+ from trading_cli.data.db import save_trade
107
+ save_trade(
108
+ self.app.db_conn, symbol, "SELL",
109
+ result.filled_price or 0.0,
110
+ result.qty,
111
+ order_id=result.order_id,
112
+ reason="Manual close from Portfolio screen",
113
+ )
114
+ self.app.notify(f"Closed {symbol}: {result.status}")
115
+ except Exception as exc:
116
+ self.app.notify(f"Close failed: {exc}", severity="error")
117
+ self.action_refresh_data()
118
+
119
+
120
+ class ConfirmCloseScreen(Screen):
121
+ """Modal confirmation dialog for closing a position."""
122
+
123
+ def __init__(self, symbol: str) -> None:
124
+ super().__init__()
125
+ self._symbol = symbol
126
+
127
+ def compose(self) -> ComposeResult:
128
+ from textual.containers import Grid
129
+
130
+ with Grid(id="confirm-grid"):
131
+ yield Label(
132
+ f"[bold red]Close position in {self._symbol}?[/bold red]\n"
133
+ "This will submit a market SELL order.",
134
+ id="confirm-msg",
135
+ )
136
+ with Horizontal(id="confirm-buttons"):
137
+ yield Button("Yes, close", id="btn-yes", variant="error")
138
+ yield Button("Cancel", id="btn-no", variant="default")
139
+
140
+ def on_button_pressed(self, event) -> None:
141
+ self.dismiss(event.button.id == "btn-yes")
trading_cli/screens/sentiment.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sentiment analysis screen — interactive FinBERT analysis per symbol."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import threading
6
+
7
+ from textual.app import ComposeResult
8
+ from textual.binding import Binding
9
+ from textual.screen import Screen
10
+ from textual.widgets import Header, Input, Label, DataTable, Static
11
+ from textual.containers import Vertical
12
+ from textual.reactive import reactive
13
+ from textual import work
14
+ from rich.text import Text
15
+
16
+ from trading_cli.sentiment.aggregator import get_sentiment_summary
17
+ from trading_cli.widgets.ordered_footer import OrderedFooter
18
+
19
+
20
+ class SentimentScoreDisplay(Static):
21
+ """Displays sentiment score with a simple label."""
22
+
23
+ score: reactive[float] = reactive(0.0)
24
+ symbol: reactive[str] = reactive("")
25
+ positive_count: reactive[int] = reactive(0)
26
+ negative_count: reactive[int] = reactive(0)
27
+ neutral_count: reactive[int] = reactive(0)
28
+ dominant: reactive[str] = reactive("NEUTRAL")
29
+
30
+ def render(self) -> str:
31
+ if not self.symbol:
32
+ return ""
33
+ dom_style = {"POSITIVE": "green", "NEGATIVE": "red", "NEUTRAL": "yellow"}.get(self.dominant, "white")
34
+ return (
35
+ f"[bold]{self.symbol}[/bold] — "
36
+ f"[{dom_style}]{self.dominant}[/{dom_style}] "
37
+ f"(score: [bold]{self.score:+.3f}[/bold], "
38
+ f"+{self.positive_count} / −{self.negative_count} / ={self.neutral_count})"
39
+ )
40
+
41
+
42
+ class SentimentScreen(Screen):
43
+ """Screen ID 5 — on-demand FinBERT sentiment analysis."""
44
+
45
+ BINDINGS = [
46
+ Binding("r", "refresh_symbol", "Refresh", show=False),
47
+ ]
48
+
49
+ _current_symbol: str = ""
50
+ _analysis_task: str = "" # Track the latest symbol being analyzed
51
+
52
+ def compose(self) -> ComposeResult:
53
+ yield Header(show_clock=True)
54
+ with Vertical():
55
+ # Create asset autocomplete input
56
+ app = self.app
57
+ if hasattr(app, 'asset_search') and app.asset_search.is_ready:
58
+ from trading_cli.widgets.asset_autocomplete import create_asset_autocomplete
59
+ input_widget, autocomplete_widget = create_asset_autocomplete(
60
+ app.asset_search,
61
+ placeholder="Search by symbol or company name… (Tab to complete)",
62
+ id="sent-input",
63
+ )
64
+ yield input_widget
65
+ yield autocomplete_widget
66
+ else:
67
+ yield Input(placeholder="Search by symbol or company name…", id="sent-input")
68
+
69
+ yield Label("", id="sent-loading-status")
70
+ yield SentimentScoreDisplay(id="sent-summary")
71
+ yield DataTable(id="sent-table", cursor_type="row")
72
+ yield OrderedFooter()
73
+
74
+ def on_mount(self) -> None:
75
+ tbl = self.query_one("#sent-table", DataTable)
76
+ tbl.add_column("Headline", key="headline")
77
+ tbl.add_column("Label", key="label")
78
+ tbl.add_column("Score", key="score")
79
+ self.query_one("#sent-input", Input).focus()
80
+ self._clear_loading_status()
81
+
82
+ # ------------------------------------------------------------------
83
+ # Loading status helpers
84
+ # ------------------------------------------------------------------
85
+
86
+ def _set_loading_status(self, text: str) -> None:
87
+ """Update the status label text."""
88
+ def _update():
89
+ try:
90
+ self.query_one("#sent-loading-status", Label).update(f"[dim]{text}[/dim]")
91
+ except Exception:
92
+ pass
93
+
94
+ # Only use call_from_thread if we're in a background thread
95
+ if threading.get_ident() != self.app._thread_id:
96
+ self.app.call_from_thread(_update)
97
+ else:
98
+ _update()
99
+
100
+ def _clear_loading_status(self) -> None:
101
+ """Clear the status label."""
102
+ def _update():
103
+ try:
104
+ self.query_one("#sent-loading-status", Label).update("")
105
+ except Exception:
106
+ pass
107
+
108
+ # Only use call_from_thread if we're in a background thread
109
+ if threading.get_ident() != self.app._thread_id:
110
+ self.app.call_from_thread(_update)
111
+ else:
112
+ _update()
113
+
114
+ # ------------------------------------------------------------------
115
+ # Event handlers
116
+ # ------------------------------------------------------------------
117
+
118
+ def on_input_submitted(self, event: Input.Submitted) -> None:
119
+ value = event.value.strip()
120
+ if not value:
121
+ return
122
+
123
+ # Extract symbol from autocomplete format "SYMBOL — Company Name"
124
+ if " — " in value:
125
+ symbol = value.split(" — ")[0].strip().upper()
126
+ else:
127
+ symbol = value.upper()
128
+
129
+ if symbol:
130
+ self._current_symbol = symbol
131
+ self._run_analysis(symbol)
132
+
133
+ def action_refresh_symbol(self) -> None:
134
+ if self._current_symbol:
135
+ self._run_analysis(self._current_symbol)
136
+
137
+ # ------------------------------------------------------------------
138
+ # Analysis (background thread)
139
+ # ------------------------------------------------------------------
140
+
141
+ def _run_analysis(self, symbol: str) -> None:
142
+ """Kick off background analysis."""
143
+ # Update the task tracker to the latest symbol (cancels previous tasks)
144
+ self._analysis_task = symbol
145
+
146
+ # Clear the table to show we're working on a new request
147
+ tbl = self.query_one("#sent-table", DataTable)
148
+ tbl.clear()
149
+
150
+ # Reset summary display
151
+ lbl = self.query_one("#sent-summary", SentimentScoreDisplay)
152
+ lbl.symbol = ""
153
+ lbl.score = 0.0
154
+
155
+ self._do_analysis(symbol)
156
+
157
+ @work(thread=True, exclusive=False, description="Analyzing sentiment")
158
+ def _do_analysis(self, symbol: str) -> None:
159
+ """Analyze sentiment for a symbol (non-blocking, allows cancellation)."""
160
+ analyzer = getattr(self.app, "finbert", None)
161
+ db_conn = getattr(self.app, "db_conn", None)
162
+
163
+ # Check if this task has been superseded by a newer request
164
+ def is_cancelled() -> bool:
165
+ return self._analysis_task != symbol
166
+
167
+ # Attempt to reload FinBERT if not loaded
168
+ if analyzer and not analyzer.is_loaded:
169
+ self._set_loading_status("Loading FinBERT model…")
170
+ success = analyzer.reload(
171
+ progress_callback=lambda msg: self._set_loading_status(msg),
172
+ )
173
+ if not success:
174
+ error_msg = analyzer.load_error or "Unknown error"
175
+ self.app.call_from_thread(
176
+ self.app.notify,
177
+ f"FinBERT failed to load: {error_msg}",
178
+ severity="error",
179
+ )
180
+ self._set_loading_status(f"Failed: {error_msg}")
181
+ return
182
+
183
+ # Check cancellation after model loading
184
+ if is_cancelled():
185
+ return
186
+
187
+ self._set_loading_status(f"Fetching headlines for {symbol}…")
188
+
189
+ from trading_cli.data.news import fetch_headlines
190
+ headlines = fetch_headlines(symbol, max_articles=20)
191
+
192
+ # Check cancellation after network call
193
+ if is_cancelled():
194
+ return
195
+
196
+ if not headlines:
197
+ self.app.call_from_thread(
198
+ self.app.notify, f"No headlines found for {symbol}", severity="warning",
199
+ )
200
+ self._clear_loading_status()
201
+ return
202
+
203
+ self._set_loading_status("Running sentiment analysis…")
204
+
205
+ results = []
206
+ if analyzer and analyzer.is_loaded:
207
+ if db_conn:
208
+ results = analyzer.analyze_with_cache(headlines, db_conn)
209
+ else:
210
+ results = analyzer.analyze_batch(headlines)
211
+ else:
212
+ results = [{"label": "neutral", "score": 0.5}] * len(headlines)
213
+
214
+ # Check cancellation after heavy computation
215
+ if is_cancelled():
216
+ return
217
+
218
+ self._clear_loading_status()
219
+
220
+ # Only update UI if this is still the latest task
221
+ if not is_cancelled():
222
+ # Dispatch UI update back to main thread
223
+ self.app.call_from_thread(self._display_results, symbol, headlines, results)
224
+
225
+ # ------------------------------------------------------------------
226
+ # Display
227
+ # ------------------------------------------------------------------
228
+
229
+ def _display_results(self, symbol: str, headlines: list[str], results: list[dict]) -> None:
230
+ summary = get_sentiment_summary(results)
231
+
232
+ # Update summary
233
+ lbl = self.query_one("#sent-summary", SentimentScoreDisplay)
234
+ lbl.symbol = symbol
235
+ lbl.score = summary["score"]
236
+ lbl.positive_count = summary["positive_count"]
237
+ lbl.negative_count = summary["negative_count"]
238
+ lbl.neutral_count = summary["neutral_count"]
239
+ lbl.dominant = summary["dominant"].upper()
240
+
241
+ tbl = self.query_one("#sent-table", DataTable)
242
+ tbl.clear()
243
+ for headline, result in zip(headlines, results):
244
+ label = result.get("label", "neutral")
245
+ score_val = result.get("score", 0.5)
246
+ label_style = {"positive": "green", "negative": "red", "neutral": "yellow"}.get(label, "white")
247
+ tbl.add_row(
248
+ headline[:80],
249
+ Text(label.upper(), style=f"bold {label_style}"),
250
+ Text(f"{score_val:.3f}", style=label_style),
251
+ )
trading_cli/screens/trades.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trade history screen — scrollable log with filter and CSV export."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import csv
6
+ import os
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+
10
+ from textual.app import ComposeResult
11
+ from textual.binding import Binding
12
+ from textual.screen import Screen
13
+ from textual.widgets import Header, DataTable, Input
14
+ from textual.containers import Vertical
15
+ from rich.text import Text
16
+
17
+ from trading_cli.widgets.ordered_footer import OrderedFooter
18
+
19
+
20
+ class TradesScreen(Screen):
21
+ """Screen ID 4 — all executed trades with filter and export."""
22
+
23
+ BINDINGS = [
24
+ Binding("e", "export_csv", "Export", show=True),
25
+ Binding("r", "refresh_data", "Refresh", show=True),
26
+ Binding("f", "focus_filter", "Filter", show=True),
27
+ ]
28
+
29
+ def compose(self) -> ComposeResult:
30
+ yield Header(show_clock=True)
31
+ with Vertical():
32
+ yield Input(placeholder="Filter by symbol or action…", id="trades-filter")
33
+ yield DataTable(id="trades-table", cursor_type="row")
34
+ yield OrderedFooter()
35
+
36
+ def on_mount(self) -> None:
37
+ tbl = self.query_one("#trades-table", DataTable)
38
+ tbl.add_column("Time", key="time")
39
+ tbl.add_column("Symbol", key="symbol")
40
+ tbl.add_column("Action", key="action")
41
+ tbl.add_column("Price $", key="price")
42
+ tbl.add_column("Qty", key="qty")
43
+ tbl.add_column("P&L $", key="pnl")
44
+ tbl.add_column("Order ID", key="order_id")
45
+ tbl.add_column("Reason", key="reason")
46
+ self.action_refresh_data()
47
+
48
+ def action_refresh_data(self, filter_text: str = "") -> None:
49
+ from trading_cli.data.db import get_trade_history
50
+
51
+ app = self.app
52
+ if not hasattr(app, "db_conn"):
53
+ return
54
+ trades = get_trade_history(app.db_conn, limit=200)
55
+ tbl = self.query_one("#trades-table", DataTable)
56
+ tbl.clear()
57
+ ft = filter_text.upper()
58
+ for trade in trades:
59
+ if ft and ft not in trade["symbol"] and ft not in trade["action"]:
60
+ continue
61
+ ts = trade["timestamp"][:19].replace("T", " ")
62
+ action = trade["action"]
63
+ action_style = {"BUY": "bold green", "SELL": "bold red"}.get(action, "yellow")
64
+ pnl = trade.get("pnl") or 0.0
65
+ pnl_str = Text(f"{pnl:+.2f}" if pnl != 0 else "—",
66
+ style="green" if pnl > 0 else ("red" if pnl < 0 else "dim"))
67
+ tbl.add_row(
68
+ ts,
69
+ trade["symbol"],
70
+ Text(action, style=action_style),
71
+ f"{trade['price']:.2f}",
72
+ str(trade["quantity"]),
73
+ pnl_str,
74
+ trade.get("order_id") or "—",
75
+ (trade.get("reason") or "")[:40],
76
+ )
77
+
78
+ def on_input_submitted(self, event: Input.Submitted) -> None:
79
+ self.action_refresh_data(event.value.strip())
80
+
81
+ def action_export_csv(self) -> None:
82
+ from trading_cli.data.db import get_trade_history
83
+
84
+ app = self.app
85
+ if not hasattr(app, "db_conn"):
86
+ return
87
+ trades = get_trade_history(app.db_conn, limit=10000)
88
+ export_dir = Path.home() / "Downloads"
89
+ export_dir.mkdir(exist_ok=True)
90
+ fname = export_dir / f"trades_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
91
+ with open(fname, "w", newline="") as f:
92
+ if not trades:
93
+ f.write("No trades\n")
94
+ else:
95
+ writer = csv.DictWriter(f, fieldnames=trades[0].keys())
96
+ writer.writeheader()
97
+ writer.writerows(trades)
98
+ app.notify(f"Exported to {fname}")
99
+
100
+ def action_focus_filter(self) -> None:
101
+ self.query_one("#trades-filter", Input).focus()
trading_cli/screens/watchlist.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Watchlist screen — add/remove symbols, live prices and signals."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from textual.app import ComposeResult
6
+ from textual.binding import Binding
7
+ from textual.screen import Screen
8
+ from textual.widgets import Header, DataTable, Input, Label, Static
9
+ from textual.containers import Vertical, Horizontal
10
+ from textual.reactive import reactive
11
+ from rich.text import Text
12
+
13
+ from trading_cli.widgets.ordered_footer import OrderedFooter
14
+
15
+
16
+ class WatchlistScreen(Screen):
17
+ """Screen ID 2 — symbol watchlist with live prices and signals."""
18
+
19
+ BINDINGS = [
20
+ Binding("a", "focus_add", "Add", show=True),
21
+ Binding("d", "delete_selected", "Delete", show=True),
22
+ Binding("r", "refresh", "Refresh", show=True),
23
+ ]
24
+
25
+ _prices: dict[str, float] = {}
26
+ _sentiments: dict[str, float] = {}
27
+ _signals: dict[str, str] = {}
28
+
29
+ def compose(self) -> ComposeResult:
30
+ yield Header(show_clock=True)
31
+ with Vertical():
32
+ # Primary search input (like sentiment screen)
33
+ app = self.app
34
+ if hasattr(app, 'asset_search') and app.asset_search.is_ready:
35
+ from trading_cli.widgets.asset_autocomplete import create_asset_autocomplete
36
+ input_widget, autocomplete_widget = create_asset_autocomplete(
37
+ app.asset_search,
38
+ placeholder="Search by symbol or company name… (Tab to complete)",
39
+ id="wl-input",
40
+ )
41
+ yield input_widget
42
+ yield autocomplete_widget
43
+ else:
44
+ yield Input(placeholder="Search by symbol or company name…", id="wl-input")
45
+
46
+ yield DataTable(id="wl-table", cursor_type="row")
47
+ yield OrderedFooter()
48
+
49
+ def on_mount(self) -> None:
50
+ tbl = self.query_one("#wl-table", DataTable)
51
+ tbl.add_column("Symbol", key="symbol")
52
+ tbl.add_column("Price $", key="price")
53
+ tbl.add_column("Sentiment", key="sentiment")
54
+ tbl.add_column("Signal", key="signal")
55
+ self._populate_table()
56
+
57
+ def _populate_table(self) -> None:
58
+ tbl = self.query_one("#wl-table", DataTable)
59
+ tbl.clear()
60
+ app = self.app
61
+ watchlist = getattr(app, "watchlist", [])
62
+ for sym in watchlist:
63
+ price = self._prices.get(sym, 0.0)
64
+ sent = self._sentiments.get(sym, 0.0)
65
+ sig = self._signals.get(sym, "HOLD")
66
+
67
+ price_str = f"${price:.2f}" if price else "—"
68
+ sent_str = Text(f"{sent:+.3f}", style="green" if sent > 0 else ("red" if sent < 0 else "dim"))
69
+ sig_style = {"BUY": "bold green", "SELL": "bold red", "HOLD": "yellow"}.get(sig, "white")
70
+ sig_str = Text(sig, style=sig_style)
71
+
72
+ tbl.add_row(sym, price_str, sent_str, sig_str, key=sym)
73
+
74
+ def update_data(
75
+ self,
76
+ prices: dict[str, float],
77
+ sentiments: dict[str, float],
78
+ signals: dict[str, str],
79
+ ) -> None:
80
+ self._prices = prices
81
+ self._sentiments = sentiments
82
+ self._signals = signals
83
+ self._populate_table()
84
+
85
+ def action_focus_add(self) -> None:
86
+ self.query_one("#wl-input", Input).focus()
87
+
88
+ def action_delete_selected(self) -> None:
89
+ tbl = self.query_one("#wl-table", DataTable)
90
+ if tbl.cursor_row is not None:
91
+ row_key = tbl.get_row_at(tbl.cursor_row)
92
+ if row_key:
93
+ symbol = str(row_key[0])
94
+ app = self.app
95
+ if hasattr(app, "remove_from_watchlist"):
96
+ app.remove_from_watchlist(symbol)
97
+ self._populate_table()
98
+
99
+ def action_refresh(self) -> None:
100
+ self._populate_table()
101
+
102
+ def on_input_submitted(self, event: Input.Submitted) -> None:
103
+ value = event.value.strip()
104
+ if not value:
105
+ return
106
+
107
+ # Extract symbol from autocomplete format "SYMBOL — Company Name"
108
+ # If it contains " — ", take the first part as the symbol
109
+ if " — " in value:
110
+ symbol = value.split(" — ")[0].strip().upper()
111
+ else:
112
+ symbol = value.upper()
113
+
114
+ if symbol:
115
+ app = self.app
116
+ if hasattr(app, "add_to_watchlist"):
117
+ app.add_to_watchlist(symbol)
118
+ event.input.value = ""
119
+ self._populate_table()
trading_cli/sentiment/aggregator.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Aggregate FinBERT per-headline results into a single symbol-level score.
2
+
3
+ Supports event-type weighting (earnings/executive/product/macro/generic)
4
+ and temporal decay (newer headlines have more impact).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import time
10
+ from datetime import datetime, timezone
11
+
12
+ from trading_cli.sentiment.news_classifier import EventType, EventClassification, DEFAULT_WEIGHTS
13
+
14
+ LABEL_DIRECTION = {"positive": 1.0, "negative": -1.0, "neutral": 0.0}
15
+
16
+
17
+ def aggregate_scores(results: list[dict]) -> float:
18
+ """
19
+ Weighted average of label directions, weighted by confidence score.
20
+
21
+ Returns float in [-1.0, +1.0]:
22
+ +1.0 = all headlines strongly positive
23
+ -1.0 = all headlines strongly negative
24
+ 0.0 = neutral or empty
25
+ """
26
+ if not results:
27
+ return 0.0
28
+ total_weight = 0.0
29
+ weighted_sum = 0.0
30
+ for r in results:
31
+ label = r.get("label", "neutral")
32
+ score = float(r.get("score", 0.5))
33
+ direction = LABEL_DIRECTION.get(label, 0.0)
34
+ weighted_sum += direction * score
35
+ total_weight += score
36
+ if total_weight == 0.0:
37
+ return 0.0
38
+ return max(-1.0, min(1.0, weighted_sum / total_weight))
39
+
40
+
41
+ def aggregate_scores_weighted(
42
+ results: list[dict],
43
+ classifications: list[EventClassification] | None = None,
44
+ timestamps: list[float] | None = None,
45
+ event_weights: dict[EventType, float] | None = None,
46
+ half_life_hours: float = 24.0,
47
+ ) -> float:
48
+ """
49
+ Weighted sentiment aggregation with event-type and temporal decay.
50
+
51
+ Args:
52
+ results: List of FinBERT results with "label" and "score" keys.
53
+ classifications: Optional event classifications for each headline.
54
+ timestamps: Optional Unix timestamps for each headline (for temporal decay).
55
+ event_weights: Custom event type weight multipliers.
56
+ half_life_hours: Hours for temporal half-life decay. Default 24h.
57
+
58
+ Returns float in [-1.0, +1.0].
59
+ """
60
+ if not results:
61
+ return 0.0
62
+
63
+ now = time.time()
64
+ total_weight = 0.0
65
+ weighted_sum = 0.0
66
+ weights = event_weights or DEFAULT_WEIGHTS
67
+
68
+ for i, r in enumerate(results):
69
+ label = r.get("label", "neutral")
70
+ score = float(r.get("score", 0.5))
71
+ direction = LABEL_DIRECTION.get(label, 0.0)
72
+
73
+ # Base weight from FinBERT confidence
74
+ w = score
75
+
76
+ # Event type weight multiplier
77
+ if classifications and i < len(classifications):
78
+ ec = classifications[i]
79
+ w *= weights.get(ec.event_type, 1.0)
80
+
81
+ # Temporal decay: newer headlines weight more
82
+ if timestamps and i < len(timestamps):
83
+ ts = timestamps[i]
84
+ age_hours = (now - ts) / 3600.0
85
+ # Exponential decay: weight halves every half_life_hours
86
+ decay = 0.5 ** (age_hours / half_life_hours)
87
+ w *= decay
88
+
89
+ weighted_sum += direction * w
90
+ total_weight += w
91
+
92
+ if total_weight == 0.0:
93
+ return 0.0
94
+ return max(-1.0, min(1.0, weighted_sum / total_weight))
95
+
96
+
97
+ def get_sentiment_summary(results: list[dict]) -> dict:
98
+ """Return counts, dominant label, and aggregate score."""
99
+ counts = {"positive": 0, "negative": 0, "neutral": 0}
100
+ for r in results:
101
+ label = r.get("label", "neutral")
102
+ if label in counts:
103
+ counts[label] += 1
104
+ dominant = max(counts, key=lambda k: counts[k]) if results else "neutral"
105
+ return {
106
+ "score": aggregate_scores(results),
107
+ "positive_count": counts["positive"],
108
+ "negative_count": counts["negative"],
109
+ "neutral_count": counts["neutral"],
110
+ "total": len(results),
111
+ "dominant": dominant,
112
+ }
113
+
114
+
115
+ def score_to_bar(score: float, width: int = 20) -> str:
116
+ """Render a text gauge like: ──────●────────── for display in terminals."""
117
+ clamped = max(-1.0, min(1.0, score))
118
+ mid = width // 2
119
+ pos = int(mid + clamped * mid)
120
+ pos = max(0, min(width - 1, pos))
121
+ bar = list("─" * width)
122
+ bar[mid] = "┼"
123
+ bar[pos] = "●"
124
+ return "".join(bar)
trading_cli/sentiment/finbert.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FinBERT sentiment analysis — lazy-loaded singleton, cached inference."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import threading
7
+ from typing import Callable
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # File descriptor limit is set in __main__.py at startup
12
+ # This module-level code is kept for backward compatibility when imported directly
13
+ try:
14
+ import resource
15
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
16
+ target_limit = 256
17
+ if soft > target_limit:
18
+ new_soft = min(target_limit, hard)
19
+ resource.setrlimit(resource.RLIMIT_NOFILE, (new_soft, hard))
20
+ logger.info(f"Auto-adjusted file descriptor limit from {soft} to {new_soft}")
21
+ except Exception as e:
22
+ if logger:
23
+ logger.debug(f"Could not adjust file descriptor limit: {e}")
24
+
25
+ _MODEL_NAME = "ProsusAI/finbert"
26
+ _LABELS = ["positive", "negative", "neutral"]
27
+
28
+
29
+ class FinBERTAnalyzer:
30
+ """
31
+ Lazy-loaded FinBERT wrapper.
32
+
33
+ Usage:
34
+ analyzer = FinBERTAnalyzer()
35
+ analyzer.load(progress_callback=lambda msg: print(msg))
36
+ results = analyzer.analyze_batch(["Apple beats earnings", "Market crashes"])
37
+ """
38
+
39
+ _instance: FinBERTAnalyzer | None = None
40
+ _lock = threading.Lock()
41
+
42
+ def __init__(self) -> None:
43
+ self._model = None
44
+ self._tokenizer = None
45
+ self._loaded = False
46
+ self._load_error: str | None = None
47
+ self._device: str = "cpu"
48
+ self._tried_fds_workaround: bool = False
49
+
50
+ @classmethod
51
+ def get_instance(cls) -> FinBERTAnalyzer:
52
+ if cls._instance is None:
53
+ with cls._lock:
54
+ if cls._instance is None:
55
+ cls._instance = FinBERTAnalyzer()
56
+ assert cls._instance is not None
57
+ return cls._instance
58
+
59
+ @property
60
+ def is_loaded(self) -> bool:
61
+ return self._loaded
62
+
63
+ @property
64
+ def load_error(self) -> str | None:
65
+ return self._load_error
66
+
67
+ def reload(self, progress_callback: Callable[[str], None] | None = None) -> bool:
68
+ """
69
+ Reset error state and attempt to load again.
70
+ Returns True on success, False on failure.
71
+ """
72
+ self._loaded = False
73
+ self._load_error = None # Will be set by load() if it fails
74
+ self._model = None
75
+ self._tokenizer = None
76
+ self._tried_fds_workaround = False # Reset workaround flag for fresh attempt
77
+ return self.load(progress_callback)
78
+
79
+ def load(self, progress_callback: Callable[[str], None] | None = None) -> bool:
80
+ """
81
+ Load model from HuggingFace Hub (or local cache).
82
+ Returns True on success, False on failure.
83
+ """
84
+ if self._loaded:
85
+ return True
86
+
87
+ def _cb(msg: str) -> None:
88
+ if progress_callback:
89
+ progress_callback(msg)
90
+ logger.info(msg)
91
+
92
+ try:
93
+ import os
94
+
95
+ # Suppress warnings
96
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
97
+ os.environ["TRANSFORMERS_VERBOSITY"] = "error"
98
+ os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
99
+ # Disable tqdm to avoid threading issues
100
+ os.environ["TQDM_DISABLE"] = "1"
101
+
102
+ import transformers
103
+ transformers.logging.set_verbosity_error()
104
+
105
+ # Auto-detect device
106
+ import torch
107
+ if torch.cuda.is_available():
108
+ self._device = "cuda"
109
+ _cb(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}")
110
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
111
+ self._device = "mps"
112
+ _cb("Using Apple Metal (MPS)")
113
+ elif hasattr(torch.version, 'hip') and torch.version.hip is not None:
114
+ self._device = "cuda" # ROCm uses cuda device type
115
+ _cb("Using AMD ROCm GPU")
116
+ else:
117
+ self._device = "cpu"
118
+ # Enable multi-threaded CPU inference for Intel/AMD CPUs
119
+ # Don't restrict threads - let PyTorch use available cores
120
+ _cb(f"Using CPU ({torch.get_num_threads()} threads)")
121
+
122
+ _cb("Loading FinBERT tokenizer...")
123
+ from transformers import AutoTokenizer
124
+
125
+ self._tokenizer = AutoTokenizer.from_pretrained(
126
+ _MODEL_NAME,
127
+ use_fast=True, # Fast tokenizer is much quicker
128
+ )
129
+
130
+ _cb("Loading FinBERT model weights (~500MB)...")
131
+ from transformers import AutoModelForSequenceClassification
132
+
133
+ # Use low_cpu_mem_usage for faster loading with meta tensors
134
+ # CRITICAL: Do NOT use device_map="auto" as it can trigger subprocess issues
135
+ # Instead, load on CPU first, then move to device manually
136
+ self._model = AutoModelForSequenceClassification.from_pretrained(
137
+ _MODEL_NAME,
138
+ low_cpu_mem_usage=True,
139
+ device_map=None, # Avoid subprocess spawning
140
+ # Disable features that might use subprocesses
141
+ trust_remote_code=False,
142
+ )
143
+ self._model.eval()
144
+
145
+ # Move to device after loading
146
+ self._model = self._model.to(self._device)
147
+
148
+ _cb(f"FinBERT ready on {self._device.upper()} ✓")
149
+ self._loaded = True
150
+ return True
151
+
152
+ except Exception as exc:
153
+ import traceback
154
+ import sys as sys_mod
155
+ full_traceback = traceback.format_exc()
156
+ msg = f"FinBERT load failed: {exc}"
157
+ logger.error(msg)
158
+ logger.error("Full traceback:\n%s", full_traceback)
159
+ self._load_error = msg
160
+ if progress_callback:
161
+ progress_callback(msg)
162
+
163
+ # If it's the fds_to_keep error, try once more with additional workarounds
164
+ if "fds_to_keep" in str(exc) and not getattr(self, '_tried_fds_workaround', False):
165
+ self._tried_fds_workaround = True
166
+ logger.info("Attempting retry with fds_to_keep workaround...")
167
+ logger.info("Original traceback:\n%s", full_traceback)
168
+ # Preserve original error if workaround also fails
169
+ original_error = msg
170
+ success = self._load_with_fds_workaround(progress_callback)
171
+ if not success and not self._load_error:
172
+ # Add helpful context about Python version
173
+ python_version = sys_mod.version
174
+ self._load_error = (
175
+ f"{original_error}\n"
176
+ f"\n"
177
+ f"This is a known issue with Python 3.12+ and transformers.\n"
178
+ f"Your Python version: {python_version}\n"
179
+ f"\n"
180
+ f"To fix this, consider:\n"
181
+ f" 1. Downgrade to Python 3.11 (recommended)\n"
182
+ f" 2. Or upgrade transformers: pip install -U transformers>=4.45.0\n"
183
+ f" 3. Or use the --no-sentiment flag to skip FinBERT loading"
184
+ )
185
+ return success
186
+
187
+ return False
188
+
189
+ def _load_with_fds_workaround(self, progress_callback) -> bool:
190
+ """Fallback loading method with additional workarounds for fds_to_keep error."""
191
+ if self._loaded:
192
+ return True
193
+
194
+ def _cb(msg: str) -> None:
195
+ if progress_callback:
196
+ progress_callback(msg)
197
+ logger.info(msg)
198
+
199
+ try:
200
+ import os
201
+
202
+ # Suppress warnings
203
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
204
+ os.environ["TRANSFORMERS_VERBOSITY"] = "error"
205
+ os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
206
+ os.environ["TQDM_DISABLE"] = "1"
207
+
208
+ # Try to lower file descriptor limit if it's very high
209
+ try:
210
+ import resource
211
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
212
+ _cb(f"Current file descriptor limit: soft={soft}, hard={hard}")
213
+ # Force lower limit for workaround attempt - must be very low for Python 3.14
214
+ target_limit = 128
215
+ if soft > target_limit:
216
+ new_soft = min(target_limit, hard)
217
+ resource.setrlimit(resource.RLIMIT_NOFILE, (new_soft, hard))
218
+ _cb(f"Lowered file descriptor limit from {soft} to {new_soft} (emergency fallback)")
219
+ except (ImportError, ValueError, OSError) as e:
220
+ logger.debug(f"Could not adjust file descriptor limit: {e}")
221
+
222
+ import transformers
223
+ transformers.logging.set_verbosity_error()
224
+
225
+ # Auto-detect device
226
+ import torch
227
+ if torch.cuda.is_available():
228
+ self._device = "cuda"
229
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
230
+ self._device = "mps"
231
+ else:
232
+ self._device = "cpu"
233
+ # Limit CPU threads for more stable loading
234
+ torch.set_num_threads(min(torch.get_num_threads(), 4))
235
+
236
+ _cb(f"Retrying FinBERT load on {self._device.upper()} ({torch.get_num_threads()} threads)...")
237
+
238
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
239
+
240
+ # Use fast tokenizer and optimized loading
241
+ # Disable subprocess-based tokenization
242
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
243
+ self._tokenizer = AutoTokenizer.from_pretrained(
244
+ _MODEL_NAME,
245
+ use_fast=True,
246
+ )
247
+
248
+ # Use device_map for auto placement
249
+ # For Python 3.14+, avoid using device_map="auto" which can trigger subprocess issues
250
+ device_map = None
251
+ self._model = AutoModelForSequenceClassification.from_pretrained(
252
+ _MODEL_NAME,
253
+ low_cpu_mem_usage=True,
254
+ device_map=device_map,
255
+ )
256
+ self._model.eval()
257
+
258
+ # Manually move to device
259
+ self._model = self._model.to(self._device)
260
+
261
+ _cb(f"FinBERT ready on {self._device.upper()} ✓")
262
+ self._loaded = True
263
+ return True
264
+
265
+ except Exception as exc:
266
+ msg = f"FinBERT load failed (workaround attempt): {exc}"
267
+ logger.error(msg)
268
+ self._load_error = msg
269
+ if progress_callback:
270
+ progress_callback(msg)
271
+ # Log additional context for debugging
272
+ import traceback
273
+ logger.debug("Workaround load traceback:\n%s", traceback.format_exc())
274
+
275
+ # If still failing with fds_to_keep, try one more time with subprocess isolation
276
+ if "fds_to_keep" in str(exc):
277
+ logger.info("Attempting final retry with subprocess isolation...")
278
+ return self._load_with_subprocess_isolation(progress_callback)
279
+
280
+ return False
281
+
282
+ def _load_with_subprocess_isolation(self, progress_callback) -> bool:
283
+ """Final attempt: load model with maximum subprocess isolation for Python 3.14+."""
284
+ if self._loaded:
285
+ return True
286
+
287
+ def _cb(msg: str) -> None:
288
+ if progress_callback:
289
+ progress_callback(msg)
290
+ logger.info(msg)
291
+
292
+ try:
293
+ import os
294
+ import subprocess
295
+ import sys
296
+
297
+ # Set maximum isolation before loading
298
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
299
+ os.environ["TRANSFORMERS_VERBOSITY"] = "error"
300
+ os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
301
+ os.environ["TQDM_DISABLE"] = "1"
302
+
303
+ # Additional isolation for Python 3.14
304
+ os.environ["RAYON_RS_NUM_CPUS"] = "1"
305
+ os.environ["OMP_NUM_THREADS"] = "1"
306
+
307
+ # Force file descriptor limit to minimum
308
+ try:
309
+ import resource
310
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
311
+ resource.setrlimit(resource.RLIMIT_NOFILE, (64, hard))
312
+ _cb("Set file descriptor limit to 64 (maximum isolation)")
313
+ except Exception:
314
+ pass
315
+
316
+ import transformers
317
+ transformers.logging.set_verbosity_error()
318
+
319
+ import torch
320
+ if torch.cuda.is_available():
321
+ self._device = "cuda"
322
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
323
+ self._device = "mps"
324
+ else:
325
+ self._device = "cpu"
326
+ torch.set_num_threads(1) # Single thread for maximum isolation
327
+
328
+ _cb(f"Loading with subprocess isolation on {self._device.upper()}...")
329
+
330
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
331
+
332
+ # Use slow tokenizer to avoid Rust subprocess issues
333
+ self._tokenizer = AutoTokenizer.from_pretrained(
334
+ _MODEL_NAME,
335
+ use_fast=False, # Use slow tokenizer
336
+ )
337
+
338
+ self._model = AutoModelForSequenceClassification.from_pretrained(
339
+ _MODEL_NAME,
340
+ low_cpu_mem_usage=True,
341
+ )
342
+ self._model.eval()
343
+ self._model = self._model.to(self._device)
344
+
345
+ _cb(f"FinBERT ready on {self._device.upper()} ✓")
346
+ self._loaded = True
347
+ return True
348
+
349
+ except Exception as exc:
350
+ msg = f"FinBERT load failed (subprocess isolation): {exc}"
351
+ logger.error(msg)
352
+ self._load_error = msg
353
+ if progress_callback:
354
+ progress_callback(msg)
355
+ import traceback
356
+ logger.debug("Subprocess isolation traceback:\n%s", traceback.format_exc())
357
+
358
+ # Add helpful context
359
+ import sys as sys_mod
360
+ python_version = sys_mod.version
361
+ self._load_error = (
362
+ f"{msg}\n"
363
+ f"\n"
364
+ f"This is a known compatibility issue between Python 3.12+ and the transformers library.\n"
365
+ f"Your Python version: {python_version}\n"
366
+ f"\n"
367
+ f"To resolve this issue:\n"
368
+ f" 1. Downgrade to Python 3.11 (most reliable solution)\n"
369
+ f" - Use pyenv: pyenv install 3.11 && pyenv local 3.11\n"
370
+ f" 2. Or upgrade to the latest transformers: pip install -U transformers\n"
371
+ f" - Note: As of now, you have transformers 5.5.0\n"
372
+ f" 3. Or run with sentiment disabled: trading-cli --no-sentiment\n"
373
+ f"\n"
374
+ f"The app will continue without sentiment analysis."
375
+ )
376
+ return False
377
+
378
+ def analyze_with_cache(self, headlines: list[str], conn) -> list[dict]:
379
+ """
380
+ Analyze headlines, checking SQLite cache first to avoid re-inference.
381
+ Uncached headlines are batch-processed and then stored in the cache.
382
+ """
383
+ from trading_cli.data.db import get_cached_sentiment, cache_sentiment
384
+
385
+ results: list[dict] = []
386
+ uncached_indices: list[int] = []
387
+ uncached_texts: list[str] = []
388
+
389
+ for i, text in enumerate(headlines):
390
+ cached = get_cached_sentiment(conn, text)
391
+ if cached:
392
+ results.append(cached)
393
+ else:
394
+ results.append(None) # placeholder
395
+ uncached_indices.append(i)
396
+ uncached_texts.append(text)
397
+
398
+ if uncached_texts:
399
+ fresh = self.analyze_batch(uncached_texts)
400
+ for idx, text, res in zip(uncached_indices, uncached_texts, fresh):
401
+ results[idx] = res
402
+ try:
403
+ cache_sentiment(conn, text, res["label"], res["score"])
404
+ except Exception:
405
+ pass
406
+
407
+ return [r or {"label": "neutral", "score": 0.5} for r in results]
408
+
409
+ def analyze_batch(
410
+ self,
411
+ headlines: list[str],
412
+ batch_size: int = 50,
413
+ ) -> list[dict]:
414
+ """
415
+ Run FinBERT inference on a list of headlines.
416
+
417
+ Returns list of {"label": str, "score": float} dicts,
418
+ one per input headline. Falls back to {"label": "neutral", "score": 0.5}
419
+ if model is not loaded.
420
+ """
421
+ if not headlines:
422
+ return []
423
+ if not self._loaded:
424
+ logger.warning("FinBERT not loaded — returning neutral for all headlines")
425
+ return [{"label": "neutral", "score": 0.5}] * len(headlines)
426
+
427
+ import torch
428
+
429
+ results: list[dict] = []
430
+ for i in range(0, len(headlines), batch_size):
431
+ batch = headlines[i : i + batch_size]
432
+ try:
433
+ inputs = self._tokenizer(
434
+ batch,
435
+ padding=True,
436
+ truncation=True,
437
+ max_length=512,
438
+ return_tensors="pt",
439
+ ).to(self._device) # Move inputs to correct device
440
+ with torch.no_grad():
441
+ outputs = self._model(**inputs)
442
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
443
+ for prob_row in probs:
444
+ idx = int(prob_row.argmax())
445
+ label = self._model.config.id2label[idx].lower()
446
+ # Normalise label variants (ProsusAI uses "positive","negative","neutral")
447
+ if label not in _LABELS:
448
+ label = "neutral"
449
+ results.append({"label": label, "score": float(prob_row[idx])})
450
+ except Exception as exc:
451
+ logger.error("FinBERT inference error on batch %d: %s", i, exc)
452
+ results.extend([{"label": "neutral", "score": 0.5}] * len(batch))
453
+ return results
trading_cli/sentiment/news_classifier.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """News event classifier — assigns importance weights to headlines by type.
2
+
3
+ Categorizes headlines into:
4
+ - earnings: earnings reports, guidance updates
5
+ - executive: CEO/CFO changes, board moves
6
+ - product: product launches, recalls, approvals
7
+ - macro: interest rates, CPI, unemployment, Fed policy
8
+ - generic: everything else (lower weight)
9
+
10
+ Each category has a configurable weight multiplier.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import re
16
+ from dataclasses import dataclass
17
+ from enum import Enum
18
+ from typing import Optional
19
+
20
+
21
+ class EventType(Enum):
22
+ EARNINGS = "earnings"
23
+ EXECUTIVE = "executive"
24
+ PRODUCT = "product"
25
+ MACRO = "macro"
26
+ GENERIC = "generic"
27
+
28
+
29
+ @dataclass
30
+ class EventClassification:
31
+ event_type: EventType
32
+ weight: float
33
+ confidence: float # 0.0-1.0 how confident we are in the classification
34
+
35
+
36
+ # Default weights — higher means the headline is more impactful
37
+ DEFAULT_WEIGHTS: dict[EventType, float] = {
38
+ EventType.EARNINGS: 1.5, # earnings reports move markets significantly
39
+ EventType.EXECUTIVE: 1.3, # leadership changes signal strategic shifts
40
+ EventType.PRODUCT: 1.2, # product news affects company outlook
41
+ EventType.MACRO: 1.4, # macro news affects entire market
42
+ EventType.GENERIC: 0.8, # generic news has lower impact
43
+ }
44
+
45
+ # Keyword patterns for classification
46
+ EARNINGS_KEYWORDS = [
47
+ r'\bearnings\b', r'\bprofit\b', r'\brevenue\b', r'\bloss\b',
48
+ r'\bEPS\b', r'\bper share\b', r'\bquarterly\b.*\bresult',
49
+ r'\bguidance\b', r'\bforecast\b', r'\boutlook\b',
50
+ r'\bbeat.*expect', r'\bmiss.*expect', r'\banalyst.*expect',
51
+ r'\breport.*earning', r'\bQ\d\b.*\bresult',
52
+ ]
53
+
54
+ EXECUTIVE_KEYWORDS = [
55
+ r'\bCEO\b', r'\bCFO\b', r'\bCOO\b', r'\bCTO\b',
56
+ r'\bchief\s+(executive|financial|operating|technology)',
57
+ r'\bresign', r'\bstep\s+down\b', r'\bappointed\b',
58
+ r'\bnew\s+CEO\b', r'\bboard\b', r'\bdirector',
59
+ r'\bleadership\b', r'\bexecutive\b',
60
+ ]
61
+
62
+ PRODUCT_KEYWORDS = [
63
+ r'\bproduct\s+launch', r'\brecall\b', r'\bFDA\b',
64
+ r'\bapproval\b', r'\brecalled\b', r'\bnew\s+product',
65
+ r'\biPhone\b', r'\biPad\b', r'\bTesla\b.*\bmodel',
66
+ r'\bpipeline\b', r'\btrial\b', r'\bclinical\b',
67
+ r'\bpatent\b', r'\binnovation\b',
68
+ ]
69
+
70
+ MACRO_KEYWORDS = [
71
+ r'\bFed\b', r'\bFederal\s+Reserve\b', r'\binterest\s+rate',
72
+ r'\bCPI\b', r'\binflation\b', r'\bunemployment\b',
73
+ r'\bjobs\s+report', r'\bGDP\b', r'\brecession\b',
74
+ r'\btariff\b', r'\btrade\s+war\b', r'\bsanction',
75
+ r'\bcentral\s+bank\b', r'\bmonetary\s+policy',
76
+ r'\bquantitative\s+(easing|tightening)',
77
+ ]
78
+
79
+
80
+ def classify_headline(headline: str, custom_weights: dict[EventType, float] | None = None) -> EventClassification:
81
+ """Classify a headline into an event type and return its weight.
82
+
83
+ Uses keyword matching with confidence based on how many keywords match.
84
+ """
85
+ text = headline.lower()
86
+ weights = custom_weights or DEFAULT_WEIGHTS
87
+
88
+ patterns = {
89
+ EventType.EARNINGS: EARNINGS_KEYWORDS,
90
+ EventType.EXECUTIVE: EXECUTIVE_KEYWORDS,
91
+ EventType.PRODUCT: PRODUCT_KEYWORDS,
92
+ EventType.MACRO: MACRO_KEYWORDS,
93
+ }
94
+
95
+ best_type = EventType.GENERIC
96
+ best_confidence = 0.0
97
+
98
+ for event_type, keyword_list in patterns.items():
99
+ matches = sum(1 for kw in keyword_list if re.search(kw, text))
100
+ if matches > 0:
101
+ confidence = min(1.0, matches / 3.0) # 3+ matches = high confidence
102
+ if confidence > best_confidence:
103
+ best_confidence = confidence
104
+ best_type = event_type
105
+
106
+ return EventClassification(
107
+ event_type=best_type,
108
+ weight=weights.get(best_type, 1.0),
109
+ confidence=best_confidence,
110
+ )
111
+
112
+
113
+ def classify_headlines(
114
+ headlines: list[str],
115
+ custom_weights: dict[EventType, float] | None = None,
116
+ ) -> list[EventClassification]:
117
+ """Classify multiple headlines at once."""
118
+ return [classify_headline(h, custom_weights) for h in headlines]
trading_cli/strategy/adapters/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Strategy adapters — pluggable trading strategy implementations."""
2
+
3
+ from trading_cli.strategy.adapters.base import SignalResult, StrategyAdapter, StrategyInfo
4
+ from trading_cli.strategy.adapters.registry import (
5
+ create_strategy,
6
+ get_strategy,
7
+ list_strategies,
8
+ register_strategy,
9
+ )
10
+
11
+ # Import all strategy implementations to trigger registration
12
+ from trading_cli.strategy.adapters.hybrid import HybridStrategy
13
+ from trading_cli.strategy.adapters.momentum import MomentumStrategy
14
+ from trading_cli.strategy.adapters.mean_reversion import MeanReversionStrategy
15
+ from trading_cli.strategy.adapters.mean_reversion_rsi2 import MeanReversionRSI2Strategy
16
+ from trading_cli.strategy.adapters.trend_following import TrendFollowingStrategy
17
+ from trading_cli.strategy.adapters.sentiment_driven import SentimentStrategy
18
+ from trading_cli.strategy.adapters.regime_aware import RegimeAwareStrategy
19
+ from trading_cli.strategy.adapters.super_strategy import SuperStrategy
20
+ from trading_cli.strategy.adapters.ai_fusion import AIFusionStrategy
21
+
22
+ __all__ = [
23
+ "StrategyAdapter",
24
+ "StrategyInfo",
25
+ "SignalResult",
26
+ "create_strategy",
27
+ "get_strategy",
28
+ "list_strategies",
29
+ "register_strategy",
30
+ "HybridStrategy",
31
+ "MomentumStrategy",
32
+ "MeanReversionStrategy",
33
+ "MeanReversionRSI2Strategy",
34
+ "TrendFollowingStrategy",
35
+ "SentimentStrategy",
36
+ "RegimeAwareStrategy",
37
+ "SuperStrategy",
38
+ "AIFusionStrategy",
39
+ ]