mahnerak commited on
Commit
ce00289
β€’
0 Parent(s):

Initial Commit πŸš€

Browse files
Files changed (42) hide show
  1. .dockerignore +3 -0
  2. .flake8 +2 -0
  3. .gitignore +7 -0
  4. CODE_OF_CONDUCT.md +80 -0
  5. CONTRIBUTING.md +31 -0
  6. Dockerfile +42 -0
  7. LICENSE +399 -0
  8. README.md +88 -0
  9. config/docker_hosting.json +13 -0
  10. config/docker_local.json +25 -0
  11. config/local.json +47 -0
  12. env.yaml +27 -0
  13. llm_transparency_tool/__init__.py +5 -0
  14. llm_transparency_tool/components/__init__.py +111 -0
  15. llm_transparency_tool/components/frontend/.env +6 -0
  16. llm_transparency_tool/components/frontend/.prettierrc +5 -0
  17. llm_transparency_tool/components/frontend/package.json +39 -0
  18. llm_transparency_tool/components/frontend/public/index.html +15 -0
  19. llm_transparency_tool/components/frontend/src/ContributionGraph.tsx +517 -0
  20. llm_transparency_tool/components/frontend/src/LlmViewer.css +77 -0
  21. llm_transparency_tool/components/frontend/src/Selector.tsx +154 -0
  22. llm_transparency_tool/components/frontend/src/common.tsx +17 -0
  23. llm_transparency_tool/components/frontend/src/index.tsx +39 -0
  24. llm_transparency_tool/components/frontend/src/react-app-env.d.ts +1 -0
  25. llm_transparency_tool/components/frontend/tsconfig.json +19 -0
  26. llm_transparency_tool/models/__init__.py +5 -0
  27. llm_transparency_tool/models/test_tlens_model.py +162 -0
  28. llm_transparency_tool/models/tlens_model.py +303 -0
  29. llm_transparency_tool/models/transparent_llm.py +199 -0
  30. llm_transparency_tool/routes/__init__.py +5 -0
  31. llm_transparency_tool/routes/contributions.py +201 -0
  32. llm_transparency_tool/routes/graph.py +163 -0
  33. llm_transparency_tool/routes/graph_node.py +90 -0
  34. llm_transparency_tool/routes/test_contributions.py +148 -0
  35. llm_transparency_tool/server/app.py +659 -0
  36. llm_transparency_tool/server/graph_selection.py +56 -0
  37. llm_transparency_tool/server/monitor.py +99 -0
  38. llm_transparency_tool/server/styles.py +107 -0
  39. llm_transparency_tool/server/utils.py +133 -0
  40. pyproject.toml +2 -0
  41. sample_input.txt +3 -0
  42. setup.py +13 -0
.dockerignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ **/.git
2
+ **/node_modules
3
+ **/.mypy_cache
.flake8 ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [flake8]
2
+ max-line-length = 120
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ **/frontend/node_modules*
2
+ **/frontend/build/
3
+ **/frontend/.yarn*
4
+ .vscode/
5
+ .mypy_cache/
6
+ __pycache__/
7
+ .DS_Store
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@meta.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to llm-transparency-tool
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to llm-transparency-tool, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
Dockerfile ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
8
+
9
+ RUN apt-get update && apt-get install -y \
10
+ wget \
11
+ git \
12
+ && apt-get clean \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ RUN useradd -m -u 1000 user
16
+ USER user
17
+
18
+ ENV HOME=/home/user
19
+
20
+ RUN wget -P /tmp \
21
+ "https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Mambaforge-23.11.0-0-Linux-x86_64.sh" \
22
+ && bash /tmp/Mambaforge-23.11.0-0-Linux-x86_64.sh -b -p $HOME/mambaforge3 \
23
+ && rm /tmp/Mambaforge-23.11.0-0-Linux-x86_64.sh
24
+ ENV PATH $HOME/mambaforge3/bin:$PATH
25
+
26
+ WORKDIR $HOME
27
+
28
+ ENV REPO=$HOME/llm-transparency-tool
29
+ COPY --chown=user . $REPO
30
+
31
+ WORKDIR $REPO
32
+
33
+ RUN mamba env create --name llmtt -f env.yaml -y
34
+ ENV PATH $HOME/mambaforge3/envs/llmtt/bin:$PATH
35
+ RUN pip install -e .
36
+
37
+ RUN cd llm_transparency_tool/components/frontend \
38
+ && yarn install \
39
+ && yarn build
40
+
41
+ EXPOSE 7860
42
+ CMD ["streamlit", "run", "llm_transparency_tool/server/app.py", "--server.port=7860", "--server.address=0.0.0.0", "--theme.font=Inconsolata", "--", "config/docker_hosting.json"]
LICENSE ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+ Section 1 -- Definitions.
71
+
72
+ a. Adapted Material means material subject to Copyright and Similar
73
+ Rights that is derived from or based upon the Licensed Material
74
+ and in which the Licensed Material is translated, altered,
75
+ arranged, transformed, or otherwise modified in a manner requiring
76
+ permission under the Copyright and Similar Rights held by the
77
+ Licensor. For purposes of this Public License, where the Licensed
78
+ Material is a musical work, performance, or sound recording,
79
+ Adapted Material is always produced where the Licensed Material is
80
+ synched in timed relation with a moving image.
81
+
82
+ b. Adapter's License means the license You apply to Your Copyright
83
+ and Similar Rights in Your contributions to Adapted Material in
84
+ accordance with the terms and conditions of this Public License.
85
+
86
+ c. Copyright and Similar Rights means copyright and/or similar rights
87
+ closely related to copyright including, without limitation,
88
+ performance, broadcast, sound recording, and Sui Generis Database
89
+ Rights, without regard to how the rights are labeled or
90
+ categorized. For purposes of this Public License, the rights
91
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
92
+ Rights.
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. NonCommercial means not primarily intended for or directed towards
116
+ commercial advantage or monetary compensation. For purposes of
117
+ this Public License, the exchange of the Licensed Material for
118
+ other material subject to Copyright and Similar Rights by digital
119
+ file-sharing or similar means is NonCommercial provided there is
120
+ no payment of monetary compensation in connection with the
121
+ exchange.
122
+
123
+ j. Share means to provide material to the public by any means or
124
+ process that requires permission under the Licensed Rights, such
125
+ as reproduction, public display, public performance, distribution,
126
+ dissemination, communication, or importation, and to make material
127
+ available to the public including in ways that members of the
128
+ public may access the material from a place and at a time
129
+ individually chosen by them.
130
+
131
+ k. Sui Generis Database Rights means rights other than copyright
132
+ resulting from Directive 96/9/EC of the European Parliament and of
133
+ the Council of 11 March 1996 on the legal protection of databases,
134
+ as amended and/or succeeded, as well as other essentially
135
+ equivalent rights anywhere in the world.
136
+
137
+ l. You means the individual or entity exercising the Licensed Rights
138
+ under this Public License. Your has a corresponding meaning.
139
+
140
+ Section 2 -- Scope.
141
+
142
+ a. License grant.
143
+
144
+ 1. Subject to the terms and conditions of this Public License,
145
+ the Licensor hereby grants You a worldwide, royalty-free,
146
+ non-sublicensable, non-exclusive, irrevocable license to
147
+ exercise the Licensed Rights in the Licensed Material to:
148
+
149
+ a. reproduce and Share the Licensed Material, in whole or
150
+ in part, for NonCommercial purposes only; and
151
+
152
+ b. produce, reproduce, and Share Adapted Material for
153
+ NonCommercial purposes only.
154
+
155
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
156
+ Exceptions and Limitations apply to Your use, this Public
157
+ License does not apply, and You do not need to comply with
158
+ its terms and conditions.
159
+
160
+ 3. Term. The term of this Public License is specified in Section
161
+ 6(a).
162
+
163
+ 4. Media and formats; technical modifications allowed. The
164
+ Licensor authorizes You to exercise the Licensed Rights in
165
+ all media and formats whether now known or hereafter created,
166
+ and to make technical modifications necessary to do so. The
167
+ Licensor waives and/or agrees not to assert any right or
168
+ authority to forbid You from making technical modifications
169
+ necessary to exercise the Licensed Rights, including
170
+ technical modifications necessary to circumvent Effective
171
+ Technological Measures. For purposes of this Public License,
172
+ simply making modifications authorized by this Section 2(a)
173
+ (4) never produces Adapted Material.
174
+
175
+ 5. Downstream recipients.
176
+
177
+ a. Offer from the Licensor -- Licensed Material. Every
178
+ recipient of the Licensed Material automatically
179
+ receives an offer from the Licensor to exercise the
180
+ Licensed Rights under the terms and conditions of this
181
+ Public License.
182
+
183
+ b. No downstream restrictions. You may not offer or impose
184
+ any additional or different terms or conditions on, or
185
+ apply any Effective Technological Measures to, the
186
+ Licensed Material if doing so restricts exercise of the
187
+ Licensed Rights by any recipient of the Licensed
188
+ Material.
189
+
190
+ 6. No endorsement. Nothing in this Public License constitutes or
191
+ may be construed as permission to assert or imply that You
192
+ are, or that Your use of the Licensed Material is, connected
193
+ with, or sponsored, endorsed, or granted official status by,
194
+ the Licensor or others designated to receive attribution as
195
+ provided in Section 3(a)(1)(A)(i).
196
+
197
+ b. Other rights.
198
+
199
+ 1. Moral rights, such as the right of integrity, are not
200
+ licensed under this Public License, nor are publicity,
201
+ privacy, and/or other similar personality rights; however, to
202
+ the extent possible, the Licensor waives and/or agrees not to
203
+ assert any such rights held by the Licensor to the limited
204
+ extent necessary to allow You to exercise the Licensed
205
+ Rights, but not otherwise.
206
+
207
+ 2. Patent and trademark rights are not licensed under this
208
+ Public License.
209
+
210
+ 3. To the extent possible, the Licensor waives any right to
211
+ collect royalties from You for the exercise of the Licensed
212
+ Rights, whether directly or through a collecting society
213
+ under any voluntary or waivable statutory or compulsory
214
+ licensing scheme. In all other cases the Licensor expressly
215
+ reserves any right to collect such royalties, including when
216
+ the Licensed Material is used other than for NonCommercial
217
+ purposes.
218
+
219
+ Section 3 -- License Conditions.
220
+
221
+ Your exercise of the Licensed Rights is expressly made subject to the
222
+ following conditions.
223
+
224
+ a. Attribution.
225
+
226
+ 1. If You Share the Licensed Material (including in modified
227
+ form), You must:
228
+
229
+ a. retain the following if it is supplied by the Licensor
230
+ with the Licensed Material:
231
+
232
+ i. identification of the creator(s) of the Licensed
233
+ Material and any others designated to receive
234
+ attribution, in any reasonable manner requested by
235
+ the Licensor (including by pseudonym if
236
+ designated);
237
+
238
+ ii. a copyright notice;
239
+
240
+ iii. a notice that refers to this Public License;
241
+
242
+ iv. a notice that refers to the disclaimer of
243
+ warranties;
244
+
245
+ v. a URI or hyperlink to the Licensed Material to the
246
+ extent reasonably practicable;
247
+
248
+ b. indicate if You modified the Licensed Material and
249
+ retain an indication of any previous modifications; and
250
+
251
+ c. indicate the Licensed Material is licensed under this
252
+ Public License, and include the text of, or the URI or
253
+ hyperlink to, this Public License.
254
+
255
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
256
+ reasonable manner based on the medium, means, and context in
257
+ which You Share the Licensed Material. For example, it may be
258
+ reasonable to satisfy the conditions by providing a URI or
259
+ hyperlink to a resource that includes the required
260
+ information.
261
+
262
+ 3. If requested by the Licensor, You must remove any of the
263
+ information required by Section 3(a)(1)(A) to the extent
264
+ reasonably practicable.
265
+
266
+ 4. If You Share Adapted Material You produce, the Adapter's
267
+ License You apply must not prevent recipients of the Adapted
268
+ Material from complying with this Public License.
269
+
270
+ Section 4 -- Sui Generis Database Rights.
271
+
272
+ Where the Licensed Rights include Sui Generis Database Rights that
273
+ apply to Your use of the Licensed Material:
274
+
275
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276
+ to extract, reuse, reproduce, and Share all or a substantial
277
+ portion of the contents of the database for NonCommercial purposes
278
+ only;
279
+
280
+ b. if You include all or a substantial portion of the database
281
+ contents in a database in which You have Sui Generis Database
282
+ Rights, then the database in which You have Sui Generis Database
283
+ Rights (but not its individual contents) is Adapted Material; and
284
+
285
+ c. You must comply with the conditions in Section 3(a) if You Share
286
+ all or a substantial portion of the contents of the database.
287
+
288
+ For the avoidance of doubt, this Section 4 supplements and does not
289
+ replace Your obligations under this Public License where the Licensed
290
+ Rights include other Copyright and Similar Rights.
291
+
292
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
+
294
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
+
305
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
+
315
+ c. The disclaimer of warranties and limitation of liability provided
316
+ above shall be interpreted in a manner that, to the extent
317
+ possible, most closely approximates an absolute disclaimer and
318
+ waiver of all liability.
319
+
320
+ Section 6 -- Term and Termination.
321
+
322
+ a. This Public License applies for the term of the Copyright and
323
+ Similar Rights licensed here. However, if You fail to comply with
324
+ this Public License, then Your rights under this Public License
325
+ terminate automatically.
326
+
327
+ b. Where Your right to use the Licensed Material has terminated under
328
+ Section 6(a), it reinstates:
329
+
330
+ 1. automatically as of the date the violation is cured, provided
331
+ it is cured within 30 days of Your discovery of the
332
+ violation; or
333
+
334
+ 2. upon express reinstatement by the Licensor.
335
+
336
+ For the avoidance of doubt, this Section 6(b) does not affect any
337
+ right the Licensor may have to seek remedies for Your violations
338
+ of this Public License.
339
+
340
+ c. For the avoidance of doubt, the Licensor may also offer the
341
+ Licensed Material under separate terms or conditions or stop
342
+ distributing the Licensed Material at any time; however, doing so
343
+ will not terminate this Public License.
344
+
345
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346
+ License.
347
+
348
+ Section 7 -- Other Terms and Conditions.
349
+
350
+ a. The Licensor shall not be bound by any additional or different
351
+ terms or conditions communicated by You unless expressly agreed.
352
+
353
+ b. Any arrangements, understandings, or agreements regarding the
354
+ Licensed Material not stated herein are separate from and
355
+ independent of the terms and conditions of this Public License.
356
+
357
+ Section 8 -- Interpretation.
358
+
359
+ a. For the avoidance of doubt, this Public License does not, and
360
+ shall not be interpreted to, reduce, limit, restrict, or impose
361
+ conditions on any use of the Licensed Material that could lawfully
362
+ be made without permission under this Public License.
363
+
364
+ b. To the extent possible, if any provision of this Public License is
365
+ deemed unenforceable, it shall be automatically reformed to the
366
+ minimum extent necessary to make it enforceable. If the provision
367
+ cannot be reformed, it shall be severed from this Public License
368
+ without affecting the enforceability of the remaining terms and
369
+ conditions.
370
+
371
+ c. No term or condition of this Public License will be waived and no
372
+ failure to comply consented to unless expressly agreed to by the
373
+ Licensor.
374
+
375
+ d. Nothing in this Public License constitutes or may be interpreted
376
+ as a limitation upon, or waiver of, any privileges and immunities
377
+ that apply to the Licensor or You, including from the legal
378
+ processes of any jurisdiction or authority.
379
+
380
+ =======================================================================
381
+
382
+ Creative Commons is not a party to its public
383
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
384
+ its public licenses to material it publishes and in those instances
385
+ will be considered the β€œLicensor.” The text of the Creative Commons
386
+ public licenses is dedicated to the public domain under the CC0 Public
387
+ Domain Dedication. Except for the limited purpose of indicating that
388
+ material is shared under a Creative Commons public license or as
389
+ otherwise permitted by the Creative Commons policies published at
390
+ creativecommons.org/policies, Creative Commons does not authorize the
391
+ use of the trademark "Creative Commons" or any other trademark or logo
392
+ of Creative Commons without its prior written consent including,
393
+ without limitation, in connection with any unauthorized modifications
394
+ to any of its public licenses or any other arrangements,
395
+ understandings, or agreements concerning use of licensed material. For
396
+ the avoidance of doubt, this paragraph does not form part of the
397
+ public licenses.
398
+
399
+ Creative Commons may be contacted at creativecommons.org.
README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1>
2
+ <img width="500" alt="LLM Transparency Tool" src="https://github.com/facebookresearch/llm-transparency-tool/assets/1367529/4bbf2544-88de-4576-9622-6047a056c5c8">
3
+ </h1>
4
+
5
+ <img width="832" alt="screenshot" src="https://github.com/facebookresearch/llm-transparency-tool/assets/1367529/78f6f9e2-fe76-4ded-bb78-a57f64f4ac3a">
6
+
7
+
8
+ ## Key functionality
9
+
10
+ * Choose your model, choose or add your prompt, run the inference.
11
+ * Browse contribution graph.
12
+ * Select the token to build the graph from.
13
+ * Tune the contribution threshold.
14
+ * Select representation of any token after any block.
15
+ * For the representation, see its projection to the output vocabulary, see which tokens
16
+ were promoted/suppressed but the previous block.
17
+ * The following things are clickable:
18
+ * Edges. That shows more info about the contributing attention head.
19
+ * Heads when an edge is selected. You can see what this head is promoting/suppressing.
20
+ * FFN blocks (little squares on the graph).
21
+ * Neurons when an FFN block is selected.
22
+
23
+
24
+ ## Installation
25
+
26
+ ### Dockerized running
27
+ ```bash
28
+ # From the repository root directory
29
+ docker build -t llm_transparency_tool .
30
+ docker run --rm -p 7860:7860 llm_transparency_tool
31
+ ```
32
+
33
+ ### Local Installation
34
+
35
+
36
+ ```bash
37
+ # download
38
+ git clone git@github.com:facebookresearch/llm-transparency-tool.git
39
+ cd llm-transparency-tool
40
+
41
+ # install the necessary packages
42
+ conda env create --name llmtt -f env.yaml
43
+ # install the `llm_transparency_tool` package
44
+ pip install -e .
45
+
46
+ # now, we need to build the frontend
47
+ # don't worry, even `yarn` comes preinstalled by `env.yaml`
48
+ cd llm_transparency_tool/components/frontend
49
+ yarn install
50
+ yarn build
51
+ ```
52
+
53
+ ### Launch
54
+
55
+ ```bash
56
+ streamlit run llm_transparency_tool/server/app.py -- config/local.json
57
+ ```
58
+
59
+
60
+ ## Adding support for your LLM
61
+
62
+ Initially, the tool allows you to select from just a handful of models. Here are the
63
+ options you can try for using your model in the tool, from least to most
64
+ effort.
65
+
66
+
67
+ ### The model is already supported by TransformerLens
68
+
69
+ Full list of models is [here](https://github.com/neelnanda-io/TransformerLens/blob/0825c5eb4196e7ad72d28bcf4e615306b3897490/transformer_lens/loading_from_pretrained.py#L18).
70
+ In this case, the model can be added to the configuration json file.
71
+
72
+
73
+ ### Tuned version of a model supported by TransformerLens
74
+
75
+ Add the official name of the model to the config along with the location to read the
76
+ weights from.
77
+
78
+
79
+ ### The model is not supported by TransformerLens
80
+
81
+ In this case the UI wouldn't know how to create proper hooks for the model. You'd need
82
+ to implement your version of [TransparentLlm](./llm_transparency_tool/models/transparent_llm.py#L28) class and alter the
83
+ Streamlit app to use your implementation.
84
+
85
+
86
+ ## License
87
+ This code is made available under a [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license, as found in the LICENSE file.
88
+ However you may have other legal obligations that govern your use of other content, such as the terms of service for third-party models.
config/docker_hosting.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "allow_loading_dataset_files": false,
3
+ "max_user_string_length": 100,
4
+ "preloaded_dataset_filename": "sample_input.txt",
5
+ "debug": false,
6
+ "demo_mode": true,
7
+ "models": {
8
+ "facebook/opt-125m": null,
9
+ "gpt2": null,
10
+ "distilgpt2": null
11
+ },
12
+ "default_model": "gpt2"
13
+ }
config/docker_local.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "allow_loading_dataset_files": true,
3
+ "preloaded_dataset_filename": "sample_input.txt",
4
+ "debug": true,
5
+ "models": {
6
+ "": null,
7
+ "facebook/opt-125m": null,
8
+ "facebook/opt-1.3b": null,
9
+ "facebook/opt-2.7b": null,
10
+ "facebook/opt-6.7b": null,
11
+ "facebook/opt-13b": null,
12
+ "facebook/opt-30b": null,
13
+ "meta-llama/Llama-2-7b-hf": null,
14
+ "meta-llama/Llama-2-7b-chat-hf": null,
15
+ "meta-llama/Llama-2-13b-hf": null,
16
+ "meta-llama/Llama-2-13b-chat-hf": null,
17
+ "gpt2": null,
18
+ "gpt2-medium": null,
19
+ "gpt2-large": null,
20
+ "gpt2-xl": null,
21
+ "distilgpt2": null
22
+ },
23
+ "default_model": "distilgpt2",
24
+ "demo_mode": false
25
+ }
config/local.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "allow_loading_dataset_files": true,
3
+ "preloaded_dataset_filename": "sample_input.txt",
4
+ "debug": true,
5
+ "models": {
6
+ "": null,
7
+
8
+ "gpt2": null,
9
+ "distilgpt2": null,
10
+ "facebook/opt-125m": null,
11
+ "facebook/opt-1.3b": null,
12
+ "EleutherAI/gpt-neo-125M": null,
13
+ "Qwen/Qwen-1_8B": null,
14
+ "Qwen/Qwen1.5-0.5B": null,
15
+ "Qwen/Qwen1.5-0.5B-Chat": null,
16
+ "Qwen/Qwen1.5-1.8B": null,
17
+ "Qwen/Qwen1.5-1.8B-Chat": null,
18
+ "microsoft/phi-1": null,
19
+ "microsoft/phi-1_5": null,
20
+ "microsoft/phi-2": null,
21
+
22
+ "meta-llama/Llama-2-7b-hf": null,
23
+ "meta-llama/Llama-2-7b-chat-hf": null,
24
+
25
+ "meta-llama/Llama-2-13b-hf": null,
26
+ "meta-llama/Llama-2-13b-chat-hf": null,
27
+
28
+
29
+ "gpt2-medium": null,
30
+ "gpt2-large": null,
31
+ "gpt2-xl": null,
32
+
33
+ "mistralai/Mistral-7B-v0.1": null,
34
+ "mistralai/Mistral-7B-Instruct-v0.1": null,
35
+ "mistralai/Mistral-7B-Instruct-v0.2": null,
36
+
37
+ "google/gemma-7b": null,
38
+ "google/gemma-2b": null,
39
+
40
+ "facebook/opt-2.7b": null,
41
+ "facebook/opt-6.7b": null,
42
+ "facebook/opt-13b": null,
43
+ "facebook/opt-30b": null
44
+ },
45
+ "default_model": "",
46
+ "demo_mode": false
47
+ }
env.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llmtt
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ dependencies:
7
+ - python
8
+ - pytorch
9
+ - pytorch-cuda=11.8
10
+ - nodejs
11
+ - yarn
12
+ - pip
13
+ - pip:
14
+ - datasets
15
+ - einops
16
+ - fancy_einsum
17
+ - jaxtyping
18
+ - networkx
19
+ - plotly
20
+ - pyinstrument
21
+ - setuptools
22
+ - streamlit
23
+ - streamlit_extras
24
+ - tokenizers
25
+ - transformer_lens
26
+ - transformers
27
+ - pytest # fixes wrong dependencies of transformer_lens
llm_transparency_tool/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
llm_transparency_tool/components/__init__.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from typing import List, Optional
9
+
10
+ import networkx as nx
11
+ import streamlit.components.v1 as components
12
+
13
+ from llm_transparency_tool.models.transparent_llm import ModelInfo
14
+ from llm_transparency_tool.server.graph_selection import GraphSelection, UiGraphNode
15
+
16
+ _RELEASE = True
17
+
18
+ if _RELEASE:
19
+ parent_dir = os.path.dirname(os.path.abspath(__file__))
20
+ config = {
21
+ "path": os.path.join(parent_dir, "frontend/build"),
22
+ }
23
+ else:
24
+ config = {
25
+ "url": "http://localhost:3001",
26
+ }
27
+
28
+ _component_func = components.declare_component("contribution_graph", **config)
29
+
30
+
31
+ def is_node_valid(node: UiGraphNode, n_layers: int, n_tokens: int):
32
+ return node.layer < n_layers and node.token < n_tokens
33
+
34
+
35
+ def is_selection_valid(s: GraphSelection, n_layers: int, n_tokens: int):
36
+ if not s:
37
+ return True
38
+ if s.node:
39
+ if not is_node_valid(s.node, n_layers, n_tokens):
40
+ return False
41
+ if s.edge:
42
+ for node in [s.edge.source, s.edge.target]:
43
+ if not is_node_valid(node, n_layers, n_tokens):
44
+ return False
45
+ return True
46
+
47
+
48
+ def contribution_graph(
49
+ model_info: ModelInfo,
50
+ tokens: List[str],
51
+ graphs: List[nx.Graph],
52
+ key: str,
53
+ ) -> Optional[GraphSelection]:
54
+ """Create a new instance of contribution graph.
55
+
56
+ Returns selected graph node or None if nothing was selected.
57
+ """
58
+ assert len(tokens) == len(graphs)
59
+
60
+ result = _component_func(
61
+ component="graph",
62
+ model_info=model_info.__dict__,
63
+ tokens=tokens,
64
+ edges_per_token=[nx.node_link_data(g)["links"] for g in graphs],
65
+ default=None,
66
+ key=key,
67
+ )
68
+
69
+ selection = GraphSelection.from_json(result)
70
+
71
+ n_tokens = len(tokens)
72
+ n_layers = model_info.n_layers
73
+ # We need this extra protection because even though the component has to check for
74
+ # the validity of the selection, sometimes it allows invalid output. It's some
75
+ # unexpected effect that has something to do with React and how the output value is
76
+ # set for the component.
77
+ if not is_selection_valid(selection, n_layers, n_tokens):
78
+ selection = None
79
+
80
+ return selection
81
+
82
+
83
+ def selector(
84
+ items: List[str],
85
+ indices: List[int],
86
+ temperatures: Optional[List[float]],
87
+ preselected_index: Optional[int],
88
+ key: str,
89
+ ) -> Optional[int]:
90
+ """Create a new instance of selector.
91
+
92
+ Returns selected item index.
93
+ """
94
+ n = len(items)
95
+ assert n == len(indices)
96
+ items = [{"index": i, "text": s} for s, i in zip(items, indices)]
97
+
98
+ if temperatures is not None:
99
+ assert n == len(temperatures)
100
+ for i, t in enumerate(temperatures):
101
+ items[i]["temperature"] = t
102
+
103
+ result = _component_func(
104
+ component="selector",
105
+ items=items,
106
+ preselected_index=preselected_index,
107
+ default=None,
108
+ key=key,
109
+ )
110
+
111
+ return None if result is None else int(result)
llm_transparency_tool/components/frontend/.env ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Run the component's dev server on :3001
2
+ # (The Streamlit dev server already runs on :3000)
3
+ PORT=3001
4
+
5
+ # Don't automatically open the web browser on `npm run start`.
6
+ BROWSER=none
llm_transparency_tool/components/frontend/.prettierrc ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "endOfLine": "lf",
3
+ "semi": false,
4
+ "trailingComma": "es5"
5
+ }
llm_transparency_tool/components/frontend/package.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "contribution_graph",
3
+ "version": "0.1.0",
4
+ "private": true,
5
+ "dependencies": {
6
+ "@types/d3": "^7.4.0",
7
+ "d3": "^7.8.5",
8
+ "react": "^18.2.0",
9
+ "react-dom": "^18.2.0",
10
+ "streamlit-component-lib": "^2.0.0"
11
+ },
12
+ "scripts": {
13
+ "start": "react-scripts start",
14
+ "build": "react-scripts build",
15
+ "test": "react-scripts test",
16
+ "eject": "react-scripts eject"
17
+ },
18
+ "browserslist": {
19
+ "production": [
20
+ ">0.2%",
21
+ "not dead",
22
+ "not op_mini all"
23
+ ],
24
+ "development": [
25
+ "last 1 chrome version",
26
+ "last 1 firefox version",
27
+ "last 1 safari version"
28
+ ]
29
+ },
30
+ "homepage": ".",
31
+ "devDependencies": {
32
+ "@types/node": "^20.11.17",
33
+ "@types/react": "^18.2.55",
34
+ "@types/react-dom": "^18.2.19",
35
+ "eslint-config-react-app": "^7.0.1",
36
+ "react-scripts": "^5.0.1",
37
+ "typescript": "^5.3.3"
38
+ }
39
+ }
llm_transparency_tool/components/frontend/public/index.html ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <title>Contribution Graph for Streamlit</title>
5
+ <meta charset="UTF-8" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
7
+ <meta name="theme-color" content="#000000" />
8
+ <meta name="description" content="Contribution Graph for Streamlit" />
9
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.5.0/dist/css/bootstrap.min.css" />
10
+ </head>
11
+ <body>
12
+ <noscript>You need to enable JavaScript to run this app.</noscript>
13
+ <div id="root"></div>
14
+ </body>
15
+ </html>
llm_transparency_tool/components/frontend/src/ContributionGraph.tsx ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ import {
10
+ ComponentProps,
11
+ Streamlit,
12
+ withStreamlitConnection,
13
+ } from 'streamlit-component-lib'
14
+ import React, { useEffect, useMemo, useRef, useState } from 'react';
15
+ import * as d3 from 'd3';
16
+
17
+ import {
18
+ Label,
19
+ Point,
20
+ } from './common';
21
+ import './LlmViewer.css';
22
+
23
+ export const renderParams = {
24
+ cellH: 32,
25
+ cellW: 32,
26
+ attnSize: 8,
27
+ afterFfnSize: 8,
28
+ ffnSize: 6,
29
+ tokenSelectorSize: 16,
30
+ layerCornerRadius: 6,
31
+ }
32
+
33
+ interface Cell {
34
+ layer: number
35
+ token: number
36
+ }
37
+
38
+ enum CellItem {
39
+ AfterAttn = 'after_attn',
40
+ AfterFfn = 'after_ffn',
41
+ Ffn = 'ffn',
42
+ Original = 'original', // They will only be at level = 0
43
+ }
44
+
45
+ interface Node {
46
+ cell: Cell | null
47
+ item: CellItem | null
48
+ }
49
+
50
+ interface NodeProps {
51
+ node: Node
52
+ pos: Point
53
+ isActive: boolean
54
+ }
55
+
56
+ interface EdgeRaw {
57
+ weight: number
58
+ source: string
59
+ target: string
60
+ }
61
+
62
+ interface Edge {
63
+ weight: number
64
+ from: Node
65
+ to: Node
66
+ fromPos: Point
67
+ toPos: Point
68
+ isSelectable: boolean
69
+ isFfn: boolean
70
+ }
71
+
72
+ interface Selection {
73
+ node: Node | null
74
+ edge: Edge | null
75
+ }
76
+
77
+ function tokenPointerPolygon(origin: Point) {
78
+ const r = renderParams.tokenSelectorSize / 2
79
+ const dy = r / 2
80
+ const dx = r * Math.sqrt(3.0) / 2
81
+ // Draw an arrow looking down
82
+ return [
83
+ [origin.x, origin.y + r],
84
+ [origin.x + dx, origin.y - dy],
85
+ [origin.x - dx, origin.y - dy],
86
+ ].toString()
87
+ }
88
+
89
+ function isSameCell(cell1: Cell | null, cell2: Cell | null) {
90
+ if (cell1 == null || cell2 == null) {
91
+ return false
92
+ }
93
+ return cell1.layer === cell2.layer && cell1.token === cell2.token
94
+ }
95
+
96
+ function isSameNode(node1: Node | null, node2: Node | null) {
97
+ if (node1 === null || node2 === null) {
98
+ return false
99
+ }
100
+ return isSameCell(node1.cell, node2.cell)
101
+ && node1.item === node2.item;
102
+ }
103
+
104
+ function isSameEdge(edge1: Edge | null, edge2: Edge | null) {
105
+ if (edge1 === null || edge2 === null) {
106
+ return false
107
+ }
108
+ return isSameNode(edge1.from, edge2.from) && isSameNode(edge1.to, edge2.to);
109
+ }
110
+
111
+ function nodeFromString(name: string) {
112
+ const match = name.match(/([AIMX])(\d+)_(\d+)/)
113
+ if (match == null) {
114
+ return {
115
+ cell: null,
116
+ item: null,
117
+ }
118
+ }
119
+ const [, type, layerStr, tokenStr] = match
120
+ const layer = +layerStr
121
+ const token = +tokenStr
122
+
123
+ const typeToCellItem = new Map<string, CellItem>([
124
+ ['A', CellItem.AfterAttn],
125
+ ['I', CellItem.AfterFfn],
126
+ ['M', CellItem.Ffn],
127
+ ['X', CellItem.Original],
128
+ ])
129
+ return {
130
+ cell: {
131
+ layer: layer,
132
+ token: token,
133
+ },
134
+ item: typeToCellItem.get(type) ?? null,
135
+ }
136
+ }
137
+
138
+ function isValidNode(node: Node, nLayers: number, nTokens: number) {
139
+ if (node.cell === null) {
140
+ return true
141
+ }
142
+ return node.cell.layer < nLayers && node.cell.token < nTokens
143
+ }
144
+
145
+ function isValidSelection(selection: Selection, nLayers: number, nTokens: number) {
146
+ if (selection.node !== null) {
147
+ return isValidNode(selection.node, nLayers, nTokens)
148
+ }
149
+ if (selection.edge !== null) {
150
+ return isValidNode(selection.edge.from, nLayers, nTokens) &&
151
+ isValidNode(selection.edge.to, nLayers, nTokens)
152
+ }
153
+ return true
154
+ }
155
+
156
+ const ContributionGraph = ({ args }: ComponentProps) => {
157
+ const modelInfo = args['model_info']
158
+ const tokens = args['tokens']
159
+ const edgesRaw: EdgeRaw[][] = args['edges_per_token']
160
+
161
+ const nLayers = modelInfo === null ? 0 : modelInfo.n_layers
162
+ const nTokens = tokens === null ? 0 : tokens.length
163
+
164
+ const [selection, setSelection] = useState<Selection>({
165
+ node: null,
166
+ edge: null,
167
+ })
168
+ var curSelection = selection
169
+ if (!isValidSelection(selection, nLayers, nTokens)) {
170
+ curSelection = {
171
+ node: null,
172
+ edge: null,
173
+ }
174
+ setSelection(curSelection)
175
+ Streamlit.setComponentValue(curSelection)
176
+ }
177
+
178
+ const [startToken, setStartToken] = useState<number>(nTokens - 1)
179
+ // We have startToken state var, but it won't be updated till next render, so use
180
+ // this var in the current render.
181
+ var curStartToken = startToken
182
+ if (startToken >= nTokens) {
183
+ curStartToken = nTokens - 1
184
+ setStartToken(curStartToken)
185
+ }
186
+
187
+ const handleRepresentationClick = (node: Node) => {
188
+ const newSelection: Selection = {
189
+ node: node,
190
+ edge: null,
191
+ }
192
+ setSelection(newSelection)
193
+ Streamlit.setComponentValue(newSelection)
194
+ }
195
+
196
+ const handleEdgeClick = (edge: Edge) => {
197
+ if (!edge.isSelectable) {
198
+ return
199
+ }
200
+ const newSelection: Selection = {
201
+ node: edge.to,
202
+ edge: edge,
203
+ }
204
+ setSelection(newSelection)
205
+ Streamlit.setComponentValue(newSelection)
206
+ }
207
+
208
+ const handleTokenClick = (t: number) => {
209
+ setStartToken(t)
210
+ }
211
+
212
+ const [xScale, yScale] = useMemo(() => {
213
+ const x = d3.scaleLinear()
214
+ .domain([-2, nTokens - 1])
215
+ .range([0, renderParams.cellW * (nTokens + 2)])
216
+ const y = d3.scaleLinear()
217
+ .domain([-1, nLayers])
218
+ .range([renderParams.cellH * (nLayers + 2), 0])
219
+ return [x, y]
220
+ }, [nLayers, nTokens])
221
+
222
+ const cells = useMemo(() => {
223
+ let result: Cell[] = []
224
+ for (let l = 0; l < nLayers; l++) {
225
+ for (let t = 0; t < nTokens; t++) {
226
+ result.push({
227
+ layer: l,
228
+ token: t,
229
+ })
230
+ }
231
+ }
232
+ return result
233
+ }, [nLayers, nTokens])
234
+
235
+ const nodeCoords = useMemo(() => {
236
+ let result = new Map<string, Point>()
237
+ const w = renderParams.cellW
238
+ const h = renderParams.cellH
239
+ for (var cell of cells) {
240
+ const cx = xScale(cell.token + 0.5)
241
+ const cy = yScale(cell.layer - 0.5)
242
+ result.set(
243
+ JSON.stringify({ cell: cell, item: CellItem.AfterAttn }),
244
+ { x: cx, y: cy + h / 4 },
245
+ )
246
+ result.set(
247
+ JSON.stringify({ cell: cell, item: CellItem.AfterFfn }),
248
+ { x: cx, y: cy - h / 4 },
249
+ )
250
+ result.set(
251
+ JSON.stringify({ cell: cell, item: CellItem.Ffn }),
252
+ { x: cx + 5 * w / 16, y: cy },
253
+ )
254
+ }
255
+ for (let t = 0; t < nTokens; t++) {
256
+ cell = {
257
+ layer: 0,
258
+ token: t,
259
+ }
260
+ const cx = xScale(cell.token + 0.5)
261
+ const cy = yScale(cell.layer - 1.0)
262
+ result.set(
263
+ JSON.stringify({ cell: cell, item: CellItem.Original }),
264
+ { x: cx, y: cy + h / 4 },
265
+ )
266
+ }
267
+ return result
268
+ }, [cells, nTokens, xScale, yScale])
269
+
270
+ const edges: Edge[][] = useMemo(() => {
271
+ let result = []
272
+ for (var edgeList of edgesRaw) {
273
+ let edgesPerStartToken = []
274
+ for (var edge of edgeList) {
275
+ const u = nodeFromString(edge.source)
276
+ const v = nodeFromString(edge.target)
277
+ var isSelectable = (
278
+ u.cell !== null && v.cell !== null && v.item === CellItem.AfterAttn
279
+ )
280
+ var isFfn = (
281
+ u.cell !== null && v.cell !== null && (
282
+ u.item === CellItem.Ffn || v.item === CellItem.Ffn
283
+ )
284
+ )
285
+ edgesPerStartToken.push({
286
+ weight: edge.weight,
287
+ from: u,
288
+ to: v,
289
+ fromPos: nodeCoords.get(JSON.stringify(u)) ?? { 'x': 0, 'y': 0 },
290
+ toPos: nodeCoords.get(JSON.stringify(v)) ?? { 'x': 0, 'y': 0 },
291
+ isSelectable: isSelectable,
292
+ isFfn: isFfn,
293
+ })
294
+ }
295
+ result.push(edgesPerStartToken)
296
+ }
297
+ return result
298
+ }, [edgesRaw, nodeCoords])
299
+
300
+ const activeNodes = useMemo(() => {
301
+ let result = new Set<string>()
302
+ for (var edge of edges[curStartToken]) {
303
+ const u = JSON.stringify(edge.from)
304
+ const v = JSON.stringify(edge.to)
305
+ result.add(u)
306
+ result.add(v)
307
+ }
308
+ return result
309
+ }, [edges, curStartToken])
310
+
311
+ const nodeProps = useMemo(() => {
312
+ let result: Array<NodeProps> = []
313
+ nodeCoords.forEach((p: Point, node: string) => {
314
+ result.push({
315
+ node: JSON.parse(node),
316
+ pos: p,
317
+ isActive: activeNodes.has(node),
318
+ })
319
+ })
320
+ return result
321
+ }, [nodeCoords, activeNodes])
322
+
323
+ const tokenLabels: Label[] = useMemo(() => {
324
+ if (!tokens) {
325
+ return []
326
+ }
327
+ return tokens.map((s: string, i: number) => ({
328
+ text: s.replace(/ /g, 'Β·'),
329
+ pos: {
330
+ x: xScale(i + 0.5),
331
+ y: yScale(-1.5),
332
+ },
333
+ }))
334
+ }, [tokens, xScale, yScale])
335
+
336
+ const layerLabels: Label[] = useMemo(() => {
337
+ return Array.from(Array(nLayers).keys()).map(i => ({
338
+ text: 'L' + i,
339
+ pos: {
340
+ x: xScale(-0.25),
341
+ y: yScale(i - 0.5),
342
+ },
343
+ }))
344
+ }, [nLayers, xScale, yScale])
345
+
346
+ const tokenSelectors: Array<[number, Point]> = useMemo(() => {
347
+ return Array.from(Array(nTokens).keys()).map(i => ([
348
+ i,
349
+ {
350
+ x: xScale(i + 0.5),
351
+ y: yScale(nLayers - 0.5),
352
+ }
353
+ ]))
354
+ }, [nTokens, nLayers, xScale, yScale])
355
+
356
+ const totalW = xScale(nTokens + 2)
357
+ const totalH = yScale(-4)
358
+ useEffect(() => {
359
+ Streamlit.setFrameHeight(totalH)
360
+ }, [totalH])
361
+
362
+ const colorScale = d3.scaleLinear(
363
+ [0.0, 0.5, 1.0],
364
+ ['#9eba66', 'darkolivegreen', 'darkolivegreen']
365
+ )
366
+ const ffnEdgeColorScale = d3.scaleLinear(
367
+ [0.0, 0.5, 1.0],
368
+ ['orchid', 'purple', 'purple']
369
+ )
370
+ const edgeWidthScale = d3.scaleLinear([0.0, 0.5, 1.0], [2.0, 3.0, 3.0])
371
+
372
+ const svgRef = useRef(null);
373
+
374
+ useEffect(() => {
375
+ const getNodeStyle = (p: NodeProps, type: string) => {
376
+ if (isSameNode(p.node, curSelection.node)) {
377
+ return 'selectable-item selection'
378
+ }
379
+ if (p.isActive) {
380
+ return 'selectable-item active-' + type + '-node'
381
+ }
382
+ return 'selectable-item inactive-node'
383
+ }
384
+
385
+ const svg = d3.select(svgRef.current)
386
+ svg.selectAll('*').remove()
387
+
388
+ svg
389
+ .selectAll('layers')
390
+ .data(Array.from(Array(nLayers).keys()).filter((x) => x % 2 === 1))
391
+ .enter()
392
+ .append('rect')
393
+ .attr('class', 'layer-highlight')
394
+ .attr('x', xScale(-1.0))
395
+ .attr('y', (layer) => yScale(layer))
396
+ .attr('width', xScale(nTokens + 0.25) - xScale(-1.0))
397
+ .attr('height', (layer) => yScale(layer) - yScale(layer + 1))
398
+ .attr('rx', renderParams.layerCornerRadius)
399
+
400
+ svg
401
+ .selectAll('edges')
402
+ .data(edges[curStartToken])
403
+ .enter()
404
+ .append('line')
405
+ .style('stroke', (edge: Edge) => {
406
+ if (isSameEdge(edge, curSelection.edge)) {
407
+ return 'orange'
408
+ }
409
+ if (edge.isFfn) {
410
+ return ffnEdgeColorScale(edge.weight)
411
+ }
412
+ return colorScale(edge.weight)
413
+ })
414
+ .attr('class', (edge: Edge) => edge.isSelectable ? 'selectable-edge' : '')
415
+ .style('stroke-width', (edge: Edge) => edgeWidthScale(edge.weight))
416
+ .attr('x1', (edge: Edge) => edge.fromPos.x)
417
+ .attr('y1', (edge: Edge) => edge.fromPos.y)
418
+ .attr('x2', (edge: Edge) => edge.toPos.x)
419
+ .attr('y2', (edge: Edge) => edge.toPos.y)
420
+ .on('click', (event: PointerEvent, edge) => {
421
+ handleEdgeClick(edge)
422
+ })
423
+
424
+ svg
425
+ .selectAll('residual')
426
+ .data(nodeProps)
427
+ .enter()
428
+ .filter((p) => {
429
+ return p.node.item === CellItem.AfterAttn
430
+ || p.node.item === CellItem.AfterFfn
431
+ })
432
+ .append('circle')
433
+ .attr('class', (p) => getNodeStyle(p, 'residual'))
434
+ .attr('cx', (p) => p.pos.x)
435
+ .attr('cy', (p) => p.pos.y)
436
+ .attr('r', renderParams.attnSize / 2)
437
+ .on('click', (event: PointerEvent, p) => {
438
+ handleRepresentationClick(p.node)
439
+ })
440
+
441
+ svg
442
+ .selectAll('ffn')
443
+ .data(nodeProps)
444
+ .enter()
445
+ .filter((p) => p.node.item === CellItem.Ffn && p.isActive)
446
+ .append('rect')
447
+ .attr('class', (p) => getNodeStyle(p, 'ffn'))
448
+ .attr('x', (p) => p.pos.x - renderParams.ffnSize / 2)
449
+ .attr('y', (p) => p.pos.y - renderParams.ffnSize / 2)
450
+ .attr('width', renderParams.ffnSize)
451
+ .attr('height', renderParams.ffnSize)
452
+ .on('click', (event: PointerEvent, p) => {
453
+ handleRepresentationClick(p.node)
454
+ })
455
+
456
+ svg
457
+ .selectAll('token_labels')
458
+ .data(tokenLabels)
459
+ .enter()
460
+ .append('text')
461
+ .attr('x', (label: Label) => label.pos.x)
462
+ .attr('y', (label: Label) => label.pos.y)
463
+ .attr('text-anchor', 'end')
464
+ .attr('dominant-baseline', 'middle')
465
+ .attr('alignment-baseline', 'top')
466
+ .attr('transform', (label: Label) =>
467
+ 'rotate(-40, ' + label.pos.x + ', ' + label.pos.y + ')')
468
+ .text((label: Label) => label.text)
469
+
470
+ svg
471
+ .selectAll('layer_labels')
472
+ .data(layerLabels)
473
+ .enter()
474
+ .append('text')
475
+ .attr('x', (label: Label) => label.pos.x)
476
+ .attr('y', (label: Label) => label.pos.y)
477
+ .attr('text-anchor', 'middle')
478
+ .attr('alignment-baseline', 'middle')
479
+ .text((label: Label) => label.text)
480
+
481
+ svg
482
+ .selectAll('token_selectors')
483
+ .data(tokenSelectors)
484
+ .enter()
485
+ .append('polygon')
486
+ .attr('class', ([i,]) => (
487
+ curStartToken === i
488
+ ? 'selectable-item selection'
489
+ : 'selectable-item token-selector'
490
+ ))
491
+ .attr('points', ([, p]) => tokenPointerPolygon(p))
492
+ .attr('r', renderParams.tokenSelectorSize / 2)
493
+ .on('click', (event: PointerEvent, [i,]) => {
494
+ handleTokenClick(i)
495
+ })
496
+ }, [
497
+ cells,
498
+ edges,
499
+ nodeProps,
500
+ tokenLabels,
501
+ layerLabels,
502
+ tokenSelectors,
503
+ curStartToken,
504
+ curSelection,
505
+ colorScale,
506
+ ffnEdgeColorScale,
507
+ edgeWidthScale,
508
+ nLayers,
509
+ nTokens,
510
+ xScale,
511
+ yScale
512
+ ])
513
+
514
+ return <svg ref={svgRef} width={totalW} height={totalH}></svg>
515
+ }
516
+
517
+ export default withStreamlitConnection(ContributionGraph)
llm_transparency_tool/components/frontend/src/LlmViewer.css ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ .graph-container {
10
+ display: flex;
11
+ justify-content: center;
12
+ align-items: center;
13
+ height: 100vh;
14
+ }
15
+
16
+ .svg {
17
+ border: 1px solid #ccc;
18
+ }
19
+
20
+ .layer-highlight {
21
+ fill: #f0f5f0;
22
+ }
23
+
24
+ .selectable-item {
25
+ stroke: black;
26
+ cursor: pointer;
27
+ }
28
+
29
+ .selection,
30
+ .selection:hover {
31
+ fill: orange;
32
+ }
33
+
34
+ .active-residual-node {
35
+ fill: yellowgreen;
36
+ }
37
+
38
+ .active-residual-node:hover {
39
+ fill: olivedrab;
40
+ }
41
+
42
+ .active-ffn-node {
43
+ fill: orchid;
44
+ }
45
+
46
+ .active-ffn-node:hover {
47
+ fill: purple;
48
+ }
49
+
50
+ .inactive-node {
51
+ fill: lightgray;
52
+ stroke-width: 0.5px;
53
+ }
54
+
55
+ .inactive-node:hover {
56
+ fill: gray;
57
+ }
58
+
59
+ .selectable-edge {
60
+ cursor: pointer;
61
+ }
62
+
63
+ .token-selector {
64
+ fill: lightblue;
65
+ }
66
+
67
+ .token-selector:hover {
68
+ fill: cornflowerblue;
69
+ }
70
+
71
+ .selector-item {
72
+ fill: lightblue;
73
+ }
74
+
75
+ .selector-item:hover {
76
+ fill: cornflowerblue;
77
+ }
llm_transparency_tool/components/frontend/src/Selector.tsx ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ import {
10
+ ComponentProps,
11
+ Streamlit,
12
+ withStreamlitConnection,
13
+ } from "streamlit-component-lib"
14
+ import React, { useEffect, useMemo, useRef, useState } from 'react';
15
+ import * as d3 from 'd3';
16
+
17
+ import {
18
+ Point,
19
+ } from './common';
20
+ import './LlmViewer.css';
21
+
22
+ export const renderParams = {
23
+ verticalGap: 24,
24
+ horizontalGap: 24,
25
+ itemSize: 8,
26
+ }
27
+
28
+ interface Item {
29
+ index: number
30
+ text: string
31
+ temperature: number
32
+ }
33
+
34
+ const Selector = ({ args }: ComponentProps) => {
35
+ const items: Item[] = args["items"]
36
+ const preselected_index: number | null = args["preselected_index"]
37
+ const n = items.length
38
+
39
+ const [selection, setSelection] = useState<number | null>(null)
40
+
41
+ // Ensure the preselected element has effect only when it's a new data.
42
+ var args_json = JSON.stringify(args)
43
+ useEffect(() => {
44
+ setSelection(preselected_index)
45
+ Streamlit.setComponentValue(preselected_index)
46
+ }, [args_json, preselected_index]);
47
+
48
+ const handleItemClick = (index: number) => {
49
+ setSelection(index)
50
+ Streamlit.setComponentValue(index)
51
+ }
52
+
53
+ const [xScale, yScale] = useMemo(() => {
54
+ const x = d3.scaleLinear()
55
+ .domain([0, 1])
56
+ .range([0, renderParams.horizontalGap])
57
+ const y = d3.scaleLinear()
58
+ .domain([0, n - 1])
59
+ .range([0, renderParams.verticalGap * (n - 1)])
60
+ return [x, y]
61
+ }, [n])
62
+
63
+ const itemCoords: Point[] = useMemo(() => {
64
+ return Array.from(Array(n).keys()).map(i => ({
65
+ x: xScale(0.5),
66
+ y: yScale(i + 0.5),
67
+ }))
68
+ }, [n, xScale, yScale])
69
+
70
+ var hasTemperature = false
71
+ if (n > 0) {
72
+ var t = items[0].temperature
73
+ hasTemperature = (t !== null && t !== undefined)
74
+ }
75
+ const colorScale = useMemo(() => {
76
+ var min_t = 0.0
77
+ var max_t = 1.0
78
+ if (hasTemperature) {
79
+ min_t = items[0].temperature
80
+ max_t = items[0].temperature
81
+ for (var i = 0; i < n; i++) {
82
+ const t = items[i].temperature
83
+ min_t = Math.min(min_t, t)
84
+ max_t = Math.max(max_t, t)
85
+ }
86
+ }
87
+ const norm = d3.scaleLinear([min_t, max_t], [0.0, 1.0])
88
+ const colorScale = d3.scaleSequential(d3.interpolateYlGn);
89
+ return d3.scaleSequential(value => colorScale(norm(value)))
90
+ }, [items, hasTemperature, n])
91
+
92
+ const totalW = 100
93
+ const totalH = yScale(n)
94
+ useEffect(() => {
95
+ Streamlit.setFrameHeight(totalH)
96
+ }, [totalH])
97
+
98
+ const svgRef = useRef(null);
99
+
100
+ useEffect(() => {
101
+ const svg = d3.select(svgRef.current)
102
+ svg.selectAll('*').remove()
103
+
104
+ const getItemClass = (index: number) => {
105
+ var style = 'selectable-item '
106
+ style += index === selection ? 'selection' : 'selector-item'
107
+ return style
108
+ }
109
+
110
+ const getItemColor = (item: Item) => {
111
+ var t = item.temperature ?? 0.0
112
+ return item.index === selection ? 'orange' : colorScale(t)
113
+ }
114
+
115
+ var icons = svg
116
+ .selectAll('items')
117
+ .data(Array.from(Array(n).keys()))
118
+ .enter()
119
+ .append('circle')
120
+ .attr('cx', (i) => itemCoords[i].x)
121
+ .attr('cy', (i) => itemCoords[i].y)
122
+ .attr('r', renderParams.itemSize / 2)
123
+ .on('click', (event: PointerEvent, i) => {
124
+ handleItemClick(items[i].index)
125
+ })
126
+ .attr('class', (i) => getItemClass(items[i].index))
127
+ if (hasTemperature) {
128
+ icons.style('fill', (i) => getItemColor(items[i]))
129
+ }
130
+
131
+ svg
132
+ .selectAll('labels')
133
+ .data(Array.from(Array(n).keys()))
134
+ .enter()
135
+ .append('text')
136
+ .attr('x', (i) => itemCoords[i].x + renderParams.horizontalGap / 2)
137
+ .attr('y', (i) => itemCoords[i].y)
138
+ .attr('text-anchor', 'left')
139
+ .attr('alignment-baseline', 'middle')
140
+ .text((i) => items[i].text)
141
+
142
+ }, [
143
+ items,
144
+ n,
145
+ itemCoords,
146
+ selection,
147
+ colorScale,
148
+ hasTemperature,
149
+ ])
150
+
151
+ return <svg ref={svgRef} width={totalW} height={totalH}></svg>
152
+ }
153
+
154
+ export default withStreamlitConnection(Selector)
llm_transparency_tool/components/frontend/src/common.tsx ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ export interface Point {
10
+ x: number
11
+ y: number
12
+ }
13
+
14
+ export interface Label {
15
+ text: string
16
+ pos: Point
17
+ }
llm_transparency_tool/components/frontend/src/index.tsx ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ import React from "react"
10
+ import ReactDOM from "react-dom"
11
+
12
+ import {
13
+ ComponentProps,
14
+ withStreamlitConnection,
15
+ } from "streamlit-component-lib"
16
+
17
+
18
+ import ContributionGraph from "./ContributionGraph"
19
+ import Selector from "./Selector"
20
+
21
+ const LlmViewerComponent = (props: ComponentProps) => {
22
+ switch (props.args['component']) {
23
+ case 'graph':
24
+ return <ContributionGraph />
25
+ case 'selector':
26
+ return <Selector />
27
+ default:
28
+ return <></>
29
+ }
30
+ };
31
+
32
+ const StreamlitLlmViewerComponent = withStreamlitConnection(LlmViewerComponent)
33
+
34
+ ReactDOM.render(
35
+ <React.StrictMode>
36
+ <StreamlitLlmViewerComponent />
37
+ </React.StrictMode>,
38
+ document.getElementById("root")
39
+ )
llm_transparency_tool/components/frontend/src/react-app-env.d.ts ADDED
@@ -0,0 +1 @@
 
 
1
+ /// <reference types="react-scripts" />
llm_transparency_tool/components/frontend/tsconfig.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "target": "es5",
4
+ "lib": ["dom", "dom.iterable", "esnext"],
5
+ "allowJs": true,
6
+ "skipLibCheck": true,
7
+ "esModuleInterop": true,
8
+ "allowSyntheticDefaultImports": true,
9
+ "strict": true,
10
+ "forceConsistentCasingInFileNames": true,
11
+ "module": "esnext",
12
+ "moduleResolution": "node",
13
+ "resolveJsonModule": true,
14
+ "isolatedModules": true,
15
+ "noEmit": true,
16
+ "jsx": "react"
17
+ },
18
+ "include": ["src"]
19
+ }
llm_transparency_tool/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
llm_transparency_tool/models/test_tlens_model.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import unittest
8
+
9
+ import torch
10
+
11
+ from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
12
+ from llm_transparency_tool.models.transparent_llm import ModelInfo
13
+
14
+
15
+ class TransparentLlmTestCase(unittest.TestCase):
16
+ @classmethod
17
+ def setUpClass(cls):
18
+ # Picking the smallest model possible so that the test runs faster. It's ok to
19
+ # change this model, but you'll need to update tokenization specifics in some
20
+ # tests.
21
+ cls._llm = TransformerLensTransparentLlm(
22
+ model_name="facebook/opt-125m",
23
+ device="cpu",
24
+ )
25
+
26
+ def setUp(self):
27
+ self._llm.run(["test", "test 1"])
28
+ self._eps = 1e-5
29
+
30
+ def test_model_info(self):
31
+ info = self._llm.model_info()
32
+ self.assertEqual(
33
+ info,
34
+ ModelInfo(
35
+ name="facebook/opt-125m",
36
+ n_params_estimate=84934656,
37
+ n_layers=12,
38
+ n_heads=12,
39
+ d_model=768,
40
+ d_vocab=50272,
41
+ ),
42
+ )
43
+
44
+ def test_tokens(self):
45
+ tokens = self._llm.tokens()
46
+
47
+ pad = 1
48
+ bos = 2
49
+ test = 21959
50
+ one = 112
51
+
52
+ self.assertEqual(tokens.tolist(), [[bos, test, pad], [bos, test, one]])
53
+
54
+ def test_tokens_to_strings(self):
55
+ s = self._llm.tokens_to_strings(torch.Tensor([2, 21959, 112]).to(torch.int))
56
+ self.assertEqual(s, ["</s>", "test", " 1"])
57
+
58
+ def test_manage_state(self):
59
+ # One llm.run was called at the setup. Call one more and make sure the object
60
+ # returns values for the new state.
61
+ self._llm.run(["one", "two", "three", "four"])
62
+ self.assertEqual(self._llm.tokens().shape[0], 4)
63
+
64
+ def test_residual_in_and_out(self):
65
+ """
66
+ Test that residual_in is a residual_out for the previous layer.
67
+ """
68
+ for layer in range(1, 12):
69
+ prev_residual_out = self._llm.residual_out(layer - 1)
70
+ residual_in = self._llm.residual_in(layer)
71
+ diff = torch.max(torch.abs(residual_in - prev_residual_out)).item()
72
+ self.assertLess(diff, self._eps, f"layer {layer}")
73
+
74
+ def test_residual_plus_block(self):
75
+ """
76
+ Make sure that new residual = old residual + block output. Here, block is an ffn
77
+ or attention. It's not that obvious because it could be that layer norm is
78
+ applied after the block output, but before saving the result to residual.
79
+ Luckily, this is not the case in TransformerLens, and we're relying on that.
80
+ """
81
+ layer = 3
82
+ batch = 0
83
+ pos = 0
84
+
85
+ residual_in = self._llm.residual_in(layer)[batch][pos]
86
+ residual_mid = self._llm.residual_after_attn(layer)[batch][pos]
87
+ residual_out = self._llm.residual_out(layer)[batch][pos]
88
+ ffn_out = self._llm.ffn_out(layer)[batch][pos]
89
+ attn_out = self._llm.attention_output(batch, layer, pos)
90
+
91
+ a = residual_mid
92
+ b = residual_in + attn_out
93
+ diff = torch.max(torch.abs(a - b)).item()
94
+ self.assertLess(diff, self._eps, "attn")
95
+
96
+ a = residual_out
97
+ b = residual_mid + ffn_out
98
+ diff = torch.max(torch.abs(a - b)).item()
99
+ self.assertLess(diff, self._eps, "ffn")
100
+
101
+ def test_tensor_shapes(self):
102
+ # Not much we can do about the tensors, but at least check their shapes and
103
+ # that they don't contain NaNs.
104
+ vocab_size = 50272
105
+ n_batch = 2
106
+ n_tokens = 3
107
+ d_model = 768
108
+ d_hidden = d_model * 4
109
+ n_heads = 12
110
+ layer = 5
111
+
112
+ device = self._llm.residual_in(0).device
113
+
114
+ for name, tensor, expected_shape in [
115
+ ("r_in", self._llm.residual_in(layer), [n_batch, n_tokens, d_model]),
116
+ (
117
+ "r_mid",
118
+ self._llm.residual_after_attn(layer),
119
+ [n_batch, n_tokens, d_model],
120
+ ),
121
+ ("r_out", self._llm.residual_out(layer), [n_batch, n_tokens, d_model]),
122
+ ("logits", self._llm.logits(), [n_batch, n_tokens, vocab_size]),
123
+ ("ffn_out", self._llm.ffn_out(layer), [n_batch, n_tokens, d_model]),
124
+ (
125
+ "decomposed_ffn_out",
126
+ self._llm.decomposed_ffn_out(0, 0, 0),
127
+ [d_hidden, d_model],
128
+ ),
129
+ ("neuron_activations", self._llm.neuron_activations(0, 0, 0), [d_hidden]),
130
+ ("neuron_output", self._llm.neuron_output(0, 0), [d_model]),
131
+ (
132
+ "attention_matrix",
133
+ self._llm.attention_matrix(0, 0, 0),
134
+ [n_tokens, n_tokens],
135
+ ),
136
+ (
137
+ "attention_output_per_head",
138
+ self._llm.attention_output_per_head(0, 0, 0, 0),
139
+ [d_model],
140
+ ),
141
+ (
142
+ "attention_output",
143
+ self._llm.attention_output(0, 0, 0),
144
+ [d_model],
145
+ ),
146
+ (
147
+ "decomposed_attn",
148
+ self._llm.decomposed_attn(0, layer),
149
+ [n_tokens, n_tokens, n_heads, d_model],
150
+ ),
151
+ (
152
+ "unembed",
153
+ self._llm.unembed(torch.zeros([d_model]).to(device), normalize=True),
154
+ [vocab_size],
155
+ ),
156
+ ]:
157
+ self.assertEqual(list(tensor.shape), expected_shape, name)
158
+ self.assertFalse(torch.any(tensor.isnan()), name)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ unittest.main()
llm_transparency_tool/models/tlens_model.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import List, Optional
9
+
10
+ import torch
11
+ import transformer_lens
12
+ import transformers
13
+ from fancy_einsum import einsum
14
+ from jaxtyping import Float, Int
15
+ from typeguard import typechecked
16
+ import streamlit as st
17
+
18
+ from llm_transparency_tool.models.transparent_llm import ModelInfo, TransparentLlm
19
+
20
+
21
+ @dataclass
22
+ class _RunInfo:
23
+ tokens: Int[torch.Tensor, "batch pos"]
24
+ logits: Float[torch.Tensor, "batch pos d_vocab"]
25
+ cache: transformer_lens.ActivationCache
26
+
27
+
28
+ @st.cache_resource(
29
+ max_entries=1,
30
+ show_spinner=True,
31
+ hash_funcs={
32
+ transformers.PreTrainedModel: id,
33
+ transformers.PreTrainedTokenizer: id
34
+ }
35
+ )
36
+ def load_hooked_transformer(
37
+ model_name: str,
38
+ hf_model: Optional[transformers.PreTrainedModel] = None,
39
+ tlens_device: str = "cuda",
40
+ dtype: torch.dtype = torch.float32,
41
+ ):
42
+ # if tlens_device == "cuda":
43
+ # n_devices = torch.cuda.device_count()
44
+ # else:
45
+ # n_devices = 1
46
+ tlens_model = transformer_lens.HookedTransformer.from_pretrained(
47
+ model_name,
48
+ hf_model=hf_model,
49
+ fold_ln=False, # Keep layer norm where it is.
50
+ center_writing_weights=False,
51
+ center_unembed=False,
52
+ device=tlens_device,
53
+ # n_devices=n_devices,
54
+ dtype=dtype,
55
+ )
56
+ tlens_model.eval()
57
+ return tlens_model
58
+
59
+
60
+ # TODO(igortufanov): If we want to scale the app to multiple users, we need more careful
61
+ # thread-safe implementation. The simplest option could be to wrap the existing methods
62
+ # in mutexes.
63
+ class TransformerLensTransparentLlm(TransparentLlm):
64
+ """
65
+ Implementation of Transparent LLM based on transformer lens.
66
+
67
+ Args:
68
+ - model_name: The official name of the model from HuggingFace. Even if the model was
69
+ patched or loaded locally, the name should still be official because that's how
70
+ transformer_lens treats the model.
71
+ - hf_model: The language model as a HuggingFace class.
72
+ - tokenizer,
73
+ - device: "gpu" or "cpu"
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ model_name: str,
79
+ hf_model: Optional[transformers.PreTrainedModel] = None,
80
+ tokenizer: Optional[transformers.PreTrainedTokenizer] = None,
81
+ device: str = "gpu",
82
+ dtype: torch.dtype = torch.float32,
83
+ ):
84
+ if device == "gpu":
85
+ self.device = "cuda"
86
+ if not torch.cuda.is_available():
87
+ RuntimeError("Asked to run on gpu, but torch couldn't find cuda")
88
+ elif device == "cpu":
89
+ self.device = "cpu"
90
+ else:
91
+ raise RuntimeError(f"Specified device {device} is not a valid option")
92
+
93
+ self.dtype = dtype
94
+ self.hf_tokenizer = tokenizer
95
+ self.hf_model = hf_model
96
+
97
+ # self._model = tlens_model
98
+ self._model_name = model_name
99
+ self._prepend_bos = True
100
+ self._last_run = None
101
+ self._run_exception = RuntimeError(
102
+ "Tried to use the model output before calling the `run` method"
103
+ )
104
+
105
+ def copy(self):
106
+ import copy
107
+ return copy.copy(self)
108
+
109
+ @property
110
+ def _model(self):
111
+ tlens_model = load_hooked_transformer(
112
+ self._model_name,
113
+ hf_model=self.hf_model,
114
+ tlens_device=self.device,
115
+ dtype=self.dtype,
116
+ )
117
+
118
+ if self.hf_tokenizer is not None:
119
+ tlens_model.set_tokenizer(self.hf_tokenizer, default_padding_side="left")
120
+
121
+ tlens_model.set_use_attn_result(True)
122
+ tlens_model.set_use_attn_in(False)
123
+ tlens_model.set_use_split_qkv_input(False)
124
+
125
+ return tlens_model
126
+
127
+ def model_info(self) -> ModelInfo:
128
+ cfg = self._model.cfg
129
+ return ModelInfo(
130
+ name=self._model_name,
131
+ n_params_estimate=cfg.n_params,
132
+ n_layers=cfg.n_layers,
133
+ n_heads=cfg.n_heads,
134
+ d_model=cfg.d_model,
135
+ d_vocab=cfg.d_vocab,
136
+ )
137
+
138
+ @torch.no_grad()
139
+ def run(self, sentences: List[str]) -> None:
140
+ tokens = self._model.to_tokens(sentences, prepend_bos=self._prepend_bos)
141
+ logits, cache = self._model.run_with_cache(tokens)
142
+
143
+ self._last_run = _RunInfo(
144
+ tokens=tokens,
145
+ logits=logits,
146
+ cache=cache,
147
+ )
148
+
149
+ def batch_size(self) -> int:
150
+ if not self._last_run:
151
+ raise self._run_exception
152
+ return self._last_run.logits.shape[0]
153
+
154
+ @typechecked
155
+ def tokens(self) -> Int[torch.Tensor, "batch pos"]:
156
+ if not self._last_run:
157
+ raise self._run_exception
158
+ return self._last_run.tokens
159
+
160
+ @typechecked
161
+ def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
162
+ return self._model.to_str_tokens(tokens)
163
+
164
+ @typechecked
165
+ def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
166
+ if not self._last_run:
167
+ raise self._run_exception
168
+ return self._last_run.logits
169
+
170
+ @torch.no_grad()
171
+ @typechecked
172
+ def unembed(
173
+ self,
174
+ t: Float[torch.Tensor, "d_model"],
175
+ normalize: bool,
176
+ ) -> Float[torch.Tensor, "vocab"]:
177
+ # t: [d_model] -> [batch, pos, d_model]
178
+ tdim = t.unsqueeze(0).unsqueeze(0)
179
+ if normalize:
180
+ normalized = self._model.ln_final(tdim)
181
+ result = self._model.unembed(normalized)
182
+ else:
183
+ result = self._model.unembed(tdim)
184
+ return result[0][0]
185
+
186
+ def _get_block(self, layer: int, block_name: str) -> str:
187
+ if not self._last_run:
188
+ raise self._run_exception
189
+ return self._last_run.cache[f"blocks.{layer}.{block_name}"]
190
+
191
+ # ================= Methods related to the residual stream =================
192
+
193
+ @typechecked
194
+ def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
195
+ if not self._last_run:
196
+ raise self._run_exception
197
+ return self._get_block(layer, "hook_resid_pre")
198
+
199
+ @typechecked
200
+ def residual_after_attn(
201
+ self, layer: int
202
+ ) -> Float[torch.Tensor, "batch pos d_model"]:
203
+ if not self._last_run:
204
+ raise self._run_exception
205
+ return self._get_block(layer, "hook_resid_mid")
206
+
207
+ @typechecked
208
+ def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
209
+ if not self._last_run:
210
+ raise self._run_exception
211
+ return self._get_block(layer, "hook_resid_post")
212
+
213
+ # ================ Methods related to the feed-forward layer ===============
214
+
215
+ @typechecked
216
+ def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
217
+ if not self._last_run:
218
+ raise self._run_exception
219
+ return self._get_block(layer, "hook_mlp_out")
220
+
221
+ @torch.no_grad()
222
+ @typechecked
223
+ def decomposed_ffn_out(
224
+ self,
225
+ batch_i: int,
226
+ layer: int,
227
+ pos: int,
228
+ ) -> Float[torch.Tensor, "hidden d_model"]:
229
+ # Take activations right before they're multiplied by W_out, i.e. non-linearity
230
+ # and layer norm are already applied.
231
+ processed_activations = self._get_block(layer, "mlp.hook_post")[batch_i][pos]
232
+ return torch.mul(processed_activations.unsqueeze(-1), self._model.W_out[layer])
233
+
234
+ @typechecked
235
+ def neuron_activations(
236
+ self,
237
+ batch_i: int,
238
+ layer: int,
239
+ pos: int,
240
+ ) -> Float[torch.Tensor, "hidden"]:
241
+ return self._get_block(layer, "mlp.hook_pre")[batch_i][pos]
242
+
243
+ @typechecked
244
+ def neuron_output(
245
+ self,
246
+ layer: int,
247
+ neuron: int,
248
+ ) -> Float[torch.Tensor, "d_model"]:
249
+ return self._model.W_out[layer][neuron]
250
+
251
+ # ==================== Methods related to the attention ====================
252
+
253
+ @typechecked
254
+ def attention_matrix(
255
+ self, batch_i: int, layer: int, head: int
256
+ ) -> Float[torch.Tensor, "query_pos key_pos"]:
257
+ return self._get_block(layer, "attn.hook_pattern")[batch_i][head]
258
+
259
+ @typechecked
260
+ def attention_output_per_head(
261
+ self,
262
+ batch_i: int,
263
+ layer: int,
264
+ pos: int,
265
+ head: int,
266
+ ) -> Float[torch.Tensor, "d_model"]:
267
+ return self._get_block(layer, "attn.hook_result")[batch_i][pos][head]
268
+
269
+ @typechecked
270
+ def attention_output(
271
+ self,
272
+ batch_i: int,
273
+ layer: int,
274
+ pos: int,
275
+ ) -> Float[torch.Tensor, "d_model"]:
276
+ return self._get_block(layer, "hook_attn_out")[batch_i][pos]
277
+
278
+ @torch.no_grad()
279
+ @typechecked
280
+ def decomposed_attn(
281
+ self, batch_i: int, layer: int
282
+ ) -> Float[torch.Tensor, "pos key_pos head d_model"]:
283
+ if not self._last_run:
284
+ raise self._run_exception
285
+ hook_v = self._get_block(layer, "attn.hook_v")[batch_i]
286
+ b_v = self._model.b_V[layer]
287
+ v = hook_v + b_v
288
+ pattern = self._get_block(layer, "attn.hook_pattern")[batch_i].to(v.dtype)
289
+ z = einsum(
290
+ "key_pos head d_head, "
291
+ "head query_pos key_pos -> "
292
+ "query_pos key_pos head d_head",
293
+ v,
294
+ pattern,
295
+ )
296
+ decomposed_attn = einsum(
297
+ "pos key_pos head d_head, "
298
+ "head d_head d_model -> "
299
+ "pos key_pos head d_model",
300
+ z,
301
+ self._model.W_O[layer],
302
+ )
303
+ return decomposed_attn
llm_transparency_tool/models/transparent_llm.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass
9
+ from typing import List
10
+
11
+ import torch
12
+ from jaxtyping import Float, Int
13
+
14
+
15
+ @dataclass
16
+ class ModelInfo:
17
+ name: str
18
+
19
+ # Not the actual number of parameters, but rather the order of magnitude
20
+ n_params_estimate: int
21
+
22
+ n_layers: int
23
+ n_heads: int
24
+ d_model: int
25
+ d_vocab: int
26
+
27
+
28
+ class TransparentLlm(ABC):
29
+ """
30
+ An abstract stateful interface for a language model. The model is supposed to be
31
+ loaded at the class initialization.
32
+
33
+ The internal state is the resulting tensors from the last call of the `run` method.
34
+ Most of the methods could return values based on the state, but some may do cheap
35
+ computations based on them.
36
+ """
37
+
38
+ @abstractmethod
39
+ def model_info(self) -> ModelInfo:
40
+ """
41
+ Gives general info about the model. This method must be available before any
42
+ calls of the `run`.
43
+ """
44
+ pass
45
+
46
+ @abstractmethod
47
+ def run(self, sentences: List[str]) -> None:
48
+ """
49
+ Run the inference on the given sentences in a single batch and store all
50
+ necessary info in the internal state.
51
+ """
52
+ pass
53
+
54
+ @abstractmethod
55
+ def batch_size(self) -> int:
56
+ """
57
+ The size of the batch that was used for the last call of `run`.
58
+ """
59
+ pass
60
+
61
+ @abstractmethod
62
+ def tokens(self) -> Int[torch.Tensor, "batch pos"]:
63
+ pass
64
+
65
+ @abstractmethod
66
+ def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
67
+ pass
68
+
69
+ @abstractmethod
70
+ def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
71
+ pass
72
+
73
+ @abstractmethod
74
+ def unembed(
75
+ self,
76
+ t: Float[torch.Tensor, "d_model"],
77
+ normalize: bool,
78
+ ) -> Float[torch.Tensor, "vocab"]:
79
+ """
80
+ Project the given vector (for example, the state of the residual stream for a
81
+ layer and token) into the output vocabulary.
82
+
83
+ normalize: whether to apply the final normalization before the unembedding.
84
+ Setting it to True and applying to output of the last layer gives the output of
85
+ the model.
86
+ """
87
+ pass
88
+
89
+ # ================= Methods related to the residual stream =================
90
+
91
+ @abstractmethod
92
+ def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
93
+ """
94
+ The state of the residual stream before entering the layer. For example, when
95
+ layer == 0 these must the embedded tokens (including positional embedding).
96
+ """
97
+ pass
98
+
99
+ @abstractmethod
100
+ def residual_after_attn(
101
+ self, layer: int
102
+ ) -> Float[torch.Tensor, "batch pos d_model"]:
103
+ """
104
+ The state of the residual stream after attention, but before the FFN in the
105
+ given layer.
106
+ """
107
+ pass
108
+
109
+ @abstractmethod
110
+ def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
111
+ """
112
+ The state of the residual stream after the given layer. This is equivalent to the
113
+ next layer's input.
114
+ """
115
+ pass
116
+
117
+ # ================ Methods related to the feed-forward layer ===============
118
+
119
+ @abstractmethod
120
+ def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
121
+ """
122
+ The output of the FFN layer, before it gets merged into the residual stream.
123
+ """
124
+ pass
125
+
126
+ @abstractmethod
127
+ def decomposed_ffn_out(
128
+ self,
129
+ batch_i: int,
130
+ layer: int,
131
+ pos: int,
132
+ ) -> Float[torch.Tensor, "hidden d_model"]:
133
+ """
134
+ A collection of vectors added to the residual stream by each neuron. It should
135
+ be the same as neuron activations multiplied by neuron outputs.
136
+ """
137
+ pass
138
+
139
+ @abstractmethod
140
+ def neuron_activations(
141
+ self,
142
+ batch_i: int,
143
+ layer: int,
144
+ pos: int,
145
+ ) -> Float[torch.Tensor, "d_ffn"]:
146
+ """
147
+ The content of the hidden layer right after the activation function was applied.
148
+ """
149
+ pass
150
+
151
+ @abstractmethod
152
+ def neuron_output(
153
+ self,
154
+ layer: int,
155
+ neuron: int,
156
+ ) -> Float[torch.Tensor, "d_model"]:
157
+ """
158
+ Return the value that the given neuron adds to the residual stream. It's a raw
159
+ vector from the model parameters, no activation involved.
160
+ """
161
+ pass
162
+
163
+ # ==================== Methods related to the attention ====================
164
+
165
+ @abstractmethod
166
+ def attention_matrix(
167
+ self, batch_i, layer: int, head: int
168
+ ) -> Float[torch.Tensor, "query_pos key_pos"]:
169
+ """
170
+ Return a lower-diagonal attention matrix.
171
+ """
172
+ pass
173
+
174
+ @abstractmethod
175
+ def attention_output(
176
+ self,
177
+ batch_i: int,
178
+ layer: int,
179
+ pos: int,
180
+ head: int,
181
+ ) -> Float[torch.Tensor, "d_model"]:
182
+ """
183
+ Return what the given head at the given layer and pos added to the residual
184
+ stream.
185
+ """
186
+ pass
187
+
188
+ @abstractmethod
189
+ def decomposed_attn(
190
+ self, batch_i: int, layer: int
191
+ ) -> Float[torch.Tensor, "source target head d_model"]:
192
+ """
193
+ Here
194
+ - source: index of token from the previous layer
195
+ - target: index of token on the current layer
196
+ The decomposed attention tells what vector from source representation was used
197
+ in order to contribute to the taget representation.
198
+ """
199
+ pass
llm_transparency_tool/routes/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
llm_transparency_tool/routes/contributions.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Tuple
8
+
9
+ import einops
10
+ import torch
11
+ from jaxtyping import Float
12
+ from typeguard import typechecked
13
+
14
+
15
+ @torch.no_grad()
16
+ @typechecked
17
+ def get_contributions(
18
+ parts: torch.Tensor,
19
+ whole: torch.Tensor,
20
+ distance_norm: int = 1,
21
+ ) -> torch.Tensor:
22
+ """
23
+ Compute contributions of the `parts` vectors into the `whole` vector.
24
+
25
+ Shapes of the tensors are as follows:
26
+ parts: p_1 ... p_k, v_1 ... v_n, d
27
+ whole: v_1 ... v_n, d
28
+ result: p_1 ... p_k, v_1 ... v_n
29
+
30
+ Here
31
+ * `p_1 ... p_k`: dimensions for enumerating the parts
32
+ * `v_1 ... v_n`: dimensions listing the independent cases (batching),
33
+ * `d` is the dimension to compute the distances on.
34
+
35
+ The resulting contributions will be normalized so that
36
+ for each v_: sum(over p_ of result(p_, v_)) = 1.
37
+ """
38
+ EPS = 1e-5
39
+
40
+ k = len(parts.shape) - len(whole.shape)
41
+ assert k >= 0
42
+ assert parts.shape[k:] == whole.shape
43
+ bc_whole = whole.expand(parts.shape) # new dims p_1 ... p_k are added to the front
44
+
45
+ distance = torch.nn.functional.pairwise_distance(parts, bc_whole, p=distance_norm)
46
+
47
+ whole_norm = torch.norm(whole, p=distance_norm, dim=-1)
48
+ distance = (whole_norm - distance).clip(min=EPS)
49
+
50
+ sum = distance.sum(dim=tuple(range(k)), keepdim=True)
51
+
52
+ return distance / sum
53
+
54
+
55
+ @torch.no_grad()
56
+ @typechecked
57
+ def get_contributions_with_one_off_part(
58
+ parts: torch.Tensor,
59
+ one_off: torch.Tensor,
60
+ whole: torch.Tensor,
61
+ distance_norm: int = 1,
62
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ """
64
+ Same as computing the contributions, but there is one additional part. That's useful
65
+ because we always have the residual stream as one of the parts.
66
+
67
+ See `get_contributions` documentation about `parts` and `whole` dimensions. The
68
+ `one_off` should have the same dimensions as `whole`.
69
+
70
+ Returns a pair consisting of
71
+ 1. contributions tensor for the `parts`
72
+ 2. contributions tensor for the `one_off` vector
73
+ """
74
+ assert one_off.shape == whole.shape
75
+
76
+ k = len(parts.shape) - len(whole.shape)
77
+ assert k >= 0
78
+
79
+ # Flatten the p_ dimensions, get contributions for the list, unflatten.
80
+ flat = parts.flatten(start_dim=0, end_dim=k - 1)
81
+ flat = torch.cat([flat, one_off.unsqueeze(0)])
82
+ contributions = get_contributions(flat, whole, distance_norm)
83
+ parts_contributions, one_off_contributions = torch.split(
84
+ contributions, flat.shape[0] - 1
85
+ )
86
+ return (
87
+ parts_contributions.unflatten(0, parts.shape[0:k]),
88
+ one_off_contributions[0],
89
+ )
90
+
91
+
92
+ @torch.no_grad()
93
+ @typechecked
94
+ def get_attention_contributions(
95
+ resid_pre: Float[torch.Tensor, "batch pos d_model"],
96
+ resid_mid: Float[torch.Tensor, "batch pos d_model"],
97
+ decomposed_attn: Float[torch.Tensor, "batch pos key_pos head d_model"],
98
+ distance_norm: int = 1,
99
+ ) -> Tuple[
100
+ Float[torch.Tensor, "batch pos key_pos head"],
101
+ Float[torch.Tensor, "batch pos"],
102
+ ]:
103
+ """
104
+ Returns a pair of
105
+ - a tensor of contributions of each token via each head
106
+ - the contribution of the residual stream.
107
+ """
108
+
109
+ # part dimensions | batch dimensions | vector dimension
110
+ # ----------------+------------------+-----------------
111
+ # key_pos, head | batch, pos | d_model
112
+ parts = einops.rearrange(
113
+ decomposed_attn,
114
+ "batch pos key_pos head d_model -> key_pos head batch pos d_model",
115
+ )
116
+ attn_contribution, residual_contribution = get_contributions_with_one_off_part(
117
+ parts, resid_pre, resid_mid, distance_norm
118
+ )
119
+ return (
120
+ einops.rearrange(
121
+ attn_contribution, "key_pos head batch pos -> batch pos key_pos head"
122
+ ),
123
+ residual_contribution,
124
+ )
125
+
126
+
127
+ @torch.no_grad()
128
+ @typechecked
129
+ def get_mlp_contributions(
130
+ resid_mid: Float[torch.Tensor, "batch pos d_model"],
131
+ resid_post: Float[torch.Tensor, "batch pos d_model"],
132
+ mlp_out: Float[torch.Tensor, "batch pos d_model"],
133
+ distance_norm: int = 1,
134
+ ) -> Tuple[Float[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos"]]:
135
+ """
136
+ Returns a pair of (mlp, residual) contributions for each sentence and token.
137
+ """
138
+
139
+ contributions = get_contributions(
140
+ torch.stack((mlp_out, resid_mid)), resid_post, distance_norm
141
+ )
142
+ return contributions[0], contributions[1]
143
+
144
+
145
+ @torch.no_grad()
146
+ @typechecked
147
+ def get_decomposed_mlp_contributions(
148
+ resid_mid: Float[torch.Tensor, "d_model"],
149
+ resid_post: Float[torch.Tensor, "d_model"],
150
+ decomposed_mlp_out: Float[torch.Tensor, "hidden d_model"],
151
+ distance_norm: int = 1,
152
+ ) -> Tuple[Float[torch.Tensor, "hidden"], float]:
153
+ """
154
+ Similar to `get_mlp_contributions`, but it takes the MLP output for each neuron of
155
+ the hidden layer and thus computes a contribution per neuron.
156
+
157
+ Doesn't contain batch and token dimensions for sake of saving memory. But we may
158
+ consider adding them.
159
+ """
160
+
161
+ neuron_contributions, residual_contribution = get_contributions_with_one_off_part(
162
+ decomposed_mlp_out, resid_mid, resid_post, distance_norm
163
+ )
164
+ return neuron_contributions, residual_contribution.item()
165
+
166
+
167
+ @torch.no_grad()
168
+ def apply_threshold_and_renormalize(
169
+ threshold: float,
170
+ c_blocks: torch.Tensor,
171
+ c_residual: torch.Tensor,
172
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
173
+ """
174
+ Thresholding mechanism used in the original graphs paper. After the threshold is
175
+ applied, the remaining contributions are renormalized on order to sum up to 1 for
176
+ each representation.
177
+
178
+ threshold: The threshold.
179
+ c_residual: Contribution of the residual stream for each representation. This tensor
180
+ should contain 1 element per representation, i.e., its dimensions are all batch
181
+ dimensions.
182
+ c_blocks: Contributions of the blocks. Could be 1 block per representation, like
183
+ ffn, or heads*tokens blocks in case of attention. The shape of `c_residual`
184
+ must be a prefix if the shape of this tensor. The remaining dimensions are for
185
+ listing the blocks.
186
+ """
187
+
188
+ block_dims = len(c_blocks.shape)
189
+ resid_dims = len(c_residual.shape)
190
+ bound_dims = block_dims - resid_dims
191
+ assert bound_dims >= 0
192
+ assert c_blocks.shape[0:resid_dims] == c_residual.shape
193
+
194
+ c_blocks = c_blocks * (c_blocks > threshold)
195
+ c_residual = c_residual * (c_residual > threshold)
196
+
197
+ denom = c_residual + c_blocks.sum(dim=tuple(range(resid_dims, block_dims)))
198
+ return (
199
+ c_blocks / denom.reshape(denom.shape + (1,) * bound_dims),
200
+ c_residual / denom,
201
+ )
llm_transparency_tool/routes/graph.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional
8
+
9
+ import networkx as nx
10
+ import torch
11
+
12
+ import llm_transparency_tool.routes.contributions as contributions
13
+ from llm_transparency_tool.models.transparent_llm import TransparentLlm
14
+
15
+
16
+ class GraphBuilder:
17
+ """
18
+ Constructs the contributions graph with edges given one by one. The resulting graph
19
+ is a networkx graph that can be accessed via the `graph` field. It contains the
20
+ following types of nodes:
21
+
22
+ - X0_<token>: the original token.
23
+ - A<layer>_<token>: the residual stream after attention at the given layer for the
24
+ given token.
25
+ - M<layer>_<token>: the ffn block.
26
+ - I<layer>_<token>: the residual stream after the ffn block.
27
+ """
28
+
29
+ def __init__(self, n_layers: int, n_tokens: int):
30
+ self._n_layers = n_layers
31
+ self._n_tokens = n_tokens
32
+
33
+ self.graph = nx.DiGraph()
34
+ for layer in range(n_layers):
35
+ for token in range(n_tokens):
36
+ self.graph.add_node(f"A{layer}_{token}")
37
+ self.graph.add_node(f"I{layer}_{token}")
38
+ self.graph.add_node(f"M{layer}_{token}")
39
+ for token in range(n_tokens):
40
+ self.graph.add_node(f"X0_{token}")
41
+
42
+ def get_output_node(self, token: int):
43
+ return f"I{self._n_layers - 1}_{token}"
44
+
45
+ def _add_edge(self, u: str, v: str, weight: float):
46
+ # TODO(igortufanov): Here we sum up weights for multi-edges. It happens with
47
+ # attention from the current token and the residual edge. Ideally these need to
48
+ # be 2 separate edges, but then we need to do a MultiGraph. Multigraph is fine,
49
+ # but when we try to traverse it, we face some NetworkX issue with EDGE_OK
50
+ # receiving 3 arguments instead of 2.
51
+ if self.graph.has_edge(u, v):
52
+ self.graph[u][v]["weight"] += weight
53
+ else:
54
+ self.graph.add_edge(u, v, weight=weight)
55
+
56
+ def add_attention_edge(self, layer: int, token_from: int, token_to: int, w: float):
57
+ self._add_edge(
58
+ f"I{layer-1}_{token_from}" if layer > 0 else f"X0_{token_from}",
59
+ f"A{layer}_{token_to}",
60
+ w,
61
+ )
62
+
63
+ def add_residual_to_attn(self, layer: int, token: int, w: float):
64
+ self._add_edge(
65
+ f"I{layer-1}_{token}" if layer > 0 else f"X0_{token}",
66
+ f"A{layer}_{token}",
67
+ w,
68
+ )
69
+
70
+ def add_ffn_edge(self, layer: int, token: int, w: float):
71
+ self._add_edge(f"A{layer}_{token}", f"M{layer}_{token}", w)
72
+ self._add_edge(f"M{layer}_{token}", f"I{layer}_{token}", w)
73
+
74
+ def add_residual_to_ffn(self, layer: int, token: int, w: float):
75
+ self._add_edge(f"A{layer}_{token}", f"I{layer}_{token}", w)
76
+
77
+
78
+ @torch.no_grad()
79
+ def build_full_graph(
80
+ model: TransparentLlm,
81
+ batch_i: int = 0,
82
+ renormalizing_threshold: Optional[float] = None,
83
+ ) -> nx.Graph:
84
+ """
85
+ Build the contribution graph for all blocks of the model and all tokens.
86
+
87
+ model: The transparent llm which already did the inference.
88
+ batch_i: Which sentence to use from the batch that was given to the model.
89
+ renormalizing_threshold: If specified, will apply renormalizing thresholding to the
90
+ contributions. All contributions below the threshold will be erazed and the rest
91
+ will be renormalized.
92
+ """
93
+ n_layers = model.model_info().n_layers
94
+ n_tokens = model.tokens()[batch_i].shape[0]
95
+
96
+ builder = GraphBuilder(n_layers, n_tokens)
97
+
98
+ for layer in range(n_layers):
99
+ c_attn, c_resid_attn = contributions.get_attention_contributions(
100
+ resid_pre=model.residual_in(layer)[batch_i].unsqueeze(0),
101
+ resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0),
102
+ decomposed_attn=model.decomposed_attn(batch_i, layer).unsqueeze(0),
103
+ )
104
+ if renormalizing_threshold is not None:
105
+ c_attn, c_resid_attn = contributions.apply_threshold_and_renormalize(
106
+ renormalizing_threshold, c_attn, c_resid_attn
107
+ )
108
+ for token_from in range(n_tokens):
109
+ for token_to in range(n_tokens):
110
+ # Sum attention contributions over heads.
111
+ c = c_attn[batch_i, token_to, token_from].sum().item()
112
+ builder.add_attention_edge(layer, token_from, token_to, c)
113
+ for token in range(n_tokens):
114
+ builder.add_residual_to_attn(
115
+ layer, token, c_resid_attn[batch_i, token].item()
116
+ )
117
+
118
+ c_ffn, c_resid_ffn = contributions.get_mlp_contributions(
119
+ resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0),
120
+ resid_post=model.residual_out(layer)[batch_i].unsqueeze(0),
121
+ mlp_out=model.ffn_out(layer)[batch_i].unsqueeze(0),
122
+ )
123
+ if renormalizing_threshold is not None:
124
+ c_ffn, c_resid_ffn = contributions.apply_threshold_and_renormalize(
125
+ renormalizing_threshold, c_ffn, c_resid_ffn
126
+ )
127
+ for token in range(n_tokens):
128
+ builder.add_ffn_edge(layer, token, c_ffn[batch_i, token].item())
129
+ builder.add_residual_to_ffn(
130
+ layer, token, c_resid_ffn[batch_i, token].item()
131
+ )
132
+
133
+ return builder.graph
134
+
135
+
136
+ def build_paths_to_predictions(
137
+ graph: nx.Graph,
138
+ n_layers: int,
139
+ n_tokens: int,
140
+ starting_tokens: List[int],
141
+ threshold: float,
142
+ ) -> List[nx.Graph]:
143
+ """
144
+ Given the full graph, this function returns only the trees leading to the specified
145
+ tokens. Edges with weight below `threshold` will be ignored.
146
+ """
147
+ builder = GraphBuilder(n_layers, n_tokens)
148
+
149
+ rgraph = graph.reverse()
150
+ search_graph = nx.subgraph_view(
151
+ rgraph, filter_edge=lambda u, v: rgraph[u][v]["weight"] > threshold
152
+ )
153
+
154
+ result = []
155
+ for start in starting_tokens:
156
+ assert start < n_tokens
157
+ assert start >= 0
158
+ edges = nx.edge_dfs(search_graph, source=builder.get_output_node(start))
159
+ tree = search_graph.edge_subgraph(edges)
160
+ # Reverse the edges because the dfs was going from upper layer downwards.
161
+ result.append(tree.reverse())
162
+
163
+ return result
llm_transparency_tool/routes/graph_node.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from enum import Enum
9
+ from typing import List, Optional
10
+
11
+
12
+ class NodeType(Enum):
13
+ AFTER_ATTN = "after_attn"
14
+ AFTER_FFN = "after_ffn"
15
+ FFN = "ffn"
16
+ ORIGINAL = "original" # The original tokens
17
+
18
+
19
+ def _format_block_hierachy_string(blocks: List[str]) -> str:
20
+ return " β–Έ ".join(blocks)
21
+
22
+
23
+ @dataclass
24
+ class GraphNode:
25
+ layer: int
26
+ token: int
27
+ type: NodeType
28
+
29
+ def is_in_residual_stream(self) -> bool:
30
+ return self.type in [NodeType.AFTER_ATTN, NodeType.AFTER_FFN]
31
+
32
+ def get_residual_predecessor(self) -> Optional["GraphNode"]:
33
+ """
34
+ Get another graph node which points to the state of the residual stream before
35
+ this node.
36
+
37
+ Retun None if current representation is the first one in the residual stream.
38
+ """
39
+ scheme = {
40
+ NodeType.AFTER_ATTN: GraphNode(
41
+ layer=max(self.layer - 1, 0),
42
+ token=self.token,
43
+ type=NodeType.AFTER_FFN if self.layer > 0 else NodeType.ORIGINAL,
44
+ ),
45
+ NodeType.AFTER_FFN: GraphNode(
46
+ layer=self.layer,
47
+ token=self.token,
48
+ type=NodeType.AFTER_ATTN,
49
+ ),
50
+ NodeType.FFN: GraphNode(
51
+ layer=self.layer,
52
+ token=self.token,
53
+ type=NodeType.AFTER_ATTN,
54
+ ),
55
+ NodeType.ORIGINAL: None,
56
+ }
57
+ node = scheme[self.type]
58
+ if node.layer < 0:
59
+ return None
60
+ return node
61
+
62
+ def get_name(self) -> str:
63
+ return _format_block_hierachy_string(
64
+ [f"L{self.layer}", f"T{self.token}", str(self.type.value)]
65
+ )
66
+
67
+ def get_predecessor_block_name(self) -> str:
68
+ """
69
+ Return the name of the block standing between current node and its predecessor
70
+ in the residual stream.
71
+ """
72
+ scheme = {
73
+ NodeType.AFTER_ATTN: [f"L{self.layer}", "attn"],
74
+ NodeType.AFTER_FFN: [f"L{self.layer}", "ffn"],
75
+ NodeType.FFN: [f"L{self.layer}", "ffn"],
76
+ NodeType.ORIGINAL: ["Nothing"],
77
+ }
78
+ return _format_block_hierachy_string(scheme[self.type])
79
+
80
+ def get_head_name(self, head: Optional[int]) -> str:
81
+ path = [f"L{self.layer}", "attn"]
82
+ if head is not None:
83
+ path.append(f"H{head}")
84
+ return _format_block_hierachy_string(path)
85
+
86
+ def get_neuron_name(self, neuron: Optional[int]) -> str:
87
+ path = [f"L{self.layer}", "ffn"]
88
+ if neuron is not None:
89
+ path.append(f"N{neuron}")
90
+ return _format_block_hierachy_string(path)
llm_transparency_tool/routes/test_contributions.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import unittest
8
+ from typing import Any, List
9
+
10
+ import torch
11
+
12
+ import llm_transparency_tool.routes.contributions as contributions
13
+
14
+
15
+ class TestContributions(unittest.TestCase):
16
+ def setUp(self):
17
+ torch.manual_seed(123)
18
+
19
+ self.eps = 1e-4
20
+
21
+ # It may be useful to run the test on GPU in case there are any issues with
22
+ # creating temporary tensors on another device. But turn this off by default.
23
+ self.test_on_gpu = False
24
+
25
+ self.device = "cuda" if self.test_on_gpu else "cpu"
26
+
27
+ self.batch = 4
28
+ self.tokens = 5
29
+ self.heads = 6
30
+ self.d_model = 10
31
+
32
+ self.decomposed_attn = torch.rand(
33
+ self.batch,
34
+ self.tokens,
35
+ self.tokens,
36
+ self.heads,
37
+ self.d_model,
38
+ device=self.device,
39
+ )
40
+ self.mlp_out = torch.rand(
41
+ self.batch, self.tokens, self.d_model, device=self.device
42
+ )
43
+ self.resid_pre = torch.rand(
44
+ self.batch, self.tokens, self.d_model, device=self.device
45
+ )
46
+ self.resid_mid = torch.rand(
47
+ self.batch, self.tokens, self.d_model, device=self.device
48
+ )
49
+ self.resid_post = torch.rand(
50
+ self.batch, self.tokens, self.d_model, device=self.device
51
+ )
52
+
53
+ def _assert_tensor_eq(self, t: torch.Tensor, expected: List[Any]):
54
+ self.assertTrue(
55
+ torch.isclose(t, torch.Tensor(expected), atol=self.eps).all(),
56
+ t,
57
+ )
58
+
59
+ def test_mlp_contributions(self):
60
+ mlp_out = torch.tensor([[[1.0, 1.0]]])
61
+ resid_mid = torch.tensor([[[0.0, 0.0]]])
62
+ resid_post = torch.tensor([[[1.0, 1.0]]])
63
+
64
+ c_mlp, c_residual = contributions.get_mlp_contributions(
65
+ resid_mid, resid_post, mlp_out
66
+ )
67
+ self.assertAlmostEqual(c_mlp.item(), 1.0, delta=self.eps)
68
+ self.assertAlmostEqual(c_residual.item(), 0.0, delta=self.eps)
69
+
70
+ def test_decomposed_attn_contributions(self):
71
+ resid_pre = torch.tensor([[[2.0, 1.0]]])
72
+ resid_mid = torch.tensor([[[2.0, 2.0]]])
73
+ decomposed_attn = torch.tensor(
74
+ [
75
+ [
76
+ [
77
+ [
78
+ [1.0, 1.0],
79
+ [-1.0, 0.0],
80
+ ]
81
+ ]
82
+ ]
83
+ ]
84
+ )
85
+
86
+ c_attn, c_residual = contributions.get_attention_contributions(
87
+ resid_pre, resid_mid, decomposed_attn, distance_norm=2
88
+ )
89
+ self._assert_tensor_eq(c_attn, [[[[0.43613, 0]]]])
90
+ self.assertAlmostEqual(c_residual.item(), 0.56387, delta=self.eps)
91
+
92
+ def test_decomposed_mlp_contributions(self):
93
+ pre = torch.tensor([10.0, 10.0])
94
+ post = torch.tensor([-10.0, 10.0])
95
+ neuron_impacts = torch.tensor(
96
+ [
97
+ [0.0, 1.0],
98
+ [1.0, 0.0],
99
+ [-21.0, -1.0],
100
+ ]
101
+ )
102
+ c_mlp, c_residual = contributions.get_decomposed_mlp_contributions(
103
+ pre, post, neuron_impacts, distance_norm=2
104
+ )
105
+ # A bit counter-intuitive, but the only vector pointing from 0 towards the
106
+ # output is the first one.
107
+ self._assert_tensor_eq(c_mlp, [1, 0, 0])
108
+ self.assertAlmostEqual(c_residual, 0, delta=self.eps)
109
+
110
+ def test_decomposed_mlp_contributions_single_direction(self):
111
+ pre = torch.tensor([1.0, 1.0])
112
+ post = torch.tensor([4.0, 4.0])
113
+ neuron_impacts = torch.tensor(
114
+ [
115
+ [1.0, 1.0],
116
+ [2.0, 2.0],
117
+ ]
118
+ )
119
+ c_mlp, c_residual = contributions.get_decomposed_mlp_contributions(
120
+ pre, post, neuron_impacts, distance_norm=2
121
+ )
122
+ self._assert_tensor_eq(c_mlp, [0.25, 0.5])
123
+ self.assertAlmostEqual(c_residual, 0.25, delta=self.eps)
124
+
125
+ def test_attention_contributions_shape(self):
126
+ c_attn, c_residual = contributions.get_attention_contributions(
127
+ self.resid_pre, self.resid_mid, self.decomposed_attn
128
+ )
129
+ self.assertEqual(
130
+ list(c_attn.shape), [self.batch, self.tokens, self.tokens, self.heads]
131
+ )
132
+ self.assertEqual(list(c_residual.shape), [self.batch, self.tokens])
133
+
134
+ def test_mlp_contributions_shape(self):
135
+ c_mlp, c_residual = contributions.get_mlp_contributions(
136
+ self.resid_mid, self.resid_post, self.mlp_out
137
+ )
138
+ self.assertEqual(list(c_mlp.shape), [self.batch, self.tokens])
139
+ self.assertEqual(list(c_residual.shape), [self.batch, self.tokens])
140
+
141
+ def test_renormalizing_threshold(self):
142
+ c_blocks = torch.Tensor([[0.05, 0.15], [0.05, 0.05]])
143
+ c_residual = torch.Tensor([0.8, 0.9])
144
+ norm_blocks, norm_residual = contributions.apply_threshold_and_renormalize(
145
+ 0.1, c_blocks, c_residual
146
+ )
147
+ self._assert_tensor_eq(norm_blocks, [[0.0, 0.157894], [0.0, 0.0]])
148
+ self._assert_tensor_eq(norm_residual, [0.842105, 1.0])
llm_transparency_tool/server/app.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ from dataclasses import dataclass, field
9
+ from typing import Dict, List, Optional, Tuple
10
+
11
+ import networkx as nx
12
+ import pandas as pd
13
+ import plotly.express
14
+ import plotly.graph_objects as go
15
+ import streamlit as st
16
+ import streamlit_extras.row as st_row
17
+ import torch
18
+ from jaxtyping import Float
19
+ from torch.amp import autocast
20
+ from transformers import HfArgumentParser
21
+
22
+ import llm_transparency_tool.components
23
+ from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
24
+ import llm_transparency_tool.routes.contributions as contributions
25
+ import llm_transparency_tool.routes.graph
26
+ from llm_transparency_tool.models.transparent_llm import TransparentLlm
27
+ from llm_transparency_tool.routes.graph_node import NodeType
28
+ from llm_transparency_tool.server.graph_selection import (
29
+ GraphSelection,
30
+ UiGraphEdge,
31
+ UiGraphNode,
32
+ )
33
+ from llm_transparency_tool.server.styles import (
34
+ RenderSettings,
35
+ logits_color_map,
36
+ margins_css,
37
+ string_to_display,
38
+ )
39
+ from llm_transparency_tool.server.utils import (
40
+ B0,
41
+ get_contribution_graph,
42
+ load_dataset,
43
+ load_model,
44
+ possible_devices,
45
+ run_model_with_session_caching,
46
+ st_placeholder,
47
+ )
48
+ from llm_transparency_tool.server.monitor import SystemMonitor
49
+
50
+ from networkx.classes.digraph import DiGraph
51
+
52
+
53
+ @st.cache_resource(
54
+ hash_funcs={
55
+ nx.Graph: id,
56
+ DiGraph: id
57
+ }
58
+ )
59
+ def cached_build_paths_to_predictions(
60
+ graph: nx.Graph,
61
+ n_layers: int,
62
+ n_tokens: int,
63
+ starting_tokens: List[int],
64
+ threshold: float,
65
+ ):
66
+ return llm_transparency_tool.routes.graph.build_paths_to_predictions(
67
+ graph, n_layers, n_tokens, starting_tokens, threshold
68
+ )
69
+
70
+ @st.cache_resource(
71
+ hash_funcs={
72
+ TransformerLensTransparentLlm: id
73
+ }
74
+ )
75
+ def cached_run_inference_and_populate_state(
76
+ stateless_model,
77
+ sentences,
78
+ ):
79
+ stateful_model = stateless_model.copy()
80
+ stateful_model.run(sentences)
81
+ return stateful_model
82
+
83
+
84
+ @dataclass
85
+ class LlmViewerConfig:
86
+ debug: bool = field(
87
+ default=False,
88
+ metadata={"help": "Show debugging information, like the time profile."},
89
+ )
90
+
91
+ preloaded_dataset_filename: Optional[str] = field(
92
+ default=None,
93
+ metadata={"help": "The name of the text file to load the lines from."},
94
+ )
95
+
96
+ demo_mode: bool = field(
97
+ default=False,
98
+ metadata={"help": "Whether the app should be in the demo mode."},
99
+ )
100
+
101
+ allow_loading_dataset_files: bool = field(
102
+ default=True,
103
+ metadata={"help": "Whether the app should be able to load the dataset files " "on the server side."},
104
+ )
105
+
106
+ max_user_string_length: Optional[int] = field(
107
+ default=None,
108
+ metadata={
109
+ "help": "Limit for the length of user-provided sentences (in characters), " "or None if there is no limit."
110
+ },
111
+ )
112
+
113
+ models: Dict[str, str] = field(
114
+ default_factory=dict,
115
+ metadata={
116
+ "help": "Locations of models which are stored locally. Dictionary: official "
117
+ "HuggingFace name -> path to dir. If None is specified, the model will be"
118
+ "downloaded from HuggingFace."
119
+ },
120
+ )
121
+
122
+ default_model: str = field(
123
+ default="",
124
+ metadata={"help": "The model to load once the UI is started."},
125
+ )
126
+
127
+
128
+ class App:
129
+ _stateful_model: TransparentLlm = None
130
+ render_settings = RenderSettings()
131
+ _graph: Optional[nx.Graph] = None
132
+ _contribution_threshold: float = 0.0
133
+ _renormalize_after_threshold: bool = False
134
+ _normalize_before_unembedding: bool = True
135
+
136
+ @property
137
+ def stateful_model(self) -> TransparentLlm:
138
+ return self._stateful_model
139
+
140
+ def __init__(self, config: LlmViewerConfig):
141
+ self._config = config
142
+ st.set_page_config(layout="wide")
143
+ st.markdown(margins_css, unsafe_allow_html=True)
144
+
145
+ def _get_representation(self, node: Optional[UiGraphNode]) -> Optional[Float[torch.Tensor, "d_model"]]:
146
+ if node is None:
147
+ return None
148
+ fn = {
149
+ NodeType.AFTER_ATTN: self.stateful_model.residual_after_attn,
150
+ NodeType.AFTER_FFN: self.stateful_model.residual_out,
151
+ NodeType.FFN: None,
152
+ NodeType.ORIGINAL: self.stateful_model.residual_in,
153
+ }
154
+ return fn[node.type](node.layer)[B0][node.token]
155
+
156
+ def draw_model_info(self):
157
+ info = self.stateful_model.model_info().__dict__
158
+ df = pd.DataFrame(
159
+ data=[str(x) for x in info.values()],
160
+ index=info.keys(),
161
+ columns=["Model parameter"],
162
+ )
163
+ st.dataframe(df, use_container_width=False)
164
+
165
+ def draw_dataset_selection(self) -> int:
166
+ def update_dataset(filename: Optional[str]):
167
+ dataset = load_dataset(filename) if filename is not None else []
168
+ st.session_state["dataset"] = dataset
169
+ st.session_state["dataset_file"] = filename
170
+
171
+ if "dataset" not in st.session_state:
172
+ update_dataset(self._config.preloaded_dataset_filename)
173
+
174
+
175
+ if not self._config.demo_mode:
176
+ if self._config.allow_loading_dataset_files:
177
+ row_f = st_row.row([2, 1], vertical_align="bottom")
178
+ filename = row_f.text_input("Dataset", value=st.session_state.dataset_file or "")
179
+ if row_f.button("Load"):
180
+ update_dataset(filename)
181
+ row_s = st_row.row([2, 1], vertical_align="bottom")
182
+ new_sentence = row_s.text_input("New sentence")
183
+ new_sentence_added = False
184
+
185
+ if row_s.button("Add"):
186
+ max_len = self._config.max_user_string_length
187
+ n = len(new_sentence)
188
+ if max_len is None or n <= max_len:
189
+ st.session_state.dataset.append(new_sentence)
190
+ new_sentence_added = True
191
+ st.session_state.sentence_selector = new_sentence
192
+ else:
193
+ st.warning(f"Sentence length {n} is larger than " f"the configured limit of {max_len}")
194
+
195
+ sentences = st.session_state.dataset
196
+ selection = st.selectbox(
197
+ "Sentence",
198
+ sentences,
199
+ index=len(sentences) - 1,
200
+ key="sentence_selector",
201
+ )
202
+ return selection
203
+
204
+ def _unembed(
205
+ self,
206
+ representation: torch.Tensor,
207
+ ) -> torch.Tensor:
208
+ return self.stateful_model.unembed(representation, normalize=self._normalize_before_unembedding)
209
+
210
+ def draw_graph(self, contribution_threshold: float) -> Optional[GraphSelection]:
211
+ tokens = self.stateful_model.tokens()[B0]
212
+ n_tokens = tokens.shape[0]
213
+ model_info = self.stateful_model.model_info()
214
+
215
+ graphs = cached_build_paths_to_predictions(
216
+ self._graph,
217
+ model_info.n_layers,
218
+ n_tokens,
219
+ range(n_tokens),
220
+ contribution_threshold,
221
+ )
222
+
223
+ return llm_transparency_tool.components.contribution_graph(
224
+ model_info,
225
+ self.stateful_model.tokens_to_strings(tokens),
226
+ graphs,
227
+ key=f"graph_{hash(self.sentence)}",
228
+ )
229
+
230
+ def draw_token_matrix(
231
+ self,
232
+ values: Float[torch.Tensor, "t t"],
233
+ tokens: List[str],
234
+ value_name: str,
235
+ title: str,
236
+ ):
237
+ assert values.shape[0] == len(tokens)
238
+ labels = {
239
+ "x": "<b>src</b>",
240
+ "y": "<b>tgt</b>",
241
+ "color": value_name,
242
+ }
243
+
244
+ captions = [f"({i}){t}" for i, t in enumerate(tokens)]
245
+
246
+ fig = plotly.express.imshow(
247
+ values.cpu(),
248
+ title=f'<b>{title}</b>',
249
+ labels=labels,
250
+ x=captions,
251
+ y=captions,
252
+ color_continuous_scale=self.render_settings.attention_color_map,
253
+ aspect="equal",
254
+ )
255
+ fig.update_layout(
256
+ autosize=True,
257
+ margin=go.layout.Margin(
258
+ l=50, # left margin
259
+ r=0, # right margin
260
+ b=100, # bottom margin
261
+ t=100, # top margin
262
+ # pad=10 # padding
263
+ )
264
+ )
265
+ fig.update_xaxes(tickmode="linear")
266
+ fig.update_yaxes(tickmode="linear")
267
+ fig.update_coloraxes(showscale=False)
268
+
269
+ st.plotly_chart(fig, use_container_width=True, theme=None)
270
+
271
+ def draw_attn_info(self, edge: UiGraphEdge, container_attention_map) -> Optional[int]:
272
+ """
273
+ Returns: the index of the selected head.
274
+ """
275
+
276
+ n_heads = self.stateful_model.model_info().n_heads
277
+
278
+ layer = edge.target.layer
279
+
280
+ head_contrib, _ = contributions.get_attention_contributions(
281
+ resid_pre=self.stateful_model.residual_in(layer)[B0].unsqueeze(0),
282
+ resid_mid=self.stateful_model.residual_after_attn(layer)[B0].unsqueeze(0),
283
+ decomposed_attn=self.stateful_model.decomposed_attn(B0, layer).unsqueeze(0),
284
+ )
285
+
286
+ # [batch pos key_pos head] -> [head]
287
+ flat_contrib = head_contrib[0, edge.target.token, edge.source.token, :]
288
+ assert flat_contrib.shape[0] == n_heads, f"{flat_contrib.shape} vs {n_heads}"
289
+
290
+ selected_head = llm_transparency_tool.components.selector(
291
+ items=[f"H{h}" if h >= 0 else "All" for h in range(-1, n_heads)],
292
+ indices=range(-1, n_heads),
293
+ temperatures=[sum(flat_contrib).item()] + flat_contrib.tolist(),
294
+ preselected_index=flat_contrib.argmax().item(),
295
+ key=f"head_selector_layer_{layer}" #_from_tok_{edge.source.token}_to_tok_{edge.target.token}",
296
+ )
297
+ print(f"head_selector_layer_{layer}_from_tok_{edge.source.token}_to_tok_{edge.target.token}")
298
+ if selected_head == -1 or selected_head is None:
299
+ # selected_head = None
300
+ selected_head = flat_contrib.argmax().item()
301
+ print('****\n' * 3 + f"selected_head: {selected_head}" + '\n****\n' * 3)
302
+
303
+ # Draw attention matrix and contributions for the selected head.
304
+ if selected_head is not None:
305
+ tokens = [
306
+ string_to_display(s) for s in self.stateful_model.tokens_to_strings(self.stateful_model.tokens()[B0])
307
+ ]
308
+
309
+ with container_attention_map:
310
+ attn_container, contrib_container = st.columns([1, 1])
311
+ with attn_container:
312
+ attn = self.stateful_model.attention_matrix(B0, layer, selected_head)
313
+ self.draw_token_matrix(
314
+ attn,
315
+ tokens,
316
+ "attention",
317
+ f"Attention map L{layer} H{selected_head}",
318
+ )
319
+ with contrib_container:
320
+ contrib = head_contrib[B0, :, :, selected_head]
321
+ self.draw_token_matrix(
322
+ contrib,
323
+ tokens,
324
+ "contribution",
325
+ f"Contribution map L{layer} H{selected_head}",
326
+ )
327
+
328
+ return selected_head
329
+
330
+ def draw_ffn_info(self, node: UiGraphNode) -> Optional[int]:
331
+ """
332
+ Returns: the index of the selected neuron.
333
+ """
334
+
335
+ resid_mid = self.stateful_model.residual_after_attn(node.layer)[B0][node.token]
336
+ resid_post = self.stateful_model.residual_out(node.layer)[B0][node.token]
337
+ decomposed_ffn = self.stateful_model.decomposed_ffn_out(B0, node.layer, node.token)
338
+ c_ffn, _ = contributions.get_decomposed_mlp_contributions(resid_mid, resid_post, decomposed_ffn)
339
+
340
+ top_values, top_i = c_ffn.sort(descending=True)
341
+ n = min(self.render_settings.n_top_neurons, c_ffn.shape[0])
342
+ top_neurons = top_i[0:n].tolist()
343
+
344
+ selected_neuron = llm_transparency_tool.components.selector(
345
+ items=[f"{top_neurons[i]}" if i >= 0 else "All" for i in range(-1, n)],
346
+ indices=range(-1, n),
347
+ temperatures=[0.0] + top_values[0:n].tolist(),
348
+ preselected_index=-1,
349
+ key="neuron_selector",
350
+ )
351
+ if selected_neuron is None:
352
+ selected_neuron = -1
353
+ selected_neuron = None if selected_neuron == -1 else top_neurons[selected_neuron]
354
+
355
+ return selected_neuron
356
+
357
+ def _draw_token_table(
358
+ self,
359
+ n_top: int,
360
+ n_bottom: int,
361
+ representation: torch.Tensor,
362
+ predecessor: Optional[torch.Tensor] = None,
363
+ ):
364
+ n_total = n_top + n_bottom
365
+
366
+ logits = self._unembed(representation)
367
+ n_vocab = logits.shape[0]
368
+ scores, indices = torch.topk(logits, n_top, largest=True)
369
+ positions = list(range(n_top))
370
+
371
+ if n_bottom > 0:
372
+ low_scores, low_indices = torch.topk(logits, n_bottom, largest=False)
373
+ indices = torch.cat((indices, low_indices.flip(0)))
374
+ scores = torch.cat((scores, low_scores.flip(0)))
375
+ positions += range(n_vocab - n_bottom, n_vocab)
376
+
377
+ tokens = [string_to_display(w) for w in self.stateful_model.tokens_to_strings(indices)]
378
+
379
+ if predecessor is not None:
380
+ pre_logits = self._unembed(predecessor)
381
+ _, sorted_pre_indices = pre_logits.sort(descending=True)
382
+ pre_indices_dict = {index: pos for pos, index in enumerate(sorted_pre_indices.tolist())}
383
+ old_positions = [pre_indices_dict[i] for i in indices.tolist()]
384
+
385
+ def pos_gain_string(pos, old_pos):
386
+ if pos == old_pos:
387
+ return ""
388
+ sign = "↓" if pos > old_pos else "↑"
389
+ return f"({sign}{abs(pos - old_pos)})"
390
+
391
+ position_strings = [f"{i} {pos_gain_string(i, old_i)}" for (i, old_i) in zip(positions, old_positions)]
392
+ else:
393
+ position_strings = [str(pos) for pos in positions]
394
+
395
+ def pos_gain_color(s):
396
+ color = "black"
397
+ if isinstance(s, str):
398
+ if "↓" in s:
399
+ color = "red"
400
+ if "↑" in s:
401
+ color = "green"
402
+ return f"color: {color}"
403
+
404
+ top_df = pd.DataFrame(
405
+ data=zip(position_strings, tokens, scores.tolist()),
406
+ columns=["Pos", "Token", "Score"],
407
+ )
408
+
409
+ st.dataframe(
410
+ top_df.style.map(pos_gain_color)
411
+ .background_gradient(
412
+ axis=0,
413
+ cmap=logits_color_map(positive_and_negative=n_bottom > 0),
414
+ )
415
+ .format(precision=3),
416
+ hide_index=True,
417
+ height=self.render_settings.table_cell_height * (n_total + 1),
418
+ use_container_width=True,
419
+ )
420
+
421
+ def draw_token_dynamics(self, representation: torch.Tensor, block_name: str) -> None:
422
+ st.caption(block_name)
423
+ self._draw_token_table(
424
+ self.render_settings.n_promoted_tokens,
425
+ self.render_settings.n_suppressed_tokens,
426
+ representation,
427
+ None,
428
+ )
429
+
430
+ def draw_top_tokens(
431
+ self,
432
+ node: UiGraphNode,
433
+ container_top_tokens,
434
+ container_token_dynamics,
435
+ ) -> None:
436
+ pre_node = node.get_residual_predecessor()
437
+ if pre_node is None:
438
+ return
439
+
440
+ representation = self._get_representation(node)
441
+ predecessor = self._get_representation(pre_node)
442
+
443
+ with container_top_tokens:
444
+ st.caption(node.get_name())
445
+ self._draw_token_table(
446
+ self.render_settings.n_top_tokens,
447
+ 0,
448
+ representation,
449
+ predecessor,
450
+ )
451
+ if container_token_dynamics is not None:
452
+ with container_token_dynamics:
453
+ self.draw_token_dynamics(representation - predecessor, node.get_predecessor_block_name())
454
+
455
+ def draw_attention_dynamics(self, node: UiGraphNode, head: Optional[int]):
456
+ block_name = node.get_head_name(head)
457
+ block_output = (
458
+ self.stateful_model.attention_output_per_head(B0, node.layer, node.token, head)
459
+ if head is not None
460
+ else self.stateful_model.attention_output(B0, node.layer, node.token)
461
+ )
462
+ self.draw_token_dynamics(block_output, block_name)
463
+
464
+ def draw_ffn_dynamics(self, node: UiGraphNode, neuron: Optional[int]):
465
+ block_name = node.get_neuron_name(neuron)
466
+ block_output = (
467
+ self.stateful_model.neuron_output(node.layer, neuron)
468
+ if neuron is not None
469
+ else self.stateful_model.ffn_out(node.layer)[B0][node.token]
470
+ )
471
+ self.draw_token_dynamics(block_output, block_name)
472
+
473
+ def draw_precision_controls(self, device: str) -> Tuple[torch.dtype, bool]:
474
+ """
475
+ Draw fp16/fp32 switch and AMP control.
476
+
477
+ return: The selected precision and whether AMP should be enabled.
478
+ """
479
+
480
+ if device == "cpu":
481
+ dtype = torch.float32
482
+ else:
483
+ dtype = st.selectbox(
484
+ "Precision",
485
+ [torch.float16, torch.bfloat16, torch.float32],
486
+ index=0,
487
+ )
488
+
489
+ amp_enabled = dtype != torch.float32
490
+
491
+ return dtype, amp_enabled
492
+
493
+ def draw_controls(self):
494
+ # model_container, data_container = st.columns([1, 1])
495
+ with st.sidebar.expander("Model", expanded=True):
496
+ list_of_devices = possible_devices()
497
+ if len(list_of_devices) > 1:
498
+ self.device = st.selectbox(
499
+ "Device",
500
+ possible_devices(),
501
+ index=0,
502
+ )
503
+ else:
504
+ self.device = list_of_devices[0]
505
+
506
+ self.dtype, self.amp_enabled = self.draw_precision_controls(self.device)
507
+
508
+ model_list = list(self._config.models)
509
+ default_choice = model_list.index(self._config.default_model)
510
+
511
+ self.model_name = st.selectbox(
512
+ "Model",
513
+ model_list,
514
+ index=default_choice,
515
+ )
516
+
517
+ if self.model_name:
518
+ self._stateful_model = load_model(
519
+ model_name=self.model_name,
520
+ _model_path=self._config.models[self.model_name],
521
+ _device=self.device,
522
+ _dtype=self.dtype,
523
+ )
524
+ self.model_key = self.model_name # TODO maybe something else?
525
+ self.draw_model_info()
526
+
527
+ self.sentence = self.draw_dataset_selection()
528
+
529
+ with st.sidebar.expander("Graph", expanded=True):
530
+ self._contribution_threshold = st.slider(
531
+ min_value=0.01,
532
+ max_value=0.1,
533
+ step=0.01,
534
+ value=0.04,
535
+ format=r"%.3f",
536
+ label="Contribution threshold",
537
+ )
538
+ self._renormalize_after_threshold = st.checkbox("Renormalize after threshold", value=True)
539
+ self._normalize_before_unembedding = st.checkbox("Normalize before unembedding", value=True)
540
+
541
+ def run_inference(self):
542
+
543
+ with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype):
544
+ self._stateful_model = cached_run_inference_and_populate_state(self.stateful_model, [self.sentence])
545
+
546
+ with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype):
547
+ self._graph = get_contribution_graph(
548
+ self.stateful_model,
549
+ self.model_key,
550
+ self.stateful_model.tokens()[B0].tolist(),
551
+ (self._contribution_threshold if self._renormalize_after_threshold else 0.0),
552
+ )
553
+
554
+ def draw_graph_and_selection(
555
+ self,
556
+ ) -> None:
557
+ (
558
+ container_graph,
559
+ container_tokens,
560
+ ) = st.columns(self.render_settings.column_proportions)
561
+
562
+ container_graph_left, container_graph_right = container_graph.columns([5, 1])
563
+
564
+ container_graph_left.write('##### Graph')
565
+ heads_placeholder = container_graph_right.empty()
566
+ heads_placeholder.write('##### Blocks')
567
+ container_graph_right_used = False
568
+
569
+ container_top_tokens, container_token_dynamics = container_tokens.columns([1, 1])
570
+ container_top_tokens.write('##### Top Tokens')
571
+ container_top_tokens_used = False
572
+ container_token_dynamics.write('##### Promoted Tokens')
573
+ container_token_dynamics_used = False
574
+
575
+ try:
576
+
577
+ if self.sentence is None:
578
+ return
579
+
580
+ with container_graph_left:
581
+ selection = self.draw_graph(self._contribution_threshold if not self._renormalize_after_threshold else 0.0)
582
+
583
+ if selection is None:
584
+ return
585
+
586
+ node = selection.node
587
+ edge = selection.edge
588
+
589
+ if edge is not None and edge.target.type == NodeType.AFTER_ATTN:
590
+ with container_graph_right:
591
+ container_graph_right_used = True
592
+ heads_placeholder.write('##### Heads')
593
+ head = self.draw_attn_info(edge, container_graph)
594
+ with container_token_dynamics:
595
+ self.draw_attention_dynamics(edge.target, head)
596
+ container_token_dynamics_used = True
597
+ elif node is not None and node.type == NodeType.FFN:
598
+ with container_graph_right:
599
+ container_graph_right_used = True
600
+ heads_placeholder.write('##### Neurons')
601
+ neuron = self.draw_ffn_info(node)
602
+ with container_token_dynamics:
603
+ self.draw_ffn_dynamics(node, neuron)
604
+ container_token_dynamics_used = True
605
+
606
+ if node is not None and node.is_in_residual_stream():
607
+ self.draw_top_tokens(
608
+ node,
609
+ container_top_tokens,
610
+ container_token_dynamics if not container_token_dynamics_used else None,
611
+ )
612
+ container_top_tokens_used = True
613
+ container_token_dynamics_used = True
614
+ finally:
615
+ if not container_graph_right_used:
616
+ st_placeholder('Click on an edge to see head contributions. \n\n'
617
+ 'Or click on FFN to see individual neuron contributions.', container_graph_right, height=1100)
618
+ if not container_top_tokens_used:
619
+ st_placeholder('Select a node from residual stream to see its top tokens.', container_top_tokens, height=1100)
620
+ if not container_token_dynamics_used:
621
+ st_placeholder('Select a node to see its promoted tokens.', container_token_dynamics, height=1100)
622
+
623
+
624
+ def run(self):
625
+
626
+ with st.sidebar.expander("About", expanded=True):
627
+ if self._config.demo_mode:
628
+ st.caption("""
629
+ The app is deployed in Demo Mode, thus only predefined models and inputs are available.\n
630
+ You can still install the app locally and use your own models and inputs.\n
631
+ See https://github.com/facebookresearch/llm-transparency-tool for more information.
632
+ """)
633
+
634
+ self.draw_controls()
635
+
636
+ if not self.model_name:
637
+ st.warning("No model selected")
638
+ st.stop()
639
+
640
+ if self.sentence is None:
641
+ st.warning("No sentence selected")
642
+ else:
643
+ with torch.inference_mode():
644
+ self.run_inference()
645
+
646
+ self.draw_graph_and_selection()
647
+
648
+
649
+ if __name__ == "__main__":
650
+ top_parser = argparse.ArgumentParser()
651
+ top_parser.add_argument("config_file")
652
+ args = top_parser.parse_args()
653
+
654
+ parser = HfArgumentParser([LlmViewerConfig])
655
+ config = parser.parse_json_file(args.config_file)[0]
656
+
657
+ with SystemMonitor(config.debug) as prof:
658
+ app = App(config)
659
+ app.run()
llm_transparency_tool/server/graph_selection.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Any, Dict, Optional
9
+
10
+ from llm_transparency_tool.routes.graph_node import GraphNode, NodeType
11
+
12
+
13
+ class UiGraphNode(GraphNode):
14
+ @staticmethod
15
+ def from_json(json: Dict[str, Any]) -> Optional["UiGraphNode"]:
16
+ try:
17
+ layer = json["cell"]["layer"]
18
+ token = json["cell"]["token"]
19
+ type = NodeType(json["item"])
20
+ return UiGraphNode(layer, token, type)
21
+ except (TypeError, KeyError):
22
+ return None
23
+
24
+
25
+ @dataclass
26
+ class UiGraphEdge:
27
+ source: UiGraphNode
28
+ target: UiGraphNode
29
+ weight: float
30
+
31
+ @staticmethod
32
+ def from_json(json: Dict[str, Any]) -> Optional["UiGraphEdge"]:
33
+ try:
34
+ source = UiGraphNode.from_json(json["from"])
35
+ target = UiGraphNode.from_json(json["to"])
36
+ if source is None or target is None:
37
+ return None
38
+ weight = float(json["weight"])
39
+ return UiGraphEdge(source, target, weight)
40
+ except (TypeError, KeyError):
41
+ return None
42
+
43
+
44
+ @dataclass
45
+ class GraphSelection:
46
+ node: Optional[UiGraphNode]
47
+ edge: Optional[UiGraphEdge]
48
+
49
+ @staticmethod
50
+ def from_json(json: Dict[str, Any]) -> Optional["GraphSelection"]:
51
+ try:
52
+ node = UiGraphNode.from_json(json["node"])
53
+ edge = UiGraphEdge.from_json(json["edge"])
54
+ return GraphSelection(node, edge)
55
+ except (TypeError, KeyError):
56
+ return None
llm_transparency_tool/server/monitor.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import streamlit as st
9
+ from pyinstrument import Profiler
10
+ from typing import Dict
11
+ import pandas as pd
12
+
13
+
14
+ @st.cache_resource(max_entries=1, show_spinner=False)
15
+ def init_gpu_memory():
16
+ """
17
+ When CUDA is initialized, it occupies some memory on the GPU thus this overhead
18
+ can sometimes make it difficult to understand how much memory is actually used by
19
+ the model.
20
+
21
+ This function is used to initialize CUDA and measure the overhead.
22
+ """
23
+ if not torch.cuda.is_available():
24
+ return {}
25
+
26
+ # lets init torch gpu for a moment
27
+ gpu_memory_overhead = {}
28
+ for i in range(torch.cuda.device_count()):
29
+ torch.ones(1).cuda(i)
30
+ free, total = torch.cuda.mem_get_info(i)
31
+ occupied = total - free
32
+ gpu_memory_overhead[i] = occupied
33
+
34
+ return gpu_memory_overhead
35
+
36
+
37
+ class SystemMonitor:
38
+ """
39
+ This class is used to monitor the system resources such as GPU memory and CPU
40
+ usage. It uses the pyinstrument library to profile the code and measure the
41
+ execution time of different parts of the code.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ enabled: bool = False,
47
+ ):
48
+ self.enabled = enabled
49
+ self.profiler = Profiler()
50
+ self.overhead: Dict[int, int]
51
+
52
+ def __enter__(self):
53
+ if not self.enabled:
54
+ return
55
+
56
+ self.overhead = init_gpu_memory()
57
+
58
+ self.profiler.__enter__()
59
+
60
+ def __exit__(self, exc_type, exc_value, traceback):
61
+ if not self.enabled:
62
+ return
63
+
64
+ self.profiler.__exit__(exc_type, exc_value, traceback)
65
+
66
+ self.report_gpu_usage()
67
+ self.report_profiler()
68
+
69
+ with st.expander("Session state"):
70
+ st.write(st.session_state)
71
+
72
+ return None
73
+
74
+ def report_gpu_usage(self):
75
+
76
+ if not torch.cuda.is_available():
77
+ return
78
+
79
+ data = []
80
+
81
+ for i in range(torch.cuda.device_count()):
82
+ free, total = torch.cuda.mem_get_info(i)
83
+ occupied = total - free
84
+ data.append({
85
+ 'overhead': self.overhead[i],
86
+ 'occupied': occupied - self.overhead[i],
87
+ 'free': free,
88
+ })
89
+ df = pd.DataFrame(data, columns=["overhead", "occupied", "free"])
90
+
91
+ with st.sidebar.expander("System"):
92
+ st.write("GPU memory on server")
93
+ df /= 1024 ** 3 # Convert to GB
94
+ st.bar_chart(df, width=200, height=200, color=["#fefefe", "#84c9ff", "#fe2b2b"])
95
+
96
+ def report_profiler(self):
97
+ html_code = self.profiler.output_html()
98
+ with st.expander("Profiler", expanded=False):
99
+ st.components.v1.html(html_code, height=1000, scrolling=True)
llm_transparency_tool/server/styles.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+
9
+ import matplotlib
10
+
11
+ # Unofficial way do make the padding a bit smaller.
12
+ margins_css = """
13
+ <style>
14
+ .main > div {
15
+ padding: 1rem;
16
+ padding-top: 2rem; # Still need this gap for the top bar
17
+ gap: 0rem;
18
+ }
19
+
20
+ section[data-testid="stSidebar"] {
21
+ width: 300px !important; # Set the width to your desired value
22
+ }
23
+ </style>
24
+ """
25
+
26
+
27
+ @dataclass
28
+ class RenderSettings:
29
+ column_proportions = [50, 30]
30
+
31
+ # We don't know the actual height. This will be used in order to compute the table
32
+ # viewport height when needed.
33
+ table_cell_height = 36
34
+
35
+ n_top_tokens = 30
36
+ n_promoted_tokens = 15
37
+ n_suppressed_tokens = 15
38
+
39
+ n_top_neurons = 20
40
+
41
+ attention_color_map = "Blues"
42
+
43
+ no_model_alt_text = "<no model selected>"
44
+
45
+
46
+ def string_to_display(s: str) -> str:
47
+ return s.replace(" ", "Β·")
48
+
49
+
50
+ def logits_color_map(positive_and_negative: bool) -> matplotlib.colors.Colormap:
51
+ background_colors = {
52
+ "red": [
53
+ [0.0, 0.40, 0.40],
54
+ [0.1, 0.69, 0.69],
55
+ [0.2, 0.83, 0.83],
56
+ [0.3, 0.95, 0.95],
57
+ [0.4, 0.99, 0.99],
58
+ [0.5, 1.0, 1.0],
59
+ [0.6, 0.90, 0.90],
60
+ [0.7, 0.72, 0.72],
61
+ [0.8, 0.49, 0.49],
62
+ [0.9, 0.30, 0.30],
63
+ [1.0, 0.15, 0.15],
64
+ ],
65
+ "green": [
66
+ [0.0, 0.0, 0.0],
67
+ [0.1, 0.09, 0.09],
68
+ [0.2, 0.37, 0.37],
69
+ [0.3, 0.64, 0.64],
70
+ [0.4, 0.85, 0.85],
71
+ [0.5, 1.0, 1.0],
72
+ [0.6, 0.96, 0.96],
73
+ [0.7, 0.88, 0.88],
74
+ [0.8, 0.73, 0.73],
75
+ [0.9, 0.57, 0.57],
76
+ [1.0, 0.39, 0.39],
77
+ ],
78
+ "blue": [
79
+ [0.0, 0.12, 0.12],
80
+ [0.1, 0.16, 0.16],
81
+ [0.2, 0.30, 0.30],
82
+ [0.3, 0.50, 0.50],
83
+ [0.4, 0.78, 0.78],
84
+ [0.5, 1.0, 1.0],
85
+ [0.6, 0.81, 0.81],
86
+ [0.7, 0.52, 0.52],
87
+ [0.8, 0.25, 0.25],
88
+ [0.9, 0.12, 0.12],
89
+ [1.0, 0.09, 0.09],
90
+ ],
91
+ }
92
+
93
+ if not positive_and_negative:
94
+ # Stretch the top part to the whole range
95
+ new_colors = {}
96
+ for channel, colors in background_colors.items():
97
+ new_colors[channel] = [
98
+ [(value - 0.5) * 2, color, color]
99
+ for value, color, _ in colors
100
+ if value >= 0.5
101
+ ]
102
+ background_colors = new_colors
103
+
104
+ return matplotlib.colors.LinearSegmentedColormap(
105
+ f"RdYG-{positive_and_negative}",
106
+ background_colors,
107
+ )
llm_transparency_tool/server/utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import uuid
8
+ from typing import List, Optional, Tuple
9
+
10
+ import networkx as nx
11
+ import streamlit as st
12
+ import torch
13
+ import transformers
14
+
15
+ import llm_transparency_tool.routes.graph
16
+ from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
17
+ from llm_transparency_tool.models.transparent_llm import TransparentLlm
18
+
19
+ GPU = "gpu"
20
+ CPU = "cpu"
21
+
22
+ # This variable is for expressing the idea that batch_id = 0, but make it more
23
+ # readable than just 0.
24
+ B0 = 0
25
+
26
+
27
+ def possible_devices() -> List[str]:
28
+ devices = []
29
+ if torch.cuda.is_available():
30
+ devices.append("gpu")
31
+ devices.append("cpu")
32
+ return devices
33
+
34
+
35
+ def load_dataset(filename) -> List[str]:
36
+ with open(filename) as f:
37
+ dataset = [s.strip("\n") for s in f.readlines()]
38
+ print(f"Loaded {len(dataset)} sentences from {filename}")
39
+ return dataset
40
+
41
+
42
+ @st.cache_resource(
43
+ hash_funcs={
44
+ TransformerLensTransparentLlm: id
45
+ }
46
+ )
47
+ def load_model(
48
+ model_name: str,
49
+ _device: str,
50
+ _model_path: Optional[str] = None,
51
+ _dtype: torch.dtype = torch.float32,
52
+ ) -> TransparentLlm:
53
+ """
54
+ Returns the loaded model along with its key. The key is just a unique string which
55
+ can be used later to identify if the model has changed.
56
+ """
57
+ assert _device in possible_devices()
58
+
59
+ causal_lm = None
60
+ tokenizer = None
61
+
62
+ tl_lm = TransformerLensTransparentLlm(
63
+ model_name=model_name,
64
+ hf_model=causal_lm,
65
+ tokenizer=tokenizer,
66
+ device=_device,
67
+ dtype=_dtype,
68
+ )
69
+
70
+ return tl_lm
71
+
72
+
73
+ def run_model(model: TransparentLlm, sentence: str) -> None:
74
+ print(f"Running inference for '{sentence}'")
75
+ model.run([sentence])
76
+
77
+
78
+ def load_model_with_session_caching(
79
+ **kwargs,
80
+ ) -> Tuple[TransparentLlm, str]:
81
+ return load_model(**kwargs)
82
+
83
+ def run_model_with_session_caching(
84
+ _model: TransparentLlm,
85
+ model_key: str,
86
+ sentence: str,
87
+ ):
88
+ LAST_RUN_MODEL_KEY = "last_run_model_key"
89
+ LAST_RUN_SENTENCE = "last_run_sentence"
90
+ state = st.session_state
91
+
92
+ if (
93
+ state.get(LAST_RUN_MODEL_KEY, None) == model_key
94
+ and state.get(LAST_RUN_SENTENCE, None) == sentence
95
+ ):
96
+ return
97
+
98
+ run_model(_model, sentence)
99
+ state[LAST_RUN_MODEL_KEY] = model_key
100
+ state[LAST_RUN_SENTENCE] = sentence
101
+
102
+
103
+ @st.cache_resource(
104
+ hash_funcs={
105
+ TransformerLensTransparentLlm: id
106
+ }
107
+ )
108
+ def get_contribution_graph(
109
+ model: TransparentLlm, # TODO bug here
110
+ model_key: str,
111
+ tokens: List[str],
112
+ threshold: float,
113
+ ) -> nx.Graph:
114
+ """
115
+ The `model_key` and `tokens` are used only for caching. The model itself is not
116
+ hashed, hence the `_` in the beginning.
117
+ """
118
+ return llm_transparency_tool.routes.graph.build_full_graph(
119
+ model,
120
+ B0,
121
+ threshold,
122
+ )
123
+
124
+
125
+ def st_placeholder(
126
+ text: str,
127
+ container=st,
128
+ border: bool = True,
129
+ height: Optional[int] = 500,
130
+ ):
131
+ empty = container.empty()
132
+ empty.container(border=border, height=height).write(f'<small>{text}</small>', unsafe_allow_html=True)
133
+ return empty
pyproject.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [tool.black]
2
+ line-length = 120
sample_input.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ The war lasted from the year 1732 to the year 17
2
+ 5 + 4 = 9, 2 + 3 =
3
+ When Mary and John went to the store, John gave a drink to
setup.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from setuptools import setup
8
+
9
+ setup(
10
+ name="llm_transparency_tool",
11
+ version="0.1",
12
+ packages=["llm_transparency_tool"],
13
+ )