coolmanx commited on
Commit
76aa260
1 Parent(s): 014a66e
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +18 -18
  2. .env.example +12 -12
  3. .eslintignore +13 -13
  4. .eslintrc.cjs +31 -31
  5. .gitattributes +2 -2
  6. .gitignore +309 -309
  7. .npmrc +1 -1
  8. .prettierignore +315 -315
  9. .prettierrc +9 -9
  10. CHANGELOG.md +0 -0
  11. CODE_OF_CONDUCT.md +77 -77
  12. Caddyfile.localhost +64 -64
  13. Dockerfile +176 -176
  14. INSTALLATION.md +35 -35
  15. LICENSE +21 -21
  16. Makefile +33 -33
  17. README.md +231 -231
  18. TROUBLESHOOTING.md +36 -36
  19. backend/.dockerignore +13 -13
  20. backend/.gitignore +11 -11
  21. backend/open_webui/__init__.py +77 -77
  22. backend/open_webui/alembic.ini +114 -114
  23. backend/open_webui/apps/audio/main.py +639 -639
  24. backend/open_webui/apps/images/main.py +597 -597
  25. backend/open_webui/apps/images/utils/comfyui.py +186 -186
  26. backend/open_webui/apps/ollama/main.py +1121 -1120
  27. backend/open_webui/apps/openai/main.py +554 -554
  28. backend/open_webui/apps/retrieval/loaders/main.py +190 -190
  29. backend/open_webui/apps/retrieval/main.py +1326 -1326
  30. backend/open_webui/apps/retrieval/models/colbert.py +81 -81
  31. backend/open_webui/apps/retrieval/utils.py +573 -573
  32. backend/open_webui/apps/retrieval/vector/connector.py +14 -14
  33. backend/open_webui/apps/retrieval/vector/dbs/chroma.py +161 -161
  34. backend/open_webui/apps/retrieval/vector/dbs/milvus.py +286 -286
  35. backend/open_webui/apps/retrieval/vector/dbs/qdrant.py +179 -179
  36. backend/open_webui/apps/retrieval/vector/main.py +19 -19
  37. backend/open_webui/apps/retrieval/web/brave.py +42 -42
  38. backend/open_webui/apps/retrieval/web/duckduckgo.py +50 -50
  39. backend/open_webui/apps/retrieval/web/google_pse.py +50 -50
  40. backend/open_webui/apps/retrieval/web/jina_search.py +41 -41
  41. backend/open_webui/apps/retrieval/web/main.py +22 -22
  42. backend/open_webui/apps/retrieval/web/searchapi.py +48 -48
  43. backend/open_webui/apps/retrieval/web/searxng.py +91 -91
  44. backend/open_webui/apps/retrieval/web/serper.py +43 -43
  45. backend/open_webui/apps/retrieval/web/serply.py +69 -69
  46. backend/open_webui/apps/retrieval/web/serpstack.py +48 -48
  47. backend/open_webui/apps/retrieval/web/tavily.py +38 -38
  48. backend/open_webui/apps/retrieval/web/testdata/brave.json +0 -0
  49. backend/open_webui/apps/retrieval/web/testdata/google_pse.json +442 -442
  50. backend/open_webui/apps/retrieval/web/testdata/searchapi.json +0 -0
.dockerignore CHANGED
@@ -1,19 +1,19 @@
1
- .github
2
- .DS_Store
3
- docs
4
- kubernetes
5
- node_modules
6
- /.svelte-kit
7
- /package
8
- .env
9
- .env.*
10
- vite.config.js.timestamp-*
11
- vite.config.ts.timestamp-*
12
- __pycache__
13
- .idea
14
- venv
15
- _old
16
- uploads
17
- .ipynb_checkpoints
18
- **/*.db
19
  _test
 
1
+ .github
2
+ .DS_Store
3
+ docs
4
+ kubernetes
5
+ node_modules
6
+ /.svelte-kit
7
+ /package
8
+ .env
9
+ .env.*
10
+ vite.config.js.timestamp-*
11
+ vite.config.ts.timestamp-*
12
+ __pycache__
13
+ .idea
14
+ venv
15
+ _old
16
+ uploads
17
+ .ipynb_checkpoints
18
+ **/*.db
19
  _test
.env.example CHANGED
@@ -1,13 +1,13 @@
1
- # Ollama URL for the backend to connect
2
- # The path '/ollama' will be redirected to the specified backend URL
3
- OLLAMA_BASE_URL='http://localhost:11434'
4
-
5
- OPENAI_API_BASE_URL=''
6
- OPENAI_API_KEY=''
7
-
8
- # AUTOMATIC1111_BASE_URL="http://localhost:7860"
9
-
10
- # DO NOT TRACK
11
- SCARF_NO_ANALYTICS=true
12
- DO_NOT_TRACK=true
13
  ANONYMIZED_TELEMETRY=false
 
1
+ # Ollama URL for the backend to connect
2
+ # The path '/ollama' will be redirected to the specified backend URL
3
+ OLLAMA_BASE_URL='http://localhost:11434'
4
+
5
+ OPENAI_API_BASE_URL=''
6
+ OPENAI_API_KEY=''
7
+
8
+ # AUTOMATIC1111_BASE_URL="http://localhost:7860"
9
+
10
+ # DO NOT TRACK
11
+ SCARF_NO_ANALYTICS=true
12
+ DO_NOT_TRACK=true
13
  ANONYMIZED_TELEMETRY=false
.eslintignore CHANGED
@@ -1,13 +1,13 @@
1
- .DS_Store
2
- node_modules
3
- /build
4
- /.svelte-kit
5
- /package
6
- .env
7
- .env.*
8
- !.env.example
9
-
10
- # Ignore files for PNPM, NPM and YARN
11
- pnpm-lock.yaml
12
- package-lock.json
13
- yarn.lock
 
1
+ .DS_Store
2
+ node_modules
3
+ /build
4
+ /.svelte-kit
5
+ /package
6
+ .env
7
+ .env.*
8
+ !.env.example
9
+
10
+ # Ignore files for PNPM, NPM and YARN
11
+ pnpm-lock.yaml
12
+ package-lock.json
13
+ yarn.lock
.eslintrc.cjs CHANGED
@@ -1,31 +1,31 @@
1
- module.exports = {
2
- root: true,
3
- extends: [
4
- 'eslint:recommended',
5
- 'plugin:@typescript-eslint/recommended',
6
- 'plugin:svelte/recommended',
7
- 'plugin:cypress/recommended',
8
- 'prettier'
9
- ],
10
- parser: '@typescript-eslint/parser',
11
- plugins: ['@typescript-eslint'],
12
- parserOptions: {
13
- sourceType: 'module',
14
- ecmaVersion: 2020,
15
- extraFileExtensions: ['.svelte']
16
- },
17
- env: {
18
- browser: true,
19
- es2017: true,
20
- node: true
21
- },
22
- overrides: [
23
- {
24
- files: ['*.svelte'],
25
- parser: 'svelte-eslint-parser',
26
- parserOptions: {
27
- parser: '@typescript-eslint/parser'
28
- }
29
- }
30
- ]
31
- };
 
1
+ module.exports = {
2
+ root: true,
3
+ extends: [
4
+ 'eslint:recommended',
5
+ 'plugin:@typescript-eslint/recommended',
6
+ 'plugin:svelte/recommended',
7
+ 'plugin:cypress/recommended',
8
+ 'prettier'
9
+ ],
10
+ parser: '@typescript-eslint/parser',
11
+ plugins: ['@typescript-eslint'],
12
+ parserOptions: {
13
+ sourceType: 'module',
14
+ ecmaVersion: 2020,
15
+ extraFileExtensions: ['.svelte']
16
+ },
17
+ env: {
18
+ browser: true,
19
+ es2017: true,
20
+ node: true
21
+ },
22
+ overrides: [
23
+ {
24
+ files: ['*.svelte'],
25
+ parser: 'svelte-eslint-parser',
26
+ parserOptions: {
27
+ parser: '@typescript-eslint/parser'
28
+ }
29
+ }
30
+ ]
31
+ };
.gitattributes CHANGED
@@ -1,2 +1,2 @@
1
- *.sh text eol=lf
2
- *.ttf filter=lfs diff=lfs merge=lfs -text
 
1
+ *.sh text eol=lf
2
+ *.ttf filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,309 +1,309 @@
1
- .DS_Store
2
- node_modules
3
- /build
4
- /.svelte-kit
5
- /package
6
- .env
7
- .env.*
8
- !.env.example
9
- vite.config.js.timestamp-*
10
- vite.config.ts.timestamp-*
11
- # Byte-compiled / optimized / DLL files
12
- __pycache__/
13
- *.py[cod]
14
- *$py.class
15
-
16
- # C extensions
17
- *.so
18
-
19
- # Pyodide distribution
20
- static/pyodide/*
21
- !static/pyodide/pyodide-lock.json
22
-
23
- # Distribution / packaging
24
- .Python
25
- build/
26
- develop-eggs/
27
- dist/
28
- downloads/
29
- eggs/
30
- .eggs/
31
- lib64/
32
- parts/
33
- sdist/
34
- var/
35
- wheels/
36
- share/python-wheels/
37
- *.egg-info/
38
- .installed.cfg
39
- *.egg
40
- MANIFEST
41
-
42
- # PyInstaller
43
- # Usually these files are written by a python script from a template
44
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
45
- *.manifest
46
- *.spec
47
-
48
- # Installer logs
49
- pip-log.txt
50
- pip-delete-this-directory.txt
51
-
52
- # Unit test / coverage reports
53
- htmlcov/
54
- .tox/
55
- .nox/
56
- .coverage
57
- .coverage.*
58
- .cache
59
- nosetests.xml
60
- coverage.xml
61
- *.cover
62
- *.py,cover
63
- .hypothesis/
64
- .pytest_cache/
65
- cover/
66
-
67
- # Translations
68
- *.mo
69
- *.pot
70
-
71
- # Django stuff:
72
- *.log
73
- local_settings.py
74
- db.sqlite3
75
- db.sqlite3-journal
76
-
77
- # Flask stuff:
78
- instance/
79
- .webassets-cache
80
-
81
- # Scrapy stuff:
82
- .scrapy
83
-
84
- # Sphinx documentation
85
- docs/_build/
86
-
87
- # PyBuilder
88
- .pybuilder/
89
- target/
90
-
91
- # Jupyter Notebook
92
- .ipynb_checkpoints
93
-
94
- # IPython
95
- profile_default/
96
- ipython_config.py
97
-
98
- # pyenv
99
- # For a library or package, you might want to ignore these files since the code is
100
- # intended to run in multiple environments; otherwise, check them in:
101
- # .python-version
102
-
103
- # pipenv
104
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
105
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
106
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
107
- # install all needed dependencies.
108
- #Pipfile.lock
109
-
110
- # poetry
111
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
112
- # This is especially recommended for binary packages to ensure reproducibility, and is more
113
- # commonly ignored for libraries.
114
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
115
- #poetry.lock
116
-
117
- # pdm
118
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
119
- #pdm.lock
120
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
121
- # in version control.
122
- # https://pdm.fming.dev/#use-with-ide
123
- .pdm.toml
124
-
125
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
126
- __pypackages__/
127
-
128
- # Celery stuff
129
- celerybeat-schedule
130
- celerybeat.pid
131
-
132
- # SageMath parsed files
133
- *.sage.py
134
-
135
- # Environments
136
- .env
137
- .venv
138
- env/
139
- venv/
140
- ENV/
141
- env.bak/
142
- venv.bak/
143
-
144
- # Spyder project settings
145
- .spyderproject
146
- .spyproject
147
-
148
- # Rope project settings
149
- .ropeproject
150
-
151
- # mkdocs documentation
152
- /site
153
-
154
- # mypy
155
- .mypy_cache/
156
- .dmypy.json
157
- dmypy.json
158
-
159
- # Pyre type checker
160
- .pyre/
161
-
162
- # pytype static type analyzer
163
- .pytype/
164
-
165
- # Cython debug symbols
166
- cython_debug/
167
-
168
- # PyCharm
169
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
170
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
171
- # and can be added to the global gitignore or merged into this file. For a more nuclear
172
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
173
- .idea/
174
-
175
- # Logs
176
- logs
177
- *.log
178
- npm-debug.log*
179
- yarn-debug.log*
180
- yarn-error.log*
181
- lerna-debug.log*
182
- .pnpm-debug.log*
183
-
184
- # Diagnostic reports (https://nodejs.org/api/report.html)
185
- report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
186
-
187
- # Runtime data
188
- pids
189
- *.pid
190
- *.seed
191
- *.pid.lock
192
-
193
- # Directory for instrumented libs generated by jscoverage/JSCover
194
- lib-cov
195
-
196
- # Coverage directory used by tools like istanbul
197
- coverage
198
- *.lcov
199
-
200
- # nyc test coverage
201
- .nyc_output
202
-
203
- # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
204
- .grunt
205
-
206
- # Bower dependency directory (https://bower.io/)
207
- bower_components
208
-
209
- # node-waf configuration
210
- .lock-wscript
211
-
212
- # Compiled binary addons (https://nodejs.org/api/addons.html)
213
- build/Release
214
-
215
- # Dependency directories
216
- node_modules/
217
- jspm_packages/
218
-
219
- # Snowpack dependency directory (https://snowpack.dev/)
220
- web_modules/
221
-
222
- # TypeScript cache
223
- *.tsbuildinfo
224
-
225
- # Optional npm cache directory
226
- .npm
227
-
228
- # Optional eslint cache
229
- .eslintcache
230
-
231
- # Optional stylelint cache
232
- .stylelintcache
233
-
234
- # Microbundle cache
235
- .rpt2_cache/
236
- .rts2_cache_cjs/
237
- .rts2_cache_es/
238
- .rts2_cache_umd/
239
-
240
- # Optional REPL history
241
- .node_repl_history
242
-
243
- # Output of 'npm pack'
244
- *.tgz
245
-
246
- # Yarn Integrity file
247
- .yarn-integrity
248
-
249
- # dotenv environment variable files
250
- .env
251
- .env.development.local
252
- .env.test.local
253
- .env.production.local
254
- .env.local
255
-
256
- # parcel-bundler cache (https://parceljs.org/)
257
- .cache
258
- .parcel-cache
259
-
260
- # Next.js build output
261
- .next
262
- out
263
-
264
- # Nuxt.js build / generate output
265
- .nuxt
266
- dist
267
-
268
- # Gatsby files
269
- .cache/
270
- # Comment in the public line in if your project uses Gatsby and not Next.js
271
- # https://nextjs.org/blog/next-9-1#public-directory-support
272
- # public
273
-
274
- # vuepress build output
275
- .vuepress/dist
276
-
277
- # vuepress v2.x temp and cache directory
278
- .temp
279
- .cache
280
-
281
- # Docusaurus cache and generated files
282
- .docusaurus
283
-
284
- # Serverless directories
285
- .serverless/
286
-
287
- # FuseBox cache
288
- .fusebox/
289
-
290
- # DynamoDB Local files
291
- .dynamodb/
292
-
293
- # TernJS port file
294
- .tern-port
295
-
296
- # Stores VSCode versions used for testing VSCode extensions
297
- .vscode-test
298
-
299
- # yarn v2
300
- .yarn/cache
301
- .yarn/unplugged
302
- .yarn/build-state.yml
303
- .yarn/install-state.gz
304
- .pnp.*
305
-
306
- # cypress artifacts
307
- cypress/videos
308
- cypress/screenshots
309
- .vscode/settings.json
 
1
+ .DS_Store
2
+ node_modules
3
+ /build
4
+ /.svelte-kit
5
+ /package
6
+ .env
7
+ .env.*
8
+ !.env.example
9
+ vite.config.js.timestamp-*
10
+ vite.config.ts.timestamp-*
11
+ # Byte-compiled / optimized / DLL files
12
+ __pycache__/
13
+ *.py[cod]
14
+ *$py.class
15
+
16
+ # C extensions
17
+ *.so
18
+
19
+ # Pyodide distribution
20
+ static/pyodide/*
21
+ !static/pyodide/pyodide-lock.json
22
+
23
+ # Distribution / packaging
24
+ .Python
25
+ build/
26
+ develop-eggs/
27
+ dist/
28
+ downloads/
29
+ eggs/
30
+ .eggs/
31
+ lib64/
32
+ parts/
33
+ sdist/
34
+ var/
35
+ wheels/
36
+ share/python-wheels/
37
+ *.egg-info/
38
+ .installed.cfg
39
+ *.egg
40
+ MANIFEST
41
+
42
+ # PyInstaller
43
+ # Usually these files are written by a python script from a template
44
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
45
+ *.manifest
46
+ *.spec
47
+
48
+ # Installer logs
49
+ pip-log.txt
50
+ pip-delete-this-directory.txt
51
+
52
+ # Unit test / coverage reports
53
+ htmlcov/
54
+ .tox/
55
+ .nox/
56
+ .coverage
57
+ .coverage.*
58
+ .cache
59
+ nosetests.xml
60
+ coverage.xml
61
+ *.cover
62
+ *.py,cover
63
+ .hypothesis/
64
+ .pytest_cache/
65
+ cover/
66
+
67
+ # Translations
68
+ *.mo
69
+ *.pot
70
+
71
+ # Django stuff:
72
+ *.log
73
+ local_settings.py
74
+ db.sqlite3
75
+ db.sqlite3-journal
76
+
77
+ # Flask stuff:
78
+ instance/
79
+ .webassets-cache
80
+
81
+ # Scrapy stuff:
82
+ .scrapy
83
+
84
+ # Sphinx documentation
85
+ docs/_build/
86
+
87
+ # PyBuilder
88
+ .pybuilder/
89
+ target/
90
+
91
+ # Jupyter Notebook
92
+ .ipynb_checkpoints
93
+
94
+ # IPython
95
+ profile_default/
96
+ ipython_config.py
97
+
98
+ # pyenv
99
+ # For a library or package, you might want to ignore these files since the code is
100
+ # intended to run in multiple environments; otherwise, check them in:
101
+ # .python-version
102
+
103
+ # pipenv
104
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
105
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
106
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
107
+ # install all needed dependencies.
108
+ #Pipfile.lock
109
+
110
+ # poetry
111
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
112
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
113
+ # commonly ignored for libraries.
114
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
115
+ #poetry.lock
116
+
117
+ # pdm
118
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
119
+ #pdm.lock
120
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
121
+ # in version control.
122
+ # https://pdm.fming.dev/#use-with-ide
123
+ .pdm.toml
124
+
125
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
126
+ __pypackages__/
127
+
128
+ # Celery stuff
129
+ celerybeat-schedule
130
+ celerybeat.pid
131
+
132
+ # SageMath parsed files
133
+ *.sage.py
134
+
135
+ # Environments
136
+ .env
137
+ .venv
138
+ env/
139
+ venv/
140
+ ENV/
141
+ env.bak/
142
+ venv.bak/
143
+
144
+ # Spyder project settings
145
+ .spyderproject
146
+ .spyproject
147
+
148
+ # Rope project settings
149
+ .ropeproject
150
+
151
+ # mkdocs documentation
152
+ /site
153
+
154
+ # mypy
155
+ .mypy_cache/
156
+ .dmypy.json
157
+ dmypy.json
158
+
159
+ # Pyre type checker
160
+ .pyre/
161
+
162
+ # pytype static type analyzer
163
+ .pytype/
164
+
165
+ # Cython debug symbols
166
+ cython_debug/
167
+
168
+ # PyCharm
169
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
170
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
171
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
172
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
173
+ .idea/
174
+
175
+ # Logs
176
+ logs
177
+ *.log
178
+ npm-debug.log*
179
+ yarn-debug.log*
180
+ yarn-error.log*
181
+ lerna-debug.log*
182
+ .pnpm-debug.log*
183
+
184
+ # Diagnostic reports (https://nodejs.org/api/report.html)
185
+ report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
186
+
187
+ # Runtime data
188
+ pids
189
+ *.pid
190
+ *.seed
191
+ *.pid.lock
192
+
193
+ # Directory for instrumented libs generated by jscoverage/JSCover
194
+ lib-cov
195
+
196
+ # Coverage directory used by tools like istanbul
197
+ coverage
198
+ *.lcov
199
+
200
+ # nyc test coverage
201
+ .nyc_output
202
+
203
+ # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
204
+ .grunt
205
+
206
+ # Bower dependency directory (https://bower.io/)
207
+ bower_components
208
+
209
+ # node-waf configuration
210
+ .lock-wscript
211
+
212
+ # Compiled binary addons (https://nodejs.org/api/addons.html)
213
+ build/Release
214
+
215
+ # Dependency directories
216
+ node_modules/
217
+ jspm_packages/
218
+
219
+ # Snowpack dependency directory (https://snowpack.dev/)
220
+ web_modules/
221
+
222
+ # TypeScript cache
223
+ *.tsbuildinfo
224
+
225
+ # Optional npm cache directory
226
+ .npm
227
+
228
+ # Optional eslint cache
229
+ .eslintcache
230
+
231
+ # Optional stylelint cache
232
+ .stylelintcache
233
+
234
+ # Microbundle cache
235
+ .rpt2_cache/
236
+ .rts2_cache_cjs/
237
+ .rts2_cache_es/
238
+ .rts2_cache_umd/
239
+
240
+ # Optional REPL history
241
+ .node_repl_history
242
+
243
+ # Output of 'npm pack'
244
+ *.tgz
245
+
246
+ # Yarn Integrity file
247
+ .yarn-integrity
248
+
249
+ # dotenv environment variable files
250
+ .env
251
+ .env.development.local
252
+ .env.test.local
253
+ .env.production.local
254
+ .env.local
255
+
256
+ # parcel-bundler cache (https://parceljs.org/)
257
+ .cache
258
+ .parcel-cache
259
+
260
+ # Next.js build output
261
+ .next
262
+ out
263
+
264
+ # Nuxt.js build / generate output
265
+ .nuxt
266
+ dist
267
+
268
+ # Gatsby files
269
+ .cache/
270
+ # Comment in the public line in if your project uses Gatsby and not Next.js
271
+ # https://nextjs.org/blog/next-9-1#public-directory-support
272
+ # public
273
+
274
+ # vuepress build output
275
+ .vuepress/dist
276
+
277
+ # vuepress v2.x temp and cache directory
278
+ .temp
279
+ .cache
280
+
281
+ # Docusaurus cache and generated files
282
+ .docusaurus
283
+
284
+ # Serverless directories
285
+ .serverless/
286
+
287
+ # FuseBox cache
288
+ .fusebox/
289
+
290
+ # DynamoDB Local files
291
+ .dynamodb/
292
+
293
+ # TernJS port file
294
+ .tern-port
295
+
296
+ # Stores VSCode versions used for testing VSCode extensions
297
+ .vscode-test
298
+
299
+ # yarn v2
300
+ .yarn/cache
301
+ .yarn/unplugged
302
+ .yarn/build-state.yml
303
+ .yarn/install-state.gz
304
+ .pnp.*
305
+
306
+ # cypress artifacts
307
+ cypress/videos
308
+ cypress/screenshots
309
+ .vscode/settings.json
.npmrc CHANGED
@@ -1 +1 @@
1
- engine-strict=true
 
1
+ engine-strict=true
.prettierignore CHANGED
@@ -1,316 +1,316 @@
1
- # Ignore files for PNPM, NPM and YARN
2
- pnpm-lock.yaml
3
- package-lock.json
4
- yarn.lock
5
-
6
- kubernetes/
7
-
8
- # Copy of .gitignore
9
- .DS_Store
10
- node_modules
11
- /build
12
- /.svelte-kit
13
- /package
14
- .env
15
- .env.*
16
- !.env.example
17
- vite.config.js.timestamp-*
18
- vite.config.ts.timestamp-*
19
- # Byte-compiled / optimized / DLL files
20
- __pycache__/
21
- *.py[cod]
22
- *$py.class
23
-
24
- # C extensions
25
- *.so
26
-
27
- # Distribution / packaging
28
- .Python
29
- build/
30
- develop-eggs/
31
- dist/
32
- downloads/
33
- eggs/
34
- .eggs/
35
- lib64/
36
- parts/
37
- sdist/
38
- var/
39
- wheels/
40
- share/python-wheels/
41
- *.egg-info/
42
- .installed.cfg
43
- *.egg
44
- MANIFEST
45
-
46
- # PyInstaller
47
- # Usually these files are written by a python script from a template
48
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
- *.manifest
50
- *.spec
51
-
52
- # Installer logs
53
- pip-log.txt
54
- pip-delete-this-directory.txt
55
-
56
- # Unit test / coverage reports
57
- htmlcov/
58
- .tox/
59
- .nox/
60
- .coverage
61
- .coverage.*
62
- .cache
63
- nosetests.xml
64
- coverage.xml
65
- *.cover
66
- *.py,cover
67
- .hypothesis/
68
- .pytest_cache/
69
- cover/
70
-
71
- # Translations
72
- *.mo
73
- *.pot
74
-
75
- # Django stuff:
76
- *.log
77
- local_settings.py
78
- db.sqlite3
79
- db.sqlite3-journal
80
-
81
- # Flask stuff:
82
- instance/
83
- .webassets-cache
84
-
85
- # Scrapy stuff:
86
- .scrapy
87
-
88
- # Sphinx documentation
89
- docs/_build/
90
-
91
- # PyBuilder
92
- .pybuilder/
93
- target/
94
-
95
- # Jupyter Notebook
96
- .ipynb_checkpoints
97
-
98
- # IPython
99
- profile_default/
100
- ipython_config.py
101
-
102
- # pyenv
103
- # For a library or package, you might want to ignore these files since the code is
104
- # intended to run in multiple environments; otherwise, check them in:
105
- # .python-version
106
-
107
- # pipenv
108
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
110
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
111
- # install all needed dependencies.
112
- #Pipfile.lock
113
-
114
- # poetry
115
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
116
- # This is especially recommended for binary packages to ensure reproducibility, and is more
117
- # commonly ignored for libraries.
118
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
119
- #poetry.lock
120
-
121
- # pdm
122
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
123
- #pdm.lock
124
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
125
- # in version control.
126
- # https://pdm.fming.dev/#use-with-ide
127
- .pdm.toml
128
-
129
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
130
- __pypackages__/
131
-
132
- # Celery stuff
133
- celerybeat-schedule
134
- celerybeat.pid
135
-
136
- # SageMath parsed files
137
- *.sage.py
138
-
139
- # Environments
140
- .env
141
- .venv
142
- env/
143
- venv/
144
- ENV/
145
- env.bak/
146
- venv.bak/
147
-
148
- # Spyder project settings
149
- .spyderproject
150
- .spyproject
151
-
152
- # Rope project settings
153
- .ropeproject
154
-
155
- # mkdocs documentation
156
- /site
157
-
158
- # mypy
159
- .mypy_cache/
160
- .dmypy.json
161
- dmypy.json
162
-
163
- # Pyre type checker
164
- .pyre/
165
-
166
- # pytype static type analyzer
167
- .pytype/
168
-
169
- # Cython debug symbols
170
- cython_debug/
171
-
172
- # PyCharm
173
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
174
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
175
- # and can be added to the global gitignore or merged into this file. For a more nuclear
176
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
177
- .idea/
178
-
179
- # Logs
180
- logs
181
- *.log
182
- npm-debug.log*
183
- yarn-debug.log*
184
- yarn-error.log*
185
- lerna-debug.log*
186
- .pnpm-debug.log*
187
-
188
- # Diagnostic reports (https://nodejs.org/api/report.html)
189
- report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
190
-
191
- # Runtime data
192
- pids
193
- *.pid
194
- *.seed
195
- *.pid.lock
196
-
197
- # Directory for instrumented libs generated by jscoverage/JSCover
198
- lib-cov
199
-
200
- # Coverage directory used by tools like istanbul
201
- coverage
202
- *.lcov
203
-
204
- # nyc test coverage
205
- .nyc_output
206
-
207
- # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
208
- .grunt
209
-
210
- # Bower dependency directory (https://bower.io/)
211
- bower_components
212
-
213
- # node-waf configuration
214
- .lock-wscript
215
-
216
- # Compiled binary addons (https://nodejs.org/api/addons.html)
217
- build/Release
218
-
219
- # Dependency directories
220
- node_modules/
221
- jspm_packages/
222
-
223
- # Snowpack dependency directory (https://snowpack.dev/)
224
- web_modules/
225
-
226
- # TypeScript cache
227
- *.tsbuildinfo
228
-
229
- # Optional npm cache directory
230
- .npm
231
-
232
- # Optional eslint cache
233
- .eslintcache
234
-
235
- # Optional stylelint cache
236
- .stylelintcache
237
-
238
- # Microbundle cache
239
- .rpt2_cache/
240
- .rts2_cache_cjs/
241
- .rts2_cache_es/
242
- .rts2_cache_umd/
243
-
244
- # Optional REPL history
245
- .node_repl_history
246
-
247
- # Output of 'npm pack'
248
- *.tgz
249
-
250
- # Yarn Integrity file
251
- .yarn-integrity
252
-
253
- # dotenv environment variable files
254
- .env
255
- .env.development.local
256
- .env.test.local
257
- .env.production.local
258
- .env.local
259
-
260
- # parcel-bundler cache (https://parceljs.org/)
261
- .cache
262
- .parcel-cache
263
-
264
- # Next.js build output
265
- .next
266
- out
267
-
268
- # Nuxt.js build / generate output
269
- .nuxt
270
- dist
271
-
272
- # Gatsby files
273
- .cache/
274
- # Comment in the public line in if your project uses Gatsby and not Next.js
275
- # https://nextjs.org/blog/next-9-1#public-directory-support
276
- # public
277
-
278
- # vuepress build output
279
- .vuepress/dist
280
-
281
- # vuepress v2.x temp and cache directory
282
- .temp
283
- .cache
284
-
285
- # Docusaurus cache and generated files
286
- .docusaurus
287
-
288
- # Serverless directories
289
- .serverless/
290
-
291
- # FuseBox cache
292
- .fusebox/
293
-
294
- # DynamoDB Local files
295
- .dynamodb/
296
-
297
- # TernJS port file
298
- .tern-port
299
-
300
- # Stores VSCode versions used for testing VSCode extensions
301
- .vscode-test
302
-
303
- # yarn v2
304
- .yarn/cache
305
- .yarn/unplugged
306
- .yarn/build-state.yml
307
- .yarn/install-state.gz
308
- .pnp.*
309
-
310
- # cypress artifacts
311
- cypress/videos
312
- cypress/screenshots
313
-
314
-
315
-
316
  /static/*
 
1
+ # Ignore files for PNPM, NPM and YARN
2
+ pnpm-lock.yaml
3
+ package-lock.json
4
+ yarn.lock
5
+
6
+ kubernetes/
7
+
8
+ # Copy of .gitignore
9
+ .DS_Store
10
+ node_modules
11
+ /build
12
+ /.svelte-kit
13
+ /package
14
+ .env
15
+ .env.*
16
+ !.env.example
17
+ vite.config.js.timestamp-*
18
+ vite.config.ts.timestamp-*
19
+ # Byte-compiled / optimized / DLL files
20
+ __pycache__/
21
+ *.py[cod]
22
+ *$py.class
23
+
24
+ # C extensions
25
+ *.so
26
+
27
+ # Distribution / packaging
28
+ .Python
29
+ build/
30
+ develop-eggs/
31
+ dist/
32
+ downloads/
33
+ eggs/
34
+ .eggs/
35
+ lib64/
36
+ parts/
37
+ sdist/
38
+ var/
39
+ wheels/
40
+ share/python-wheels/
41
+ *.egg-info/
42
+ .installed.cfg
43
+ *.egg
44
+ MANIFEST
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ # Installer logs
53
+ pip-log.txt
54
+ pip-delete-this-directory.txt
55
+
56
+ # Unit test / coverage reports
57
+ htmlcov/
58
+ .tox/
59
+ .nox/
60
+ .coverage
61
+ .coverage.*
62
+ .cache
63
+ nosetests.xml
64
+ coverage.xml
65
+ *.cover
66
+ *.py,cover
67
+ .hypothesis/
68
+ .pytest_cache/
69
+ cover/
70
+
71
+ # Translations
72
+ *.mo
73
+ *.pot
74
+
75
+ # Django stuff:
76
+ *.log
77
+ local_settings.py
78
+ db.sqlite3
79
+ db.sqlite3-journal
80
+
81
+ # Flask stuff:
82
+ instance/
83
+ .webassets-cache
84
+
85
+ # Scrapy stuff:
86
+ .scrapy
87
+
88
+ # Sphinx documentation
89
+ docs/_build/
90
+
91
+ # PyBuilder
92
+ .pybuilder/
93
+ target/
94
+
95
+ # Jupyter Notebook
96
+ .ipynb_checkpoints
97
+
98
+ # IPython
99
+ profile_default/
100
+ ipython_config.py
101
+
102
+ # pyenv
103
+ # For a library or package, you might want to ignore these files since the code is
104
+ # intended to run in multiple environments; otherwise, check them in:
105
+ # .python-version
106
+
107
+ # pipenv
108
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
110
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
111
+ # install all needed dependencies.
112
+ #Pipfile.lock
113
+
114
+ # poetry
115
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
116
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
117
+ # commonly ignored for libraries.
118
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
119
+ #poetry.lock
120
+
121
+ # pdm
122
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
123
+ #pdm.lock
124
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
125
+ # in version control.
126
+ # https://pdm.fming.dev/#use-with-ide
127
+ .pdm.toml
128
+
129
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
130
+ __pypackages__/
131
+
132
+ # Celery stuff
133
+ celerybeat-schedule
134
+ celerybeat.pid
135
+
136
+ # SageMath parsed files
137
+ *.sage.py
138
+
139
+ # Environments
140
+ .env
141
+ .venv
142
+ env/
143
+ venv/
144
+ ENV/
145
+ env.bak/
146
+ venv.bak/
147
+
148
+ # Spyder project settings
149
+ .spyderproject
150
+ .spyproject
151
+
152
+ # Rope project settings
153
+ .ropeproject
154
+
155
+ # mkdocs documentation
156
+ /site
157
+
158
+ # mypy
159
+ .mypy_cache/
160
+ .dmypy.json
161
+ dmypy.json
162
+
163
+ # Pyre type checker
164
+ .pyre/
165
+
166
+ # pytype static type analyzer
167
+ .pytype/
168
+
169
+ # Cython debug symbols
170
+ cython_debug/
171
+
172
+ # PyCharm
173
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
174
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
175
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
176
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
177
+ .idea/
178
+
179
+ # Logs
180
+ logs
181
+ *.log
182
+ npm-debug.log*
183
+ yarn-debug.log*
184
+ yarn-error.log*
185
+ lerna-debug.log*
186
+ .pnpm-debug.log*
187
+
188
+ # Diagnostic reports (https://nodejs.org/api/report.html)
189
+ report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
190
+
191
+ # Runtime data
192
+ pids
193
+ *.pid
194
+ *.seed
195
+ *.pid.lock
196
+
197
+ # Directory for instrumented libs generated by jscoverage/JSCover
198
+ lib-cov
199
+
200
+ # Coverage directory used by tools like istanbul
201
+ coverage
202
+ *.lcov
203
+
204
+ # nyc test coverage
205
+ .nyc_output
206
+
207
+ # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
208
+ .grunt
209
+
210
+ # Bower dependency directory (https://bower.io/)
211
+ bower_components
212
+
213
+ # node-waf configuration
214
+ .lock-wscript
215
+
216
+ # Compiled binary addons (https://nodejs.org/api/addons.html)
217
+ build/Release
218
+
219
+ # Dependency directories
220
+ node_modules/
221
+ jspm_packages/
222
+
223
+ # Snowpack dependency directory (https://snowpack.dev/)
224
+ web_modules/
225
+
226
+ # TypeScript cache
227
+ *.tsbuildinfo
228
+
229
+ # Optional npm cache directory
230
+ .npm
231
+
232
+ # Optional eslint cache
233
+ .eslintcache
234
+
235
+ # Optional stylelint cache
236
+ .stylelintcache
237
+
238
+ # Microbundle cache
239
+ .rpt2_cache/
240
+ .rts2_cache_cjs/
241
+ .rts2_cache_es/
242
+ .rts2_cache_umd/
243
+
244
+ # Optional REPL history
245
+ .node_repl_history
246
+
247
+ # Output of 'npm pack'
248
+ *.tgz
249
+
250
+ # Yarn Integrity file
251
+ .yarn-integrity
252
+
253
+ # dotenv environment variable files
254
+ .env
255
+ .env.development.local
256
+ .env.test.local
257
+ .env.production.local
258
+ .env.local
259
+
260
+ # parcel-bundler cache (https://parceljs.org/)
261
+ .cache
262
+ .parcel-cache
263
+
264
+ # Next.js build output
265
+ .next
266
+ out
267
+
268
+ # Nuxt.js build / generate output
269
+ .nuxt
270
+ dist
271
+
272
+ # Gatsby files
273
+ .cache/
274
+ # Comment in the public line in if your project uses Gatsby and not Next.js
275
+ # https://nextjs.org/blog/next-9-1#public-directory-support
276
+ # public
277
+
278
+ # vuepress build output
279
+ .vuepress/dist
280
+
281
+ # vuepress v2.x temp and cache directory
282
+ .temp
283
+ .cache
284
+
285
+ # Docusaurus cache and generated files
286
+ .docusaurus
287
+
288
+ # Serverless directories
289
+ .serverless/
290
+
291
+ # FuseBox cache
292
+ .fusebox/
293
+
294
+ # DynamoDB Local files
295
+ .dynamodb/
296
+
297
+ # TernJS port file
298
+ .tern-port
299
+
300
+ # Stores VSCode versions used for testing VSCode extensions
301
+ .vscode-test
302
+
303
+ # yarn v2
304
+ .yarn/cache
305
+ .yarn/unplugged
306
+ .yarn/build-state.yml
307
+ .yarn/install-state.gz
308
+ .pnp.*
309
+
310
+ # cypress artifacts
311
+ cypress/videos
312
+ cypress/screenshots
313
+
314
+
315
+
316
  /static/*
.prettierrc CHANGED
@@ -1,9 +1,9 @@
1
- {
2
- "useTabs": true,
3
- "singleQuote": true,
4
- "trailingComma": "none",
5
- "printWidth": 100,
6
- "plugins": ["prettier-plugin-svelte"],
7
- "pluginSearchDirs": ["."],
8
- "overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }]
9
- }
 
1
+ {
2
+ "useTabs": true,
3
+ "singleQuote": true,
4
+ "trailingComma": "none",
5
+ "printWidth": 100,
6
+ "plugins": ["prettier-plugin-svelte"],
7
+ "pluginSearchDirs": ["."],
8
+ "overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }]
9
+ }
CHANGELOG.md CHANGED
The diff for this file is too large to render. See raw diff
 
CODE_OF_CONDUCT.md CHANGED
@@ -1,77 +1,77 @@
1
- # Contributor Covenant Code of Conduct
2
-
3
- ## Our Pledge
4
-
5
- We as members, contributors, and leaders pledge to make participation in our
6
- community a harassment-free experience for everyone, regardless of age, body
7
- size, visible or invisible disability, ethnicity, sex characteristics, gender
8
- identity and expression, level of experience, education, socio-economic status,
9
- nationality, personal appearance, race, religion, or sexual identity
10
- and orientation.
11
-
12
- We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
13
-
14
- ## Our Standards
15
-
16
- Examples of behavior that contribute to a positive environment for our community include:
17
-
18
- - Demonstrating empathy and kindness toward other people
19
- - Being respectful of differing opinions, viewpoints, and experiences
20
- - Giving and gracefully accepting constructive feedback
21
- - Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
22
- - Focusing on what is best not just for us as individuals, but for the overall community
23
-
24
- Examples of unacceptable behavior include:
25
-
26
- - The use of sexualized language or imagery, and sexual attention or advances of any kind
27
- - Trolling, insulting or derogatory comments, and personal or political attacks
28
- - Public or private harassment
29
- - Publishing others' private information, such as a physical or email address, without their explicit permission
30
- - **Spamming of any kind**
31
- - Aggressive sales tactics targeting our community members are strictly prohibited. You can mention your product if it's relevant to the discussion, but under no circumstances should you push it forcefully
32
- - Other conduct which could reasonably be considered inappropriate in a professional setting
33
-
34
- ## Enforcement Responsibilities
35
-
36
- Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.
37
-
38
- ## Scope
39
-
40
- This Code of Conduct applies within all community spaces and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
41
-
42
- ## Enforcement
43
-
44
- Instances of abusive, harassing, spamming, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at hello@openwebui.com. All complaints will be reviewed and investigated promptly and fairly.
45
-
46
- All community leaders are obligated to respect the privacy and security of the reporter of any incident.
47
-
48
- ## Enforcement Guidelines
49
-
50
- Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
51
-
52
- ### 1. Temporary Ban
53
-
54
- **Community Impact**: Any violation of community standards, including but not limited to inappropriate language, unprofessional behavior, harassment, or spamming.
55
-
56
- **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
57
-
58
- ### 2. Permanent Ban
59
-
60
- **Community Impact**: Repeated or severe violations of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
61
-
62
- **Consequence**: A permanent ban from any sort of public interaction within the community.
63
-
64
- ## Attribution
65
-
66
- This Code of Conduct is adapted from the [Contributor Covenant][homepage],
67
- version 2.0, available at
68
- https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
69
-
70
- Community Impact Guidelines were inspired by [Mozilla's code of conduct
71
- enforcement ladder](https://github.com/mozilla/diversity).
72
-
73
- [homepage]: https://www.contributor-covenant.org
74
-
75
- For answers to common questions about this code of conduct, see the FAQ at
76
- https://www.contributor-covenant.org/faq. Translations are available at
77
- https://www.contributor-covenant.org/translations.
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ We as members, contributors, and leaders pledge to make participation in our
6
+ community a harassment-free experience for everyone, regardless of age, body
7
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
8
+ identity and expression, level of experience, education, socio-economic status,
9
+ nationality, personal appearance, race, religion, or sexual identity
10
+ and orientation.
11
+
12
+ We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
13
+
14
+ ## Our Standards
15
+
16
+ Examples of behavior that contribute to a positive environment for our community include:
17
+
18
+ - Demonstrating empathy and kindness toward other people
19
+ - Being respectful of differing opinions, viewpoints, and experiences
20
+ - Giving and gracefully accepting constructive feedback
21
+ - Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
22
+ - Focusing on what is best not just for us as individuals, but for the overall community
23
+
24
+ Examples of unacceptable behavior include:
25
+
26
+ - The use of sexualized language or imagery, and sexual attention or advances of any kind
27
+ - Trolling, insulting or derogatory comments, and personal or political attacks
28
+ - Public or private harassment
29
+ - Publishing others' private information, such as a physical or email address, without their explicit permission
30
+ - **Spamming of any kind**
31
+ - Aggressive sales tactics targeting our community members are strictly prohibited. You can mention your product if it's relevant to the discussion, but under no circumstances should you push it forcefully
32
+ - Other conduct which could reasonably be considered inappropriate in a professional setting
33
+
34
+ ## Enforcement Responsibilities
35
+
36
+ Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.
37
+
38
+ ## Scope
39
+
40
+ This Code of Conduct applies within all community spaces and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
41
+
42
+ ## Enforcement
43
+
44
+ Instances of abusive, harassing, spamming, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at hello@openwebui.com. All complaints will be reviewed and investigated promptly and fairly.
45
+
46
+ All community leaders are obligated to respect the privacy and security of the reporter of any incident.
47
+
48
+ ## Enforcement Guidelines
49
+
50
+ Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
51
+
52
+ ### 1. Temporary Ban
53
+
54
+ **Community Impact**: Any violation of community standards, including but not limited to inappropriate language, unprofessional behavior, harassment, or spamming.
55
+
56
+ **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
57
+
58
+ ### 2. Permanent Ban
59
+
60
+ **Community Impact**: Repeated or severe violations of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
61
+
62
+ **Consequence**: A permanent ban from any sort of public interaction within the community.
63
+
64
+ ## Attribution
65
+
66
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
67
+ version 2.0, available at
68
+ https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
69
+
70
+ Community Impact Guidelines were inspired by [Mozilla's code of conduct
71
+ enforcement ladder](https://github.com/mozilla/diversity).
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see the FAQ at
76
+ https://www.contributor-covenant.org/faq. Translations are available at
77
+ https://www.contributor-covenant.org/translations.
Caddyfile.localhost CHANGED
@@ -1,64 +1,64 @@
1
- # Run with
2
- # caddy run --envfile ./example.env --config ./Caddyfile.localhost
3
- #
4
- # This is configured for
5
- # - Automatic HTTPS (even for localhost)
6
- # - Reverse Proxying to Ollama API Base URL (http://localhost:11434/api)
7
- # - CORS
8
- # - HTTP Basic Auth API Tokens (uncomment basicauth section)
9
-
10
-
11
- # CORS Preflight (OPTIONS) + Request (GET, POST, PATCH, PUT, DELETE)
12
- (cors-api) {
13
- @match-cors-api-preflight method OPTIONS
14
- handle @match-cors-api-preflight {
15
- header {
16
- Access-Control-Allow-Origin "{http.request.header.origin}"
17
- Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
18
- Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
19
- Access-Control-Allow-Credentials "true"
20
- Access-Control-Max-Age "3600"
21
- defer
22
- }
23
- respond "" 204
24
- }
25
-
26
- @match-cors-api-request {
27
- not {
28
- header Origin "{http.request.scheme}://{http.request.host}"
29
- }
30
- header Origin "{http.request.header.origin}"
31
- }
32
- handle @match-cors-api-request {
33
- header {
34
- Access-Control-Allow-Origin "{http.request.header.origin}"
35
- Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
36
- Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
37
- Access-Control-Allow-Credentials "true"
38
- Access-Control-Max-Age "3600"
39
- defer
40
- }
41
- }
42
- }
43
-
44
- # replace localhost with example.com or whatever
45
- localhost {
46
- ## HTTP Basic Auth
47
- ## (uncomment to enable)
48
- # basicauth {
49
- # # see .example.env for how to generate tokens
50
- # {env.OLLAMA_API_ID} {env.OLLAMA_API_TOKEN_DIGEST}
51
- # }
52
-
53
- handle /api/* {
54
- # Comment to disable CORS
55
- import cors-api
56
-
57
- reverse_proxy localhost:11434
58
- }
59
-
60
- # Same-Origin Static Web Server
61
- file_server {
62
- root ./build/
63
- }
64
- }
 
1
+ # Run with
2
+ # caddy run --envfile ./example.env --config ./Caddyfile.localhost
3
+ #
4
+ # This is configured for
5
+ # - Automatic HTTPS (even for localhost)
6
+ # - Reverse Proxying to Ollama API Base URL (http://localhost:11434/api)
7
+ # - CORS
8
+ # - HTTP Basic Auth API Tokens (uncomment basicauth section)
9
+
10
+
11
+ # CORS Preflight (OPTIONS) + Request (GET, POST, PATCH, PUT, DELETE)
12
+ (cors-api) {
13
+ @match-cors-api-preflight method OPTIONS
14
+ handle @match-cors-api-preflight {
15
+ header {
16
+ Access-Control-Allow-Origin "{http.request.header.origin}"
17
+ Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
18
+ Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
19
+ Access-Control-Allow-Credentials "true"
20
+ Access-Control-Max-Age "3600"
21
+ defer
22
+ }
23
+ respond "" 204
24
+ }
25
+
26
+ @match-cors-api-request {
27
+ not {
28
+ header Origin "{http.request.scheme}://{http.request.host}"
29
+ }
30
+ header Origin "{http.request.header.origin}"
31
+ }
32
+ handle @match-cors-api-request {
33
+ header {
34
+ Access-Control-Allow-Origin "{http.request.header.origin}"
35
+ Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
36
+ Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
37
+ Access-Control-Allow-Credentials "true"
38
+ Access-Control-Max-Age "3600"
39
+ defer
40
+ }
41
+ }
42
+ }
43
+
44
+ # replace localhost with example.com or whatever
45
+ localhost {
46
+ ## HTTP Basic Auth
47
+ ## (uncomment to enable)
48
+ # basicauth {
49
+ # # see .example.env for how to generate tokens
50
+ # {env.OLLAMA_API_ID} {env.OLLAMA_API_TOKEN_DIGEST}
51
+ # }
52
+
53
+ handle /api/* {
54
+ # Comment to disable CORS
55
+ import cors-api
56
+
57
+ reverse_proxy localhost:11434
58
+ }
59
+
60
+ # Same-Origin Static Web Server
61
+ file_server {
62
+ root ./build/
63
+ }
64
+ }
Dockerfile CHANGED
@@ -1,176 +1,176 @@
1
- # syntax=docker/dockerfile:1
2
- # Initialize device type args
3
- # use build args in the docker build command with --build-arg="BUILDARG=true"
4
- ARG USE_CUDA=false
5
- ARG USE_OLLAMA=false
6
- # Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
7
- ARG USE_CUDA_VER=cu121
8
- # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
9
- # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
10
- # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
11
- # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
12
- ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
13
- ARG USE_RERANKING_MODEL=""
14
-
15
- # Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
16
- ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
17
-
18
- ARG BUILD_HASH=dev-build
19
- # Override at your own risk - non-root configurations are untested
20
- ARG UID=0
21
- ARG GID=0
22
-
23
- ######## WebUI frontend ########
24
- FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build
25
- ARG BUILD_HASH
26
-
27
- WORKDIR /app
28
-
29
- COPY package.json package-lock.json ./
30
- RUN npm ci
31
-
32
- COPY . .
33
- ENV APP_BUILD_HASH=${BUILD_HASH}
34
- RUN npm run build
35
-
36
- ######## WebUI backend ########
37
- FROM python:3.11-slim-bookworm AS base
38
-
39
- # Use args
40
- ARG USE_CUDA
41
- ARG USE_OLLAMA
42
- ARG USE_CUDA_VER
43
- ARG USE_EMBEDDING_MODEL
44
- ARG USE_RERANKING_MODEL
45
- ARG UID
46
- ARG GID
47
-
48
- ## Basis ##
49
- ENV ENV=prod \
50
- PORT=8080 \
51
- # pass build args to the build
52
- USE_OLLAMA_DOCKER=${USE_OLLAMA} \
53
- USE_CUDA_DOCKER=${USE_CUDA} \
54
- USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
55
- USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
56
- USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
57
-
58
- ## Basis URL Config ##
59
- ENV OLLAMA_BASE_URL="/ollama" \
60
- OPENAI_API_BASE_URL=""
61
-
62
- ## API Key and Security Config ##
63
- ENV OPENAI_API_KEY="" \
64
- WEBUI_SECRET_KEY="" \
65
- SCARF_NO_ANALYTICS=true \
66
- DO_NOT_TRACK=true \
67
- ANONYMIZED_TELEMETRY=false
68
-
69
- #### Other models #########################################################
70
- ## whisper TTS model settings ##
71
- ENV WHISPER_MODEL="base" \
72
- WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
73
-
74
- ## RAG Embedding model settings ##
75
- ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
76
- RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
77
- SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
78
-
79
- ## Tiktoken model settings ##
80
- ENV TIKTOKEN_ENCODING_NAME="$USE_TIKTOKEN_ENCODING_NAME" \
81
- TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
82
-
83
- ## Hugging Face download cache ##
84
- ENV HF_HOME="/app/backend/data/cache/embedding/models"
85
-
86
- ## Torch Extensions ##
87
- # ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions"
88
-
89
- #### Other models ##########################################################
90
-
91
- WORKDIR /app/backend
92
-
93
- ENV HOME=/root
94
- # Create user and group if not root
95
- RUN if [ $UID -ne 0 ]; then \
96
- if [ $GID -ne 0 ]; then \
97
- addgroup --gid $GID app; \
98
- fi; \
99
- adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
100
- fi
101
-
102
- RUN mkdir -p $HOME/.cache/chroma
103
- RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
104
-
105
- # Make sure the user has access to the app and root directory
106
- RUN chown -R $UID:$GID /app $HOME
107
-
108
- RUN if [ "$USE_OLLAMA" = "true" ]; then \
109
- apt-get update && \
110
- # Install pandoc and netcat
111
- apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
112
- apt-get install -y --no-install-recommends gcc python3-dev && \
113
- # for RAG OCR
114
- apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
115
- # install helper tools
116
- apt-get install -y --no-install-recommends curl jq && \
117
- # install ollama
118
- curl -fsSL https://ollama.com/install.sh | sh && \
119
- # cleanup
120
- rm -rf /var/lib/apt/lists/*; \
121
- else \
122
- apt-get update && \
123
- # Install pandoc, netcat and gcc
124
- apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
125
- apt-get install -y --no-install-recommends gcc python3-dev && \
126
- # for RAG OCR
127
- apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
128
- # cleanup
129
- rm -rf /var/lib/apt/lists/*; \
130
- fi
131
-
132
- # install python dependencies
133
- COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
134
-
135
- RUN pip3 install uv && \
136
- if [ "$USE_CUDA" = "true" ]; then \
137
- # If you use CUDA the whisper and embedding model will be downloaded on first use
138
- pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
139
- uv pip install --system -r requirements.txt --no-cache-dir && \
140
- python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
141
- python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
142
- python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
143
- else \
144
- pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
145
- uv pip install --system -r requirements.txt --no-cache-dir && \
146
- python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
147
- python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
148
- python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
149
- fi; \
150
- chown -R $UID:$GID /app/backend/data/
151
-
152
-
153
-
154
- # copy embedding weight from build
155
- # RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
156
- # COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
157
-
158
- # copy built frontend files
159
- COPY --chown=$UID:$GID --from=build /app/build /app/build
160
- COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md
161
- COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json
162
-
163
- # copy backend files
164
- COPY --chown=$UID:$GID ./backend .
165
-
166
- EXPOSE 8080
167
-
168
- HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
169
-
170
- USER $UID:$GID
171
-
172
- ARG BUILD_HASH
173
- ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
174
- ENV DOCKER=true
175
-
176
- CMD [ "bash", "start.sh"]
 
1
+ # syntax=docker/dockerfile:1
2
+ # Initialize device type args
3
+ # use build args in the docker build command with --build-arg="BUILDARG=true"
4
+ ARG USE_CUDA=false
5
+ ARG USE_OLLAMA=false
6
+ # Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
7
+ ARG USE_CUDA_VER=cu121
8
+ # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
9
+ # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
10
+ # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
11
+ # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
12
+ ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
13
+ ARG USE_RERANKING_MODEL=""
14
+
15
+ # Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
16
+ ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
17
+
18
+ ARG BUILD_HASH=dev-build
19
+ # Override at your own risk - non-root configurations are untested
20
+ ARG UID=0
21
+ ARG GID=0
22
+
23
+ ######## WebUI frontend ########
24
+ FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build
25
+ ARG BUILD_HASH
26
+
27
+ WORKDIR /app
28
+
29
+ COPY package.json package-lock.json ./
30
+ RUN npm ci
31
+
32
+ COPY . .
33
+ ENV APP_BUILD_HASH=${BUILD_HASH}
34
+ RUN npm run build
35
+
36
+ ######## WebUI backend ########
37
+ FROM python:3.11-slim-bookworm AS base
38
+
39
+ # Use args
40
+ ARG USE_CUDA
41
+ ARG USE_OLLAMA
42
+ ARG USE_CUDA_VER
43
+ ARG USE_EMBEDDING_MODEL
44
+ ARG USE_RERANKING_MODEL
45
+ ARG UID
46
+ ARG GID
47
+
48
+ ## Basis ##
49
+ ENV ENV=prod \
50
+ PORT=8080 \
51
+ # pass build args to the build
52
+ USE_OLLAMA_DOCKER=${USE_OLLAMA} \
53
+ USE_CUDA_DOCKER=${USE_CUDA} \
54
+ USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
55
+ USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
56
+ USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
57
+
58
+ ## Basis URL Config ##
59
+ ENV OLLAMA_BASE_URL="/ollama" \
60
+ OPENAI_API_BASE_URL=""
61
+
62
+ ## API Key and Security Config ##
63
+ ENV OPENAI_API_KEY="" \
64
+ WEBUI_SECRET_KEY="" \
65
+ SCARF_NO_ANALYTICS=true \
66
+ DO_NOT_TRACK=true \
67
+ ANONYMIZED_TELEMETRY=false
68
+
69
+ #### Other models #########################################################
70
+ ## whisper TTS model settings ##
71
+ ENV WHISPER_MODEL="base" \
72
+ WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
73
+
74
+ ## RAG Embedding model settings ##
75
+ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
76
+ RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
77
+ SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
78
+
79
+ ## Tiktoken model settings ##
80
+ ENV TIKTOKEN_ENCODING_NAME="$USE_TIKTOKEN_ENCODING_NAME" \
81
+ TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
82
+
83
+ ## Hugging Face download cache ##
84
+ ENV HF_HOME="/app/backend/data/cache/embedding/models"
85
+
86
+ ## Torch Extensions ##
87
+ # ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions"
88
+
89
+ #### Other models ##########################################################
90
+
91
+ WORKDIR /app/backend
92
+
93
+ ENV HOME=/root
94
+ # Create user and group if not root
95
+ RUN if [ $UID -ne 0 ]; then \
96
+ if [ $GID -ne 0 ]; then \
97
+ addgroup --gid $GID app; \
98
+ fi; \
99
+ adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
100
+ fi
101
+
102
+ RUN mkdir -p $HOME/.cache/chroma
103
+ RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
104
+
105
+ # Make sure the user has access to the app and root directory
106
+ RUN chown -R $UID:$GID /app $HOME
107
+
108
+ RUN if [ "$USE_OLLAMA" = "true" ]; then \
109
+ apt-get update && \
110
+ # Install pandoc and netcat
111
+ apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
112
+ apt-get install -y --no-install-recommends gcc python3-dev && \
113
+ # for RAG OCR
114
+ apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
115
+ # install helper tools
116
+ apt-get install -y --no-install-recommends curl jq && \
117
+ # install ollama
118
+ curl -fsSL https://ollama.com/install.sh | sh && \
119
+ # cleanup
120
+ rm -rf /var/lib/apt/lists/*; \
121
+ else \
122
+ apt-get update && \
123
+ # Install pandoc, netcat and gcc
124
+ apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
125
+ apt-get install -y --no-install-recommends gcc python3-dev && \
126
+ # for RAG OCR
127
+ apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
128
+ # cleanup
129
+ rm -rf /var/lib/apt/lists/*; \
130
+ fi
131
+
132
+ # install python dependencies
133
+ COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
134
+
135
+ RUN pip3 install uv && \
136
+ if [ "$USE_CUDA" = "true" ]; then \
137
+ # If you use CUDA the whisper and embedding model will be downloaded on first use
138
+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
139
+ uv pip install --system -r requirements.txt --no-cache-dir && \
140
+ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
141
+ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
142
+ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
143
+ else \
144
+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
145
+ uv pip install --system -r requirements.txt --no-cache-dir && \
146
+ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
147
+ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
148
+ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
149
+ fi; \
150
+ chown -R $UID:$GID /app/backend/data/
151
+
152
+
153
+
154
+ # copy embedding weight from build
155
+ # RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
156
+ # COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
157
+
158
+ # copy built frontend files
159
+ COPY --chown=$UID:$GID --from=build /app/build /app/build
160
+ COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md
161
+ COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json
162
+
163
+ # copy backend files
164
+ COPY --chown=$UID:$GID ./backend .
165
+
166
+ EXPOSE 8080
167
+
168
+ HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
169
+
170
+ USER $UID:$GID
171
+
172
+ ARG BUILD_HASH
173
+ ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
174
+ ENV DOCKER=true
175
+
176
+ CMD [ "bash", "start.sh"]
INSTALLATION.md CHANGED
@@ -1,35 +1,35 @@
1
- ### Installing Both Ollama and Open WebUI Using Kustomize
2
-
3
- For cpu-only pod
4
-
5
- ```bash
6
- kubectl apply -f ./kubernetes/manifest/base
7
- ```
8
-
9
- For gpu-enabled pod
10
-
11
- ```bash
12
- kubectl apply -k ./kubernetes/manifest
13
- ```
14
-
15
- ### Installing Both Ollama and Open WebUI Using Helm
16
-
17
- Package Helm file first
18
-
19
- ```bash
20
- helm package ./kubernetes/helm/
21
- ```
22
-
23
- For cpu-only pod
24
-
25
- ```bash
26
- helm install ollama-webui ./ollama-webui-*.tgz
27
- ```
28
-
29
- For gpu-enabled pod
30
-
31
- ```bash
32
- helm install ollama-webui ./ollama-webui-*.tgz --set ollama.resources.limits.nvidia.com/gpu="1"
33
- ```
34
-
35
- Check the `kubernetes/helm/values.yaml` file to know which parameters are available for customization
 
1
+ ### Installing Both Ollama and Open WebUI Using Kustomize
2
+
3
+ For cpu-only pod
4
+
5
+ ```bash
6
+ kubectl apply -f ./kubernetes/manifest/base
7
+ ```
8
+
9
+ For gpu-enabled pod
10
+
11
+ ```bash
12
+ kubectl apply -k ./kubernetes/manifest
13
+ ```
14
+
15
+ ### Installing Both Ollama and Open WebUI Using Helm
16
+
17
+ Package Helm file first
18
+
19
+ ```bash
20
+ helm package ./kubernetes/helm/
21
+ ```
22
+
23
+ For cpu-only pod
24
+
25
+ ```bash
26
+ helm install ollama-webui ./ollama-webui-*.tgz
27
+ ```
28
+
29
+ For gpu-enabled pod
30
+
31
+ ```bash
32
+ helm install ollama-webui ./ollama-webui-*.tgz --set ollama.resources.limits.nvidia.com/gpu="1"
33
+ ```
34
+
35
+ Check the `kubernetes/helm/values.yaml` file to know which parameters are available for customization
LICENSE CHANGED
@@ -1,21 +1,21 @@
1
- MIT License
2
-
3
- Copyright (c) 2023 Timothy Jaeryang Baek
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Timothy Jaeryang Baek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Makefile CHANGED
@@ -1,33 +1,33 @@
1
-
2
- ifneq ($(shell which docker-compose 2>/dev/null),)
3
- DOCKER_COMPOSE := docker-compose
4
- else
5
- DOCKER_COMPOSE := docker compose
6
- endif
7
-
8
- install:
9
- $(DOCKER_COMPOSE) up -d
10
-
11
- remove:
12
- @chmod +x confirm_remove.sh
13
- @./confirm_remove.sh
14
-
15
- start:
16
- $(DOCKER_COMPOSE) start
17
- startAndBuild:
18
- $(DOCKER_COMPOSE) up -d --build
19
-
20
- stop:
21
- $(DOCKER_COMPOSE) stop
22
-
23
- update:
24
- # Calls the LLM update script
25
- chmod +x update_ollama_models.sh
26
- @./update_ollama_models.sh
27
- @git pull
28
- $(DOCKER_COMPOSE) down
29
- # Make sure the ollama-webui container is stopped before rebuilding
30
- @docker stop open-webui || true
31
- $(DOCKER_COMPOSE) up --build -d
32
- $(DOCKER_COMPOSE) start
33
-
 
1
+
2
+ ifneq ($(shell which docker-compose 2>/dev/null),)
3
+ DOCKER_COMPOSE := docker-compose
4
+ else
5
+ DOCKER_COMPOSE := docker compose
6
+ endif
7
+
8
+ install:
9
+ $(DOCKER_COMPOSE) up -d
10
+
11
+ remove:
12
+ @chmod +x confirm_remove.sh
13
+ @./confirm_remove.sh
14
+
15
+ start:
16
+ $(DOCKER_COMPOSE) start
17
+ startAndBuild:
18
+ $(DOCKER_COMPOSE) up -d --build
19
+
20
+ stop:
21
+ $(DOCKER_COMPOSE) stop
22
+
23
+ update:
24
+ # Calls the LLM update script
25
+ chmod +x update_ollama_models.sh
26
+ @./update_ollama_models.sh
27
+ @git pull
28
+ $(DOCKER_COMPOSE) down
29
+ # Make sure the ollama-webui container is stopped before rebuilding
30
+ @docker stop open-webui || true
31
+ $(DOCKER_COMPOSE) up --build -d
32
+ $(DOCKER_COMPOSE) start
33
+
README.md CHANGED
@@ -1,231 +1,231 @@
1
- ---
2
- title: Open WebUI
3
- emoji: 🐳
4
- colorFrom: purple
5
- colorTo: gray
6
- sdk: docker
7
- app_port: 8080
8
- ---
9
- # Open WebUI 👋
10
-
11
- ![GitHub stars](https://img.shields.io/github/stars/open-webui/open-webui?style=social)
12
- ![GitHub forks](https://img.shields.io/github/forks/open-webui/open-webui?style=social)
13
- ![GitHub watchers](https://img.shields.io/github/watchers/open-webui/open-webui?style=social)
14
- ![GitHub repo size](https://img.shields.io/github/repo-size/open-webui/open-webui)
15
- ![GitHub language count](https://img.shields.io/github/languages/count/open-webui/open-webui)
16
- ![GitHub top language](https://img.shields.io/github/languages/top/open-webui/open-webui)
17
- ![GitHub last commit](https://img.shields.io/github/last-commit/open-webui/open-webui?color=red)
18
- ![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Follama-webui%2Follama-wbui&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)
19
- [![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s)
20
- [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck)
21
-
22
- Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-rich, and user-friendly self-hosted WebUI designed to operate entirely offline. It supports various LLM runners, including Ollama and OpenAI-compatible APIs. For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
23
-
24
- ![Open WebUI Demo](./demo.gif)
25
-
26
- ## Key Features of Open WebUI ⭐
27
-
28
- - 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images.
29
-
30
- - 🤝 **Ollama/OpenAI API Integration**: Effortlessly integrate OpenAI-compatible APIs for versatile conversations alongside Ollama models. Customize the OpenAI API URL to link with **LMStudio, GroqCloud, Mistral, OpenRouter, and more**.
31
-
32
- - 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
33
-
34
- - 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
35
-
36
- - 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
37
-
38
- - ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
39
-
40
- - 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment.
41
-
42
- - 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
43
-
44
- - 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
45
-
46
- - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
47
-
48
- - 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch` and `SearchApi` and inject the results directly into your chat experience.
49
-
50
- - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
51
-
52
- - 🎨 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API or ComfyUI (local), and OpenAI's DALL-E (external), enriching your chat experience with dynamic visual content.
53
-
54
- - ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel.
55
-
56
- - 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
57
-
58
- - 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
59
-
60
- - 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates, fixes, and new features.
61
-
62
- Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!
63
-
64
- ## 🔗 Also Check Out Open WebUI Community!
65
-
66
- Don't forget to explore our sibling project, [Open WebUI Community](https://openwebui.com/), where you can discover, download, and explore customized Modelfiles. Open WebUI Community offers a wide range of exciting possibilities for enhancing your chat interactions with Open WebUI! 🚀
67
-
68
- ## How to Install 🚀
69
-
70
- ### Installation via Python pip 🐍
71
-
72
- Open WebUI can be installed using pip, the Python package installer. Before proceeding, ensure you're using **Python 3.11** to avoid compatibility issues.
73
-
74
- 1. **Install Open WebUI**:
75
- Open your terminal and run the following command to install Open WebUI:
76
-
77
- ```bash
78
- pip install open-webui
79
- ```
80
-
81
- 2. **Running Open WebUI**:
82
- After installation, you can start Open WebUI by executing:
83
-
84
- ```bash
85
- open-webui serve
86
- ```
87
-
88
- This will start the Open WebUI server, which you can access at [http://localhost:8080](http://localhost:8080)
89
-
90
- ### Quick Start with Docker 🐳
91
-
92
- > [!NOTE]
93
- > Please note that for certain Docker environments, additional configurations might be needed. If you encounter any connection issues, our detailed guide on [Open WebUI Documentation](https://docs.openwebui.com/) is ready to assist you.
94
-
95
- > [!WARNING]
96
- > When using Docker to install Open WebUI, make sure to include the `-v open-webui:/app/backend/data` in your Docker command. This step is crucial as it ensures your database is properly mounted and prevents any loss of data.
97
-
98
- > [!TIP]
99
- > If you wish to utilize Open WebUI with Ollama included or CUDA acceleration, we recommend utilizing our official images tagged with either `:cuda` or `:ollama`. To enable CUDA, you must install the [Nvidia CUDA container toolkit](https://docs.nvidia.com/dgx/nvidia-container-runtime-upgrade/) on your Linux/WSL system.
100
-
101
- ### Installation with Default Configuration
102
-
103
- - **If Ollama is on your computer**, use this command:
104
-
105
- ```bash
106
- docker run -d -p 3000:8080 --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
107
- ```
108
-
109
- - **If Ollama is on a Different Server**, use this command:
110
-
111
- To connect to Ollama on another server, change the `OLLAMA_BASE_URL` to the server's URL:
112
-
113
- ```bash
114
- docker run -d -p 3000:8080 -e OLLAMA_BASE_URL=https://example.com -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
115
- ```
116
-
117
- - **To run Open WebUI with Nvidia GPU support**, use this command:
118
-
119
- ```bash
120
- docker run -d -p 3000:8080 --gpus all --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:cuda
121
- ```
122
-
123
- ### Installation for OpenAI API Usage Only
124
-
125
- - **If you're only using OpenAI API**, use this command:
126
-
127
- ```bash
128
- docker run -d -p 3000:8080 -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
129
- ```
130
-
131
- ### Installing Open WebUI with Bundled Ollama Support
132
-
133
- This installation method uses a single container image that bundles Open WebUI with Ollama, allowing for a streamlined setup via a single command. Choose the appropriate command based on your hardware setup:
134
-
135
- - **With GPU Support**:
136
- Utilize GPU resources by running the following command:
137
-
138
- ```bash
139
- docker run -d -p 3000:8080 --gpus=all -v ollama:/root/.ollama -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:ollama
140
- ```
141
-
142
- - **For CPU Only**:
143
- If you're not using a GPU, use this command instead:
144
-
145
- ```bash
146
- docker run -d -p 3000:8080 -v ollama:/root/.ollama -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:ollama
147
- ```
148
-
149
- Both commands facilitate a built-in, hassle-free installation of both Open WebUI and Ollama, ensuring that you can get everything up and running swiftly.
150
-
151
- After installation, you can access Open WebUI at [http://localhost:3000](http://localhost:3000). Enjoy! 😄
152
-
153
- ### Other Installation Methods
154
-
155
- We offer various installation alternatives, including non-Docker native installation methods, Docker Compose, Kustomize, and Helm. Visit our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/) or join our [Discord community](https://discord.gg/5rJgQTnV4s) for comprehensive guidance.
156
-
157
- ### Troubleshooting
158
-
159
- Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s).
160
-
161
- #### Open WebUI: Server Connection Error
162
-
163
- If you're experiencing connection issues, it’s often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`.
164
-
165
- **Example Docker Command**:
166
-
167
- ```bash
168
- docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
169
- ```
170
-
171
- ### Keeping Your Docker Installation Up-to-Date
172
-
173
- In case you want to update your local Docker installation to the latest version, you can do it with [Watchtower](https://containrrr.dev/watchtower/):
174
-
175
- ```bash
176
- docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower --run-once open-webui
177
- ```
178
-
179
- In the last part of the command, replace `open-webui` with your container name if it is different.
180
-
181
- Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/tutorials/migration/).
182
-
183
- ### Using the Dev Branch 🌙
184
-
185
- > [!WARNING]
186
- > The `:dev` branch contains the latest unstable features and changes. Use it at your own risk as it may have bugs or incomplete features.
187
-
188
- If you want to try out the latest bleeding-edge features and are okay with occasional instability, you can use the `:dev` tag like this:
189
-
190
- ```bash
191
- docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --add-host=host.docker.internal:host-gateway --restart always ghcr.io/open-webui/open-webui:dev
192
- ```
193
-
194
- ## What's Next? 🌟
195
-
196
- Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/).
197
-
198
- ## Supporters ✨
199
-
200
- A big shoutout to our amazing supporters who's helping to make this project possible! 🙏
201
-
202
- ### Platinum Sponsors 🤍
203
-
204
- - We're looking for Sponsors!
205
-
206
- ### Acknowledgments
207
-
208
- Special thanks to [Prof. Lawrence Kim](https://www.lhkim.com/) and [Prof. Nick Vincent](https://www.nickmvincent.com/) for their invaluable support and guidance in shaping this project into a research endeavor. Grateful for your mentorship throughout the journey! 🙌
209
-
210
- ## License 📜
211
-
212
- This project is licensed under the [MIT License](LICENSE) - see the [LICENSE](LICENSE) file for details. 📄
213
-
214
- ## Support 💬
215
-
216
- If you have any questions, suggestions, or need assistance, please open an issue or join our
217
- [Open WebUI Discord community](https://discord.gg/5rJgQTnV4s) to connect with us! 🤝
218
-
219
- ## Star History
220
-
221
- <a href="https://star-history.com/#open-webui/open-webui&Date">
222
- <picture>
223
- <source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date&theme=dark" />
224
- <source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date" />
225
- <img alt="Star History Chart" src="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date" />
226
- </picture>
227
- </a>
228
-
229
- ---
230
-
231
- Created by [Timothy Jaeryang Baek](https://github.com/tjbck) - Let's make Open WebUI even more amazing together! 💪
 
1
+ ---
2
+ title: Open WebUI
3
+ emoji: 🐳
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: docker
7
+ app_port: 8080
8
+ ---
9
+ # Open WebUI 👋
10
+
11
+ ![GitHub stars](https://img.shields.io/github/stars/open-webui/open-webui?style=social)
12
+ ![GitHub forks](https://img.shields.io/github/forks/open-webui/open-webui?style=social)
13
+ ![GitHub watchers](https://img.shields.io/github/watchers/open-webui/open-webui?style=social)
14
+ ![GitHub repo size](https://img.shields.io/github/repo-size/open-webui/open-webui)
15
+ ![GitHub language count](https://img.shields.io/github/languages/count/open-webui/open-webui)
16
+ ![GitHub top language](https://img.shields.io/github/languages/top/open-webui/open-webui)
17
+ ![GitHub last commit](https://img.shields.io/github/last-commit/open-webui/open-webui?color=red)
18
+ ![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Follama-webui%2Follama-wbui&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)
19
+ [![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s)
20
+ [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck)
21
+
22
+ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-rich, and user-friendly self-hosted WebUI designed to operate entirely offline. It supports various LLM runners, including Ollama and OpenAI-compatible APIs. For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
23
+
24
+ ![Open WebUI Demo](./demo.gif)
25
+
26
+ ## Key Features of Open WebUI ⭐
27
+
28
+ - 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images.
29
+
30
+ - 🤝 **Ollama/OpenAI API Integration**: Effortlessly integrate OpenAI-compatible APIs for versatile conversations alongside Ollama models. Customize the OpenAI API URL to link with **LMStudio, GroqCloud, Mistral, OpenRouter, and more**.
31
+
32
+ - 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
33
+
34
+ - 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
35
+
36
+ - 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
37
+
38
+ - ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
39
+
40
+ - 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment.
41
+
42
+ - 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
43
+
44
+ - 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
45
+
46
+ - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
47
+
48
+ - 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch` and `SearchApi` and inject the results directly into your chat experience.
49
+
50
+ - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
51
+
52
+ - 🎨 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API or ComfyUI (local), and OpenAI's DALL-E (external), enriching your chat experience with dynamic visual content.
53
+
54
+ - ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel.
55
+
56
+ - 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
57
+
58
+ - 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
59
+
60
+ - 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates, fixes, and new features.
61
+
62
+ Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!
63
+
64
+ ## 🔗 Also Check Out Open WebUI Community!
65
+
66
+ Don't forget to explore our sibling project, [Open WebUI Community](https://openwebui.com/), where you can discover, download, and explore customized Modelfiles. Open WebUI Community offers a wide range of exciting possibilities for enhancing your chat interactions with Open WebUI! 🚀
67
+
68
+ ## How to Install 🚀
69
+
70
+ ### Installation via Python pip 🐍
71
+
72
+ Open WebUI can be installed using pip, the Python package installer. Before proceeding, ensure you're using **Python 3.11** to avoid compatibility issues.
73
+
74
+ 1. **Install Open WebUI**:
75
+ Open your terminal and run the following command to install Open WebUI:
76
+
77
+ ```bash
78
+ pip install open-webui
79
+ ```
80
+
81
+ 2. **Running Open WebUI**:
82
+ After installation, you can start Open WebUI by executing:
83
+
84
+ ```bash
85
+ open-webui serve
86
+ ```
87
+
88
+ This will start the Open WebUI server, which you can access at [http://localhost:8080](http://localhost:8080)
89
+
90
+ ### Quick Start with Docker 🐳
91
+
92
+ > [!NOTE]
93
+ > Please note that for certain Docker environments, additional configurations might be needed. If you encounter any connection issues, our detailed guide on [Open WebUI Documentation](https://docs.openwebui.com/) is ready to assist you.
94
+
95
+ > [!WARNING]
96
+ > When using Docker to install Open WebUI, make sure to include the `-v open-webui:/app/backend/data` in your Docker command. This step is crucial as it ensures your database is properly mounted and prevents any loss of data.
97
+
98
+ > [!TIP]
99
+ > If you wish to utilize Open WebUI with Ollama included or CUDA acceleration, we recommend utilizing our official images tagged with either `:cuda` or `:ollama`. To enable CUDA, you must install the [Nvidia CUDA container toolkit](https://docs.nvidia.com/dgx/nvidia-container-runtime-upgrade/) on your Linux/WSL system.
100
+
101
+ ### Installation with Default Configuration
102
+
103
+ - **If Ollama is on your computer**, use this command:
104
+
105
+ ```bash
106
+ docker run -d -p 3000:8080 --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
107
+ ```
108
+
109
+ - **If Ollama is on a Different Server**, use this command:
110
+
111
+ To connect to Ollama on another server, change the `OLLAMA_BASE_URL` to the server's URL:
112
+
113
+ ```bash
114
+ docker run -d -p 3000:8080 -e OLLAMA_BASE_URL=https://example.com -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
115
+ ```
116
+
117
+ - **To run Open WebUI with Nvidia GPU support**, use this command:
118
+
119
+ ```bash
120
+ docker run -d -p 3000:8080 --gpus all --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:cuda
121
+ ```
122
+
123
+ ### Installation for OpenAI API Usage Only
124
+
125
+ - **If you're only using OpenAI API**, use this command:
126
+
127
+ ```bash
128
+ docker run -d -p 3000:8080 -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
129
+ ```
130
+
131
+ ### Installing Open WebUI with Bundled Ollama Support
132
+
133
+ This installation method uses a single container image that bundles Open WebUI with Ollama, allowing for a streamlined setup via a single command. Choose the appropriate command based on your hardware setup:
134
+
135
+ - **With GPU Support**:
136
+ Utilize GPU resources by running the following command:
137
+
138
+ ```bash
139
+ docker run -d -p 3000:8080 --gpus=all -v ollama:/root/.ollama -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:ollama
140
+ ```
141
+
142
+ - **For CPU Only**:
143
+ If you're not using a GPU, use this command instead:
144
+
145
+ ```bash
146
+ docker run -d -p 3000:8080 -v ollama:/root/.ollama -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:ollama
147
+ ```
148
+
149
+ Both commands facilitate a built-in, hassle-free installation of both Open WebUI and Ollama, ensuring that you can get everything up and running swiftly.
150
+
151
+ After installation, you can access Open WebUI at [http://localhost:3000](http://localhost:3000). Enjoy! 😄
152
+
153
+ ### Other Installation Methods
154
+
155
+ We offer various installation alternatives, including non-Docker native installation methods, Docker Compose, Kustomize, and Helm. Visit our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/) or join our [Discord community](https://discord.gg/5rJgQTnV4s) for comprehensive guidance.
156
+
157
+ ### Troubleshooting
158
+
159
+ Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s).
160
+
161
+ #### Open WebUI: Server Connection Error
162
+
163
+ If you're experiencing connection issues, it’s often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`.
164
+
165
+ **Example Docker Command**:
166
+
167
+ ```bash
168
+ docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
169
+ ```
170
+
171
+ ### Keeping Your Docker Installation Up-to-Date
172
+
173
+ In case you want to update your local Docker installation to the latest version, you can do it with [Watchtower](https://containrrr.dev/watchtower/):
174
+
175
+ ```bash
176
+ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower --run-once open-webui
177
+ ```
178
+
179
+ In the last part of the command, replace `open-webui` with your container name if it is different.
180
+
181
+ Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/tutorials/migration/).
182
+
183
+ ### Using the Dev Branch 🌙
184
+
185
+ > [!WARNING]
186
+ > The `:dev` branch contains the latest unstable features and changes. Use it at your own risk as it may have bugs or incomplete features.
187
+
188
+ If you want to try out the latest bleeding-edge features and are okay with occasional instability, you can use the `:dev` tag like this:
189
+
190
+ ```bash
191
+ docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --add-host=host.docker.internal:host-gateway --restart always ghcr.io/open-webui/open-webui:dev
192
+ ```
193
+
194
+ ## What's Next? 🌟
195
+
196
+ Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/).
197
+
198
+ ## Supporters ✨
199
+
200
+ A big shoutout to our amazing supporters who's helping to make this project possible! 🙏
201
+
202
+ ### Platinum Sponsors 🤍
203
+
204
+ - We're looking for Sponsors!
205
+
206
+ ### Acknowledgments
207
+
208
+ Special thanks to [Prof. Lawrence Kim](https://www.lhkim.com/) and [Prof. Nick Vincent](https://www.nickmvincent.com/) for their invaluable support and guidance in shaping this project into a research endeavor. Grateful for your mentorship throughout the journey! 🙌
209
+
210
+ ## License 📜
211
+
212
+ This project is licensed under the [MIT License](LICENSE) - see the [LICENSE](LICENSE) file for details. 📄
213
+
214
+ ## Support 💬
215
+
216
+ If you have any questions, suggestions, or need assistance, please open an issue or join our
217
+ [Open WebUI Discord community](https://discord.gg/5rJgQTnV4s) to connect with us! 🤝
218
+
219
+ ## Star History
220
+
221
+ <a href="https://star-history.com/#open-webui/open-webui&Date">
222
+ <picture>
223
+ <source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date&theme=dark" />
224
+ <source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date" />
225
+ <img alt="Star History Chart" src="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date" />
226
+ </picture>
227
+ </a>
228
+
229
+ ---
230
+
231
+ Created by [Timothy Jaeryang Baek](https://github.com/tjbck) - Let's make Open WebUI even more amazing together! 💪
TROUBLESHOOTING.md CHANGED
@@ -1,36 +1,36 @@
1
- # Open WebUI Troubleshooting Guide
2
-
3
- ## Understanding the Open WebUI Architecture
4
-
5
- The Open WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues.
6
-
7
- - **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_BASE_URL` environment variable. Therefore, a request made to `/ollama` in the WebUI is effectively the same as making a request to `OLLAMA_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_BASE_URL/api/tags` in the backend.
8
-
9
- - **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer.
10
-
11
- ## Open WebUI: Server Connection Error
12
-
13
- If you're experiencing connection issues, it’s often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`.
14
-
15
- **Example Docker Command**:
16
-
17
- ```bash
18
- docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
19
- ```
20
-
21
- ### Error on Slow Responses for Ollama
22
-
23
- Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds.
24
-
25
- ### General Connection Errors
26
-
27
- **Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.com/) for the latest updates.
28
-
29
- **Troubleshooting Steps**:
30
-
31
- 1. **Verify Ollama URL Format**:
32
- - When running the Web UI container, ensure the `OLLAMA_BASE_URL` is correctly set. (e.g., `http://192.168.1.1:11434` for different host setups).
33
- - In the Open WebUI, navigate to "Settings" > "General".
34
- - Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]` (e.g., `http://localhost:11434`).
35
-
36
- By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord.
 
1
+ # Open WebUI Troubleshooting Guide
2
+
3
+ ## Understanding the Open WebUI Architecture
4
+
5
+ The Open WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues.
6
+
7
+ - **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_BASE_URL` environment variable. Therefore, a request made to `/ollama` in the WebUI is effectively the same as making a request to `OLLAMA_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_BASE_URL/api/tags` in the backend.
8
+
9
+ - **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer.
10
+
11
+ ## Open WebUI: Server Connection Error
12
+
13
+ If you're experiencing connection issues, it’s often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`.
14
+
15
+ **Example Docker Command**:
16
+
17
+ ```bash
18
+ docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
19
+ ```
20
+
21
+ ### Error on Slow Responses for Ollama
22
+
23
+ Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds.
24
+
25
+ ### General Connection Errors
26
+
27
+ **Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.com/) for the latest updates.
28
+
29
+ **Troubleshooting Steps**:
30
+
31
+ 1. **Verify Ollama URL Format**:
32
+ - When running the Web UI container, ensure the `OLLAMA_BASE_URL` is correctly set. (e.g., `http://192.168.1.1:11434` for different host setups).
33
+ - In the Open WebUI, navigate to "Settings" > "General".
34
+ - Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]` (e.g., `http://localhost:11434`).
35
+
36
+ By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord.
backend/.dockerignore CHANGED
@@ -1,14 +1,14 @@
1
- __pycache__
2
- .env
3
- _old
4
- uploads
5
- .ipynb_checkpoints
6
- *.db
7
- _test
8
- !/data
9
- /data/*
10
- !/data/litellm
11
- /data/litellm/*
12
- !data/litellm/config.yaml
13
-
14
  !data/config.json
 
1
+ __pycache__
2
+ .env
3
+ _old
4
+ uploads
5
+ .ipynb_checkpoints
6
+ *.db
7
+ _test
8
+ !/data
9
+ /data/*
10
+ !/data/litellm
11
+ /data/litellm/*
12
+ !data/litellm/config.yaml
13
+
14
  !data/config.json
backend/.gitignore CHANGED
@@ -1,12 +1,12 @@
1
- __pycache__
2
- .env
3
- _old
4
- uploads
5
- .ipynb_checkpoints
6
- *.db
7
- _test
8
- Pipfile
9
- !/data
10
- /data/*
11
- /open_webui/data/*
12
  .webui_secret_key
 
1
+ __pycache__
2
+ .env
3
+ _old
4
+ uploads
5
+ .ipynb_checkpoints
6
+ *.db
7
+ _test
8
+ Pipfile
9
+ !/data
10
+ /data/*
11
+ /open_webui/data/*
12
  .webui_secret_key
backend/open_webui/__init__.py CHANGED
@@ -1,77 +1,77 @@
1
- import base64
2
- import os
3
- import random
4
- from pathlib import Path
5
-
6
- import typer
7
- import uvicorn
8
-
9
- app = typer.Typer()
10
-
11
- KEY_FILE = Path.cwd() / ".webui_secret_key"
12
-
13
-
14
- @app.command()
15
- def serve(
16
- host: str = "0.0.0.0",
17
- port: int = 8080,
18
- ):
19
- os.environ["FROM_INIT_PY"] = "true"
20
- if os.getenv("WEBUI_SECRET_KEY") is None:
21
- typer.echo(
22
- "Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
23
- )
24
- if not KEY_FILE.exists():
25
- typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}")
26
- KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12)))
27
- typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}")
28
- os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text()
29
-
30
- if os.getenv("USE_CUDA_DOCKER", "false") == "true":
31
- typer.echo(
32
- "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
33
- )
34
- LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":")
35
- os.environ["LD_LIBRARY_PATH"] = ":".join(
36
- LD_LIBRARY_PATH
37
- + [
38
- "/usr/local/lib/python3.11/site-packages/torch/lib",
39
- "/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
40
- ]
41
- )
42
- try:
43
- import torch
44
-
45
- assert torch.cuda.is_available(), "CUDA not available"
46
- typer.echo("CUDA seems to be working")
47
- except Exception as e:
48
- typer.echo(
49
- "Error when testing CUDA but USE_CUDA_DOCKER is true. "
50
- "Resetting USE_CUDA_DOCKER to false and removing "
51
- f"LD_LIBRARY_PATH modifications: {e}"
52
- )
53
- os.environ["USE_CUDA_DOCKER"] = "false"
54
- os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
55
-
56
- import open_webui.main # we need set environment variables before importing main
57
-
58
- uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*")
59
-
60
-
61
- @app.command()
62
- def dev(
63
- host: str = "0.0.0.0",
64
- port: int = 8080,
65
- reload: bool = True,
66
- ):
67
- uvicorn.run(
68
- "open_webui.main:app",
69
- host=host,
70
- port=port,
71
- reload=reload,
72
- forwarded_allow_ips="*",
73
- )
74
-
75
-
76
- if __name__ == "__main__":
77
- app()
 
1
+ import base64
2
+ import os
3
+ import random
4
+ from pathlib import Path
5
+
6
+ import typer
7
+ import uvicorn
8
+
9
+ app = typer.Typer()
10
+
11
+ KEY_FILE = Path.cwd() / ".webui_secret_key"
12
+
13
+
14
+ @app.command()
15
+ def serve(
16
+ host: str = "0.0.0.0",
17
+ port: int = 8080,
18
+ ):
19
+ os.environ["FROM_INIT_PY"] = "true"
20
+ if os.getenv("WEBUI_SECRET_KEY") is None:
21
+ typer.echo(
22
+ "Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
23
+ )
24
+ if not KEY_FILE.exists():
25
+ typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}")
26
+ KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12)))
27
+ typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}")
28
+ os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text()
29
+
30
+ if os.getenv("USE_CUDA_DOCKER", "false") == "true":
31
+ typer.echo(
32
+ "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
33
+ )
34
+ LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":")
35
+ os.environ["LD_LIBRARY_PATH"] = ":".join(
36
+ LD_LIBRARY_PATH
37
+ + [
38
+ "/usr/local/lib/python3.11/site-packages/torch/lib",
39
+ "/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
40
+ ]
41
+ )
42
+ try:
43
+ import torch
44
+
45
+ assert torch.cuda.is_available(), "CUDA not available"
46
+ typer.echo("CUDA seems to be working")
47
+ except Exception as e:
48
+ typer.echo(
49
+ "Error when testing CUDA but USE_CUDA_DOCKER is true. "
50
+ "Resetting USE_CUDA_DOCKER to false and removing "
51
+ f"LD_LIBRARY_PATH modifications: {e}"
52
+ )
53
+ os.environ["USE_CUDA_DOCKER"] = "false"
54
+ os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
55
+
56
+ import open_webui.main # we need set environment variables before importing main
57
+
58
+ uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*")
59
+
60
+
61
+ @app.command()
62
+ def dev(
63
+ host: str = "0.0.0.0",
64
+ port: int = 8080,
65
+ reload: bool = True,
66
+ ):
67
+ uvicorn.run(
68
+ "open_webui.main:app",
69
+ host=host,
70
+ port=port,
71
+ reload=reload,
72
+ forwarded_allow_ips="*",
73
+ )
74
+
75
+
76
+ if __name__ == "__main__":
77
+ app()
backend/open_webui/alembic.ini CHANGED
@@ -1,114 +1,114 @@
1
- # A generic, single database configuration.
2
-
3
- [alembic]
4
- # path to migration scripts
5
- script_location = migrations
6
-
7
- # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
8
- # Uncomment the line below if you want the files to be prepended with date and time
9
- # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
10
-
11
- # sys.path path, will be prepended to sys.path if present.
12
- # defaults to the current working directory.
13
- prepend_sys_path = .
14
-
15
- # timezone to use when rendering the date within the migration file
16
- # as well as the filename.
17
- # If specified, requires the python>=3.9 or backports.zoneinfo library.
18
- # Any required deps can installed by adding `alembic[tz]` to the pip requirements
19
- # string value is passed to ZoneInfo()
20
- # leave blank for localtime
21
- # timezone =
22
-
23
- # max length of characters to apply to the
24
- # "slug" field
25
- # truncate_slug_length = 40
26
-
27
- # set to 'true' to run the environment during
28
- # the 'revision' command, regardless of autogenerate
29
- # revision_environment = false
30
-
31
- # set to 'true' to allow .pyc and .pyo files without
32
- # a source .py file to be detected as revisions in the
33
- # versions/ directory
34
- # sourceless = false
35
-
36
- # version location specification; This defaults
37
- # to migrations/versions. When using multiple version
38
- # directories, initial revisions must be specified with --version-path.
39
- # The path separator used here should be the separator specified by "version_path_separator" below.
40
- # version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
41
-
42
- # version path separator; As mentioned above, this is the character used to split
43
- # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
44
- # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
45
- # Valid values for version_path_separator are:
46
- #
47
- # version_path_separator = :
48
- # version_path_separator = ;
49
- # version_path_separator = space
50
- version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
51
-
52
- # set to 'true' to search source files recursively
53
- # in each "version_locations" directory
54
- # new in Alembic version 1.10
55
- # recursive_version_locations = false
56
-
57
- # the output encoding used when revision files
58
- # are written from script.py.mako
59
- # output_encoding = utf-8
60
-
61
- # sqlalchemy.url = REPLACE_WITH_DATABASE_URL
62
-
63
-
64
- [post_write_hooks]
65
- # post_write_hooks defines scripts or Python functions that are run
66
- # on newly generated revision scripts. See the documentation for further
67
- # detail and examples
68
-
69
- # format using "black" - use the console_scripts runner, against the "black" entrypoint
70
- # hooks = black
71
- # black.type = console_scripts
72
- # black.entrypoint = black
73
- # black.options = -l 79 REVISION_SCRIPT_FILENAME
74
-
75
- # lint with attempts to fix using "ruff" - use the exec runner, execute a binary
76
- # hooks = ruff
77
- # ruff.type = exec
78
- # ruff.executable = %(here)s/.venv/bin/ruff
79
- # ruff.options = --fix REVISION_SCRIPT_FILENAME
80
-
81
- # Logging configuration
82
- [loggers]
83
- keys = root,sqlalchemy,alembic
84
-
85
- [handlers]
86
- keys = console
87
-
88
- [formatters]
89
- keys = generic
90
-
91
- [logger_root]
92
- level = WARN
93
- handlers = console
94
- qualname =
95
-
96
- [logger_sqlalchemy]
97
- level = WARN
98
- handlers =
99
- qualname = sqlalchemy.engine
100
-
101
- [logger_alembic]
102
- level = INFO
103
- handlers =
104
- qualname = alembic
105
-
106
- [handler_console]
107
- class = StreamHandler
108
- args = (sys.stderr,)
109
- level = NOTSET
110
- formatter = generic
111
-
112
- [formatter_generic]
113
- format = %(levelname)-5.5s [%(name)s] %(message)s
114
- datefmt = %H:%M:%S
 
1
+ # A generic, single database configuration.
2
+
3
+ [alembic]
4
+ # path to migration scripts
5
+ script_location = migrations
6
+
7
+ # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
8
+ # Uncomment the line below if you want the files to be prepended with date and time
9
+ # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
10
+
11
+ # sys.path path, will be prepended to sys.path if present.
12
+ # defaults to the current working directory.
13
+ prepend_sys_path = .
14
+
15
+ # timezone to use when rendering the date within the migration file
16
+ # as well as the filename.
17
+ # If specified, requires the python>=3.9 or backports.zoneinfo library.
18
+ # Any required deps can installed by adding `alembic[tz]` to the pip requirements
19
+ # string value is passed to ZoneInfo()
20
+ # leave blank for localtime
21
+ # timezone =
22
+
23
+ # max length of characters to apply to the
24
+ # "slug" field
25
+ # truncate_slug_length = 40
26
+
27
+ # set to 'true' to run the environment during
28
+ # the 'revision' command, regardless of autogenerate
29
+ # revision_environment = false
30
+
31
+ # set to 'true' to allow .pyc and .pyo files without
32
+ # a source .py file to be detected as revisions in the
33
+ # versions/ directory
34
+ # sourceless = false
35
+
36
+ # version location specification; This defaults
37
+ # to migrations/versions. When using multiple version
38
+ # directories, initial revisions must be specified with --version-path.
39
+ # The path separator used here should be the separator specified by "version_path_separator" below.
40
+ # version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
41
+
42
+ # version path separator; As mentioned above, this is the character used to split
43
+ # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
44
+ # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
45
+ # Valid values for version_path_separator are:
46
+ #
47
+ # version_path_separator = :
48
+ # version_path_separator = ;
49
+ # version_path_separator = space
50
+ version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
51
+
52
+ # set to 'true' to search source files recursively
53
+ # in each "version_locations" directory
54
+ # new in Alembic version 1.10
55
+ # recursive_version_locations = false
56
+
57
+ # the output encoding used when revision files
58
+ # are written from script.py.mako
59
+ # output_encoding = utf-8
60
+
61
+ # sqlalchemy.url = REPLACE_WITH_DATABASE_URL
62
+
63
+
64
+ [post_write_hooks]
65
+ # post_write_hooks defines scripts or Python functions that are run
66
+ # on newly generated revision scripts. See the documentation for further
67
+ # detail and examples
68
+
69
+ # format using "black" - use the console_scripts runner, against the "black" entrypoint
70
+ # hooks = black
71
+ # black.type = console_scripts
72
+ # black.entrypoint = black
73
+ # black.options = -l 79 REVISION_SCRIPT_FILENAME
74
+
75
+ # lint with attempts to fix using "ruff" - use the exec runner, execute a binary
76
+ # hooks = ruff
77
+ # ruff.type = exec
78
+ # ruff.executable = %(here)s/.venv/bin/ruff
79
+ # ruff.options = --fix REVISION_SCRIPT_FILENAME
80
+
81
+ # Logging configuration
82
+ [loggers]
83
+ keys = root,sqlalchemy,alembic
84
+
85
+ [handlers]
86
+ keys = console
87
+
88
+ [formatters]
89
+ keys = generic
90
+
91
+ [logger_root]
92
+ level = WARN
93
+ handlers = console
94
+ qualname =
95
+
96
+ [logger_sqlalchemy]
97
+ level = WARN
98
+ handlers =
99
+ qualname = sqlalchemy.engine
100
+
101
+ [logger_alembic]
102
+ level = INFO
103
+ handlers =
104
+ qualname = alembic
105
+
106
+ [handler_console]
107
+ class = StreamHandler
108
+ args = (sys.stderr,)
109
+ level = NOTSET
110
+ formatter = generic
111
+
112
+ [formatter_generic]
113
+ format = %(levelname)-5.5s [%(name)s] %(message)s
114
+ datefmt = %H:%M:%S
backend/open_webui/apps/audio/main.py CHANGED
@@ -1,639 +1,639 @@
1
- import hashlib
2
- import json
3
- import logging
4
- import os
5
- import uuid
6
- from functools import lru_cache
7
- from pathlib import Path
8
- from pydub import AudioSegment
9
- from pydub.silence import split_on_silence
10
-
11
- import requests
12
- from open_webui.config import (
13
- AUDIO_STT_ENGINE,
14
- AUDIO_STT_MODEL,
15
- AUDIO_STT_OPENAI_API_BASE_URL,
16
- AUDIO_STT_OPENAI_API_KEY,
17
- AUDIO_TTS_API_KEY,
18
- AUDIO_TTS_ENGINE,
19
- AUDIO_TTS_MODEL,
20
- AUDIO_TTS_OPENAI_API_BASE_URL,
21
- AUDIO_TTS_OPENAI_API_KEY,
22
- AUDIO_TTS_SPLIT_ON,
23
- AUDIO_TTS_VOICE,
24
- AUDIO_TTS_AZURE_SPEECH_REGION,
25
- AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
26
- CACHE_DIR,
27
- CORS_ALLOW_ORIGIN,
28
- WHISPER_MODEL,
29
- WHISPER_MODEL_AUTO_UPDATE,
30
- WHISPER_MODEL_DIR,
31
- AppConfig,
32
- )
33
-
34
- from open_webui.constants import ERROR_MESSAGES
35
- from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE
36
- from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
37
- from fastapi.middleware.cors import CORSMiddleware
38
- from fastapi.responses import FileResponse
39
- from pydantic import BaseModel
40
- from open_webui.utils.utils import get_admin_user, get_verified_user
41
-
42
- # Constants
43
- MAX_FILE_SIZE_MB = 25
44
- MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
45
-
46
-
47
- log = logging.getLogger(__name__)
48
- log.setLevel(SRC_LOG_LEVELS["AUDIO"])
49
-
50
- app = FastAPI()
51
- app.add_middleware(
52
- CORSMiddleware,
53
- allow_origins=CORS_ALLOW_ORIGIN,
54
- allow_credentials=True,
55
- allow_methods=["*"],
56
- allow_headers=["*"],
57
- )
58
-
59
- app.state.config = AppConfig()
60
-
61
- app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
62
- app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
63
- app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
64
- app.state.config.STT_MODEL = AUDIO_STT_MODEL
65
-
66
- app.state.config.WHISPER_MODEL = WHISPER_MODEL
67
- app.state.faster_whisper_model = None
68
-
69
- app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
70
- app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
71
- app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
72
- app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
73
- app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
74
- app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
75
- app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
76
-
77
- app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
78
- app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
79
-
80
- # setting device type for whisper model
81
- whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
82
- log.info(f"whisper_device_type: {whisper_device_type}")
83
-
84
- SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
85
- SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
86
-
87
-
88
- def set_faster_whisper_model(model: str, auto_update: bool = False):
89
- if model and app.state.config.STT_ENGINE == "":
90
- from faster_whisper import WhisperModel
91
-
92
- faster_whisper_kwargs = {
93
- "model_size_or_path": model,
94
- "device": whisper_device_type,
95
- "compute_type": "int8",
96
- "download_root": WHISPER_MODEL_DIR,
97
- "local_files_only": not auto_update,
98
- }
99
-
100
- try:
101
- app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
102
- except Exception:
103
- log.warning(
104
- "WhisperModel initialization failed, attempting download with local_files_only=False"
105
- )
106
- faster_whisper_kwargs["local_files_only"] = False
107
- app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
108
-
109
- else:
110
- app.state.faster_whisper_model = None
111
-
112
-
113
- class TTSConfigForm(BaseModel):
114
- OPENAI_API_BASE_URL: str
115
- OPENAI_API_KEY: str
116
- API_KEY: str
117
- ENGINE: str
118
- MODEL: str
119
- VOICE: str
120
- SPLIT_ON: str
121
- AZURE_SPEECH_REGION: str
122
- AZURE_SPEECH_OUTPUT_FORMAT: str
123
-
124
-
125
- class STTConfigForm(BaseModel):
126
- OPENAI_API_BASE_URL: str
127
- OPENAI_API_KEY: str
128
- ENGINE: str
129
- MODEL: str
130
- WHISPER_MODEL: str
131
-
132
-
133
- class AudioConfigUpdateForm(BaseModel):
134
- tts: TTSConfigForm
135
- stt: STTConfigForm
136
-
137
-
138
- from pydub import AudioSegment
139
- from pydub.utils import mediainfo
140
-
141
-
142
- def is_mp4_audio(file_path):
143
- """Check if the given file is an MP4 audio file."""
144
- if not os.path.isfile(file_path):
145
- print(f"File not found: {file_path}")
146
- return False
147
-
148
- info = mediainfo(file_path)
149
- if (
150
- info.get("codec_name") == "aac"
151
- and info.get("codec_type") == "audio"
152
- and info.get("codec_tag_string") == "mp4a"
153
- ):
154
- return True
155
- return False
156
-
157
-
158
- def convert_mp4_to_wav(file_path, output_path):
159
- """Convert MP4 audio file to WAV format."""
160
- audio = AudioSegment.from_file(file_path, format="mp4")
161
- audio.export(output_path, format="wav")
162
- print(f"Converted {file_path} to {output_path}")
163
-
164
-
165
- @app.get("/config")
166
- async def get_audio_config(user=Depends(get_admin_user)):
167
- return {
168
- "tts": {
169
- "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
170
- "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
171
- "API_KEY": app.state.config.TTS_API_KEY,
172
- "ENGINE": app.state.config.TTS_ENGINE,
173
- "MODEL": app.state.config.TTS_MODEL,
174
- "VOICE": app.state.config.TTS_VOICE,
175
- "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
176
- "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
177
- "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
178
- },
179
- "stt": {
180
- "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
181
- "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
182
- "ENGINE": app.state.config.STT_ENGINE,
183
- "MODEL": app.state.config.STT_MODEL,
184
- "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
185
- },
186
- }
187
-
188
-
189
- @app.post("/config/update")
190
- async def update_audio_config(
191
- form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
192
- ):
193
- app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
194
- app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
195
- app.state.config.TTS_API_KEY = form_data.tts.API_KEY
196
- app.state.config.TTS_ENGINE = form_data.tts.ENGINE
197
- app.state.config.TTS_MODEL = form_data.tts.MODEL
198
- app.state.config.TTS_VOICE = form_data.tts.VOICE
199
- app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
200
- app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
201
- app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
202
- form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
203
- )
204
-
205
- app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
206
- app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
207
- app.state.config.STT_ENGINE = form_data.stt.ENGINE
208
- app.state.config.STT_MODEL = form_data.stt.MODEL
209
- app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
210
- set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
211
-
212
- return {
213
- "tts": {
214
- "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
215
- "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
216
- "API_KEY": app.state.config.TTS_API_KEY,
217
- "ENGINE": app.state.config.TTS_ENGINE,
218
- "MODEL": app.state.config.TTS_MODEL,
219
- "VOICE": app.state.config.TTS_VOICE,
220
- "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
221
- "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
222
- "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
223
- },
224
- "stt": {
225
- "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
226
- "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
227
- "ENGINE": app.state.config.STT_ENGINE,
228
- "MODEL": app.state.config.STT_MODEL,
229
- "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
230
- },
231
- }
232
-
233
-
234
- @app.post("/speech")
235
- async def speech(request: Request, user=Depends(get_verified_user)):
236
- body = await request.body()
237
- name = hashlib.sha256(body).hexdigest()
238
-
239
- file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
240
- file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
241
-
242
- # Check if the file already exists in the cache
243
- if file_path.is_file():
244
- return FileResponse(file_path)
245
-
246
- if app.state.config.TTS_ENGINE == "openai":
247
- headers = {}
248
- headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
249
- headers["Content-Type"] = "application/json"
250
-
251
- try:
252
- body = body.decode("utf-8")
253
- body = json.loads(body)
254
- body["model"] = app.state.config.TTS_MODEL
255
- body = json.dumps(body).encode("utf-8")
256
- except Exception:
257
- pass
258
-
259
- r = None
260
- try:
261
- r = requests.post(
262
- url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
263
- data=body,
264
- headers=headers,
265
- stream=True,
266
- )
267
-
268
- r.raise_for_status()
269
-
270
- # Save the streaming content to a file
271
- with open(file_path, "wb") as f:
272
- for chunk in r.iter_content(chunk_size=8192):
273
- f.write(chunk)
274
-
275
- with open(file_body_path, "w") as f:
276
- json.dump(json.loads(body.decode("utf-8")), f)
277
-
278
- # Return the saved file
279
- return FileResponse(file_path)
280
-
281
- except Exception as e:
282
- log.exception(e)
283
- error_detail = "Open WebUI: Server Connection Error"
284
- if r is not None:
285
- try:
286
- res = r.json()
287
- if "error" in res:
288
- error_detail = f"External: {res['error']['message']}"
289
- except Exception:
290
- error_detail = f"External: {e}"
291
-
292
- raise HTTPException(
293
- status_code=r.status_code if r != None else 500,
294
- detail=error_detail,
295
- )
296
-
297
- elif app.state.config.TTS_ENGINE == "elevenlabs":
298
- payload = None
299
- try:
300
- payload = json.loads(body.decode("utf-8"))
301
- except Exception as e:
302
- log.exception(e)
303
- raise HTTPException(status_code=400, detail="Invalid JSON payload")
304
-
305
- voice_id = payload.get("voice", "")
306
-
307
- if voice_id not in get_available_voices():
308
- raise HTTPException(
309
- status_code=400,
310
- detail="Invalid voice id",
311
- )
312
-
313
- url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
314
-
315
- headers = {
316
- "Accept": "audio/mpeg",
317
- "Content-Type": "application/json",
318
- "xi-api-key": app.state.config.TTS_API_KEY,
319
- }
320
-
321
- data = {
322
- "text": payload["input"],
323
- "model_id": app.state.config.TTS_MODEL,
324
- "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
325
- }
326
-
327
- try:
328
- r = requests.post(url, json=data, headers=headers)
329
-
330
- r.raise_for_status()
331
-
332
- # Save the streaming content to a file
333
- with open(file_path, "wb") as f:
334
- for chunk in r.iter_content(chunk_size=8192):
335
- f.write(chunk)
336
-
337
- with open(file_body_path, "w") as f:
338
- json.dump(json.loads(body.decode("utf-8")), f)
339
-
340
- # Return the saved file
341
- return FileResponse(file_path)
342
-
343
- except Exception as e:
344
- log.exception(e)
345
- error_detail = "Open WebUI: Server Connection Error"
346
- if r is not None:
347
- try:
348
- res = r.json()
349
- if "error" in res:
350
- error_detail = f"External: {res['error']['message']}"
351
- except Exception:
352
- error_detail = f"External: {e}"
353
-
354
- raise HTTPException(
355
- status_code=r.status_code if r != None else 500,
356
- detail=error_detail,
357
- )
358
-
359
- elif app.state.config.TTS_ENGINE == "azure":
360
- payload = None
361
- try:
362
- payload = json.loads(body.decode("utf-8"))
363
- except Exception as e:
364
- log.exception(e)
365
- raise HTTPException(status_code=400, detail="Invalid JSON payload")
366
-
367
- region = app.state.config.TTS_AZURE_SPEECH_REGION
368
- language = app.state.config.TTS_VOICE
369
- locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
370
- output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
371
- url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
372
-
373
- headers = {
374
- "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
375
- "Content-Type": "application/ssml+xml",
376
- "X-Microsoft-OutputFormat": output_format,
377
- }
378
-
379
- data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
380
- <voice name="{language}">{payload["input"]}</voice>
381
- </speak>"""
382
-
383
- response = requests.post(url, headers=headers, data=data)
384
-
385
- if response.status_code == 200:
386
- with open(file_path, "wb") as f:
387
- f.write(response.content)
388
- return FileResponse(file_path)
389
- else:
390
- log.error(f"Error synthesizing speech - {response.reason}")
391
- raise HTTPException(
392
- status_code=500, detail=f"Error synthesizing speech - {response.reason}"
393
- )
394
-
395
-
396
- def transcribe(file_path):
397
- print("transcribe", file_path)
398
- filename = os.path.basename(file_path)
399
- file_dir = os.path.dirname(file_path)
400
- id = filename.split(".")[0]
401
-
402
- if app.state.config.STT_ENGINE == "":
403
- if app.state.faster_whisper_model is None:
404
- set_faster_whisper_model(app.state.config.WHISPER_MODEL)
405
-
406
- model = app.state.faster_whisper_model
407
- segments, info = model.transcribe(file_path, beam_size=5)
408
- log.info(
409
- "Detected language '%s' with probability %f"
410
- % (info.language, info.language_probability)
411
- )
412
-
413
- transcript = "".join([segment.text for segment in list(segments)])
414
- data = {"text": transcript.strip()}
415
-
416
- # save the transcript to a json file
417
- transcript_file = f"{file_dir}/{id}.json"
418
- with open(transcript_file, "w") as f:
419
- json.dump(data, f)
420
-
421
- log.debug(data)
422
- return data
423
- elif app.state.config.STT_ENGINE == "openai":
424
- if is_mp4_audio(file_path):
425
- print("is_mp4_audio")
426
- os.rename(file_path, file_path.replace(".wav", ".mp4"))
427
- # Convert MP4 audio file to WAV format
428
- convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
429
-
430
- headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
431
-
432
- files = {"file": (filename, open(file_path, "rb"))}
433
- data = {"model": app.state.config.STT_MODEL}
434
-
435
- log.debug(files, data)
436
-
437
- r = None
438
- try:
439
- r = requests.post(
440
- url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
441
- headers=headers,
442
- files=files,
443
- data=data,
444
- )
445
-
446
- r.raise_for_status()
447
-
448
- data = r.json()
449
-
450
- # save the transcript to a json file
451
- transcript_file = f"{file_dir}/{id}.json"
452
- with open(transcript_file, "w") as f:
453
- json.dump(data, f)
454
-
455
- print(data)
456
- return data
457
- except Exception as e:
458
- log.exception(e)
459
- error_detail = "Open WebUI: Server Connection Error"
460
- if r is not None:
461
- try:
462
- res = r.json()
463
- if "error" in res:
464
- error_detail = f"External: {res['error']['message']}"
465
- except Exception:
466
- error_detail = f"External: {e}"
467
-
468
- raise Exception(error_detail)
469
-
470
-
471
- @app.post("/transcriptions")
472
- def transcription(
473
- file: UploadFile = File(...),
474
- user=Depends(get_verified_user),
475
- ):
476
- log.info(f"file.content_type: {file.content_type}")
477
-
478
- if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
479
- raise HTTPException(
480
- status_code=status.HTTP_400_BAD_REQUEST,
481
- detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
482
- )
483
-
484
- try:
485
- ext = file.filename.split(".")[-1]
486
- id = uuid.uuid4()
487
-
488
- filename = f"{id}.{ext}"
489
- contents = file.file.read()
490
-
491
- file_dir = f"{CACHE_DIR}/audio/transcriptions"
492
- os.makedirs(file_dir, exist_ok=True)
493
- file_path = f"{file_dir}/{filename}"
494
-
495
- with open(file_path, "wb") as f:
496
- f.write(contents)
497
-
498
- try:
499
- if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB
500
- log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
501
- audio = AudioSegment.from_file(file_path)
502
- audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
503
- compressed_path = f"{file_dir}/{id}_compressed.opus"
504
- audio.export(compressed_path, format="opus", bitrate="32k")
505
- log.debug(f"Compressed audio to {compressed_path}")
506
- file_path = compressed_path
507
-
508
- if (
509
- os.path.getsize(file_path) > MAX_FILE_SIZE
510
- ): # Still larger than 25MB after compression
511
- log.debug(
512
- f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
513
- )
514
- raise HTTPException(
515
- status_code=status.HTTP_400_BAD_REQUEST,
516
- detail=ERROR_MESSAGES.FILE_TOO_LARGE(
517
- size=f"{MAX_FILE_SIZE_MB}MB"
518
- ),
519
- )
520
-
521
- data = transcribe(file_path)
522
- else:
523
- data = transcribe(file_path)
524
-
525
- return data
526
- except Exception as e:
527
- log.exception(e)
528
- raise HTTPException(
529
- status_code=status.HTTP_400_BAD_REQUEST,
530
- detail=ERROR_MESSAGES.DEFAULT(e),
531
- )
532
-
533
- except Exception as e:
534
- log.exception(e)
535
-
536
- raise HTTPException(
537
- status_code=status.HTTP_400_BAD_REQUEST,
538
- detail=ERROR_MESSAGES.DEFAULT(e),
539
- )
540
-
541
-
542
- def get_available_models() -> list[dict]:
543
- if app.state.config.TTS_ENGINE == "openai":
544
- return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
545
- elif app.state.config.TTS_ENGINE == "elevenlabs":
546
- headers = {
547
- "xi-api-key": app.state.config.TTS_API_KEY,
548
- "Content-Type": "application/json",
549
- }
550
-
551
- try:
552
- response = requests.get(
553
- "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
554
- )
555
- response.raise_for_status()
556
- models = response.json()
557
- return [
558
- {"name": model["name"], "id": model["model_id"]} for model in models
559
- ]
560
- except requests.RequestException as e:
561
- log.error(f"Error fetching voices: {str(e)}")
562
- return []
563
-
564
-
565
- @app.get("/models")
566
- async def get_models(user=Depends(get_verified_user)):
567
- return {"models": get_available_models()}
568
-
569
-
570
- def get_available_voices() -> dict:
571
- """Returns {voice_id: voice_name} dict"""
572
- ret = {}
573
- if app.state.config.TTS_ENGINE == "openai":
574
- ret = {
575
- "alloy": "alloy",
576
- "echo": "echo",
577
- "fable": "fable",
578
- "onyx": "onyx",
579
- "nova": "nova",
580
- "shimmer": "shimmer",
581
- }
582
- elif app.state.config.TTS_ENGINE == "elevenlabs":
583
- try:
584
- ret = get_elevenlabs_voices()
585
- except Exception:
586
- # Avoided @lru_cache with exception
587
- pass
588
- elif app.state.config.TTS_ENGINE == "azure":
589
- try:
590
- region = app.state.config.TTS_AZURE_SPEECH_REGION
591
- url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
592
- headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
593
-
594
- response = requests.get(url, headers=headers)
595
- response.raise_for_status()
596
- voices = response.json()
597
- for voice in voices:
598
- ret[voice["ShortName"]] = (
599
- f"{voice['DisplayName']} ({voice['ShortName']})"
600
- )
601
- except requests.RequestException as e:
602
- log.error(f"Error fetching voices: {str(e)}")
603
-
604
- return ret
605
-
606
-
607
- @lru_cache
608
- def get_elevenlabs_voices() -> dict:
609
- """
610
- Note, set the following in your .env file to use Elevenlabs:
611
- AUDIO_TTS_ENGINE=elevenlabs
612
- AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
613
- AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
614
- AUDIO_TTS_MODEL=eleven_multilingual_v2
615
- """
616
- headers = {
617
- "xi-api-key": app.state.config.TTS_API_KEY,
618
- "Content-Type": "application/json",
619
- }
620
- try:
621
- # TODO: Add retries
622
- response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
623
- response.raise_for_status()
624
- voices_data = response.json()
625
-
626
- voices = {}
627
- for voice in voices_data.get("voices", []):
628
- voices[voice["voice_id"]] = voice["name"]
629
- except requests.RequestException as e:
630
- # Avoid @lru_cache with exception
631
- log.error(f"Error fetching voices: {str(e)}")
632
- raise RuntimeError(f"Error fetching voices: {str(e)}")
633
-
634
- return voices
635
-
636
-
637
- @app.get("/voices")
638
- async def get_voices(user=Depends(get_verified_user)):
639
- return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}
 
1
+ import hashlib
2
+ import json
3
+ import logging
4
+ import os
5
+ import uuid
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+ from pydub import AudioSegment
9
+ from pydub.silence import split_on_silence
10
+
11
+ import requests
12
+ from open_webui.config import (
13
+ AUDIO_STT_ENGINE,
14
+ AUDIO_STT_MODEL,
15
+ AUDIO_STT_OPENAI_API_BASE_URL,
16
+ AUDIO_STT_OPENAI_API_KEY,
17
+ AUDIO_TTS_API_KEY,
18
+ AUDIO_TTS_ENGINE,
19
+ AUDIO_TTS_MODEL,
20
+ AUDIO_TTS_OPENAI_API_BASE_URL,
21
+ AUDIO_TTS_OPENAI_API_KEY,
22
+ AUDIO_TTS_SPLIT_ON,
23
+ AUDIO_TTS_VOICE,
24
+ AUDIO_TTS_AZURE_SPEECH_REGION,
25
+ AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
26
+ CACHE_DIR,
27
+ CORS_ALLOW_ORIGIN,
28
+ WHISPER_MODEL,
29
+ WHISPER_MODEL_AUTO_UPDATE,
30
+ WHISPER_MODEL_DIR,
31
+ AppConfig,
32
+ )
33
+
34
+ from open_webui.constants import ERROR_MESSAGES
35
+ from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE
36
+ from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
37
+ from fastapi.middleware.cors import CORSMiddleware
38
+ from fastapi.responses import FileResponse
39
+ from pydantic import BaseModel
40
+ from open_webui.utils.utils import get_admin_user, get_verified_user
41
+
42
+ # Constants
43
+ MAX_FILE_SIZE_MB = 25
44
+ MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
45
+
46
+
47
+ log = logging.getLogger(__name__)
48
+ log.setLevel(SRC_LOG_LEVELS["AUDIO"])
49
+
50
+ app = FastAPI()
51
+ app.add_middleware(
52
+ CORSMiddleware,
53
+ allow_origins=CORS_ALLOW_ORIGIN,
54
+ allow_credentials=True,
55
+ allow_methods=["*"],
56
+ allow_headers=["*"],
57
+ )
58
+
59
+ app.state.config = AppConfig()
60
+
61
+ app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
62
+ app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
63
+ app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
64
+ app.state.config.STT_MODEL = AUDIO_STT_MODEL
65
+
66
+ app.state.config.WHISPER_MODEL = WHISPER_MODEL
67
+ app.state.faster_whisper_model = None
68
+
69
+ app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
70
+ app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
71
+ app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
72
+ app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
73
+ app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
74
+ app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
75
+ app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
76
+
77
+ app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
78
+ app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
79
+
80
+ # setting device type for whisper model
81
+ whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
82
+ log.info(f"whisper_device_type: {whisper_device_type}")
83
+
84
+ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
85
+ SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
86
+
87
+
88
+ def set_faster_whisper_model(model: str, auto_update: bool = False):
89
+ if model and app.state.config.STT_ENGINE == "":
90
+ from faster_whisper import WhisperModel
91
+
92
+ faster_whisper_kwargs = {
93
+ "model_size_or_path": model,
94
+ "device": whisper_device_type,
95
+ "compute_type": "int8",
96
+ "download_root": WHISPER_MODEL_DIR,
97
+ "local_files_only": not auto_update,
98
+ }
99
+
100
+ try:
101
+ app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
102
+ except Exception:
103
+ log.warning(
104
+ "WhisperModel initialization failed, attempting download with local_files_only=False"
105
+ )
106
+ faster_whisper_kwargs["local_files_only"] = False
107
+ app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
108
+
109
+ else:
110
+ app.state.faster_whisper_model = None
111
+
112
+
113
+ class TTSConfigForm(BaseModel):
114
+ OPENAI_API_BASE_URL: str
115
+ OPENAI_API_KEY: str
116
+ API_KEY: str
117
+ ENGINE: str
118
+ MODEL: str
119
+ VOICE: str
120
+ SPLIT_ON: str
121
+ AZURE_SPEECH_REGION: str
122
+ AZURE_SPEECH_OUTPUT_FORMAT: str
123
+
124
+
125
+ class STTConfigForm(BaseModel):
126
+ OPENAI_API_BASE_URL: str
127
+ OPENAI_API_KEY: str
128
+ ENGINE: str
129
+ MODEL: str
130
+ WHISPER_MODEL: str
131
+
132
+
133
+ class AudioConfigUpdateForm(BaseModel):
134
+ tts: TTSConfigForm
135
+ stt: STTConfigForm
136
+
137
+
138
+ from pydub import AudioSegment
139
+ from pydub.utils import mediainfo
140
+
141
+
142
+ def is_mp4_audio(file_path):
143
+ """Check if the given file is an MP4 audio file."""
144
+ if not os.path.isfile(file_path):
145
+ print(f"File not found: {file_path}")
146
+ return False
147
+
148
+ info = mediainfo(file_path)
149
+ if (
150
+ info.get("codec_name") == "aac"
151
+ and info.get("codec_type") == "audio"
152
+ and info.get("codec_tag_string") == "mp4a"
153
+ ):
154
+ return True
155
+ return False
156
+
157
+
158
+ def convert_mp4_to_wav(file_path, output_path):
159
+ """Convert MP4 audio file to WAV format."""
160
+ audio = AudioSegment.from_file(file_path, format="mp4")
161
+ audio.export(output_path, format="wav")
162
+ print(f"Converted {file_path} to {output_path}")
163
+
164
+
165
+ @app.get("/config")
166
+ async def get_audio_config(user=Depends(get_admin_user)):
167
+ return {
168
+ "tts": {
169
+ "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
170
+ "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
171
+ "API_KEY": app.state.config.TTS_API_KEY,
172
+ "ENGINE": app.state.config.TTS_ENGINE,
173
+ "MODEL": app.state.config.TTS_MODEL,
174
+ "VOICE": app.state.config.TTS_VOICE,
175
+ "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
176
+ "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
177
+ "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
178
+ },
179
+ "stt": {
180
+ "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
181
+ "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
182
+ "ENGINE": app.state.config.STT_ENGINE,
183
+ "MODEL": app.state.config.STT_MODEL,
184
+ "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
185
+ },
186
+ }
187
+
188
+
189
+ @app.post("/config/update")
190
+ async def update_audio_config(
191
+ form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
192
+ ):
193
+ app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
194
+ app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
195
+ app.state.config.TTS_API_KEY = form_data.tts.API_KEY
196
+ app.state.config.TTS_ENGINE = form_data.tts.ENGINE
197
+ app.state.config.TTS_MODEL = form_data.tts.MODEL
198
+ app.state.config.TTS_VOICE = form_data.tts.VOICE
199
+ app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
200
+ app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
201
+ app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
202
+ form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
203
+ )
204
+
205
+ app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
206
+ app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
207
+ app.state.config.STT_ENGINE = form_data.stt.ENGINE
208
+ app.state.config.STT_MODEL = form_data.stt.MODEL
209
+ app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
210
+ set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
211
+
212
+ return {
213
+ "tts": {
214
+ "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
215
+ "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
216
+ "API_KEY": app.state.config.TTS_API_KEY,
217
+ "ENGINE": app.state.config.TTS_ENGINE,
218
+ "MODEL": app.state.config.TTS_MODEL,
219
+ "VOICE": app.state.config.TTS_VOICE,
220
+ "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
221
+ "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
222
+ "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
223
+ },
224
+ "stt": {
225
+ "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
226
+ "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
227
+ "ENGINE": app.state.config.STT_ENGINE,
228
+ "MODEL": app.state.config.STT_MODEL,
229
+ "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
230
+ },
231
+ }
232
+
233
+
234
+ @app.post("/speech")
235
+ async def speech(request: Request, user=Depends(get_verified_user)):
236
+ body = await request.body()
237
+ name = hashlib.sha256(body).hexdigest()
238
+
239
+ file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
240
+ file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
241
+
242
+ # Check if the file already exists in the cache
243
+ if file_path.is_file():
244
+ return FileResponse(file_path)
245
+
246
+ if app.state.config.TTS_ENGINE == "openai":
247
+ headers = {}
248
+ headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
249
+ headers["Content-Type"] = "application/json"
250
+
251
+ try:
252
+ body = body.decode("utf-8")
253
+ body = json.loads(body)
254
+ body["model"] = app.state.config.TTS_MODEL
255
+ body = json.dumps(body).encode("utf-8")
256
+ except Exception:
257
+ pass
258
+
259
+ r = None
260
+ try:
261
+ r = requests.post(
262
+ url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
263
+ data=body,
264
+ headers=headers,
265
+ stream=True,
266
+ )
267
+
268
+ r.raise_for_status()
269
+
270
+ # Save the streaming content to a file
271
+ with open(file_path, "wb") as f:
272
+ for chunk in r.iter_content(chunk_size=8192):
273
+ f.write(chunk)
274
+
275
+ with open(file_body_path, "w") as f:
276
+ json.dump(json.loads(body.decode("utf-8")), f)
277
+
278
+ # Return the saved file
279
+ return FileResponse(file_path)
280
+
281
+ except Exception as e:
282
+ log.exception(e)
283
+ error_detail = "Open WebUI: Server Connection Error"
284
+ if r is not None:
285
+ try:
286
+ res = r.json()
287
+ if "error" in res:
288
+ error_detail = f"External: {res['error']['message']}"
289
+ except Exception:
290
+ error_detail = f"External: {e}"
291
+
292
+ raise HTTPException(
293
+ status_code=r.status_code if r != None else 500,
294
+ detail=error_detail,
295
+ )
296
+
297
+ elif app.state.config.TTS_ENGINE == "elevenlabs":
298
+ payload = None
299
+ try:
300
+ payload = json.loads(body.decode("utf-8"))
301
+ except Exception as e:
302
+ log.exception(e)
303
+ raise HTTPException(status_code=400, detail="Invalid JSON payload")
304
+
305
+ voice_id = payload.get("voice", "")
306
+
307
+ if voice_id not in get_available_voices():
308
+ raise HTTPException(
309
+ status_code=400,
310
+ detail="Invalid voice id",
311
+ )
312
+
313
+ url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
314
+
315
+ headers = {
316
+ "Accept": "audio/mpeg",
317
+ "Content-Type": "application/json",
318
+ "xi-api-key": app.state.config.TTS_API_KEY,
319
+ }
320
+
321
+ data = {
322
+ "text": payload["input"],
323
+ "model_id": app.state.config.TTS_MODEL,
324
+ "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
325
+ }
326
+
327
+ try:
328
+ r = requests.post(url, json=data, headers=headers)
329
+
330
+ r.raise_for_status()
331
+
332
+ # Save the streaming content to a file
333
+ with open(file_path, "wb") as f:
334
+ for chunk in r.iter_content(chunk_size=8192):
335
+ f.write(chunk)
336
+
337
+ with open(file_body_path, "w") as f:
338
+ json.dump(json.loads(body.decode("utf-8")), f)
339
+
340
+ # Return the saved file
341
+ return FileResponse(file_path)
342
+
343
+ except Exception as e:
344
+ log.exception(e)
345
+ error_detail = "Open WebUI: Server Connection Error"
346
+ if r is not None:
347
+ try:
348
+ res = r.json()
349
+ if "error" in res:
350
+ error_detail = f"External: {res['error']['message']}"
351
+ except Exception:
352
+ error_detail = f"External: {e}"
353
+
354
+ raise HTTPException(
355
+ status_code=r.status_code if r != None else 500,
356
+ detail=error_detail,
357
+ )
358
+
359
+ elif app.state.config.TTS_ENGINE == "azure":
360
+ payload = None
361
+ try:
362
+ payload = json.loads(body.decode("utf-8"))
363
+ except Exception as e:
364
+ log.exception(e)
365
+ raise HTTPException(status_code=400, detail="Invalid JSON payload")
366
+
367
+ region = app.state.config.TTS_AZURE_SPEECH_REGION
368
+ language = app.state.config.TTS_VOICE
369
+ locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
370
+ output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
371
+ url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
372
+
373
+ headers = {
374
+ "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
375
+ "Content-Type": "application/ssml+xml",
376
+ "X-Microsoft-OutputFormat": output_format,
377
+ }
378
+
379
+ data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
380
+ <voice name="{language}">{payload["input"]}</voice>
381
+ </speak>"""
382
+
383
+ response = requests.post(url, headers=headers, data=data)
384
+
385
+ if response.status_code == 200:
386
+ with open(file_path, "wb") as f:
387
+ f.write(response.content)
388
+ return FileResponse(file_path)
389
+ else:
390
+ log.error(f"Error synthesizing speech - {response.reason}")
391
+ raise HTTPException(
392
+ status_code=500, detail=f"Error synthesizing speech - {response.reason}"
393
+ )
394
+
395
+
396
+ def transcribe(file_path):
397
+ print("transcribe", file_path)
398
+ filename = os.path.basename(file_path)
399
+ file_dir = os.path.dirname(file_path)
400
+ id = filename.split(".")[0]
401
+
402
+ if app.state.config.STT_ENGINE == "":
403
+ if app.state.faster_whisper_model is None:
404
+ set_faster_whisper_model(app.state.config.WHISPER_MODEL)
405
+
406
+ model = app.state.faster_whisper_model
407
+ segments, info = model.transcribe(file_path, beam_size=5)
408
+ log.info(
409
+ "Detected language '%s' with probability %f"
410
+ % (info.language, info.language_probability)
411
+ )
412
+
413
+ transcript = "".join([segment.text for segment in list(segments)])
414
+ data = {"text": transcript.strip()}
415
+
416
+ # save the transcript to a json file
417
+ transcript_file = f"{file_dir}/{id}.json"
418
+ with open(transcript_file, "w") as f:
419
+ json.dump(data, f)
420
+
421
+ log.debug(data)
422
+ return data
423
+ elif app.state.config.STT_ENGINE == "openai":
424
+ if is_mp4_audio(file_path):
425
+ print("is_mp4_audio")
426
+ os.rename(file_path, file_path.replace(".wav", ".mp4"))
427
+ # Convert MP4 audio file to WAV format
428
+ convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
429
+
430
+ headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
431
+
432
+ files = {"file": (filename, open(file_path, "rb"))}
433
+ data = {"model": app.state.config.STT_MODEL}
434
+
435
+ log.debug(files, data)
436
+
437
+ r = None
438
+ try:
439
+ r = requests.post(
440
+ url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
441
+ headers=headers,
442
+ files=files,
443
+ data=data,
444
+ )
445
+
446
+ r.raise_for_status()
447
+
448
+ data = r.json()
449
+
450
+ # save the transcript to a json file
451
+ transcript_file = f"{file_dir}/{id}.json"
452
+ with open(transcript_file, "w") as f:
453
+ json.dump(data, f)
454
+
455
+ print(data)
456
+ return data
457
+ except Exception as e:
458
+ log.exception(e)
459
+ error_detail = "Open WebUI: Server Connection Error"
460
+ if r is not None:
461
+ try:
462
+ res = r.json()
463
+ if "error" in res:
464
+ error_detail = f"External: {res['error']['message']}"
465
+ except Exception:
466
+ error_detail = f"External: {e}"
467
+
468
+ raise Exception(error_detail)
469
+
470
+
471
+ @app.post("/transcriptions")
472
+ def transcription(
473
+ file: UploadFile = File(...),
474
+ user=Depends(get_verified_user),
475
+ ):
476
+ log.info(f"file.content_type: {file.content_type}")
477
+
478
+ if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
479
+ raise HTTPException(
480
+ status_code=status.HTTP_400_BAD_REQUEST,
481
+ detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
482
+ )
483
+
484
+ try:
485
+ ext = file.filename.split(".")[-1]
486
+ id = uuid.uuid4()
487
+
488
+ filename = f"{id}.{ext}"
489
+ contents = file.file.read()
490
+
491
+ file_dir = f"{CACHE_DIR}/audio/transcriptions"
492
+ os.makedirs(file_dir, exist_ok=True)
493
+ file_path = f"{file_dir}/{filename}"
494
+
495
+ with open(file_path, "wb") as f:
496
+ f.write(contents)
497
+
498
+ try:
499
+ if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB
500
+ log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
501
+ audio = AudioSegment.from_file(file_path)
502
+ audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
503
+ compressed_path = f"{file_dir}/{id}_compressed.opus"
504
+ audio.export(compressed_path, format="opus", bitrate="32k")
505
+ log.debug(f"Compressed audio to {compressed_path}")
506
+ file_path = compressed_path
507
+
508
+ if (
509
+ os.path.getsize(file_path) > MAX_FILE_SIZE
510
+ ): # Still larger than 25MB after compression
511
+ log.debug(
512
+ f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
513
+ )
514
+ raise HTTPException(
515
+ status_code=status.HTTP_400_BAD_REQUEST,
516
+ detail=ERROR_MESSAGES.FILE_TOO_LARGE(
517
+ size=f"{MAX_FILE_SIZE_MB}MB"
518
+ ),
519
+ )
520
+
521
+ data = transcribe(file_path)
522
+ else:
523
+ data = transcribe(file_path)
524
+
525
+ return data
526
+ except Exception as e:
527
+ log.exception(e)
528
+ raise HTTPException(
529
+ status_code=status.HTTP_400_BAD_REQUEST,
530
+ detail=ERROR_MESSAGES.DEFAULT(e),
531
+ )
532
+
533
+ except Exception as e:
534
+ log.exception(e)
535
+
536
+ raise HTTPException(
537
+ status_code=status.HTTP_400_BAD_REQUEST,
538
+ detail=ERROR_MESSAGES.DEFAULT(e),
539
+ )
540
+
541
+
542
+ def get_available_models() -> list[dict]:
543
+ if app.state.config.TTS_ENGINE == "openai":
544
+ return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
545
+ elif app.state.config.TTS_ENGINE == "elevenlabs":
546
+ headers = {
547
+ "xi-api-key": app.state.config.TTS_API_KEY,
548
+ "Content-Type": "application/json",
549
+ }
550
+
551
+ try:
552
+ response = requests.get(
553
+ "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
554
+ )
555
+ response.raise_for_status()
556
+ models = response.json()
557
+ return [
558
+ {"name": model["name"], "id": model["model_id"]} for model in models
559
+ ]
560
+ except requests.RequestException as e:
561
+ log.error(f"Error fetching voices: {str(e)}")
562
+ return []
563
+
564
+
565
+ @app.get("/models")
566
+ async def get_models(user=Depends(get_verified_user)):
567
+ return {"models": get_available_models()}
568
+
569
+
570
+ def get_available_voices() -> dict:
571
+ """Returns {voice_id: voice_name} dict"""
572
+ ret = {}
573
+ if app.state.config.TTS_ENGINE == "openai":
574
+ ret = {
575
+ "alloy": "alloy",
576
+ "echo": "echo",
577
+ "fable": "fable",
578
+ "onyx": "onyx",
579
+ "nova": "nova",
580
+ "shimmer": "shimmer",
581
+ }
582
+ elif app.state.config.TTS_ENGINE == "elevenlabs":
583
+ try:
584
+ ret = get_elevenlabs_voices()
585
+ except Exception:
586
+ # Avoided @lru_cache with exception
587
+ pass
588
+ elif app.state.config.TTS_ENGINE == "azure":
589
+ try:
590
+ region = app.state.config.TTS_AZURE_SPEECH_REGION
591
+ url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
592
+ headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
593
+
594
+ response = requests.get(url, headers=headers)
595
+ response.raise_for_status()
596
+ voices = response.json()
597
+ for voice in voices:
598
+ ret[voice["ShortName"]] = (
599
+ f"{voice['DisplayName']} ({voice['ShortName']})"
600
+ )
601
+ except requests.RequestException as e:
602
+ log.error(f"Error fetching voices: {str(e)}")
603
+
604
+ return ret
605
+
606
+
607
+ @lru_cache
608
+ def get_elevenlabs_voices() -> dict:
609
+ """
610
+ Note, set the following in your .env file to use Elevenlabs:
611
+ AUDIO_TTS_ENGINE=elevenlabs
612
+ AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
613
+ AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
614
+ AUDIO_TTS_MODEL=eleven_multilingual_v2
615
+ """
616
+ headers = {
617
+ "xi-api-key": app.state.config.TTS_API_KEY,
618
+ "Content-Type": "application/json",
619
+ }
620
+ try:
621
+ # TODO: Add retries
622
+ response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
623
+ response.raise_for_status()
624
+ voices_data = response.json()
625
+
626
+ voices = {}
627
+ for voice in voices_data.get("voices", []):
628
+ voices[voice["voice_id"]] = voice["name"]
629
+ except requests.RequestException as e:
630
+ # Avoid @lru_cache with exception
631
+ log.error(f"Error fetching voices: {str(e)}")
632
+ raise RuntimeError(f"Error fetching voices: {str(e)}")
633
+
634
+ return voices
635
+
636
+
637
+ @app.get("/voices")
638
+ async def get_voices(user=Depends(get_verified_user)):
639
+ return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}
backend/open_webui/apps/images/main.py CHANGED
@@ -1,597 +1,597 @@
1
- import asyncio
2
- import base64
3
- import json
4
- import logging
5
- import mimetypes
6
- import re
7
- import uuid
8
- from pathlib import Path
9
- from typing import Optional
10
-
11
- import requests
12
- from open_webui.apps.images.utils.comfyui import (
13
- ComfyUIGenerateImageForm,
14
- ComfyUIWorkflow,
15
- comfyui_generate_image,
16
- )
17
- from open_webui.config import (
18
- AUTOMATIC1111_API_AUTH,
19
- AUTOMATIC1111_BASE_URL,
20
- AUTOMATIC1111_CFG_SCALE,
21
- AUTOMATIC1111_SAMPLER,
22
- AUTOMATIC1111_SCHEDULER,
23
- CACHE_DIR,
24
- COMFYUI_BASE_URL,
25
- COMFYUI_WORKFLOW,
26
- COMFYUI_WORKFLOW_NODES,
27
- CORS_ALLOW_ORIGIN,
28
- ENABLE_IMAGE_GENERATION,
29
- IMAGE_GENERATION_ENGINE,
30
- IMAGE_GENERATION_MODEL,
31
- IMAGE_SIZE,
32
- IMAGE_STEPS,
33
- IMAGES_OPENAI_API_BASE_URL,
34
- IMAGES_OPENAI_API_KEY,
35
- AppConfig,
36
- )
37
- from open_webui.constants import ERROR_MESSAGES
38
- from open_webui.env import SRC_LOG_LEVELS
39
- from fastapi import Depends, FastAPI, HTTPException, Request
40
- from fastapi.middleware.cors import CORSMiddleware
41
- from pydantic import BaseModel
42
- from open_webui.utils.utils import get_admin_user, get_verified_user
43
-
44
- log = logging.getLogger(__name__)
45
- log.setLevel(SRC_LOG_LEVELS["IMAGES"])
46
-
47
- IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
48
- IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
49
-
50
- app = FastAPI()
51
- app.add_middleware(
52
- CORSMiddleware,
53
- allow_origins=CORS_ALLOW_ORIGIN,
54
- allow_credentials=True,
55
- allow_methods=["*"],
56
- allow_headers=["*"],
57
- )
58
-
59
- app.state.config = AppConfig()
60
-
61
- app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
62
- app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
63
-
64
- app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
65
- app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
66
-
67
- app.state.config.MODEL = IMAGE_GENERATION_MODEL
68
-
69
- app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
70
- app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
71
- app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
72
- app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
73
- app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
74
- app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
75
- app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
76
- app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
77
-
78
- app.state.config.IMAGE_SIZE = IMAGE_SIZE
79
- app.state.config.IMAGE_STEPS = IMAGE_STEPS
80
-
81
-
82
- @app.get("/config")
83
- async def get_config(request: Request, user=Depends(get_admin_user)):
84
- return {
85
- "enabled": app.state.config.ENABLED,
86
- "engine": app.state.config.ENGINE,
87
- "openai": {
88
- "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
89
- "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
90
- },
91
- "automatic1111": {
92
- "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
93
- "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
94
- "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
95
- "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
96
- "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
97
- },
98
- "comfyui": {
99
- "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
100
- "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
101
- "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
102
- },
103
- }
104
-
105
-
106
- class OpenAIConfigForm(BaseModel):
107
- OPENAI_API_BASE_URL: str
108
- OPENAI_API_KEY: str
109
-
110
-
111
- class Automatic1111ConfigForm(BaseModel):
112
- AUTOMATIC1111_BASE_URL: str
113
- AUTOMATIC1111_API_AUTH: str
114
- AUTOMATIC1111_CFG_SCALE: Optional[str]
115
- AUTOMATIC1111_SAMPLER: Optional[str]
116
- AUTOMATIC1111_SCHEDULER: Optional[str]
117
-
118
-
119
- class ComfyUIConfigForm(BaseModel):
120
- COMFYUI_BASE_URL: str
121
- COMFYUI_WORKFLOW: str
122
- COMFYUI_WORKFLOW_NODES: list[dict]
123
-
124
-
125
- class ConfigForm(BaseModel):
126
- enabled: bool
127
- engine: str
128
- openai: OpenAIConfigForm
129
- automatic1111: Automatic1111ConfigForm
130
- comfyui: ComfyUIConfigForm
131
-
132
-
133
- @app.post("/config/update")
134
- async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
135
- app.state.config.ENGINE = form_data.engine
136
- app.state.config.ENABLED = form_data.enabled
137
-
138
- app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
139
- app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
140
-
141
- app.state.config.AUTOMATIC1111_BASE_URL = (
142
- form_data.automatic1111.AUTOMATIC1111_BASE_URL
143
- )
144
- app.state.config.AUTOMATIC1111_API_AUTH = (
145
- form_data.automatic1111.AUTOMATIC1111_API_AUTH
146
- )
147
-
148
- app.state.config.AUTOMATIC1111_CFG_SCALE = (
149
- float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
150
- if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
151
- else None
152
- )
153
- app.state.config.AUTOMATIC1111_SAMPLER = (
154
- form_data.automatic1111.AUTOMATIC1111_SAMPLER
155
- if form_data.automatic1111.AUTOMATIC1111_SAMPLER
156
- else None
157
- )
158
- app.state.config.AUTOMATIC1111_SCHEDULER = (
159
- form_data.automatic1111.AUTOMATIC1111_SCHEDULER
160
- if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
161
- else None
162
- )
163
-
164
- app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
165
- app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
166
- app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
167
-
168
- return {
169
- "enabled": app.state.config.ENABLED,
170
- "engine": app.state.config.ENGINE,
171
- "openai": {
172
- "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
173
- "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
174
- },
175
- "automatic1111": {
176
- "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
177
- "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
178
- "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
179
- "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
180
- "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
181
- },
182
- "comfyui": {
183
- "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
184
- "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
185
- "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
186
- },
187
- }
188
-
189
-
190
- def get_automatic1111_api_auth():
191
- if app.state.config.AUTOMATIC1111_API_AUTH is None:
192
- return ""
193
- else:
194
- auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
195
- auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
196
- auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
197
- return f"Basic {auth1111_base64_encoded_string}"
198
-
199
-
200
- @app.get("/config/url/verify")
201
- async def verify_url(user=Depends(get_admin_user)):
202
- if app.state.config.ENGINE == "automatic1111":
203
- try:
204
- r = requests.get(
205
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
206
- headers={"authorization": get_automatic1111_api_auth()},
207
- )
208
- r.raise_for_status()
209
- return True
210
- except Exception:
211
- app.state.config.ENABLED = False
212
- raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
213
- elif app.state.config.ENGINE == "comfyui":
214
- try:
215
- r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
216
- r.raise_for_status()
217
- return True
218
- except Exception:
219
- app.state.config.ENABLED = False
220
- raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
221
- else:
222
- return True
223
-
224
-
225
- def set_image_model(model: str):
226
- log.info(f"Setting image model to {model}")
227
- app.state.config.MODEL = model
228
- if app.state.config.ENGINE in ["", "automatic1111"]:
229
- api_auth = get_automatic1111_api_auth()
230
- r = requests.get(
231
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
232
- headers={"authorization": api_auth},
233
- )
234
- options = r.json()
235
- if model != options["sd_model_checkpoint"]:
236
- options["sd_model_checkpoint"] = model
237
- r = requests.post(
238
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
239
- json=options,
240
- headers={"authorization": api_auth},
241
- )
242
- return app.state.config.MODEL
243
-
244
-
245
- def get_image_model():
246
- if app.state.config.ENGINE == "openai":
247
- return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
248
- elif app.state.config.ENGINE == "comfyui":
249
- return app.state.config.MODEL if app.state.config.MODEL else ""
250
- elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
251
- try:
252
- r = requests.get(
253
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
254
- headers={"authorization": get_automatic1111_api_auth()},
255
- )
256
- options = r.json()
257
- return options["sd_model_checkpoint"]
258
- except Exception as e:
259
- app.state.config.ENABLED = False
260
- raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
261
-
262
-
263
- class ImageConfigForm(BaseModel):
264
- MODEL: str
265
- IMAGE_SIZE: str
266
- IMAGE_STEPS: int
267
-
268
-
269
- @app.get("/image/config")
270
- async def get_image_config(user=Depends(get_admin_user)):
271
- return {
272
- "MODEL": app.state.config.MODEL,
273
- "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
274
- "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
275
- }
276
-
277
-
278
- @app.post("/image/config/update")
279
- async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
280
-
281
- set_image_model(form_data.MODEL)
282
-
283
- pattern = r"^\d+x\d+$"
284
- if re.match(pattern, form_data.IMAGE_SIZE):
285
- app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
286
- else:
287
- raise HTTPException(
288
- status_code=400,
289
- detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
290
- )
291
-
292
- if form_data.IMAGE_STEPS >= 0:
293
- app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
294
- else:
295
- raise HTTPException(
296
- status_code=400,
297
- detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
298
- )
299
-
300
- return {
301
- "MODEL": app.state.config.MODEL,
302
- "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
303
- "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
304
- }
305
-
306
-
307
- @app.get("/models")
308
- def get_models(user=Depends(get_verified_user)):
309
- try:
310
- if app.state.config.ENGINE == "openai":
311
- return [
312
- {"id": "dall-e-2", "name": "DALL·E 2"},
313
- {"id": "dall-e-3", "name": "DALL·E 3"},
314
- ]
315
- elif app.state.config.ENGINE == "comfyui":
316
- # TODO - get models from comfyui
317
- r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
318
- info = r.json()
319
-
320
- workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
321
- model_node_id = None
322
-
323
- for node in app.state.config.COMFYUI_WORKFLOW_NODES:
324
- if node["type"] == "model":
325
- if node["node_ids"]:
326
- model_node_id = node["node_ids"][0]
327
- break
328
-
329
- if model_node_id:
330
- model_list_key = None
331
-
332
- print(workflow[model_node_id]["class_type"])
333
- for key in info[workflow[model_node_id]["class_type"]]["input"][
334
- "required"
335
- ]:
336
- if "_name" in key:
337
- model_list_key = key
338
- break
339
-
340
- if model_list_key:
341
- return list(
342
- map(
343
- lambda model: {"id": model, "name": model},
344
- info[workflow[model_node_id]["class_type"]]["input"][
345
- "required"
346
- ][model_list_key][0],
347
- )
348
- )
349
- else:
350
- return list(
351
- map(
352
- lambda model: {"id": model, "name": model},
353
- info["CheckpointLoaderSimple"]["input"]["required"][
354
- "ckpt_name"
355
- ][0],
356
- )
357
- )
358
- elif (
359
- app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
360
- ):
361
- r = requests.get(
362
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
363
- headers={"authorization": get_automatic1111_api_auth()},
364
- )
365
- models = r.json()
366
- return list(
367
- map(
368
- lambda model: {"id": model["title"], "name": model["model_name"]},
369
- models,
370
- )
371
- )
372
- except Exception as e:
373
- app.state.config.ENABLED = False
374
- raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
375
-
376
-
377
- class GenerateImageForm(BaseModel):
378
- model: Optional[str] = None
379
- prompt: str
380
- size: Optional[str] = None
381
- n: int = 1
382
- negative_prompt: Optional[str] = None
383
-
384
-
385
- def save_b64_image(b64_str):
386
- try:
387
- image_id = str(uuid.uuid4())
388
-
389
- if "," in b64_str:
390
- header, encoded = b64_str.split(",", 1)
391
- mime_type = header.split(";")[0]
392
-
393
- img_data = base64.b64decode(encoded)
394
- image_format = mimetypes.guess_extension(mime_type)
395
-
396
- image_filename = f"{image_id}{image_format}"
397
- file_path = IMAGE_CACHE_DIR / f"{image_filename}"
398
- with open(file_path, "wb") as f:
399
- f.write(img_data)
400
- return image_filename
401
- else:
402
- image_filename = f"{image_id}.png"
403
- file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
404
-
405
- img_data = base64.b64decode(b64_str)
406
-
407
- # Write the image data to a file
408
- with open(file_path, "wb") as f:
409
- f.write(img_data)
410
- return image_filename
411
-
412
- except Exception as e:
413
- log.exception(f"Error saving image: {e}")
414
- return None
415
-
416
-
417
- def save_url_image(url):
418
- image_id = str(uuid.uuid4())
419
- try:
420
- r = requests.get(url)
421
- r.raise_for_status()
422
- if r.headers["content-type"].split("/")[0] == "image":
423
- mime_type = r.headers["content-type"]
424
- image_format = mimetypes.guess_extension(mime_type)
425
-
426
- if not image_format:
427
- raise ValueError("Could not determine image type from MIME type")
428
-
429
- image_filename = f"{image_id}{image_format}"
430
-
431
- file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
432
- with open(file_path, "wb") as image_file:
433
- for chunk in r.iter_content(chunk_size=8192):
434
- image_file.write(chunk)
435
- return image_filename
436
- else:
437
- log.error("Url does not point to an image.")
438
- return None
439
-
440
- except Exception as e:
441
- log.exception(f"Error saving image: {e}")
442
- return None
443
-
444
-
445
- @app.post("/generations")
446
- async def image_generations(
447
- form_data: GenerateImageForm,
448
- user=Depends(get_verified_user),
449
- ):
450
- width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
451
-
452
- r = None
453
- try:
454
- if app.state.config.ENGINE == "openai":
455
- headers = {}
456
- headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
457
- headers["Content-Type"] = "application/json"
458
-
459
- data = {
460
- "model": (
461
- app.state.config.MODEL
462
- if app.state.config.MODEL != ""
463
- else "dall-e-2"
464
- ),
465
- "prompt": form_data.prompt,
466
- "n": form_data.n,
467
- "size": (
468
- form_data.size if form_data.size else app.state.config.IMAGE_SIZE
469
- ),
470
- "response_format": "b64_json",
471
- }
472
-
473
- # Use asyncio.to_thread for the requests.post call
474
- r = await asyncio.to_thread(
475
- requests.post,
476
- url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
477
- json=data,
478
- headers=headers,
479
- )
480
-
481
- r.raise_for_status()
482
- res = r.json()
483
-
484
- images = []
485
-
486
- for image in res["data"]:
487
- image_filename = save_b64_image(image["b64_json"])
488
- images.append({"url": f"/cache/image/generations/{image_filename}"})
489
- file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
490
-
491
- with open(file_body_path, "w") as f:
492
- json.dump(data, f)
493
-
494
- return images
495
-
496
- elif app.state.config.ENGINE == "comfyui":
497
- data = {
498
- "prompt": form_data.prompt,
499
- "width": width,
500
- "height": height,
501
- "n": form_data.n,
502
- }
503
-
504
- if app.state.config.IMAGE_STEPS is not None:
505
- data["steps"] = app.state.config.IMAGE_STEPS
506
-
507
- if form_data.negative_prompt is not None:
508
- data["negative_prompt"] = form_data.negative_prompt
509
-
510
- form_data = ComfyUIGenerateImageForm(
511
- **{
512
- "workflow": ComfyUIWorkflow(
513
- **{
514
- "workflow": app.state.config.COMFYUI_WORKFLOW,
515
- "nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
516
- }
517
- ),
518
- **data,
519
- }
520
- )
521
- res = await comfyui_generate_image(
522
- app.state.config.MODEL,
523
- form_data,
524
- user.id,
525
- app.state.config.COMFYUI_BASE_URL,
526
- )
527
- log.debug(f"res: {res}")
528
-
529
- images = []
530
-
531
- for image in res["data"]:
532
- image_filename = save_url_image(image["url"])
533
- images.append({"url": f"/cache/image/generations/{image_filename}"})
534
- file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
535
-
536
- with open(file_body_path, "w") as f:
537
- json.dump(form_data.model_dump(exclude_none=True), f)
538
-
539
- log.debug(f"images: {images}")
540
- return images
541
- elif (
542
- app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
543
- ):
544
- if form_data.model:
545
- set_image_model(form_data.model)
546
-
547
- data = {
548
- "prompt": form_data.prompt,
549
- "batch_size": form_data.n,
550
- "width": width,
551
- "height": height,
552
- }
553
-
554
- if app.state.config.IMAGE_STEPS is not None:
555
- data["steps"] = app.state.config.IMAGE_STEPS
556
-
557
- if form_data.negative_prompt is not None:
558
- data["negative_prompt"] = form_data.negative_prompt
559
-
560
- if app.state.config.AUTOMATIC1111_CFG_SCALE:
561
- data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
562
-
563
- if app.state.config.AUTOMATIC1111_SAMPLER:
564
- data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
565
-
566
- if app.state.config.AUTOMATIC1111_SCHEDULER:
567
- data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
568
-
569
- # Use asyncio.to_thread for the requests.post call
570
- r = await asyncio.to_thread(
571
- requests.post,
572
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
573
- json=data,
574
- headers={"authorization": get_automatic1111_api_auth()},
575
- )
576
-
577
- res = r.json()
578
- log.debug(f"res: {res}")
579
-
580
- images = []
581
-
582
- for image in res["images"]:
583
- image_filename = save_b64_image(image)
584
- images.append({"url": f"/cache/image/generations/{image_filename}"})
585
- file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
586
-
587
- with open(file_body_path, "w") as f:
588
- json.dump({**data, "info": res["info"]}, f)
589
-
590
- return images
591
- except Exception as e:
592
- error = e
593
- if r != None:
594
- data = r.json()
595
- if "error" in data:
596
- error = data["error"]["message"]
597
- raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
 
1
+ import asyncio
2
+ import base64
3
+ import json
4
+ import logging
5
+ import mimetypes
6
+ import re
7
+ import uuid
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ import requests
12
+ from open_webui.apps.images.utils.comfyui import (
13
+ ComfyUIGenerateImageForm,
14
+ ComfyUIWorkflow,
15
+ comfyui_generate_image,
16
+ )
17
+ from open_webui.config import (
18
+ AUTOMATIC1111_API_AUTH,
19
+ AUTOMATIC1111_BASE_URL,
20
+ AUTOMATIC1111_CFG_SCALE,
21
+ AUTOMATIC1111_SAMPLER,
22
+ AUTOMATIC1111_SCHEDULER,
23
+ CACHE_DIR,
24
+ COMFYUI_BASE_URL,
25
+ COMFYUI_WORKFLOW,
26
+ COMFYUI_WORKFLOW_NODES,
27
+ CORS_ALLOW_ORIGIN,
28
+ ENABLE_IMAGE_GENERATION,
29
+ IMAGE_GENERATION_ENGINE,
30
+ IMAGE_GENERATION_MODEL,
31
+ IMAGE_SIZE,
32
+ IMAGE_STEPS,
33
+ IMAGES_OPENAI_API_BASE_URL,
34
+ IMAGES_OPENAI_API_KEY,
35
+ AppConfig,
36
+ )
37
+ from open_webui.constants import ERROR_MESSAGES
38
+ from open_webui.env import SRC_LOG_LEVELS
39
+ from fastapi import Depends, FastAPI, HTTPException, Request
40
+ from fastapi.middleware.cors import CORSMiddleware
41
+ from pydantic import BaseModel
42
+ from open_webui.utils.utils import get_admin_user, get_verified_user
43
+
44
+ log = logging.getLogger(__name__)
45
+ log.setLevel(SRC_LOG_LEVELS["IMAGES"])
46
+
47
+ IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
48
+ IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
49
+
50
+ app = FastAPI()
51
+ app.add_middleware(
52
+ CORSMiddleware,
53
+ allow_origins=CORS_ALLOW_ORIGIN,
54
+ allow_credentials=True,
55
+ allow_methods=["*"],
56
+ allow_headers=["*"],
57
+ )
58
+
59
+ app.state.config = AppConfig()
60
+
61
+ app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
62
+ app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
63
+
64
+ app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
65
+ app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
66
+
67
+ app.state.config.MODEL = IMAGE_GENERATION_MODEL
68
+
69
+ app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
70
+ app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
71
+ app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
72
+ app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
73
+ app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
74
+ app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
75
+ app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
76
+ app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
77
+
78
+ app.state.config.IMAGE_SIZE = IMAGE_SIZE
79
+ app.state.config.IMAGE_STEPS = IMAGE_STEPS
80
+
81
+
82
+ @app.get("/config")
83
+ async def get_config(request: Request, user=Depends(get_admin_user)):
84
+ return {
85
+ "enabled": app.state.config.ENABLED,
86
+ "engine": app.state.config.ENGINE,
87
+ "openai": {
88
+ "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
89
+ "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
90
+ },
91
+ "automatic1111": {
92
+ "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
93
+ "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
94
+ "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
95
+ "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
96
+ "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
97
+ },
98
+ "comfyui": {
99
+ "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
100
+ "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
101
+ "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
102
+ },
103
+ }
104
+
105
+
106
+ class OpenAIConfigForm(BaseModel):
107
+ OPENAI_API_BASE_URL: str
108
+ OPENAI_API_KEY: str
109
+
110
+
111
+ class Automatic1111ConfigForm(BaseModel):
112
+ AUTOMATIC1111_BASE_URL: str
113
+ AUTOMATIC1111_API_AUTH: str
114
+ AUTOMATIC1111_CFG_SCALE: Optional[str]
115
+ AUTOMATIC1111_SAMPLER: Optional[str]
116
+ AUTOMATIC1111_SCHEDULER: Optional[str]
117
+
118
+
119
+ class ComfyUIConfigForm(BaseModel):
120
+ COMFYUI_BASE_URL: str
121
+ COMFYUI_WORKFLOW: str
122
+ COMFYUI_WORKFLOW_NODES: list[dict]
123
+
124
+
125
+ class ConfigForm(BaseModel):
126
+ enabled: bool
127
+ engine: str
128
+ openai: OpenAIConfigForm
129
+ automatic1111: Automatic1111ConfigForm
130
+ comfyui: ComfyUIConfigForm
131
+
132
+
133
+ @app.post("/config/update")
134
+ async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
135
+ app.state.config.ENGINE = form_data.engine
136
+ app.state.config.ENABLED = form_data.enabled
137
+
138
+ app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
139
+ app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
140
+
141
+ app.state.config.AUTOMATIC1111_BASE_URL = (
142
+ form_data.automatic1111.AUTOMATIC1111_BASE_URL
143
+ )
144
+ app.state.config.AUTOMATIC1111_API_AUTH = (
145
+ form_data.automatic1111.AUTOMATIC1111_API_AUTH
146
+ )
147
+
148
+ app.state.config.AUTOMATIC1111_CFG_SCALE = (
149
+ float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
150
+ if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
151
+ else None
152
+ )
153
+ app.state.config.AUTOMATIC1111_SAMPLER = (
154
+ form_data.automatic1111.AUTOMATIC1111_SAMPLER
155
+ if form_data.automatic1111.AUTOMATIC1111_SAMPLER
156
+ else None
157
+ )
158
+ app.state.config.AUTOMATIC1111_SCHEDULER = (
159
+ form_data.automatic1111.AUTOMATIC1111_SCHEDULER
160
+ if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
161
+ else None
162
+ )
163
+
164
+ app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
165
+ app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
166
+ app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
167
+
168
+ return {
169
+ "enabled": app.state.config.ENABLED,
170
+ "engine": app.state.config.ENGINE,
171
+ "openai": {
172
+ "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
173
+ "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
174
+ },
175
+ "automatic1111": {
176
+ "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
177
+ "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
178
+ "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
179
+ "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
180
+ "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
181
+ },
182
+ "comfyui": {
183
+ "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
184
+ "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
185
+ "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
186
+ },
187
+ }
188
+
189
+
190
+ def get_automatic1111_api_auth():
191
+ if app.state.config.AUTOMATIC1111_API_AUTH is None:
192
+ return ""
193
+ else:
194
+ auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
195
+ auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
196
+ auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
197
+ return f"Basic {auth1111_base64_encoded_string}"
198
+
199
+
200
+ @app.get("/config/url/verify")
201
+ async def verify_url(user=Depends(get_admin_user)):
202
+ if app.state.config.ENGINE == "automatic1111":
203
+ try:
204
+ r = requests.get(
205
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
206
+ headers={"authorization": get_automatic1111_api_auth()},
207
+ )
208
+ r.raise_for_status()
209
+ return True
210
+ except Exception:
211
+ app.state.config.ENABLED = False
212
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
213
+ elif app.state.config.ENGINE == "comfyui":
214
+ try:
215
+ r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
216
+ r.raise_for_status()
217
+ return True
218
+ except Exception:
219
+ app.state.config.ENABLED = False
220
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
221
+ else:
222
+ return True
223
+
224
+
225
+ def set_image_model(model: str):
226
+ log.info(f"Setting image model to {model}")
227
+ app.state.config.MODEL = model
228
+ if app.state.config.ENGINE in ["", "automatic1111"]:
229
+ api_auth = get_automatic1111_api_auth()
230
+ r = requests.get(
231
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
232
+ headers={"authorization": api_auth},
233
+ )
234
+ options = r.json()
235
+ if model != options["sd_model_checkpoint"]:
236
+ options["sd_model_checkpoint"] = model
237
+ r = requests.post(
238
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
239
+ json=options,
240
+ headers={"authorization": api_auth},
241
+ )
242
+ return app.state.config.MODEL
243
+
244
+
245
+ def get_image_model():
246
+ if app.state.config.ENGINE == "openai":
247
+ return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
248
+ elif app.state.config.ENGINE == "comfyui":
249
+ return app.state.config.MODEL if app.state.config.MODEL else ""
250
+ elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
251
+ try:
252
+ r = requests.get(
253
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
254
+ headers={"authorization": get_automatic1111_api_auth()},
255
+ )
256
+ options = r.json()
257
+ return options["sd_model_checkpoint"]
258
+ except Exception as e:
259
+ app.state.config.ENABLED = False
260
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
261
+
262
+
263
+ class ImageConfigForm(BaseModel):
264
+ MODEL: str
265
+ IMAGE_SIZE: str
266
+ IMAGE_STEPS: int
267
+
268
+
269
+ @app.get("/image/config")
270
+ async def get_image_config(user=Depends(get_admin_user)):
271
+ return {
272
+ "MODEL": app.state.config.MODEL,
273
+ "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
274
+ "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
275
+ }
276
+
277
+
278
+ @app.post("/image/config/update")
279
+ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
280
+
281
+ set_image_model(form_data.MODEL)
282
+
283
+ pattern = r"^\d+x\d+$"
284
+ if re.match(pattern, form_data.IMAGE_SIZE):
285
+ app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
286
+ else:
287
+ raise HTTPException(
288
+ status_code=400,
289
+ detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
290
+ )
291
+
292
+ if form_data.IMAGE_STEPS >= 0:
293
+ app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
294
+ else:
295
+ raise HTTPException(
296
+ status_code=400,
297
+ detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
298
+ )
299
+
300
+ return {
301
+ "MODEL": app.state.config.MODEL,
302
+ "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
303
+ "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
304
+ }
305
+
306
+
307
+ @app.get("/models")
308
+ def get_models(user=Depends(get_verified_user)):
309
+ try:
310
+ if app.state.config.ENGINE == "openai":
311
+ return [
312
+ {"id": "dall-e-2", "name": "DALL·E 2"},
313
+ {"id": "dall-e-3", "name": "DALL·E 3"},
314
+ ]
315
+ elif app.state.config.ENGINE == "comfyui":
316
+ # TODO - get models from comfyui
317
+ r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
318
+ info = r.json()
319
+
320
+ workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
321
+ model_node_id = None
322
+
323
+ for node in app.state.config.COMFYUI_WORKFLOW_NODES:
324
+ if node["type"] == "model":
325
+ if node["node_ids"]:
326
+ model_node_id = node["node_ids"][0]
327
+ break
328
+
329
+ if model_node_id:
330
+ model_list_key = None
331
+
332
+ print(workflow[model_node_id]["class_type"])
333
+ for key in info[workflow[model_node_id]["class_type"]]["input"][
334
+ "required"
335
+ ]:
336
+ if "_name" in key:
337
+ model_list_key = key
338
+ break
339
+
340
+ if model_list_key:
341
+ return list(
342
+ map(
343
+ lambda model: {"id": model, "name": model},
344
+ info[workflow[model_node_id]["class_type"]]["input"][
345
+ "required"
346
+ ][model_list_key][0],
347
+ )
348
+ )
349
+ else:
350
+ return list(
351
+ map(
352
+ lambda model: {"id": model, "name": model},
353
+ info["CheckpointLoaderSimple"]["input"]["required"][
354
+ "ckpt_name"
355
+ ][0],
356
+ )
357
+ )
358
+ elif (
359
+ app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
360
+ ):
361
+ r = requests.get(
362
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
363
+ headers={"authorization": get_automatic1111_api_auth()},
364
+ )
365
+ models = r.json()
366
+ return list(
367
+ map(
368
+ lambda model: {"id": model["title"], "name": model["model_name"]},
369
+ models,
370
+ )
371
+ )
372
+ except Exception as e:
373
+ app.state.config.ENABLED = False
374
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
375
+
376
+
377
+ class GenerateImageForm(BaseModel):
378
+ model: Optional[str] = None
379
+ prompt: str
380
+ size: Optional[str] = None
381
+ n: int = 1
382
+ negative_prompt: Optional[str] = None
383
+
384
+
385
+ def save_b64_image(b64_str):
386
+ try:
387
+ image_id = str(uuid.uuid4())
388
+
389
+ if "," in b64_str:
390
+ header, encoded = b64_str.split(",", 1)
391
+ mime_type = header.split(";")[0]
392
+
393
+ img_data = base64.b64decode(encoded)
394
+ image_format = mimetypes.guess_extension(mime_type)
395
+
396
+ image_filename = f"{image_id}{image_format}"
397
+ file_path = IMAGE_CACHE_DIR / f"{image_filename}"
398
+ with open(file_path, "wb") as f:
399
+ f.write(img_data)
400
+ return image_filename
401
+ else:
402
+ image_filename = f"{image_id}.png"
403
+ file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
404
+
405
+ img_data = base64.b64decode(b64_str)
406
+
407
+ # Write the image data to a file
408
+ with open(file_path, "wb") as f:
409
+ f.write(img_data)
410
+ return image_filename
411
+
412
+ except Exception as e:
413
+ log.exception(f"Error saving image: {e}")
414
+ return None
415
+
416
+
417
+ def save_url_image(url):
418
+ image_id = str(uuid.uuid4())
419
+ try:
420
+ r = requests.get(url)
421
+ r.raise_for_status()
422
+ if r.headers["content-type"].split("/")[0] == "image":
423
+ mime_type = r.headers["content-type"]
424
+ image_format = mimetypes.guess_extension(mime_type)
425
+
426
+ if not image_format:
427
+ raise ValueError("Could not determine image type from MIME type")
428
+
429
+ image_filename = f"{image_id}{image_format}"
430
+
431
+ file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
432
+ with open(file_path, "wb") as image_file:
433
+ for chunk in r.iter_content(chunk_size=8192):
434
+ image_file.write(chunk)
435
+ return image_filename
436
+ else:
437
+ log.error("Url does not point to an image.")
438
+ return None
439
+
440
+ except Exception as e:
441
+ log.exception(f"Error saving image: {e}")
442
+ return None
443
+
444
+
445
+ @app.post("/generations")
446
+ async def image_generations(
447
+ form_data: GenerateImageForm,
448
+ user=Depends(get_verified_user),
449
+ ):
450
+ width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
451
+
452
+ r = None
453
+ try:
454
+ if app.state.config.ENGINE == "openai":
455
+ headers = {}
456
+ headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
457
+ headers["Content-Type"] = "application/json"
458
+
459
+ data = {
460
+ "model": (
461
+ app.state.config.MODEL
462
+ if app.state.config.MODEL != ""
463
+ else "dall-e-2"
464
+ ),
465
+ "prompt": form_data.prompt,
466
+ "n": form_data.n,
467
+ "size": (
468
+ form_data.size if form_data.size else app.state.config.IMAGE_SIZE
469
+ ),
470
+ "response_format": "b64_json",
471
+ }
472
+
473
+ # Use asyncio.to_thread for the requests.post call
474
+ r = await asyncio.to_thread(
475
+ requests.post,
476
+ url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
477
+ json=data,
478
+ headers=headers,
479
+ )
480
+
481
+ r.raise_for_status()
482
+ res = r.json()
483
+
484
+ images = []
485
+
486
+ for image in res["data"]:
487
+ image_filename = save_b64_image(image["b64_json"])
488
+ images.append({"url": f"/cache/image/generations/{image_filename}"})
489
+ file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
490
+
491
+ with open(file_body_path, "w") as f:
492
+ json.dump(data, f)
493
+
494
+ return images
495
+
496
+ elif app.state.config.ENGINE == "comfyui":
497
+ data = {
498
+ "prompt": form_data.prompt,
499
+ "width": width,
500
+ "height": height,
501
+ "n": form_data.n,
502
+ }
503
+
504
+ if app.state.config.IMAGE_STEPS is not None:
505
+ data["steps"] = app.state.config.IMAGE_STEPS
506
+
507
+ if form_data.negative_prompt is not None:
508
+ data["negative_prompt"] = form_data.negative_prompt
509
+
510
+ form_data = ComfyUIGenerateImageForm(
511
+ **{
512
+ "workflow": ComfyUIWorkflow(
513
+ **{
514
+ "workflow": app.state.config.COMFYUI_WORKFLOW,
515
+ "nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
516
+ }
517
+ ),
518
+ **data,
519
+ }
520
+ )
521
+ res = await comfyui_generate_image(
522
+ app.state.config.MODEL,
523
+ form_data,
524
+ user.id,
525
+ app.state.config.COMFYUI_BASE_URL,
526
+ )
527
+ log.debug(f"res: {res}")
528
+
529
+ images = []
530
+
531
+ for image in res["data"]:
532
+ image_filename = save_url_image(image["url"])
533
+ images.append({"url": f"/cache/image/generations/{image_filename}"})
534
+ file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
535
+
536
+ with open(file_body_path, "w") as f:
537
+ json.dump(form_data.model_dump(exclude_none=True), f)
538
+
539
+ log.debug(f"images: {images}")
540
+ return images
541
+ elif (
542
+ app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
543
+ ):
544
+ if form_data.model:
545
+ set_image_model(form_data.model)
546
+
547
+ data = {
548
+ "prompt": form_data.prompt,
549
+ "batch_size": form_data.n,
550
+ "width": width,
551
+ "height": height,
552
+ }
553
+
554
+ if app.state.config.IMAGE_STEPS is not None:
555
+ data["steps"] = app.state.config.IMAGE_STEPS
556
+
557
+ if form_data.negative_prompt is not None:
558
+ data["negative_prompt"] = form_data.negative_prompt
559
+
560
+ if app.state.config.AUTOMATIC1111_CFG_SCALE:
561
+ data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
562
+
563
+ if app.state.config.AUTOMATIC1111_SAMPLER:
564
+ data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
565
+
566
+ if app.state.config.AUTOMATIC1111_SCHEDULER:
567
+ data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
568
+
569
+ # Use asyncio.to_thread for the requests.post call
570
+ r = await asyncio.to_thread(
571
+ requests.post,
572
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
573
+ json=data,
574
+ headers={"authorization": get_automatic1111_api_auth()},
575
+ )
576
+
577
+ res = r.json()
578
+ log.debug(f"res: {res}")
579
+
580
+ images = []
581
+
582
+ for image in res["images"]:
583
+ image_filename = save_b64_image(image)
584
+ images.append({"url": f"/cache/image/generations/{image_filename}"})
585
+ file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
586
+
587
+ with open(file_body_path, "w") as f:
588
+ json.dump({**data, "info": res["info"]}, f)
589
+
590
+ return images
591
+ except Exception as e:
592
+ error = e
593
+ if r != None:
594
+ data = r.json()
595
+ if "error" in data:
596
+ error = data["error"]["message"]
597
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
backend/open_webui/apps/images/utils/comfyui.py CHANGED
@@ -1,186 +1,186 @@
1
- import asyncio
2
- import json
3
- import logging
4
- import random
5
- import urllib.parse
6
- import urllib.request
7
- from typing import Optional
8
-
9
- import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
10
- from open_webui.env import SRC_LOG_LEVELS
11
- from pydantic import BaseModel
12
-
13
- log = logging.getLogger(__name__)
14
- log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
15
-
16
- default_headers = {"User-Agent": "Mozilla/5.0"}
17
-
18
-
19
- def queue_prompt(prompt, client_id, base_url):
20
- log.info("queue_prompt")
21
- p = {"prompt": prompt, "client_id": client_id}
22
- data = json.dumps(p).encode("utf-8")
23
- log.debug(f"queue_prompt data: {data}")
24
- try:
25
- req = urllib.request.Request(
26
- f"{base_url}/prompt", data=data, headers=default_headers
27
- )
28
- response = urllib.request.urlopen(req).read()
29
- return json.loads(response)
30
- except Exception as e:
31
- log.exception(f"Error while queuing prompt: {e}")
32
- raise e
33
-
34
-
35
- def get_image(filename, subfolder, folder_type, base_url):
36
- log.info("get_image")
37
- data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
38
- url_values = urllib.parse.urlencode(data)
39
- req = urllib.request.Request(
40
- f"{base_url}/view?{url_values}", headers=default_headers
41
- )
42
- with urllib.request.urlopen(req) as response:
43
- return response.read()
44
-
45
-
46
- def get_image_url(filename, subfolder, folder_type, base_url):
47
- log.info("get_image")
48
- data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
49
- url_values = urllib.parse.urlencode(data)
50
- return f"{base_url}/view?{url_values}"
51
-
52
-
53
- def get_history(prompt_id, base_url):
54
- log.info("get_history")
55
-
56
- req = urllib.request.Request(
57
- f"{base_url}/history/{prompt_id}", headers=default_headers
58
- )
59
- with urllib.request.urlopen(req) as response:
60
- return json.loads(response.read())
61
-
62
-
63
- def get_images(ws, prompt, client_id, base_url):
64
- prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
65
- output_images = []
66
- while True:
67
- out = ws.recv()
68
- if isinstance(out, str):
69
- message = json.loads(out)
70
- if message["type"] == "executing":
71
- data = message["data"]
72
- if data["node"] is None and data["prompt_id"] == prompt_id:
73
- break # Execution is done
74
- else:
75
- continue # previews are binary data
76
-
77
- history = get_history(prompt_id, base_url)[prompt_id]
78
- for o in history["outputs"]:
79
- for node_id in history["outputs"]:
80
- node_output = history["outputs"][node_id]
81
- if "images" in node_output:
82
- for image in node_output["images"]:
83
- url = get_image_url(
84
- image["filename"], image["subfolder"], image["type"], base_url
85
- )
86
- output_images.append({"url": url})
87
- return {"data": output_images}
88
-
89
-
90
- class ComfyUINodeInput(BaseModel):
91
- type: Optional[str] = None
92
- node_ids: list[str] = []
93
- key: Optional[str] = "text"
94
- value: Optional[str] = None
95
-
96
-
97
- class ComfyUIWorkflow(BaseModel):
98
- workflow: str
99
- nodes: list[ComfyUINodeInput]
100
-
101
-
102
- class ComfyUIGenerateImageForm(BaseModel):
103
- workflow: ComfyUIWorkflow
104
-
105
- prompt: str
106
- negative_prompt: Optional[str] = None
107
- width: int
108
- height: int
109
- n: int = 1
110
-
111
- steps: Optional[int] = None
112
- seed: Optional[int] = None
113
-
114
-
115
- async def comfyui_generate_image(
116
- model: str, payload: ComfyUIGenerateImageForm, client_id, base_url
117
- ):
118
- ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
119
- workflow = json.loads(payload.workflow.workflow)
120
-
121
- for node in payload.workflow.nodes:
122
- if node.type:
123
- if node.type == "model":
124
- for node_id in node.node_ids:
125
- workflow[node_id]["inputs"][node.key] = model
126
- elif node.type == "prompt":
127
- for node_id in node.node_ids:
128
- workflow[node_id]["inputs"][
129
- node.key if node.key else "text"
130
- ] = payload.prompt
131
- elif node.type == "negative_prompt":
132
- for node_id in node.node_ids:
133
- workflow[node_id]["inputs"][
134
- node.key if node.key else "text"
135
- ] = payload.negative_prompt
136
- elif node.type == "width":
137
- for node_id in node.node_ids:
138
- workflow[node_id]["inputs"][
139
- node.key if node.key else "width"
140
- ] = payload.width
141
- elif node.type == "height":
142
- for node_id in node.node_ids:
143
- workflow[node_id]["inputs"][
144
- node.key if node.key else "height"
145
- ] = payload.height
146
- elif node.type == "n":
147
- for node_id in node.node_ids:
148
- workflow[node_id]["inputs"][
149
- node.key if node.key else "batch_size"
150
- ] = payload.n
151
- elif node.type == "steps":
152
- for node_id in node.node_ids:
153
- workflow[node_id]["inputs"][
154
- node.key if node.key else "steps"
155
- ] = payload.steps
156
- elif node.type == "seed":
157
- seed = (
158
- payload.seed
159
- if payload.seed
160
- else random.randint(0, 18446744073709551614)
161
- )
162
- for node_id in node.node_ids:
163
- workflow[node_id]["inputs"][node.key] = seed
164
- else:
165
- for node_id in node.node_ids:
166
- workflow[node_id]["inputs"][node.key] = node.value
167
-
168
- try:
169
- ws = websocket.WebSocket()
170
- ws.connect(f"{ws_url}/ws?clientId={client_id}")
171
- log.info("WebSocket connection established.")
172
- except Exception as e:
173
- log.exception(f"Failed to connect to WebSocket server: {e}")
174
- return None
175
-
176
- try:
177
- log.info("Sending workflow to WebSocket server.")
178
- log.info(f"Workflow: {workflow}")
179
- images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url)
180
- except Exception as e:
181
- log.exception(f"Error while receiving images: {e}")
182
- images = None
183
-
184
- ws.close()
185
-
186
- return images
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import random
5
+ import urllib.parse
6
+ import urllib.request
7
+ from typing import Optional
8
+
9
+ import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
10
+ from open_webui.env import SRC_LOG_LEVELS
11
+ from pydantic import BaseModel
12
+
13
+ log = logging.getLogger(__name__)
14
+ log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
15
+
16
+ default_headers = {"User-Agent": "Mozilla/5.0"}
17
+
18
+
19
+ def queue_prompt(prompt, client_id, base_url):
20
+ log.info("queue_prompt")
21
+ p = {"prompt": prompt, "client_id": client_id}
22
+ data = json.dumps(p).encode("utf-8")
23
+ log.debug(f"queue_prompt data: {data}")
24
+ try:
25
+ req = urllib.request.Request(
26
+ f"{base_url}/prompt", data=data, headers=default_headers
27
+ )
28
+ response = urllib.request.urlopen(req).read()
29
+ return json.loads(response)
30
+ except Exception as e:
31
+ log.exception(f"Error while queuing prompt: {e}")
32
+ raise e
33
+
34
+
35
+ def get_image(filename, subfolder, folder_type, base_url):
36
+ log.info("get_image")
37
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
38
+ url_values = urllib.parse.urlencode(data)
39
+ req = urllib.request.Request(
40
+ f"{base_url}/view?{url_values}", headers=default_headers
41
+ )
42
+ with urllib.request.urlopen(req) as response:
43
+ return response.read()
44
+
45
+
46
+ def get_image_url(filename, subfolder, folder_type, base_url):
47
+ log.info("get_image")
48
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
49
+ url_values = urllib.parse.urlencode(data)
50
+ return f"{base_url}/view?{url_values}"
51
+
52
+
53
+ def get_history(prompt_id, base_url):
54
+ log.info("get_history")
55
+
56
+ req = urllib.request.Request(
57
+ f"{base_url}/history/{prompt_id}", headers=default_headers
58
+ )
59
+ with urllib.request.urlopen(req) as response:
60
+ return json.loads(response.read())
61
+
62
+
63
+ def get_images(ws, prompt, client_id, base_url):
64
+ prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
65
+ output_images = []
66
+ while True:
67
+ out = ws.recv()
68
+ if isinstance(out, str):
69
+ message = json.loads(out)
70
+ if message["type"] == "executing":
71
+ data = message["data"]
72
+ if data["node"] is None and data["prompt_id"] == prompt_id:
73
+ break # Execution is done
74
+ else:
75
+ continue # previews are binary data
76
+
77
+ history = get_history(prompt_id, base_url)[prompt_id]
78
+ for o in history["outputs"]:
79
+ for node_id in history["outputs"]:
80
+ node_output = history["outputs"][node_id]
81
+ if "images" in node_output:
82
+ for image in node_output["images"]:
83
+ url = get_image_url(
84
+ image["filename"], image["subfolder"], image["type"], base_url
85
+ )
86
+ output_images.append({"url": url})
87
+ return {"data": output_images}
88
+
89
+
90
+ class ComfyUINodeInput(BaseModel):
91
+ type: Optional[str] = None
92
+ node_ids: list[str] = []
93
+ key: Optional[str] = "text"
94
+ value: Optional[str] = None
95
+
96
+
97
+ class ComfyUIWorkflow(BaseModel):
98
+ workflow: str
99
+ nodes: list[ComfyUINodeInput]
100
+
101
+
102
+ class ComfyUIGenerateImageForm(BaseModel):
103
+ workflow: ComfyUIWorkflow
104
+
105
+ prompt: str
106
+ negative_prompt: Optional[str] = None
107
+ width: int
108
+ height: int
109
+ n: int = 1
110
+
111
+ steps: Optional[int] = None
112
+ seed: Optional[int] = None
113
+
114
+
115
+ async def comfyui_generate_image(
116
+ model: str, payload: ComfyUIGenerateImageForm, client_id, base_url
117
+ ):
118
+ ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
119
+ workflow = json.loads(payload.workflow.workflow)
120
+
121
+ for node in payload.workflow.nodes:
122
+ if node.type:
123
+ if node.type == "model":
124
+ for node_id in node.node_ids:
125
+ workflow[node_id]["inputs"][node.key] = model
126
+ elif node.type == "prompt":
127
+ for node_id in node.node_ids:
128
+ workflow[node_id]["inputs"][
129
+ node.key if node.key else "text"
130
+ ] = payload.prompt
131
+ elif node.type == "negative_prompt":
132
+ for node_id in node.node_ids:
133
+ workflow[node_id]["inputs"][
134
+ node.key if node.key else "text"
135
+ ] = payload.negative_prompt
136
+ elif node.type == "width":
137
+ for node_id in node.node_ids:
138
+ workflow[node_id]["inputs"][
139
+ node.key if node.key else "width"
140
+ ] = payload.width
141
+ elif node.type == "height":
142
+ for node_id in node.node_ids:
143
+ workflow[node_id]["inputs"][
144
+ node.key if node.key else "height"
145
+ ] = payload.height
146
+ elif node.type == "n":
147
+ for node_id in node.node_ids:
148
+ workflow[node_id]["inputs"][
149
+ node.key if node.key else "batch_size"
150
+ ] = payload.n
151
+ elif node.type == "steps":
152
+ for node_id in node.node_ids:
153
+ workflow[node_id]["inputs"][
154
+ node.key if node.key else "steps"
155
+ ] = payload.steps
156
+ elif node.type == "seed":
157
+ seed = (
158
+ payload.seed
159
+ if payload.seed
160
+ else random.randint(0, 18446744073709551614)
161
+ )
162
+ for node_id in node.node_ids:
163
+ workflow[node_id]["inputs"][node.key] = seed
164
+ else:
165
+ for node_id in node.node_ids:
166
+ workflow[node_id]["inputs"][node.key] = node.value
167
+
168
+ try:
169
+ ws = websocket.WebSocket()
170
+ ws.connect(f"{ws_url}/ws?clientId={client_id}")
171
+ log.info("WebSocket connection established.")
172
+ except Exception as e:
173
+ log.exception(f"Failed to connect to WebSocket server: {e}")
174
+ return None
175
+
176
+ try:
177
+ log.info("Sending workflow to WebSocket server.")
178
+ log.info(f"Workflow: {workflow}")
179
+ images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url)
180
+ except Exception as e:
181
+ log.exception(f"Error while receiving images: {e}")
182
+ images = None
183
+
184
+ ws.close()
185
+
186
+ return images
backend/open_webui/apps/ollama/main.py CHANGED
@@ -1,1120 +1,1121 @@
1
- import asyncio
2
- import json
3
- import logging
4
- import os
5
- import random
6
- import re
7
- import time
8
- from typing import Optional, Union
9
- from urllib.parse import urlparse
10
-
11
- import aiohttp
12
- import requests
13
- from open_webui.apps.webui.models.models import Models
14
- from open_webui.config import (
15
- CORS_ALLOW_ORIGIN,
16
- ENABLE_MODEL_FILTER,
17
- ENABLE_OLLAMA_API,
18
- MODEL_FILTER_LIST,
19
- OLLAMA_BASE_URLS,
20
- UPLOAD_DIR,
21
- AppConfig,
22
- )
23
- from open_webui.env import AIOHTTP_CLIENT_TIMEOUT
24
-
25
-
26
- from open_webui.constants import ERROR_MESSAGES
27
- from open_webui.env import SRC_LOG_LEVELS
28
- from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
29
- from fastapi.middleware.cors import CORSMiddleware
30
- from fastapi.responses import StreamingResponse
31
- from pydantic import BaseModel, ConfigDict
32
- from starlette.background import BackgroundTask
33
-
34
-
35
- from open_webui.utils.misc import (
36
- calculate_sha256,
37
- )
38
- from open_webui.utils.payload import (
39
- apply_model_params_to_body_ollama,
40
- apply_model_params_to_body_openai,
41
- apply_model_system_prompt_to_body,
42
- )
43
- from open_webui.utils.utils import get_admin_user, get_verified_user
44
-
45
- log = logging.getLogger(__name__)
46
- log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
47
-
48
- app = FastAPI()
49
- app.add_middleware(
50
- CORSMiddleware,
51
- allow_origins=CORS_ALLOW_ORIGIN,
52
- allow_credentials=True,
53
- allow_methods=["*"],
54
- allow_headers=["*"],
55
- )
56
-
57
- app.state.config = AppConfig()
58
-
59
- app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
60
- app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
61
-
62
- app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
63
- app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
64
- app.state.MODELS = {}
65
-
66
-
67
- # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
68
- # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
69
- # least connections, or least response time for better resource utilization and performance optimization.
70
-
71
-
72
- @app.middleware("http")
73
- async def check_url(request: Request, call_next):
74
- if len(app.state.MODELS) == 0:
75
- await get_all_models()
76
- else:
77
- pass
78
-
79
- response = await call_next(request)
80
- return response
81
-
82
-
83
- @app.head("/")
84
- @app.get("/")
85
- async def get_status():
86
- return {"status": True}
87
-
88
-
89
- @app.get("/config")
90
- async def get_config(user=Depends(get_admin_user)):
91
- return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
92
-
93
-
94
- class OllamaConfigForm(BaseModel):
95
- enable_ollama_api: Optional[bool] = None
96
-
97
-
98
- @app.post("/config/update")
99
- async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
100
- app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api
101
- return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
102
-
103
-
104
- @app.get("/urls")
105
- async def get_ollama_api_urls(user=Depends(get_admin_user)):
106
- return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
107
-
108
-
109
- class UrlUpdateForm(BaseModel):
110
- urls: list[str]
111
-
112
-
113
- @app.post("/urls/update")
114
- async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
115
- app.state.config.OLLAMA_BASE_URLS = form_data.urls
116
-
117
- log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
118
- return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
119
-
120
-
121
- async def fetch_url(url):
122
- timeout = aiohttp.ClientTimeout(total=3)
123
- try:
124
- async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
125
- async with session.get(url) as response:
126
- return await response.json()
127
- except Exception as e:
128
- # Handle connection error here
129
- log.error(f"Connection error: {e}")
130
- return None
131
-
132
-
133
- async def cleanup_response(
134
- response: Optional[aiohttp.ClientResponse],
135
- session: Optional[aiohttp.ClientSession],
136
- ):
137
- if response:
138
- response.close()
139
- if session:
140
- await session.close()
141
-
142
-
143
- async def post_streaming_url(
144
- url: str, payload: Union[str, bytes], stream: bool = True, content_type=None
145
- ):
146
- r = None
147
- try:
148
- session = aiohttp.ClientSession(
149
- trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
150
- )
151
- r = await session.post(
152
- url,
153
- data=payload,
154
- headers={"Content-Type": "application/json"},
155
- )
156
- r.raise_for_status()
157
-
158
- if stream:
159
- headers = dict(r.headers)
160
- if content_type:
161
- headers["Content-Type"] = content_type
162
- return StreamingResponse(
163
- r.content,
164
- status_code=r.status,
165
- headers=headers,
166
- background=BackgroundTask(
167
- cleanup_response, response=r, session=session
168
- ),
169
- )
170
- else:
171
- res = await r.json()
172
- await cleanup_response(r, session)
173
- return res
174
-
175
- except Exception as e:
176
- error_detail = "Open WebUI: Server Connection Error"
177
- if r is not None:
178
- try:
179
- res = await r.json()
180
- if "error" in res:
181
- error_detail = f"Ollama: {res['error']}"
182
- except Exception:
183
- error_detail = f"Ollama: {e}"
184
-
185
- raise HTTPException(
186
- status_code=r.status if r else 500,
187
- detail=error_detail,
188
- )
189
-
190
-
191
- def merge_models_lists(model_lists):
192
- merged_models = {}
193
-
194
- for idx, model_list in enumerate(model_lists):
195
- if model_list is not None:
196
- for model in model_list:
197
- digest = model["digest"]
198
- if digest not in merged_models:
199
- model["urls"] = [idx]
200
- merged_models[digest] = model
201
- else:
202
- merged_models[digest]["urls"].append(idx)
203
-
204
- return list(merged_models.values())
205
-
206
-
207
- async def get_all_models():
208
- log.info("get_all_models()")
209
-
210
- if app.state.config.ENABLE_OLLAMA_API:
211
- tasks = [
212
- fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS
213
- ]
214
- responses = await asyncio.gather(*tasks)
215
-
216
- models = {
217
- "models": merge_models_lists(
218
- map(
219
- lambda response: response["models"] if response else None, responses
220
- )
221
- )
222
- }
223
-
224
- else:
225
- models = {"models": []}
226
-
227
- app.state.MODELS = {model["model"]: model for model in models["models"]}
228
-
229
- return models
230
-
231
-
232
- @app.get("/api/tags")
233
- @app.get("/api/tags/{url_idx}")
234
- async def get_ollama_tags(
235
- url_idx: Optional[int] = None, user=Depends(get_verified_user)
236
- ):
237
- if url_idx is None:
238
- models = await get_all_models()
239
-
240
- if app.state.config.ENABLE_MODEL_FILTER:
241
- if user.role == "user":
242
- models["models"] = list(
243
- filter(
244
- lambda model: model["name"]
245
- in app.state.config.MODEL_FILTER_LIST,
246
- models["models"],
247
- )
248
- )
249
- return models
250
- return models
251
- else:
252
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
253
-
254
- r = None
255
- try:
256
- r = requests.request(method="GET", url=f"{url}/api/tags")
257
- r.raise_for_status()
258
-
259
- return r.json()
260
- except Exception as e:
261
- log.exception(e)
262
- error_detail = "Open WebUI: Server Connection Error"
263
- if r is not None:
264
- try:
265
- res = r.json()
266
- if "error" in res:
267
- error_detail = f"Ollama: {res['error']}"
268
- except Exception:
269
- error_detail = f"Ollama: {e}"
270
-
271
- raise HTTPException(
272
- status_code=r.status_code if r else 500,
273
- detail=error_detail,
274
- )
275
-
276
-
277
- @app.get("/api/version")
278
- @app.get("/api/version/{url_idx}")
279
- async def get_ollama_versions(url_idx: Optional[int] = None):
280
- if app.state.config.ENABLE_OLLAMA_API:
281
- if url_idx is None:
282
- # returns lowest version
283
- tasks = [
284
- fetch_url(f"{url}/api/version")
285
- for url in app.state.config.OLLAMA_BASE_URLS
286
- ]
287
- responses = await asyncio.gather(*tasks)
288
- responses = list(filter(lambda x: x is not None, responses))
289
-
290
- if len(responses) > 0:
291
- lowest_version = min(
292
- responses,
293
- key=lambda x: tuple(
294
- map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
295
- ),
296
- )
297
-
298
- return {"version": lowest_version["version"]}
299
- else:
300
- raise HTTPException(
301
- status_code=500,
302
- detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
303
- )
304
- else:
305
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
306
-
307
- r = None
308
- try:
309
- r = requests.request(method="GET", url=f"{url}/api/version")
310
- r.raise_for_status()
311
-
312
- return r.json()
313
- except Exception as e:
314
- log.exception(e)
315
- error_detail = "Open WebUI: Server Connection Error"
316
- if r is not None:
317
- try:
318
- res = r.json()
319
- if "error" in res:
320
- error_detail = f"Ollama: {res['error']}"
321
- except Exception:
322
- error_detail = f"Ollama: {e}"
323
-
324
- raise HTTPException(
325
- status_code=r.status_code if r else 500,
326
- detail=error_detail,
327
- )
328
- else:
329
- return {"version": False}
330
-
331
-
332
- class ModelNameForm(BaseModel):
333
- name: str
334
-
335
-
336
- @app.post("/api/pull")
337
- @app.post("/api/pull/{url_idx}")
338
- async def pull_model(
339
- form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
340
- ):
341
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
342
- log.info(f"url: {url}")
343
-
344
- # Admin should be able to pull models from any source
345
- payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
346
-
347
- return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
348
-
349
-
350
- class PushModelForm(BaseModel):
351
- name: str
352
- insecure: Optional[bool] = None
353
- stream: Optional[bool] = None
354
-
355
-
356
- @app.delete("/api/push")
357
- @app.delete("/api/push/{url_idx}")
358
- async def push_model(
359
- form_data: PushModelForm,
360
- url_idx: Optional[int] = None,
361
- user=Depends(get_admin_user),
362
- ):
363
- if url_idx is None:
364
- if form_data.name in app.state.MODELS:
365
- url_idx = app.state.MODELS[form_data.name]["urls"][0]
366
- else:
367
- raise HTTPException(
368
- status_code=400,
369
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
370
- )
371
-
372
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
373
- log.debug(f"url: {url}")
374
-
375
- return await post_streaming_url(
376
- f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
377
- )
378
-
379
-
380
- class CreateModelForm(BaseModel):
381
- name: str
382
- modelfile: Optional[str] = None
383
- stream: Optional[bool] = None
384
- path: Optional[str] = None
385
-
386
-
387
- @app.post("/api/create")
388
- @app.post("/api/create/{url_idx}")
389
- async def create_model(
390
- form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
391
- ):
392
- log.debug(f"form_data: {form_data}")
393
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
394
- log.info(f"url: {url}")
395
-
396
- return await post_streaming_url(
397
- f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
398
- )
399
-
400
-
401
- class CopyModelForm(BaseModel):
402
- source: str
403
- destination: str
404
-
405
-
406
- @app.post("/api/copy")
407
- @app.post("/api/copy/{url_idx}")
408
- async def copy_model(
409
- form_data: CopyModelForm,
410
- url_idx: Optional[int] = None,
411
- user=Depends(get_admin_user),
412
- ):
413
- if url_idx is None:
414
- if form_data.source in app.state.MODELS:
415
- url_idx = app.state.MODELS[form_data.source]["urls"][0]
416
- else:
417
- raise HTTPException(
418
- status_code=400,
419
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
420
- )
421
-
422
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
423
- log.info(f"url: {url}")
424
- r = requests.request(
425
- method="POST",
426
- url=f"{url}/api/copy",
427
- headers={"Content-Type": "application/json"},
428
- data=form_data.model_dump_json(exclude_none=True).encode(),
429
- )
430
-
431
- try:
432
- r.raise_for_status()
433
-
434
- log.debug(f"r.text: {r.text}")
435
-
436
- return True
437
- except Exception as e:
438
- log.exception(e)
439
- error_detail = "Open WebUI: Server Connection Error"
440
- if r is not None:
441
- try:
442
- res = r.json()
443
- if "error" in res:
444
- error_detail = f"Ollama: {res['error']}"
445
- except Exception:
446
- error_detail = f"Ollama: {e}"
447
-
448
- raise HTTPException(
449
- status_code=r.status_code if r else 500,
450
- detail=error_detail,
451
- )
452
-
453
-
454
- @app.delete("/api/delete")
455
- @app.delete("/api/delete/{url_idx}")
456
- async def delete_model(
457
- form_data: ModelNameForm,
458
- url_idx: Optional[int] = None,
459
- user=Depends(get_admin_user),
460
- ):
461
- if url_idx is None:
462
- if form_data.name in app.state.MODELS:
463
- url_idx = app.state.MODELS[form_data.name]["urls"][0]
464
- else:
465
- raise HTTPException(
466
- status_code=400,
467
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
468
- )
469
-
470
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
471
- log.info(f"url: {url}")
472
-
473
- r = requests.request(
474
- method="DELETE",
475
- url=f"{url}/api/delete",
476
- headers={"Content-Type": "application/json"},
477
- data=form_data.model_dump_json(exclude_none=True).encode(),
478
- )
479
- try:
480
- r.raise_for_status()
481
-
482
- log.debug(f"r.text: {r.text}")
483
-
484
- return True
485
- except Exception as e:
486
- log.exception(e)
487
- error_detail = "Open WebUI: Server Connection Error"
488
- if r is not None:
489
- try:
490
- res = r.json()
491
- if "error" in res:
492
- error_detail = f"Ollama: {res['error']}"
493
- except Exception:
494
- error_detail = f"Ollama: {e}"
495
-
496
- raise HTTPException(
497
- status_code=r.status_code if r else 500,
498
- detail=error_detail,
499
- )
500
-
501
-
502
- @app.post("/api/show")
503
- async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
504
- if form_data.name not in app.state.MODELS:
505
- raise HTTPException(
506
- status_code=400,
507
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
508
- )
509
-
510
- url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
511
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
512
- log.info(f"url: {url}")
513
-
514
- r = requests.request(
515
- method="POST",
516
- url=f"{url}/api/show",
517
- headers={"Content-Type": "application/json"},
518
- data=form_data.model_dump_json(exclude_none=True).encode(),
519
- )
520
- try:
521
- r.raise_for_status()
522
-
523
- return r.json()
524
- except Exception as e:
525
- log.exception(e)
526
- error_detail = "Open WebUI: Server Connection Error"
527
- if r is not None:
528
- try:
529
- res = r.json()
530
- if "error" in res:
531
- error_detail = f"Ollama: {res['error']}"
532
- except Exception:
533
- error_detail = f"Ollama: {e}"
534
-
535
- raise HTTPException(
536
- status_code=r.status_code if r else 500,
537
- detail=error_detail,
538
- )
539
-
540
-
541
- class GenerateEmbeddingsForm(BaseModel):
542
- model: str
543
- prompt: str
544
- options: Optional[dict] = None
545
- keep_alive: Optional[Union[int, str]] = None
546
-
547
-
548
- class GenerateEmbedForm(BaseModel):
549
- model: str
550
- input: list[str] | str
551
- truncate: Optional[bool] = None
552
- options: Optional[dict] = None
553
- keep_alive: Optional[Union[int, str]] = None
554
-
555
-
556
- @app.post("/api/embed")
557
- @app.post("/api/embed/{url_idx}")
558
- async def generate_embeddings(
559
- form_data: GenerateEmbedForm,
560
- url_idx: Optional[int] = None,
561
- user=Depends(get_verified_user),
562
- ):
563
- return generate_ollama_batch_embeddings(form_data, url_idx)
564
-
565
-
566
- @app.post("/api/embeddings")
567
- @app.post("/api/embeddings/{url_idx}")
568
- async def generate_embeddings(
569
- form_data: GenerateEmbeddingsForm,
570
- url_idx: Optional[int] = None,
571
- user=Depends(get_verified_user),
572
- ):
573
- return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
574
-
575
-
576
- def generate_ollama_embeddings(
577
- form_data: GenerateEmbeddingsForm,
578
- url_idx: Optional[int] = None,
579
- ):
580
- log.info(f"generate_ollama_embeddings {form_data}")
581
-
582
- if url_idx is None:
583
- model = form_data.model
584
-
585
- if ":" not in model:
586
- model = f"{model}:latest"
587
-
588
- if model in app.state.MODELS:
589
- url_idx = random.choice(app.state.MODELS[model]["urls"])
590
- else:
591
- raise HTTPException(
592
- status_code=400,
593
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
594
- )
595
-
596
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
597
- log.info(f"url: {url}")
598
-
599
- r = requests.request(
600
- method="POST",
601
- url=f"{url}/api/embeddings",
602
- headers={"Content-Type": "application/json"},
603
- data=form_data.model_dump_json(exclude_none=True).encode(),
604
- )
605
- try:
606
- r.raise_for_status()
607
-
608
- data = r.json()
609
-
610
- log.info(f"generate_ollama_embeddings {data}")
611
-
612
- if "embedding" in data:
613
- return data
614
- else:
615
- raise Exception("Something went wrong :/")
616
- except Exception as e:
617
- log.exception(e)
618
- error_detail = "Open WebUI: Server Connection Error"
619
- if r is not None:
620
- try:
621
- res = r.json()
622
- if "error" in res:
623
- error_detail = f"Ollama: {res['error']}"
624
- except Exception:
625
- error_detail = f"Ollama: {e}"
626
-
627
- raise HTTPException(
628
- status_code=r.status_code if r else 500,
629
- detail=error_detail,
630
- )
631
-
632
-
633
- def generate_ollama_batch_embeddings(
634
- form_data: GenerateEmbedForm,
635
- url_idx: Optional[int] = None,
636
- ):
637
- log.info(f"generate_ollama_batch_embeddings {form_data}")
638
-
639
- if url_idx is None:
640
- model = form_data.model
641
-
642
- if ":" not in model:
643
- model = f"{model}:latest"
644
-
645
- if model in app.state.MODELS:
646
- url_idx = random.choice(app.state.MODELS[model]["urls"])
647
- else:
648
- raise HTTPException(
649
- status_code=400,
650
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
651
- )
652
-
653
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
654
- log.info(f"url: {url}")
655
-
656
- r = requests.request(
657
- method="POST",
658
- url=f"{url}/api/embed",
659
- headers={"Content-Type": "application/json"},
660
- data=form_data.model_dump_json(exclude_none=True).encode(),
661
- )
662
- try:
663
- r.raise_for_status()
664
-
665
- data = r.json()
666
-
667
- log.info(f"generate_ollama_batch_embeddings {data}")
668
-
669
- if "embeddings" in data:
670
- return data
671
- else:
672
- raise Exception("Something went wrong :/")
673
- except Exception as e:
674
- log.exception(e)
675
- error_detail = "Open WebUI: Server Connection Error"
676
- if r is not None:
677
- try:
678
- res = r.json()
679
- if "error" in res:
680
- error_detail = f"Ollama: {res['error']}"
681
- except Exception:
682
- error_detail = f"Ollama: {e}"
683
-
684
- raise Exception(error_detail)
685
-
686
-
687
- class GenerateCompletionForm(BaseModel):
688
- model: str
689
- prompt: str
690
- images: Optional[list[str]] = None
691
- format: Optional[str] = None
692
- options: Optional[dict] = None
693
- system: Optional[str] = None
694
- template: Optional[str] = None
695
- context: Optional[str] = None
696
- stream: Optional[bool] = True
697
- raw: Optional[bool] = None
698
- keep_alive: Optional[Union[int, str]] = None
699
-
700
-
701
- @app.post("/api/generate")
702
- @app.post("/api/generate/{url_idx}")
703
- async def generate_completion(
704
- form_data: GenerateCompletionForm,
705
- url_idx: Optional[int] = None,
706
- user=Depends(get_verified_user),
707
- ):
708
- if url_idx is None:
709
- model = form_data.model
710
-
711
- if ":" not in model:
712
- model = f"{model}:latest"
713
-
714
- if model in app.state.MODELS:
715
- url_idx = random.choice(app.state.MODELS[model]["urls"])
716
- else:
717
- raise HTTPException(
718
- status_code=400,
719
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
720
- )
721
-
722
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
723
- log.info(f"url: {url}")
724
-
725
- return await post_streaming_url(
726
- f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
727
- )
728
-
729
-
730
- class ChatMessage(BaseModel):
731
- role: str
732
- content: str
733
- images: Optional[list[str]] = None
734
-
735
-
736
- class GenerateChatCompletionForm(BaseModel):
737
- model: str
738
- messages: list[ChatMessage]
739
- format: Optional[str] = None
740
- options: Optional[dict] = None
741
- template: Optional[str] = None
742
- stream: Optional[bool] = None
743
- keep_alive: Optional[Union[int, str]] = None
744
-
745
-
746
- def get_ollama_url(url_idx: Optional[int], model: str):
747
- if url_idx is None:
748
- if model not in app.state.MODELS:
749
- raise HTTPException(
750
- status_code=400,
751
- detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
752
- )
753
- url_idx = random.choice(app.state.MODELS[model]["urls"])
754
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
755
- return url
756
-
757
-
758
- @app.post("/api/chat")
759
- @app.post("/api/chat/{url_idx}")
760
- async def generate_chat_completion(
761
- form_data: GenerateChatCompletionForm,
762
- url_idx: Optional[int] = None,
763
- user=Depends(get_verified_user),
764
- ):
765
- payload = {**form_data.model_dump(exclude_none=True)}
766
- log.debug(f"generate_chat_completion() - 1.payload = {payload}")
767
- if "metadata" in payload:
768
- del payload["metadata"]
769
-
770
- model_id = form_data.model
771
-
772
- if app.state.config.ENABLE_MODEL_FILTER:
773
- if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
774
- raise HTTPException(
775
- status_code=403,
776
- detail="Model not found",
777
- )
778
-
779
- model_info = Models.get_model_by_id(model_id)
780
-
781
- if model_info:
782
- if model_info.base_model_id:
783
- payload["model"] = model_info.base_model_id
784
-
785
- params = model_info.params.model_dump()
786
-
787
- if params:
788
- if payload.get("options") is None:
789
- payload["options"] = {}
790
-
791
- payload["options"] = apply_model_params_to_body_ollama(
792
- params, payload["options"]
793
- )
794
- payload = apply_model_system_prompt_to_body(params, payload, user)
795
-
796
- if ":" not in payload["model"]:
797
- payload["model"] = f"{payload['model']}:latest"
798
-
799
- url = get_ollama_url(url_idx, payload["model"])
800
- log.info(f"url: {url}")
801
- log.debug(f"generate_chat_completion() - 2.payload = {payload}")
802
-
803
- return await post_streaming_url(
804
- f"{url}/api/chat",
805
- json.dumps(payload),
806
- stream=form_data.stream,
807
- content_type="application/x-ndjson",
808
- )
809
-
810
-
811
- # TODO: we should update this part once Ollama supports other types
812
- class OpenAIChatMessageContent(BaseModel):
813
- type: str
814
- model_config = ConfigDict(extra="allow")
815
-
816
-
817
- class OpenAIChatMessage(BaseModel):
818
- role: str
819
- content: Union[str, OpenAIChatMessageContent]
820
-
821
- model_config = ConfigDict(extra="allow")
822
-
823
-
824
- class OpenAIChatCompletionForm(BaseModel):
825
- model: str
826
- messages: list[OpenAIChatMessage]
827
-
828
- model_config = ConfigDict(extra="allow")
829
-
830
-
831
- @app.post("/v1/chat/completions")
832
- @app.post("/v1/chat/completions/{url_idx}")
833
- async def generate_openai_chat_completion(
834
- form_data: dict,
835
- url_idx: Optional[int] = None,
836
- user=Depends(get_verified_user),
837
- ):
838
- completion_form = OpenAIChatCompletionForm(**form_data)
839
- payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
840
- if "metadata" in payload:
841
- del payload["metadata"]
842
-
843
- model_id = completion_form.model
844
-
845
- if app.state.config.ENABLE_MODEL_FILTER:
846
- if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
847
- raise HTTPException(
848
- status_code=403,
849
- detail="Model not found",
850
- )
851
-
852
- model_info = Models.get_model_by_id(model_id)
853
-
854
- if model_info:
855
- if model_info.base_model_id:
856
- payload["model"] = model_info.base_model_id
857
-
858
- params = model_info.params.model_dump()
859
-
860
- if params:
861
- payload = apply_model_params_to_body_openai(params, payload)
862
- payload = apply_model_system_prompt_to_body(params, payload, user)
863
-
864
- if ":" not in payload["model"]:
865
- payload["model"] = f"{payload['model']}:latest"
866
-
867
- url = get_ollama_url(url_idx, payload["model"])
868
- log.info(f"url: {url}")
869
-
870
- return await post_streaming_url(
871
- f"{url}/v1/chat/completions",
872
- json.dumps(payload),
873
- stream=payload.get("stream", False),
874
- )
875
-
876
-
877
- @app.get("/v1/models")
878
- @app.get("/v1/models/{url_idx}")
879
- async def get_openai_models(
880
- url_idx: Optional[int] = None,
881
- user=Depends(get_verified_user),
882
- ):
883
- if url_idx is None:
884
- models = await get_all_models()
885
-
886
- if app.state.config.ENABLE_MODEL_FILTER:
887
- if user.role == "user":
888
- models["models"] = list(
889
- filter(
890
- lambda model: model["name"]
891
- in app.state.config.MODEL_FILTER_LIST,
892
- models["models"],
893
- )
894
- )
895
-
896
- return {
897
- "data": [
898
- {
899
- "id": model["model"],
900
- "object": "model",
901
- "created": int(time.time()),
902
- "owned_by": "openai",
903
- }
904
- for model in models["models"]
905
- ],
906
- "object": "list",
907
- }
908
-
909
- else:
910
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
911
- try:
912
- r = requests.request(method="GET", url=f"{url}/api/tags")
913
- r.raise_for_status()
914
-
915
- models = r.json()
916
-
917
- return {
918
- "data": [
919
- {
920
- "id": model["model"],
921
- "object": "model",
922
- "created": int(time.time()),
923
- "owned_by": "openai",
924
- }
925
- for model in models["models"]
926
- ],
927
- "object": "list",
928
- }
929
-
930
- except Exception as e:
931
- log.exception(e)
932
- error_detail = "Open WebUI: Server Connection Error"
933
- if r is not None:
934
- try:
935
- res = r.json()
936
- if "error" in res:
937
- error_detail = f"Ollama: {res['error']}"
938
- except Exception:
939
- error_detail = f"Ollama: {e}"
940
-
941
- raise HTTPException(
942
- status_code=r.status_code if r else 500,
943
- detail=error_detail,
944
- )
945
-
946
-
947
- class UrlForm(BaseModel):
948
- url: str
949
-
950
-
951
- class UploadBlobForm(BaseModel):
952
- filename: str
953
-
954
-
955
- def parse_huggingface_url(hf_url):
956
- try:
957
- # Parse the URL
958
- parsed_url = urlparse(hf_url)
959
-
960
- # Get the path and split it into components
961
- path_components = parsed_url.path.split("/")
962
-
963
- # Extract the desired output
964
- model_file = path_components[-1]
965
-
966
- return model_file
967
- except ValueError:
968
- return None
969
-
970
-
971
- async def download_file_stream(
972
- ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
973
- ):
974
- done = False
975
-
976
- if os.path.exists(file_path):
977
- current_size = os.path.getsize(file_path)
978
- else:
979
- current_size = 0
980
-
981
- headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
982
-
983
- timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
984
-
985
- async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
986
- async with session.get(file_url, headers=headers) as response:
987
- total_size = int(response.headers.get("content-length", 0)) + current_size
988
-
989
- with open(file_path, "ab+") as file:
990
- async for data in response.content.iter_chunked(chunk_size):
991
- current_size += len(data)
992
- file.write(data)
993
-
994
- done = current_size == total_size
995
- progress = round((current_size / total_size) * 100, 2)
996
-
997
- yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
998
-
999
- if done:
1000
- file.seek(0)
1001
- hashed = calculate_sha256(file)
1002
- file.seek(0)
1003
-
1004
- url = f"{ollama_url}/api/blobs/sha256:{hashed}"
1005
- response = requests.post(url, data=file)
1006
-
1007
- if response.ok:
1008
- res = {
1009
- "done": done,
1010
- "blob": f"sha256:{hashed}",
1011
- "name": file_name,
1012
- }
1013
- os.remove(file_path)
1014
-
1015
- yield f"data: {json.dumps(res)}\n\n"
1016
- else:
1017
- raise "Ollama: Could not create blob, Please try again."
1018
-
1019
-
1020
- # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
1021
- @app.post("/models/download")
1022
- @app.post("/models/download/{url_idx}")
1023
- async def download_model(
1024
- form_data: UrlForm,
1025
- url_idx: Optional[int] = None,
1026
- user=Depends(get_admin_user),
1027
- ):
1028
- allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
1029
-
1030
- if not any(form_data.url.startswith(host) for host in allowed_hosts):
1031
- raise HTTPException(
1032
- status_code=400,
1033
- detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
1034
- )
1035
-
1036
- if url_idx is None:
1037
- url_idx = 0
1038
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
1039
-
1040
- file_name = parse_huggingface_url(form_data.url)
1041
-
1042
- if file_name:
1043
- file_path = f"{UPLOAD_DIR}/{file_name}"
1044
-
1045
- return StreamingResponse(
1046
- download_file_stream(url, form_data.url, file_path, file_name),
1047
- )
1048
- else:
1049
- return None
1050
-
1051
-
1052
- @app.post("/models/upload")
1053
- @app.post("/models/upload/{url_idx}")
1054
- def upload_model(
1055
- file: UploadFile = File(...),
1056
- url_idx: Optional[int] = None,
1057
- user=Depends(get_admin_user),
1058
- ):
1059
- if url_idx is None:
1060
- url_idx = 0
1061
- ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
1062
-
1063
- file_path = f"{UPLOAD_DIR}/{file.filename}"
1064
-
1065
- # Save file in chunks
1066
- with open(file_path, "wb+") as f:
1067
- for chunk in file.file:
1068
- f.write(chunk)
1069
-
1070
- def file_process_stream():
1071
- nonlocal ollama_url
1072
- total_size = os.path.getsize(file_path)
1073
- chunk_size = 1024 * 1024
1074
- try:
1075
- with open(file_path, "rb") as f:
1076
- total = 0
1077
- done = False
1078
-
1079
- while not done:
1080
- chunk = f.read(chunk_size)
1081
- if not chunk:
1082
- done = True
1083
- continue
1084
-
1085
- total += len(chunk)
1086
- progress = round((total / total_size) * 100, 2)
1087
-
1088
- res = {
1089
- "progress": progress,
1090
- "total": total_size,
1091
- "completed": total,
1092
- }
1093
- yield f"data: {json.dumps(res)}\n\n"
1094
-
1095
- if done:
1096
- f.seek(0)
1097
- hashed = calculate_sha256(f)
1098
- f.seek(0)
1099
-
1100
- url = f"{ollama_url}/api/blobs/sha256:{hashed}"
1101
- response = requests.post(url, data=f)
1102
-
1103
- if response.ok:
1104
- res = {
1105
- "done": done,
1106
- "blob": f"sha256:{hashed}",
1107
- "name": file.filename,
1108
- }
1109
- os.remove(file_path)
1110
- yield f"data: {json.dumps(res)}\n\n"
1111
- else:
1112
- raise Exception(
1113
- "Ollama: Could not create blob, Please try again."
1114
- )
1115
-
1116
- except Exception as e:
1117
- res = {"error": str(e)}
1118
- yield f"data: {json.dumps(res)}\n\n"
1119
-
1120
- return StreamingResponse(file_process_stream(), media_type="text/event-stream")
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import os
5
+ import random
6
+ import re
7
+ import time
8
+ from typing import Optional, Union
9
+ from urllib.parse import urlparse
10
+
11
+ import aiohttp
12
+ import requests
13
+ from open_webui.apps.webui.models.models import Models
14
+ from open_webui.config import (
15
+ CORS_ALLOW_ORIGIN,
16
+ ENABLE_MODEL_FILTER,
17
+ ENABLE_OLLAMA_API,
18
+ MODEL_FILTER_LIST,
19
+ OLLAMA_BASE_URLS,
20
+ UPLOAD_DIR,
21
+ AppConfig,
22
+ )
23
+ from open_webui.env import AIOHTTP_CLIENT_TIMEOUT
24
+
25
+
26
+ from open_webui.constants import ERROR_MESSAGES
27
+ from open_webui.env import SRC_LOG_LEVELS
28
+ from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
29
+ from fastapi.middleware.cors import CORSMiddleware
30
+ from fastapi.responses import StreamingResponse
31
+ from pydantic import BaseModel, ConfigDict
32
+ from starlette.background import BackgroundTask
33
+
34
+
35
+ from open_webui.utils.misc import (
36
+ calculate_sha256,
37
+ )
38
+ from open_webui.utils.payload import (
39
+ apply_model_params_to_body_ollama,
40
+ apply_model_params_to_body_openai,
41
+ apply_model_system_prompt_to_body,
42
+ )
43
+ from open_webui.utils.utils import get_admin_user, get_verified_user
44
+
45
+ log = logging.getLogger(__name__)
46
+ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
47
+
48
+ app = FastAPI()
49
+ app.add_middleware(
50
+ CORSMiddleware,
51
+ allow_origins=CORS_ALLOW_ORIGIN,
52
+ allow_credentials=True,
53
+ allow_methods=["*"],
54
+ allow_headers=["*"],
55
+ )
56
+
57
+ app.state.config = AppConfig()
58
+
59
+ app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
60
+ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
61
+
62
+ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
63
+ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
64
+ app.state.MODELS = {}
65
+
66
+
67
+ # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
68
+ # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
69
+ # least connections, or least response time for better resource utilization and performance optimization.
70
+
71
+
72
+ @app.middleware("http")
73
+ async def check_url(request: Request, call_next):
74
+ if len(app.state.MODELS) == 0:
75
+ await get_all_models()
76
+ else:
77
+ pass
78
+
79
+ response = await call_next(request)
80
+ return response
81
+
82
+
83
+ @app.head("/")
84
+ @app.get("/")
85
+ async def get_status():
86
+ return {"status": True}
87
+
88
+
89
+ @app.get("/config")
90
+ async def get_config(user=Depends(get_admin_user)):
91
+ return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
92
+
93
+
94
+ class OllamaConfigForm(BaseModel):
95
+ enable_ollama_api: Optional[bool] = None
96
+
97
+
98
+ @app.post("/config/update")
99
+ async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
100
+ app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api
101
+ return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
102
+
103
+
104
+ @app.get("/urls")
105
+ async def get_ollama_api_urls(user=Depends(get_admin_user)):
106
+ return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
107
+
108
+
109
+ class UrlUpdateForm(BaseModel):
110
+ urls: list[str]
111
+
112
+
113
+ @app.post("/urls/update")
114
+ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
115
+ app.state.config.OLLAMA_BASE_URLS = form_data.urls
116
+
117
+ log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
118
+ return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
119
+
120
+
121
+ async def fetch_url(url):
122
+ timeout = aiohttp.ClientTimeout(total=3)
123
+ try:
124
+ async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
125
+ async with session.get(url) as response:
126
+ return await response.json()
127
+ except Exception as e:
128
+ # Handle connection error here
129
+ log.error(f"Connection error: {e}")
130
+ return None
131
+
132
+
133
+ async def cleanup_response(
134
+ response: Optional[aiohttp.ClientResponse],
135
+ session: Optional[aiohttp.ClientSession],
136
+ ):
137
+ if response:
138
+ response.close()
139
+ if session:
140
+ await session.close()
141
+
142
+
143
+ async def post_streaming_url(
144
+ url: str, payload: Union[str, bytes], stream: bool = True, content_type=None
145
+ ):
146
+ r = None
147
+ try:
148
+ session = aiohttp.ClientSession(
149
+ trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
150
+ )
151
+ r = await session.post(
152
+ url,
153
+ data=payload,
154
+ headers={"Content-Type": "application/json"},
155
+ )
156
+ r.raise_for_status()
157
+
158
+ if stream:
159
+ headers = dict(r.headers)
160
+ if content_type:
161
+ headers["Content-Type"] = content_type
162
+ return StreamingResponse(
163
+ r.content,
164
+ status_code=r.status,
165
+ headers=headers,
166
+ background=BackgroundTask(
167
+ cleanup_response, response=r, session=session
168
+ ),
169
+ )
170
+ else:
171
+ res = await r.json()
172
+ await cleanup_response(r, session)
173
+ return res
174
+
175
+ except Exception as e:
176
+ error_detail = "Open WebUI: Server Connection Error"
177
+ if r is not None:
178
+ try:
179
+ res = await r.json()
180
+ if "error" in res:
181
+ error_detail = f"Ollama: {res['error']}"
182
+ except Exception:
183
+ error_detail = f"Ollama: {e}"
184
+
185
+ raise HTTPException(
186
+ status_code=r.status if r else 500,
187
+ detail=error_detail,
188
+ )
189
+
190
+
191
+ def merge_models_lists(model_lists):
192
+ merged_models = {}
193
+
194
+ for idx, model_list in enumerate(model_lists):
195
+ if model_list is not None:
196
+ for model in model_list:
197
+ digest = model["digest"]
198
+ if digest not in merged_models:
199
+ model["urls"] = [idx]
200
+ merged_models[digest] = model
201
+ else:
202
+ merged_models[digest]["urls"].append(idx)
203
+
204
+ return list(merged_models.values())
205
+
206
+
207
+ async def get_all_models():
208
+ log.info("get_all_models()")
209
+
210
+ if app.state.config.ENABLE_OLLAMA_API:
211
+ tasks = [
212
+ fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS
213
+ ]
214
+ responses = await asyncio.gather(*tasks)
215
+
216
+ models = {
217
+ "models": merge_models_lists(
218
+ map(
219
+ lambda response: response["models"] if response else None, responses
220
+ )
221
+ )
222
+ }
223
+
224
+ else:
225
+ models = {"models": []}
226
+
227
+ app.state.MODELS = {model["model"]: model for model in models["models"]}
228
+
229
+ return models
230
+
231
+
232
+ @app.get("/api/tags")
233
+ @app.get("/api/tags/{url_idx}")
234
+ async def get_ollama_tags(
235
+ url_idx: Optional[int] = None, user=Depends(get_verified_user)
236
+ ):
237
+ if url_idx is None:
238
+ models = await get_all_models()
239
+
240
+ if app.state.config.ENABLE_MODEL_FILTER:
241
+ if user.role == "user":
242
+ models["models"] = list(
243
+ filter(
244
+ lambda model: model["name"]
245
+ in app.state.config.MODEL_FILTER_LIST,
246
+ models["models"],
247
+ )
248
+ )
249
+ return models
250
+ return models
251
+ else:
252
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
253
+
254
+ r = None
255
+ try:
256
+ r = requests.request(method="GET", url=f"{url}/api/tags")
257
+ r.raise_for_status()
258
+
259
+ return r.json()
260
+ except Exception as e:
261
+ log.exception(e)
262
+ error_detail = "Open WebUI: Server Connection Error"
263
+ if r is not None:
264
+ try:
265
+ res = r.json()
266
+ if "error" in res:
267
+ error_detail = f"Ollama: {res['error']}"
268
+ except Exception:
269
+ error_detail = f"Ollama: {e}"
270
+
271
+ raise HTTPException(
272
+ status_code=r.status_code if r else 500,
273
+ detail=error_detail,
274
+ )
275
+
276
+
277
+ @app.get("/api/version")
278
+ @app.get("/api/version/{url_idx}")
279
+ async def get_ollama_versions(url_idx: Optional[int] = None):
280
+ if app.state.config.ENABLE_OLLAMA_API:
281
+ if url_idx is None:
282
+ # returns lowest version
283
+ tasks = [
284
+ fetch_url(f"{url}/api/version")
285
+ for url in app.state.config.OLLAMA_BASE_URLS
286
+ ]
287
+ responses = await asyncio.gather(*tasks)
288
+ responses = list(filter(lambda x: x is not None, responses))
289
+
290
+ if len(responses) > 0:
291
+ lowest_version = min(
292
+ responses,
293
+ key=lambda x: tuple(
294
+ map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
295
+ ),
296
+ )
297
+
298
+ return {"version": lowest_version["version"]}
299
+ else:
300
+ raise HTTPException(
301
+ status_code=500,
302
+ detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
303
+ )
304
+ else:
305
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
306
+
307
+ r = None
308
+ try:
309
+ r = requests.request(method="GET", url=f"{url}/api/version")
310
+ r.raise_for_status()
311
+
312
+ return r.json()
313
+ except Exception as e:
314
+ log.exception(e)
315
+ error_detail = "Open WebUI: Server Connection Error"
316
+ if r is not None:
317
+ try:
318
+ res = r.json()
319
+ if "error" in res:
320
+ error_detail = f"Ollama: {res['error']}"
321
+ except Exception:
322
+ error_detail = f"Ollama: {e}"
323
+
324
+ raise HTTPException(
325
+ status_code=r.status_code if r else 500,
326
+ detail=error_detail,
327
+ )
328
+ else:
329
+ return {"version": False}
330
+
331
+
332
+ class ModelNameForm(BaseModel):
333
+ name: str
334
+
335
+
336
+ @app.post("/api/pull")
337
+ @app.post("/api/pull/{url_idx}")
338
+ async def pull_model(
339
+ form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
340
+ ):
341
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
342
+ log.info(f"url: {url}")
343
+
344
+ # Admin should be able to pull models from any source
345
+ payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
346
+
347
+ return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
348
+
349
+
350
+ class PushModelForm(BaseModel):
351
+ name: str
352
+ insecure: Optional[bool] = None
353
+ stream: Optional[bool] = None
354
+
355
+
356
+ @app.delete("/api/push")
357
+ @app.delete("/api/push/{url_idx}")
358
+ async def push_model(
359
+ form_data: PushModelForm,
360
+ url_idx: Optional[int] = None,
361
+ user=Depends(get_admin_user),
362
+ ):
363
+ if url_idx is None:
364
+ if form_data.name in app.state.MODELS:
365
+ url_idx = app.state.MODELS[form_data.name]["urls"][0]
366
+ else:
367
+ raise HTTPException(
368
+ status_code=400,
369
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
370
+ )
371
+
372
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
373
+ log.debug(f"url: {url}")
374
+
375
+ return await post_streaming_url(
376
+ f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
377
+ )
378
+
379
+
380
+ class CreateModelForm(BaseModel):
381
+ name: str
382
+ modelfile: Optional[str] = None
383
+ stream: Optional[bool] = None
384
+ path: Optional[str] = None
385
+
386
+
387
+ @app.post("/api/create")
388
+ @app.post("/api/create/{url_idx}")
389
+ async def create_model(
390
+ form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
391
+ ):
392
+ log.debug(f"form_data: {form_data}")
393
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
394
+ log.info(f"url: {url}")
395
+
396
+ return await post_streaming_url(
397
+ f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
398
+ )
399
+
400
+
401
+ class CopyModelForm(BaseModel):
402
+ source: str
403
+ destination: str
404
+
405
+
406
+ @app.post("/api/copy")
407
+ @app.post("/api/copy/{url_idx}")
408
+ async def copy_model(
409
+ form_data: CopyModelForm,
410
+ url_idx: Optional[int] = None,
411
+ user=Depends(get_admin_user),
412
+ ):
413
+ if url_idx is None:
414
+ if form_data.source in app.state.MODELS:
415
+ url_idx = app.state.MODELS[form_data.source]["urls"][0]
416
+ else:
417
+ raise HTTPException(
418
+ status_code=400,
419
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
420
+ )
421
+
422
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
423
+ log.info(f"url: {url}")
424
+ r = requests.request(
425
+ method="POST",
426
+ url=f"{url}/api/copy",
427
+ headers={"Content-Type": "application/json"},
428
+ data=form_data.model_dump_json(exclude_none=True).encode(),
429
+ )
430
+
431
+ try:
432
+ r.raise_for_status()
433
+
434
+ log.debug(f"r.text: {r.text}")
435
+
436
+ return True
437
+ except Exception as e:
438
+ log.exception(e)
439
+ error_detail = "Open WebUI: Server Connection Error"
440
+ if r is not None:
441
+ try:
442
+ res = r.json()
443
+ if "error" in res:
444
+ error_detail = f"Ollama: {res['error']}"
445
+ except Exception:
446
+ error_detail = f"Ollama: {e}"
447
+
448
+ raise HTTPException(
449
+ status_code=r.status_code if r else 500,
450
+ detail=error_detail,
451
+ )
452
+
453
+
454
+ @app.delete("/api/delete")
455
+ @app.delete("/api/delete/{url_idx}")
456
+ async def delete_model(
457
+ form_data: ModelNameForm,
458
+ url_idx: Optional[int] = None,
459
+ user=Depends(get_admin_user),
460
+ ):
461
+ if url_idx is None:
462
+ if form_data.name in app.state.MODELS:
463
+ url_idx = app.state.MODELS[form_data.name]["urls"][0]
464
+ else:
465
+ raise HTTPException(
466
+ status_code=400,
467
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
468
+ )
469
+
470
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
471
+ log.info(f"url: {url}")
472
+
473
+ r = requests.request(
474
+ method="DELETE",
475
+ url=f"{url}/api/delete",
476
+ headers={"Content-Type": "application/json"},
477
+ data=form_data.model_dump_json(exclude_none=True).encode(),
478
+ )
479
+ try:
480
+ r.raise_for_status()
481
+
482
+ log.debug(f"r.text: {r.text}")
483
+
484
+ return True
485
+ except Exception as e:
486
+ log.exception(e)
487
+ error_detail = "Open WebUI: Server Connection Error"
488
+ if r is not None:
489
+ try:
490
+ res = r.json()
491
+ if "error" in res:
492
+ error_detail = f"Ollama: {res['error']}"
493
+ except Exception:
494
+ error_detail = f"Ollama: {e}"
495
+
496
+ raise HTTPException(
497
+ status_code=r.status_code if r else 500,
498
+ detail=error_detail,
499
+ )
500
+
501
+
502
+ @app.post("/api/show")
503
+ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
504
+ if form_data.name not in app.state.MODELS:
505
+ raise HTTPException(
506
+ status_code=400,
507
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
508
+ )
509
+
510
+ url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
511
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
512
+ log.info(f"url: {url}")
513
+
514
+ r = requests.request(
515
+ method="POST",
516
+ url=f"{url}/api/show",
517
+ headers={"Content-Type": "application/json"},
518
+ data=form_data.model_dump_json(exclude_none=True).encode(),
519
+ )
520
+ try:
521
+ r.raise_for_status()
522
+
523
+ return r.json()
524
+ except Exception as e:
525
+ log.exception(e)
526
+ error_detail = "Open WebUI: Server Connection Error"
527
+ if r is not None:
528
+ try:
529
+ res = r.json()
530
+ if "error" in res:
531
+ error_detail = f"Ollama: {res['error']}"
532
+ except Exception:
533
+ error_detail = f"Ollama: {e}"
534
+
535
+ raise HTTPException(
536
+ status_code=r.status_code if r else 500,
537
+ detail=error_detail,
538
+ )
539
+
540
+
541
+ class GenerateEmbeddingsForm(BaseModel):
542
+ model: str
543
+ prompt: str
544
+ options: Optional[dict] = None
545
+ keep_alive: Optional[Union[int, str]] = None
546
+
547
+
548
+ class GenerateEmbedForm(BaseModel):
549
+ model: str
550
+ input: list[str] | str
551
+ truncate: Optional[bool] = None
552
+ options: Optional[dict] = None
553
+ keep_alive: Optional[Union[int, str]] = None
554
+
555
+
556
+ @app.post("/api/embed")
557
+ @app.post("/api/embed/{url_idx}")
558
+ async def generate_embeddings(
559
+ form_data: GenerateEmbedForm,
560
+ url_idx: Optional[int] = None,
561
+ user=Depends(get_verified_user),
562
+ ):
563
+ return generate_ollama_batch_embeddings(form_data, url_idx)
564
+
565
+
566
+ @app.post("/api/embeddings")
567
+ @app.post("/api/embeddings/{url_idx}")
568
+ async def generate_embeddings(
569
+ form_data: GenerateEmbeddingsForm,
570
+ url_idx: Optional[int] = None,
571
+ user=Depends(get_verified_user),
572
+ ):
573
+ return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
574
+
575
+
576
+ def generate_ollama_embeddings(
577
+ form_data: GenerateEmbeddingsForm,
578
+ url_idx: Optional[int] = None,
579
+ ):
580
+ log.info(f"generate_ollama_embeddings {form_data}")
581
+
582
+ if url_idx is None:
583
+ model = form_data.model
584
+
585
+ if ":" not in model:
586
+ model = f"{model}:latest"
587
+
588
+ if model in app.state.MODELS:
589
+ url_idx = random.choice(app.state.MODELS[model]["urls"])
590
+ else:
591
+ raise HTTPException(
592
+ status_code=400,
593
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
594
+ )
595
+
596
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
597
+ log.info(f"url: {url}")
598
+
599
+ r = requests.request(
600
+ method="POST",
601
+ url=f"{url}/api/embeddings",
602
+ headers={"Content-Type": "application/json"},
603
+ data=form_data.model_dump_json(exclude_none=True).encode(),
604
+ )
605
+ try:
606
+ r.raise_for_status()
607
+
608
+ data = r.json()
609
+
610
+ log.info(f"generate_ollama_embeddings {data}")
611
+
612
+ if "embedding" in data:
613
+ return data
614
+ else:
615
+ raise Exception("Something went wrong :/")
616
+ except Exception as e:
617
+ log.exception(e)
618
+ error_detail = "Open WebUI: Server Connection Error"
619
+ if r is not None:
620
+ try:
621
+ res = r.json()
622
+ if "error" in res:
623
+ error_detail = f"Ollama: {res['error']}"
624
+ except Exception:
625
+ error_detail = f"Ollama: {e}"
626
+
627
+ raise HTTPException(
628
+ status_code=r.status_code if r else 500,
629
+ detail=error_detail,
630
+ )
631
+
632
+
633
+ def generate_ollama_batch_embeddings(
634
+ form_data: GenerateEmbedForm,
635
+ url_idx: Optional[int] = None,
636
+ ):
637
+ log.info(f"generate_ollama_batch_embeddings {form_data}")
638
+
639
+ if url_idx is None:
640
+ model = form_data.model
641
+
642
+ if ":" not in model:
643
+ model = f"{model}:latest"
644
+
645
+ if model in app.state.MODELS:
646
+ url_idx = random.choice(app.state.MODELS[model]["urls"])
647
+ else:
648
+ raise HTTPException(
649
+ status_code=400,
650
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
651
+ )
652
+
653
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
654
+ log.info(f"url: {url}")
655
+
656
+ r = requests.request(
657
+ method="POST",
658
+ url=f"{url}/api/embed",
659
+ headers={"Content-Type": "application/json"},
660
+ data=form_data.model_dump_json(exclude_none=True).encode(),
661
+ )
662
+ try:
663
+ r.raise_for_status()
664
+
665
+ data = r.json()
666
+
667
+ log.info(f"generate_ollama_batch_embeddings {data}")
668
+
669
+ if "embeddings" in data:
670
+ return data
671
+ else:
672
+ raise Exception("Something went wrong :/")
673
+ except Exception as e:
674
+ log.exception(e)
675
+ error_detail = "Open WebUI: Server Connection Error"
676
+ if r is not None:
677
+ try:
678
+ res = r.json()
679
+ if "error" in res:
680
+ error_detail = f"Ollama: {res['error']}"
681
+ except Exception:
682
+ error_detail = f"Ollama: {e}"
683
+
684
+ raise Exception(error_detail)
685
+
686
+
687
+ class GenerateCompletionForm(BaseModel):
688
+ model: str
689
+ prompt: str
690
+ images: Optional[list[str]] = None
691
+ format: Optional[str] = None
692
+ options: Optional[dict] = None
693
+ system: Optional[str] = None
694
+ template: Optional[str] = None
695
+ context: Optional[str] = None
696
+ stream: Optional[bool] = True
697
+ raw: Optional[bool] = None
698
+ keep_alive: Optional[Union[int, str]] = None
699
+
700
+
701
+ @app.post("/api/generate")
702
+ @app.post("/api/generate/{url_idx}")
703
+ async def generate_completion(
704
+ form_data: GenerateCompletionForm,
705
+ url_idx: Optional[int] = None,
706
+ user=Depends(get_verified_user),
707
+ ):
708
+ if url_idx is None:
709
+ model = form_data.model
710
+
711
+ if ":" not in model:
712
+ model = f"{model}:latest"
713
+
714
+ if model in app.state.MODELS:
715
+ url_idx = random.choice(app.state.MODELS[model]["urls"])
716
+ else:
717
+ raise HTTPException(
718
+ status_code=400,
719
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
720
+ )
721
+
722
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
723
+ log.info(f"url: {url}")
724
+
725
+ return await post_streaming_url(
726
+ f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
727
+ )
728
+
729
+
730
+ class ChatMessage(BaseModel):
731
+ role: str
732
+ content: str
733
+ images: Optional[list[str]] = None
734
+
735
+
736
+ class GenerateChatCompletionForm(BaseModel):
737
+ model: str
738
+ messages: list[ChatMessage]
739
+ format: Optional[str] = None
740
+ options: Optional[dict] = None
741
+ template: Optional[str] = None
742
+ stream: Optional[bool] = None
743
+ keep_alive: Optional[Union[int, str]] = None
744
+
745
+
746
+ def get_ollama_url(url_idx: Optional[int], model: str):
747
+ if url_idx is None:
748
+ if model not in app.state.MODELS:
749
+ raise HTTPException(
750
+ status_code=400,
751
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
752
+ )
753
+ url_idx = random.choice(app.state.MODELS[model]["urls"])
754
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
755
+ return url
756
+
757
+
758
+ @app.post("/api/chat")
759
+ @app.post("/api/chat/{url_idx}")
760
+ async def generate_chat_completion(
761
+ form_data: GenerateChatCompletionForm,
762
+ url_idx: Optional[int] = None,
763
+ user=Depends(get_verified_user),
764
+ bypass_filter: Optional[bool] = False,
765
+ ):
766
+ payload = {**form_data.model_dump(exclude_none=True)}
767
+ log.debug(f"generate_chat_completion() - 1.payload = {payload}")
768
+ if "metadata" in payload:
769
+ del payload["metadata"]
770
+
771
+ model_id = form_data.model
772
+
773
+ if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
774
+ if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
775
+ raise HTTPException(
776
+ status_code=403,
777
+ detail="Model not found",
778
+ )
779
+
780
+ model_info = Models.get_model_by_id(model_id)
781
+
782
+ if model_info:
783
+ if model_info.base_model_id:
784
+ payload["model"] = model_info.base_model_id
785
+
786
+ params = model_info.params.model_dump()
787
+
788
+ if params:
789
+ if payload.get("options") is None:
790
+ payload["options"] = {}
791
+
792
+ payload["options"] = apply_model_params_to_body_ollama(
793
+ params, payload["options"]
794
+ )
795
+ payload = apply_model_system_prompt_to_body(params, payload, user)
796
+
797
+ if ":" not in payload["model"]:
798
+ payload["model"] = f"{payload['model']}:latest"
799
+
800
+ url = get_ollama_url(url_idx, payload["model"])
801
+ log.info(f"url: {url}")
802
+ log.debug(f"generate_chat_completion() - 2.payload = {payload}")
803
+
804
+ return await post_streaming_url(
805
+ f"{url}/api/chat",
806
+ json.dumps(payload),
807
+ stream=form_data.stream,
808
+ content_type="application/x-ndjson",
809
+ )
810
+
811
+
812
+ # TODO: we should update this part once Ollama supports other types
813
+ class OpenAIChatMessageContent(BaseModel):
814
+ type: str
815
+ model_config = ConfigDict(extra="allow")
816
+
817
+
818
+ class OpenAIChatMessage(BaseModel):
819
+ role: str
820
+ content: Union[str, OpenAIChatMessageContent]
821
+
822
+ model_config = ConfigDict(extra="allow")
823
+
824
+
825
+ class OpenAIChatCompletionForm(BaseModel):
826
+ model: str
827
+ messages: list[OpenAIChatMessage]
828
+
829
+ model_config = ConfigDict(extra="allow")
830
+
831
+
832
+ @app.post("/v1/chat/completions")
833
+ @app.post("/v1/chat/completions/{url_idx}")
834
+ async def generate_openai_chat_completion(
835
+ form_data: dict,
836
+ url_idx: Optional[int] = None,
837
+ user=Depends(get_verified_user),
838
+ ):
839
+ completion_form = OpenAIChatCompletionForm(**form_data)
840
+ payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
841
+ if "metadata" in payload:
842
+ del payload["metadata"]
843
+
844
+ model_id = completion_form.model
845
+
846
+ if app.state.config.ENABLE_MODEL_FILTER:
847
+ if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
848
+ raise HTTPException(
849
+ status_code=403,
850
+ detail="Model not found",
851
+ )
852
+
853
+ model_info = Models.get_model_by_id(model_id)
854
+
855
+ if model_info:
856
+ if model_info.base_model_id:
857
+ payload["model"] = model_info.base_model_id
858
+
859
+ params = model_info.params.model_dump()
860
+
861
+ if params:
862
+ payload = apply_model_params_to_body_openai(params, payload)
863
+ payload = apply_model_system_prompt_to_body(params, payload, user)
864
+
865
+ if ":" not in payload["model"]:
866
+ payload["model"] = f"{payload['model']}:latest"
867
+
868
+ url = get_ollama_url(url_idx, payload["model"])
869
+ log.info(f"url: {url}")
870
+
871
+ return await post_streaming_url(
872
+ f"{url}/v1/chat/completions",
873
+ json.dumps(payload),
874
+ stream=payload.get("stream", False),
875
+ )
876
+
877
+
878
+ @app.get("/v1/models")
879
+ @app.get("/v1/models/{url_idx}")
880
+ async def get_openai_models(
881
+ url_idx: Optional[int] = None,
882
+ user=Depends(get_verified_user),
883
+ ):
884
+ if url_idx is None:
885
+ models = await get_all_models()
886
+
887
+ if app.state.config.ENABLE_MODEL_FILTER:
888
+ if user.role == "user":
889
+ models["models"] = list(
890
+ filter(
891
+ lambda model: model["name"]
892
+ in app.state.config.MODEL_FILTER_LIST,
893
+ models["models"],
894
+ )
895
+ )
896
+
897
+ return {
898
+ "data": [
899
+ {
900
+ "id": model["model"],
901
+ "object": "model",
902
+ "created": int(time.time()),
903
+ "owned_by": "openai",
904
+ }
905
+ for model in models["models"]
906
+ ],
907
+ "object": "list",
908
+ }
909
+
910
+ else:
911
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
912
+ try:
913
+ r = requests.request(method="GET", url=f"{url}/api/tags")
914
+ r.raise_for_status()
915
+
916
+ models = r.json()
917
+
918
+ return {
919
+ "data": [
920
+ {
921
+ "id": model["model"],
922
+ "object": "model",
923
+ "created": int(time.time()),
924
+ "owned_by": "openai",
925
+ }
926
+ for model in models["models"]
927
+ ],
928
+ "object": "list",
929
+ }
930
+
931
+ except Exception as e:
932
+ log.exception(e)
933
+ error_detail = "Open WebUI: Server Connection Error"
934
+ if r is not None:
935
+ try:
936
+ res = r.json()
937
+ if "error" in res:
938
+ error_detail = f"Ollama: {res['error']}"
939
+ except Exception:
940
+ error_detail = f"Ollama: {e}"
941
+
942
+ raise HTTPException(
943
+ status_code=r.status_code if r else 500,
944
+ detail=error_detail,
945
+ )
946
+
947
+
948
+ class UrlForm(BaseModel):
949
+ url: str
950
+
951
+
952
+ class UploadBlobForm(BaseModel):
953
+ filename: str
954
+
955
+
956
+ def parse_huggingface_url(hf_url):
957
+ try:
958
+ # Parse the URL
959
+ parsed_url = urlparse(hf_url)
960
+
961
+ # Get the path and split it into components
962
+ path_components = parsed_url.path.split("/")
963
+
964
+ # Extract the desired output
965
+ model_file = path_components[-1]
966
+
967
+ return model_file
968
+ except ValueError:
969
+ return None
970
+
971
+
972
+ async def download_file_stream(
973
+ ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
974
+ ):
975
+ done = False
976
+
977
+ if os.path.exists(file_path):
978
+ current_size = os.path.getsize(file_path)
979
+ else:
980
+ current_size = 0
981
+
982
+ headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
983
+
984
+ timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
985
+
986
+ async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
987
+ async with session.get(file_url, headers=headers) as response:
988
+ total_size = int(response.headers.get("content-length", 0)) + current_size
989
+
990
+ with open(file_path, "ab+") as file:
991
+ async for data in response.content.iter_chunked(chunk_size):
992
+ current_size += len(data)
993
+ file.write(data)
994
+
995
+ done = current_size == total_size
996
+ progress = round((current_size / total_size) * 100, 2)
997
+
998
+ yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
999
+
1000
+ if done:
1001
+ file.seek(0)
1002
+ hashed = calculate_sha256(file)
1003
+ file.seek(0)
1004
+
1005
+ url = f"{ollama_url}/api/blobs/sha256:{hashed}"
1006
+ response = requests.post(url, data=file)
1007
+
1008
+ if response.ok:
1009
+ res = {
1010
+ "done": done,
1011
+ "blob": f"sha256:{hashed}",
1012
+ "name": file_name,
1013
+ }
1014
+ os.remove(file_path)
1015
+
1016
+ yield f"data: {json.dumps(res)}\n\n"
1017
+ else:
1018
+ raise "Ollama: Could not create blob, Please try again."
1019
+
1020
+
1021
+ # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
1022
+ @app.post("/models/download")
1023
+ @app.post("/models/download/{url_idx}")
1024
+ async def download_model(
1025
+ form_data: UrlForm,
1026
+ url_idx: Optional[int] = None,
1027
+ user=Depends(get_admin_user),
1028
+ ):
1029
+ allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
1030
+
1031
+ if not any(form_data.url.startswith(host) for host in allowed_hosts):
1032
+ raise HTTPException(
1033
+ status_code=400,
1034
+ detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
1035
+ )
1036
+
1037
+ if url_idx is None:
1038
+ url_idx = 0
1039
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
1040
+
1041
+ file_name = parse_huggingface_url(form_data.url)
1042
+
1043
+ if file_name:
1044
+ file_path = f"{UPLOAD_DIR}/{file_name}"
1045
+
1046
+ return StreamingResponse(
1047
+ download_file_stream(url, form_data.url, file_path, file_name),
1048
+ )
1049
+ else:
1050
+ return None
1051
+
1052
+
1053
+ @app.post("/models/upload")
1054
+ @app.post("/models/upload/{url_idx}")
1055
+ def upload_model(
1056
+ file: UploadFile = File(...),
1057
+ url_idx: Optional[int] = None,
1058
+ user=Depends(get_admin_user),
1059
+ ):
1060
+ if url_idx is None:
1061
+ url_idx = 0
1062
+ ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
1063
+
1064
+ file_path = f"{UPLOAD_DIR}/{file.filename}"
1065
+
1066
+ # Save file in chunks
1067
+ with open(file_path, "wb+") as f:
1068
+ for chunk in file.file:
1069
+ f.write(chunk)
1070
+
1071
+ def file_process_stream():
1072
+ nonlocal ollama_url
1073
+ total_size = os.path.getsize(file_path)
1074
+ chunk_size = 1024 * 1024
1075
+ try:
1076
+ with open(file_path, "rb") as f:
1077
+ total = 0
1078
+ done = False
1079
+
1080
+ while not done:
1081
+ chunk = f.read(chunk_size)
1082
+ if not chunk:
1083
+ done = True
1084
+ continue
1085
+
1086
+ total += len(chunk)
1087
+ progress = round((total / total_size) * 100, 2)
1088
+
1089
+ res = {
1090
+ "progress": progress,
1091
+ "total": total_size,
1092
+ "completed": total,
1093
+ }
1094
+ yield f"data: {json.dumps(res)}\n\n"
1095
+
1096
+ if done:
1097
+ f.seek(0)
1098
+ hashed = calculate_sha256(f)
1099
+ f.seek(0)
1100
+
1101
+ url = f"{ollama_url}/api/blobs/sha256:{hashed}"
1102
+ response = requests.post(url, data=f)
1103
+
1104
+ if response.ok:
1105
+ res = {
1106
+ "done": done,
1107
+ "blob": f"sha256:{hashed}",
1108
+ "name": file.filename,
1109
+ }
1110
+ os.remove(file_path)
1111
+ yield f"data: {json.dumps(res)}\n\n"
1112
+ else:
1113
+ raise Exception(
1114
+ "Ollama: Could not create blob, Please try again."
1115
+ )
1116
+
1117
+ except Exception as e:
1118
+ res = {"error": str(e)}
1119
+ yield f"data: {json.dumps(res)}\n\n"
1120
+
1121
+ return StreamingResponse(file_process_stream(), media_type="text/event-stream")
backend/open_webui/apps/openai/main.py CHANGED
@@ -1,554 +1,554 @@
1
- import asyncio
2
- import hashlib
3
- import json
4
- import logging
5
- from pathlib import Path
6
- from typing import Literal, Optional, overload
7
-
8
- import aiohttp
9
- import requests
10
- from open_webui.apps.webui.models.models import Models
11
- from open_webui.config import (
12
- CACHE_DIR,
13
- CORS_ALLOW_ORIGIN,
14
- ENABLE_MODEL_FILTER,
15
- ENABLE_OPENAI_API,
16
- MODEL_FILTER_LIST,
17
- OPENAI_API_BASE_URLS,
18
- OPENAI_API_KEYS,
19
- AppConfig,
20
- )
21
- from open_webui.env import (
22
- AIOHTTP_CLIENT_TIMEOUT,
23
- AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
24
- )
25
-
26
- from open_webui.constants import ERROR_MESSAGES
27
- from open_webui.env import SRC_LOG_LEVELS
28
- from fastapi import Depends, FastAPI, HTTPException, Request
29
- from fastapi.middleware.cors import CORSMiddleware
30
- from fastapi.responses import FileResponse, StreamingResponse
31
- from pydantic import BaseModel
32
- from starlette.background import BackgroundTask
33
-
34
- from open_webui.utils.payload import (
35
- apply_model_params_to_body_openai,
36
- apply_model_system_prompt_to_body,
37
- )
38
-
39
- from open_webui.utils.utils import get_admin_user, get_verified_user
40
-
41
- log = logging.getLogger(__name__)
42
- log.setLevel(SRC_LOG_LEVELS["OPENAI"])
43
-
44
- app = FastAPI()
45
- app.add_middleware(
46
- CORSMiddleware,
47
- allow_origins=CORS_ALLOW_ORIGIN,
48
- allow_credentials=True,
49
- allow_methods=["*"],
50
- allow_headers=["*"],
51
- )
52
-
53
- app.state.config = AppConfig()
54
-
55
- app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
56
- app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
57
-
58
- app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
59
- app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
60
- app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
61
-
62
- app.state.MODELS = {}
63
-
64
-
65
- @app.middleware("http")
66
- async def check_url(request: Request, call_next):
67
- if len(app.state.MODELS) == 0:
68
- await get_all_models()
69
-
70
- response = await call_next(request)
71
- return response
72
-
73
-
74
- @app.get("/config")
75
- async def get_config(user=Depends(get_admin_user)):
76
- return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
77
-
78
-
79
- class OpenAIConfigForm(BaseModel):
80
- enable_openai_api: Optional[bool] = None
81
-
82
-
83
- @app.post("/config/update")
84
- async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
85
- app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
86
- return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
87
-
88
-
89
- class UrlsUpdateForm(BaseModel):
90
- urls: list[str]
91
-
92
-
93
- class KeysUpdateForm(BaseModel):
94
- keys: list[str]
95
-
96
-
97
- @app.get("/urls")
98
- async def get_openai_urls(user=Depends(get_admin_user)):
99
- return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
100
-
101
-
102
- @app.post("/urls/update")
103
- async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
104
- await get_all_models()
105
- app.state.config.OPENAI_API_BASE_URLS = form_data.urls
106
- return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
107
-
108
-
109
- @app.get("/keys")
110
- async def get_openai_keys(user=Depends(get_admin_user)):
111
- return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
112
-
113
-
114
- @app.post("/keys/update")
115
- async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
116
- app.state.config.OPENAI_API_KEYS = form_data.keys
117
- return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
118
-
119
-
120
- @app.post("/audio/speech")
121
- async def speech(request: Request, user=Depends(get_verified_user)):
122
- idx = None
123
- try:
124
- idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
125
- body = await request.body()
126
- name = hashlib.sha256(body).hexdigest()
127
-
128
- SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
129
- SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
130
- file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
131
- file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
132
-
133
- # Check if the file already exists in the cache
134
- if file_path.is_file():
135
- return FileResponse(file_path)
136
-
137
- headers = {}
138
- headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
139
- headers["Content-Type"] = "application/json"
140
- if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
141
- headers["HTTP-Referer"] = "https://openwebui.com/"
142
- headers["X-Title"] = "Open WebUI"
143
- r = None
144
- try:
145
- r = requests.post(
146
- url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
147
- data=body,
148
- headers=headers,
149
- stream=True,
150
- )
151
-
152
- r.raise_for_status()
153
-
154
- # Save the streaming content to a file
155
- with open(file_path, "wb") as f:
156
- for chunk in r.iter_content(chunk_size=8192):
157
- f.write(chunk)
158
-
159
- with open(file_body_path, "w") as f:
160
- json.dump(json.loads(body.decode("utf-8")), f)
161
-
162
- # Return the saved file
163
- return FileResponse(file_path)
164
-
165
- except Exception as e:
166
- log.exception(e)
167
- error_detail = "Open WebUI: Server Connection Error"
168
- if r is not None:
169
- try:
170
- res = r.json()
171
- if "error" in res:
172
- error_detail = f"External: {res['error']}"
173
- except Exception:
174
- error_detail = f"External: {e}"
175
-
176
- raise HTTPException(
177
- status_code=r.status_code if r else 500, detail=error_detail
178
- )
179
-
180
- except ValueError:
181
- raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
182
-
183
-
184
- async def fetch_url(url, key):
185
- timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
186
- try:
187
- headers = {"Authorization": f"Bearer {key}"}
188
- async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
189
- async with session.get(url, headers=headers) as response:
190
- return await response.json()
191
- except Exception as e:
192
- # Handle connection error here
193
- log.error(f"Connection error: {e}")
194
- return None
195
-
196
-
197
- async def cleanup_response(
198
- response: Optional[aiohttp.ClientResponse],
199
- session: Optional[aiohttp.ClientSession],
200
- ):
201
- if response:
202
- response.close()
203
- if session:
204
- await session.close()
205
-
206
-
207
- def merge_models_lists(model_lists):
208
- log.debug(f"merge_models_lists {model_lists}")
209
- merged_list = []
210
-
211
- for idx, models in enumerate(model_lists):
212
- if models is not None and "error" not in models:
213
- merged_list.extend(
214
- [
215
- {
216
- **model,
217
- "name": model.get("name", model["id"]),
218
- "owned_by": "openai",
219
- "openai": model,
220
- "urlIdx": idx,
221
- }
222
- for model in models
223
- if "api.openai.com"
224
- not in app.state.config.OPENAI_API_BASE_URLS[idx]
225
- or not any(
226
- name in model["id"]
227
- for name in [
228
- "babbage",
229
- "dall-e",
230
- "davinci",
231
- "embedding",
232
- "tts",
233
- "whisper",
234
- ]
235
- )
236
- ]
237
- )
238
-
239
- return merged_list
240
-
241
-
242
- def is_openai_api_disabled():
243
- return not app.state.config.ENABLE_OPENAI_API
244
-
245
-
246
- async def get_all_models_raw() -> list:
247
- if is_openai_api_disabled():
248
- return []
249
-
250
- # Check if API KEYS length is same than API URLS length
251
- num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
252
- num_keys = len(app.state.config.OPENAI_API_KEYS)
253
-
254
- if num_keys != num_urls:
255
- # if there are more keys than urls, remove the extra keys
256
- if num_keys > num_urls:
257
- new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
258
- app.state.config.OPENAI_API_KEYS = new_keys
259
- # if there are more urls than keys, add empty keys
260
- else:
261
- app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
262
-
263
- tasks = [
264
- fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
265
- for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
266
- ]
267
-
268
- responses = await asyncio.gather(*tasks)
269
- log.debug(f"get_all_models:responses() {responses}")
270
-
271
- return responses
272
-
273
-
274
- @overload
275
- async def get_all_models(raw: Literal[True]) -> list: ...
276
-
277
-
278
- @overload
279
- async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
280
-
281
-
282
- async def get_all_models(raw=False) -> dict[str, list] | list:
283
- log.info("get_all_models()")
284
- if is_openai_api_disabled():
285
- return [] if raw else {"data": []}
286
-
287
- responses = await get_all_models_raw()
288
- if raw:
289
- return responses
290
-
291
- def extract_data(response):
292
- if response and "data" in response:
293
- return response["data"]
294
- if isinstance(response, list):
295
- return response
296
- return None
297
-
298
- models = {"data": merge_models_lists(map(extract_data, responses))}
299
-
300
- log.debug(f"models: {models}")
301
- app.state.MODELS = {model["id"]: model for model in models["data"]}
302
-
303
- return models
304
-
305
-
306
- @app.get("/models")
307
- @app.get("/models/{url_idx}")
308
- async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
309
- if url_idx is None:
310
- models = await get_all_models()
311
- if app.state.config.ENABLE_MODEL_FILTER:
312
- if user.role == "user":
313
- models["data"] = list(
314
- filter(
315
- lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
316
- models["data"],
317
- )
318
- )
319
- return models
320
- return models
321
- else:
322
- url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
323
- key = app.state.config.OPENAI_API_KEYS[url_idx]
324
-
325
- headers = {}
326
- headers["Authorization"] = f"Bearer {key}"
327
- headers["Content-Type"] = "application/json"
328
-
329
- r = None
330
-
331
- try:
332
- r = requests.request(method="GET", url=f"{url}/models", headers=headers)
333
- r.raise_for_status()
334
-
335
- response_data = r.json()
336
-
337
- if "api.openai.com" in url:
338
- # Filter the response data
339
- response_data["data"] = [
340
- model
341
- for model in response_data["data"]
342
- if not any(
343
- name in model["id"]
344
- for name in [
345
- "babbage",
346
- "dall-e",
347
- "davinci",
348
- "embedding",
349
- "tts",
350
- "whisper",
351
- ]
352
- )
353
- ]
354
-
355
- return response_data
356
- except Exception as e:
357
- log.exception(e)
358
- error_detail = "Open WebUI: Server Connection Error"
359
- if r is not None:
360
- try:
361
- res = r.json()
362
- if "error" in res:
363
- error_detail = f"External: {res['error']}"
364
- except Exception:
365
- error_detail = f"External: {e}"
366
-
367
- raise HTTPException(
368
- status_code=r.status_code if r else 500,
369
- detail=error_detail,
370
- )
371
-
372
-
373
- @app.post("/chat/completions")
374
- @app.post("/chat/completions/{url_idx}")
375
- async def generate_chat_completion(
376
- form_data: dict,
377
- url_idx: Optional[int] = None,
378
- user=Depends(get_verified_user),
379
- ):
380
- idx = 0
381
- payload = {**form_data}
382
-
383
- if "metadata" in payload:
384
- del payload["metadata"]
385
-
386
- model_id = form_data.get("model")
387
- model_info = Models.get_model_by_id(model_id)
388
-
389
- if model_info:
390
- if model_info.base_model_id:
391
- payload["model"] = model_info.base_model_id
392
-
393
- params = model_info.params.model_dump()
394
- payload = apply_model_params_to_body_openai(params, payload)
395
- payload = apply_model_system_prompt_to_body(params, payload, user)
396
-
397
- model = app.state.MODELS[payload.get("model")]
398
- idx = model["urlIdx"]
399
-
400
- if "pipeline" in model and model.get("pipeline"):
401
- payload["user"] = {
402
- "name": user.name,
403
- "id": user.id,
404
- "email": user.email,
405
- "role": user.role,
406
- }
407
-
408
- url = app.state.config.OPENAI_API_BASE_URLS[idx]
409
- key = app.state.config.OPENAI_API_KEYS[idx]
410
- is_o1 = payload["model"].lower().startswith("o1-")
411
-
412
- # Change max_completion_tokens to max_tokens (Backward compatible)
413
- if "api.openai.com" not in url and not is_o1:
414
- if "max_completion_tokens" in payload:
415
- # Remove "max_completion_tokens" from the payload
416
- payload["max_tokens"] = payload["max_completion_tokens"]
417
- del payload["max_completion_tokens"]
418
- else:
419
- if is_o1 and "max_tokens" in payload:
420
- payload["max_completion_tokens"] = payload["max_tokens"]
421
- del payload["max_tokens"]
422
- if "max_tokens" in payload and "max_completion_tokens" in payload:
423
- del payload["max_tokens"]
424
-
425
- # Fix: O1 does not support the "system" parameter, Modify "system" to "user"
426
- if is_o1 and payload["messages"][0]["role"] == "system":
427
- payload["messages"][0]["role"] = "user"
428
-
429
- # Convert the modified body back to JSON
430
- payload = json.dumps(payload)
431
-
432
- log.debug(payload)
433
-
434
- headers = {}
435
- headers["Authorization"] = f"Bearer {key}"
436
- headers["Content-Type"] = "application/json"
437
- if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
438
- headers["HTTP-Referer"] = "https://openwebui.com/"
439
- headers["X-Title"] = "Open WebUI"
440
-
441
- r = None
442
- session = None
443
- streaming = False
444
- response = None
445
-
446
- try:
447
- session = aiohttp.ClientSession(
448
- trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
449
- )
450
- r = await session.request(
451
- method="POST",
452
- url=f"{url}/chat/completions",
453
- data=payload,
454
- headers=headers,
455
- )
456
-
457
- # Check if response is SSE
458
- if "text/event-stream" in r.headers.get("Content-Type", ""):
459
- streaming = True
460
- return StreamingResponse(
461
- r.content,
462
- status_code=r.status,
463
- headers=dict(r.headers),
464
- background=BackgroundTask(
465
- cleanup_response, response=r, session=session
466
- ),
467
- )
468
- else:
469
- try:
470
- response = await r.json()
471
- except Exception as e:
472
- log.error(e)
473
- response = await r.text()
474
-
475
- r.raise_for_status()
476
- return response
477
- except Exception as e:
478
- log.exception(e)
479
- error_detail = "Open WebUI: Server Connection Error"
480
- if isinstance(response, dict):
481
- if "error" in response:
482
- error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
483
- elif isinstance(response, str):
484
- error_detail = response
485
-
486
- raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
487
- finally:
488
- if not streaming and session:
489
- if r:
490
- r.close()
491
- await session.close()
492
-
493
-
494
- @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
495
- async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
496
- idx = 0
497
-
498
- body = await request.body()
499
-
500
- url = app.state.config.OPENAI_API_BASE_URLS[idx]
501
- key = app.state.config.OPENAI_API_KEYS[idx]
502
-
503
- target_url = f"{url}/{path}"
504
-
505
- headers = {}
506
- headers["Authorization"] = f"Bearer {key}"
507
- headers["Content-Type"] = "application/json"
508
-
509
- r = None
510
- session = None
511
- streaming = False
512
-
513
- try:
514
- session = aiohttp.ClientSession(trust_env=True)
515
- r = await session.request(
516
- method=request.method,
517
- url=target_url,
518
- data=body,
519
- headers=headers,
520
- )
521
-
522
- r.raise_for_status()
523
-
524
- # Check if response is SSE
525
- if "text/event-stream" in r.headers.get("Content-Type", ""):
526
- streaming = True
527
- return StreamingResponse(
528
- r.content,
529
- status_code=r.status,
530
- headers=dict(r.headers),
531
- background=BackgroundTask(
532
- cleanup_response, response=r, session=session
533
- ),
534
- )
535
- else:
536
- response_data = await r.json()
537
- return response_data
538
- except Exception as e:
539
- log.exception(e)
540
- error_detail = "Open WebUI: Server Connection Error"
541
- if r is not None:
542
- try:
543
- res = await r.json()
544
- print(res)
545
- if "error" in res:
546
- error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
547
- except Exception:
548
- error_detail = f"External: {e}"
549
- raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
550
- finally:
551
- if not streaming and session:
552
- if r:
553
- r.close()
554
- await session.close()
 
1
+ import asyncio
2
+ import hashlib
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Literal, Optional, overload
7
+
8
+ import aiohttp
9
+ import requests
10
+ from open_webui.apps.webui.models.models import Models
11
+ from open_webui.config import (
12
+ CACHE_DIR,
13
+ CORS_ALLOW_ORIGIN,
14
+ ENABLE_MODEL_FILTER,
15
+ ENABLE_OPENAI_API,
16
+ MODEL_FILTER_LIST,
17
+ OPENAI_API_BASE_URLS,
18
+ OPENAI_API_KEYS,
19
+ AppConfig,
20
+ )
21
+ from open_webui.env import (
22
+ AIOHTTP_CLIENT_TIMEOUT,
23
+ AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
24
+ )
25
+
26
+ from open_webui.constants import ERROR_MESSAGES
27
+ from open_webui.env import SRC_LOG_LEVELS
28
+ from fastapi import Depends, FastAPI, HTTPException, Request
29
+ from fastapi.middleware.cors import CORSMiddleware
30
+ from fastapi.responses import FileResponse, StreamingResponse
31
+ from pydantic import BaseModel
32
+ from starlette.background import BackgroundTask
33
+
34
+ from open_webui.utils.payload import (
35
+ apply_model_params_to_body_openai,
36
+ apply_model_system_prompt_to_body,
37
+ )
38
+
39
+ from open_webui.utils.utils import get_admin_user, get_verified_user
40
+
41
+ log = logging.getLogger(__name__)
42
+ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
43
+
44
+ app = FastAPI()
45
+ app.add_middleware(
46
+ CORSMiddleware,
47
+ allow_origins=CORS_ALLOW_ORIGIN,
48
+ allow_credentials=True,
49
+ allow_methods=["*"],
50
+ allow_headers=["*"],
51
+ )
52
+
53
+ app.state.config = AppConfig()
54
+
55
+ app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
56
+ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
57
+
58
+ app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
59
+ app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
60
+ app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
61
+
62
+ app.state.MODELS = {}
63
+
64
+
65
+ @app.middleware("http")
66
+ async def check_url(request: Request, call_next):
67
+ if len(app.state.MODELS) == 0:
68
+ await get_all_models()
69
+
70
+ response = await call_next(request)
71
+ return response
72
+
73
+
74
+ @app.get("/config")
75
+ async def get_config(user=Depends(get_admin_user)):
76
+ return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
77
+
78
+
79
+ class OpenAIConfigForm(BaseModel):
80
+ enable_openai_api: Optional[bool] = None
81
+
82
+
83
+ @app.post("/config/update")
84
+ async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
85
+ app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
86
+ return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
87
+
88
+
89
+ class UrlsUpdateForm(BaseModel):
90
+ urls: list[str]
91
+
92
+
93
+ class KeysUpdateForm(BaseModel):
94
+ keys: list[str]
95
+
96
+
97
+ @app.get("/urls")
98
+ async def get_openai_urls(user=Depends(get_admin_user)):
99
+ return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
100
+
101
+
102
+ @app.post("/urls/update")
103
+ async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
104
+ await get_all_models()
105
+ app.state.config.OPENAI_API_BASE_URLS = form_data.urls
106
+ return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
107
+
108
+
109
+ @app.get("/keys")
110
+ async def get_openai_keys(user=Depends(get_admin_user)):
111
+ return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
112
+
113
+
114
+ @app.post("/keys/update")
115
+ async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
116
+ app.state.config.OPENAI_API_KEYS = form_data.keys
117
+ return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
118
+
119
+
120
+ @app.post("/audio/speech")
121
+ async def speech(request: Request, user=Depends(get_verified_user)):
122
+ idx = None
123
+ try:
124
+ idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
125
+ body = await request.body()
126
+ name = hashlib.sha256(body).hexdigest()
127
+
128
+ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
129
+ SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
130
+ file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
131
+ file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
132
+
133
+ # Check if the file already exists in the cache
134
+ if file_path.is_file():
135
+ return FileResponse(file_path)
136
+
137
+ headers = {}
138
+ headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
139
+ headers["Content-Type"] = "application/json"
140
+ if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
141
+ headers["HTTP-Referer"] = "https://openwebui.com/"
142
+ headers["X-Title"] = "Open WebUI"
143
+ r = None
144
+ try:
145
+ r = requests.post(
146
+ url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
147
+ data=body,
148
+ headers=headers,
149
+ stream=True,
150
+ )
151
+
152
+ r.raise_for_status()
153
+
154
+ # Save the streaming content to a file
155
+ with open(file_path, "wb") as f:
156
+ for chunk in r.iter_content(chunk_size=8192):
157
+ f.write(chunk)
158
+
159
+ with open(file_body_path, "w") as f:
160
+ json.dump(json.loads(body.decode("utf-8")), f)
161
+
162
+ # Return the saved file
163
+ return FileResponse(file_path)
164
+
165
+ except Exception as e:
166
+ log.exception(e)
167
+ error_detail = "Open WebUI: Server Connection Error"
168
+ if r is not None:
169
+ try:
170
+ res = r.json()
171
+ if "error" in res:
172
+ error_detail = f"External: {res['error']}"
173
+ except Exception:
174
+ error_detail = f"External: {e}"
175
+
176
+ raise HTTPException(
177
+ status_code=r.status_code if r else 500, detail=error_detail
178
+ )
179
+
180
+ except ValueError:
181
+ raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
182
+
183
+
184
+ async def fetch_url(url, key):
185
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
186
+ try:
187
+ headers = {"Authorization": f"Bearer {key}"}
188
+ async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
189
+ async with session.get(url, headers=headers) as response:
190
+ return await response.json()
191
+ except Exception as e:
192
+ # Handle connection error here
193
+ log.error(f"Connection error: {e}")
194
+ return None
195
+
196
+
197
+ async def cleanup_response(
198
+ response: Optional[aiohttp.ClientResponse],
199
+ session: Optional[aiohttp.ClientSession],
200
+ ):
201
+ if response:
202
+ response.close()
203
+ if session:
204
+ await session.close()
205
+
206
+
207
+ def merge_models_lists(model_lists):
208
+ log.debug(f"merge_models_lists {model_lists}")
209
+ merged_list = []
210
+
211
+ for idx, models in enumerate(model_lists):
212
+ if models is not None and "error" not in models:
213
+ merged_list.extend(
214
+ [
215
+ {
216
+ **model,
217
+ "name": model.get("name", model["id"]),
218
+ "owned_by": "openai",
219
+ "openai": model,
220
+ "urlIdx": idx,
221
+ }
222
+ for model in models
223
+ if "api.openai.com"
224
+ not in app.state.config.OPENAI_API_BASE_URLS[idx]
225
+ or not any(
226
+ name in model["id"]
227
+ for name in [
228
+ "babbage",
229
+ "dall-e",
230
+ "davinci",
231
+ "embedding",
232
+ "tts",
233
+ "whisper",
234
+ ]
235
+ )
236
+ ]
237
+ )
238
+
239
+ return merged_list
240
+
241
+
242
+ def is_openai_api_disabled():
243
+ return not app.state.config.ENABLE_OPENAI_API
244
+
245
+
246
+ async def get_all_models_raw() -> list:
247
+ if is_openai_api_disabled():
248
+ return []
249
+
250
+ # Check if API KEYS length is same than API URLS length
251
+ num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
252
+ num_keys = len(app.state.config.OPENAI_API_KEYS)
253
+
254
+ if num_keys != num_urls:
255
+ # if there are more keys than urls, remove the extra keys
256
+ if num_keys > num_urls:
257
+ new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
258
+ app.state.config.OPENAI_API_KEYS = new_keys
259
+ # if there are more urls than keys, add empty keys
260
+ else:
261
+ app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
262
+
263
+ tasks = [
264
+ fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
265
+ for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
266
+ ]
267
+
268
+ responses = await asyncio.gather(*tasks)
269
+ log.debug(f"get_all_models:responses() {responses}")
270
+
271
+ return responses
272
+
273
+
274
+ @overload
275
+ async def get_all_models(raw: Literal[True]) -> list: ...
276
+
277
+
278
+ @overload
279
+ async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
280
+
281
+
282
+ async def get_all_models(raw=False) -> dict[str, list] | list:
283
+ log.info("get_all_models()")
284
+ if is_openai_api_disabled():
285
+ return [] if raw else {"data": []}
286
+
287
+ responses = await get_all_models_raw()
288
+ if raw:
289
+ return responses
290
+
291
+ def extract_data(response):
292
+ if response and "data" in response:
293
+ return response["data"]
294
+ if isinstance(response, list):
295
+ return response
296
+ return None
297
+
298
+ models = {"data": merge_models_lists(map(extract_data, responses))}
299
+
300
+ log.debug(f"models: {models}")
301
+ app.state.MODELS = {model["id"]: model for model in models["data"]}
302
+
303
+ return models
304
+
305
+
306
+ @app.get("/models")
307
+ @app.get("/models/{url_idx}")
308
+ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
309
+ if url_idx is None:
310
+ models = await get_all_models()
311
+ if app.state.config.ENABLE_MODEL_FILTER:
312
+ if user.role == "user":
313
+ models["data"] = list(
314
+ filter(
315
+ lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
316
+ models["data"],
317
+ )
318
+ )
319
+ return models
320
+ return models
321
+ else:
322
+ url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
323
+ key = app.state.config.OPENAI_API_KEYS[url_idx]
324
+
325
+ headers = {}
326
+ headers["Authorization"] = f"Bearer {key}"
327
+ headers["Content-Type"] = "application/json"
328
+
329
+ r = None
330
+
331
+ try:
332
+ r = requests.request(method="GET", url=f"{url}/models", headers=headers)
333
+ r.raise_for_status()
334
+
335
+ response_data = r.json()
336
+
337
+ if "api.openai.com" in url:
338
+ # Filter the response data
339
+ response_data["data"] = [
340
+ model
341
+ for model in response_data["data"]
342
+ if not any(
343
+ name in model["id"]
344
+ for name in [
345
+ "babbage",
346
+ "dall-e",
347
+ "davinci",
348
+ "embedding",
349
+ "tts",
350
+ "whisper",
351
+ ]
352
+ )
353
+ ]
354
+
355
+ return response_data
356
+ except Exception as e:
357
+ log.exception(e)
358
+ error_detail = "Open WebUI: Server Connection Error"
359
+ if r is not None:
360
+ try:
361
+ res = r.json()
362
+ if "error" in res:
363
+ error_detail = f"External: {res['error']}"
364
+ except Exception:
365
+ error_detail = f"External: {e}"
366
+
367
+ raise HTTPException(
368
+ status_code=r.status_code if r else 500,
369
+ detail=error_detail,
370
+ )
371
+
372
+
373
+ @app.post("/chat/completions")
374
+ @app.post("/chat/completions/{url_idx}")
375
+ async def generate_chat_completion(
376
+ form_data: dict,
377
+ url_idx: Optional[int] = None,
378
+ user=Depends(get_verified_user),
379
+ ):
380
+ idx = 0
381
+ payload = {**form_data}
382
+
383
+ if "metadata" in payload:
384
+ del payload["metadata"]
385
+
386
+ model_id = form_data.get("model")
387
+ model_info = Models.get_model_by_id(model_id)
388
+
389
+ if model_info:
390
+ if model_info.base_model_id:
391
+ payload["model"] = model_info.base_model_id
392
+
393
+ params = model_info.params.model_dump()
394
+ payload = apply_model_params_to_body_openai(params, payload)
395
+ payload = apply_model_system_prompt_to_body(params, payload, user)
396
+
397
+ model = app.state.MODELS[payload.get("model")]
398
+ idx = model["urlIdx"]
399
+
400
+ if "pipeline" in model and model.get("pipeline"):
401
+ payload["user"] = {
402
+ "name": user.name,
403
+ "id": user.id,
404
+ "email": user.email,
405
+ "role": user.role,
406
+ }
407
+
408
+ url = app.state.config.OPENAI_API_BASE_URLS[idx]
409
+ key = app.state.config.OPENAI_API_KEYS[idx]
410
+ is_o1 = payload["model"].lower().startswith("o1-")
411
+
412
+ # Change max_completion_tokens to max_tokens (Backward compatible)
413
+ if "api.openai.com" not in url and not is_o1:
414
+ if "max_completion_tokens" in payload:
415
+ # Remove "max_completion_tokens" from the payload
416
+ payload["max_tokens"] = payload["max_completion_tokens"]
417
+ del payload["max_completion_tokens"]
418
+ else:
419
+ if is_o1 and "max_tokens" in payload:
420
+ payload["max_completion_tokens"] = payload["max_tokens"]
421
+ del payload["max_tokens"]
422
+ if "max_tokens" in payload and "max_completion_tokens" in payload:
423
+ del payload["max_tokens"]
424
+
425
+ # Fix: O1 does not support the "system" parameter, Modify "system" to "user"
426
+ if is_o1 and payload["messages"][0]["role"] == "system":
427
+ payload["messages"][0]["role"] = "user"
428
+
429
+ # Convert the modified body back to JSON
430
+ payload = json.dumps(payload)
431
+
432
+ log.debug(payload)
433
+
434
+ headers = {}
435
+ headers["Authorization"] = f"Bearer {key}"
436
+ headers["Content-Type"] = "application/json"
437
+ if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
438
+ headers["HTTP-Referer"] = "https://openwebui.com/"
439
+ headers["X-Title"] = "Open WebUI"
440
+
441
+ r = None
442
+ session = None
443
+ streaming = False
444
+ response = None
445
+
446
+ try:
447
+ session = aiohttp.ClientSession(
448
+ trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
449
+ )
450
+ r = await session.request(
451
+ method="POST",
452
+ url=f"{url}/chat/completions",
453
+ data=payload,
454
+ headers=headers,
455
+ )
456
+
457
+ # Check if response is SSE
458
+ if "text/event-stream" in r.headers.get("Content-Type", ""):
459
+ streaming = True
460
+ return StreamingResponse(
461
+ r.content,
462
+ status_code=r.status,
463
+ headers=dict(r.headers),
464
+ background=BackgroundTask(
465
+ cleanup_response, response=r, session=session
466
+ ),
467
+ )
468
+ else:
469
+ try:
470
+ response = await r.json()
471
+ except Exception as e:
472
+ log.error(e)
473
+ response = await r.text()
474
+
475
+ r.raise_for_status()
476
+ return response
477
+ except Exception as e:
478
+ log.exception(e)
479
+ error_detail = "Open WebUI: Server Connection Error"
480
+ if isinstance(response, dict):
481
+ if "error" in response:
482
+ error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
483
+ elif isinstance(response, str):
484
+ error_detail = response
485
+
486
+ raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
487
+ finally:
488
+ if not streaming and session:
489
+ if r:
490
+ r.close()
491
+ await session.close()
492
+
493
+
494
+ @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
495
+ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
496
+ idx = 0
497
+
498
+ body = await request.body()
499
+
500
+ url = app.state.config.OPENAI_API_BASE_URLS[idx]
501
+ key = app.state.config.OPENAI_API_KEYS[idx]
502
+
503
+ target_url = f"{url}/{path}"
504
+
505
+ headers = {}
506
+ headers["Authorization"] = f"Bearer {key}"
507
+ headers["Content-Type"] = "application/json"
508
+
509
+ r = None
510
+ session = None
511
+ streaming = False
512
+
513
+ try:
514
+ session = aiohttp.ClientSession(trust_env=True)
515
+ r = await session.request(
516
+ method=request.method,
517
+ url=target_url,
518
+ data=body,
519
+ headers=headers,
520
+ )
521
+
522
+ r.raise_for_status()
523
+
524
+ # Check if response is SSE
525
+ if "text/event-stream" in r.headers.get("Content-Type", ""):
526
+ streaming = True
527
+ return StreamingResponse(
528
+ r.content,
529
+ status_code=r.status,
530
+ headers=dict(r.headers),
531
+ background=BackgroundTask(
532
+ cleanup_response, response=r, session=session
533
+ ),
534
+ )
535
+ else:
536
+ response_data = await r.json()
537
+ return response_data
538
+ except Exception as e:
539
+ log.exception(e)
540
+ error_detail = "Open WebUI: Server Connection Error"
541
+ if r is not None:
542
+ try:
543
+ res = await r.json()
544
+ print(res)
545
+ if "error" in res:
546
+ error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
547
+ except Exception:
548
+ error_detail = f"External: {e}"
549
+ raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
550
+ finally:
551
+ if not streaming and session:
552
+ if r:
553
+ r.close()
554
+ await session.close()
backend/open_webui/apps/retrieval/loaders/main.py CHANGED
@@ -1,190 +1,190 @@
1
- import requests
2
- import logging
3
- import ftfy
4
-
5
- from langchain_community.document_loaders import (
6
- BSHTMLLoader,
7
- CSVLoader,
8
- Docx2txtLoader,
9
- OutlookMessageLoader,
10
- PyPDFLoader,
11
- TextLoader,
12
- UnstructuredEPubLoader,
13
- UnstructuredExcelLoader,
14
- UnstructuredMarkdownLoader,
15
- UnstructuredPowerPointLoader,
16
- UnstructuredRSTLoader,
17
- UnstructuredXMLLoader,
18
- YoutubeLoader,
19
- )
20
- from langchain_core.documents import Document
21
- from open_webui.env import SRC_LOG_LEVELS
22
-
23
- log = logging.getLogger(__name__)
24
- log.setLevel(SRC_LOG_LEVELS["RAG"])
25
-
26
- known_source_ext = [
27
- "go",
28
- "py",
29
- "java",
30
- "sh",
31
- "bat",
32
- "ps1",
33
- "cmd",
34
- "js",
35
- "ts",
36
- "css",
37
- "cpp",
38
- "hpp",
39
- "h",
40
- "c",
41
- "cs",
42
- "sql",
43
- "log",
44
- "ini",
45
- "pl",
46
- "pm",
47
- "r",
48
- "dart",
49
- "dockerfile",
50
- "env",
51
- "php",
52
- "hs",
53
- "hsc",
54
- "lua",
55
- "nginxconf",
56
- "conf",
57
- "m",
58
- "mm",
59
- "plsql",
60
- "perl",
61
- "rb",
62
- "rs",
63
- "db2",
64
- "scala",
65
- "bash",
66
- "swift",
67
- "vue",
68
- "svelte",
69
- "msg",
70
- "ex",
71
- "exs",
72
- "erl",
73
- "tsx",
74
- "jsx",
75
- "hs",
76
- "lhs",
77
- ]
78
-
79
-
80
- class TikaLoader:
81
- def __init__(self, url, file_path, mime_type=None):
82
- self.url = url
83
- self.file_path = file_path
84
- self.mime_type = mime_type
85
-
86
- def load(self) -> list[Document]:
87
- with open(self.file_path, "rb") as f:
88
- data = f.read()
89
-
90
- if self.mime_type is not None:
91
- headers = {"Content-Type": self.mime_type}
92
- else:
93
- headers = {}
94
-
95
- endpoint = self.url
96
- if not endpoint.endswith("/"):
97
- endpoint += "/"
98
- endpoint += "tika/text"
99
-
100
- r = requests.put(endpoint, data=data, headers=headers)
101
-
102
- if r.ok:
103
- raw_metadata = r.json()
104
- text = raw_metadata.get("X-TIKA:content", "<No text content found>")
105
-
106
- if "Content-Type" in raw_metadata:
107
- headers["Content-Type"] = raw_metadata["Content-Type"]
108
-
109
- log.info("Tika extracted text: %s", text)
110
-
111
- return [Document(page_content=text, metadata=headers)]
112
- else:
113
- raise Exception(f"Error calling Tika: {r.reason}")
114
-
115
-
116
- class Loader:
117
- def __init__(self, engine: str = "", **kwargs):
118
- self.engine = engine
119
- self.kwargs = kwargs
120
-
121
- def load(
122
- self, filename: str, file_content_type: str, file_path: str
123
- ) -> list[Document]:
124
- loader = self._get_loader(filename, file_content_type, file_path)
125
- docs = loader.load()
126
-
127
- return [
128
- Document(
129
- page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
130
- )
131
- for doc in docs
132
- ]
133
-
134
- def _get_loader(self, filename: str, file_content_type: str, file_path: str):
135
- file_ext = filename.split(".")[-1].lower()
136
-
137
- if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
138
- if file_ext in known_source_ext or (
139
- file_content_type and file_content_type.find("text/") >= 0
140
- ):
141
- loader = TextLoader(file_path, autodetect_encoding=True)
142
- else:
143
- loader = TikaLoader(
144
- url=self.kwargs.get("TIKA_SERVER_URL"),
145
- file_path=file_path,
146
- mime_type=file_content_type,
147
- )
148
- else:
149
- if file_ext == "pdf":
150
- loader = PyPDFLoader(
151
- file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES")
152
- )
153
- elif file_ext == "csv":
154
- loader = CSVLoader(file_path)
155
- elif file_ext == "rst":
156
- loader = UnstructuredRSTLoader(file_path, mode="elements")
157
- elif file_ext == "xml":
158
- loader = UnstructuredXMLLoader(file_path)
159
- elif file_ext in ["htm", "html"]:
160
- loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
161
- elif file_ext == "md":
162
- loader = UnstructuredMarkdownLoader(file_path)
163
- elif file_content_type == "application/epub+zip":
164
- loader = UnstructuredEPubLoader(file_path)
165
- elif (
166
- file_content_type
167
- == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
168
- or file_ext == "docx"
169
- ):
170
- loader = Docx2txtLoader(file_path)
171
- elif file_content_type in [
172
- "application/vnd.ms-excel",
173
- "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
174
- ] or file_ext in ["xls", "xlsx"]:
175
- loader = UnstructuredExcelLoader(file_path)
176
- elif file_content_type in [
177
- "application/vnd.ms-powerpoint",
178
- "application/vnd.openxmlformats-officedocument.presentationml.presentation",
179
- ] or file_ext in ["ppt", "pptx"]:
180
- loader = UnstructuredPowerPointLoader(file_path)
181
- elif file_ext == "msg":
182
- loader = OutlookMessageLoader(file_path)
183
- elif file_ext in known_source_ext or (
184
- file_content_type and file_content_type.find("text/") >= 0
185
- ):
186
- loader = TextLoader(file_path, autodetect_encoding=True)
187
- else:
188
- loader = TextLoader(file_path, autodetect_encoding=True)
189
-
190
- return loader
 
1
+ import requests
2
+ import logging
3
+ import ftfy
4
+
5
+ from langchain_community.document_loaders import (
6
+ BSHTMLLoader,
7
+ CSVLoader,
8
+ Docx2txtLoader,
9
+ OutlookMessageLoader,
10
+ PyPDFLoader,
11
+ TextLoader,
12
+ UnstructuredEPubLoader,
13
+ UnstructuredExcelLoader,
14
+ UnstructuredMarkdownLoader,
15
+ UnstructuredPowerPointLoader,
16
+ UnstructuredRSTLoader,
17
+ UnstructuredXMLLoader,
18
+ YoutubeLoader,
19
+ )
20
+ from langchain_core.documents import Document
21
+ from open_webui.env import SRC_LOG_LEVELS
22
+
23
+ log = logging.getLogger(__name__)
24
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
25
+
26
+ known_source_ext = [
27
+ "go",
28
+ "py",
29
+ "java",
30
+ "sh",
31
+ "bat",
32
+ "ps1",
33
+ "cmd",
34
+ "js",
35
+ "ts",
36
+ "css",
37
+ "cpp",
38
+ "hpp",
39
+ "h",
40
+ "c",
41
+ "cs",
42
+ "sql",
43
+ "log",
44
+ "ini",
45
+ "pl",
46
+ "pm",
47
+ "r",
48
+ "dart",
49
+ "dockerfile",
50
+ "env",
51
+ "php",
52
+ "hs",
53
+ "hsc",
54
+ "lua",
55
+ "nginxconf",
56
+ "conf",
57
+ "m",
58
+ "mm",
59
+ "plsql",
60
+ "perl",
61
+ "rb",
62
+ "rs",
63
+ "db2",
64
+ "scala",
65
+ "bash",
66
+ "swift",
67
+ "vue",
68
+ "svelte",
69
+ "msg",
70
+ "ex",
71
+ "exs",
72
+ "erl",
73
+ "tsx",
74
+ "jsx",
75
+ "hs",
76
+ "lhs",
77
+ ]
78
+
79
+
80
+ class TikaLoader:
81
+ def __init__(self, url, file_path, mime_type=None):
82
+ self.url = url
83
+ self.file_path = file_path
84
+ self.mime_type = mime_type
85
+
86
+ def load(self) -> list[Document]:
87
+ with open(self.file_path, "rb") as f:
88
+ data = f.read()
89
+
90
+ if self.mime_type is not None:
91
+ headers = {"Content-Type": self.mime_type}
92
+ else:
93
+ headers = {}
94
+
95
+ endpoint = self.url
96
+ if not endpoint.endswith("/"):
97
+ endpoint += "/"
98
+ endpoint += "tika/text"
99
+
100
+ r = requests.put(endpoint, data=data, headers=headers)
101
+
102
+ if r.ok:
103
+ raw_metadata = r.json()
104
+ text = raw_metadata.get("X-TIKA:content", "<No text content found>")
105
+
106
+ if "Content-Type" in raw_metadata:
107
+ headers["Content-Type"] = raw_metadata["Content-Type"]
108
+
109
+ log.info("Tika extracted text: %s", text)
110
+
111
+ return [Document(page_content=text, metadata=headers)]
112
+ else:
113
+ raise Exception(f"Error calling Tika: {r.reason}")
114
+
115
+
116
+ class Loader:
117
+ def __init__(self, engine: str = "", **kwargs):
118
+ self.engine = engine
119
+ self.kwargs = kwargs
120
+
121
+ def load(
122
+ self, filename: str, file_content_type: str, file_path: str
123
+ ) -> list[Document]:
124
+ loader = self._get_loader(filename, file_content_type, file_path)
125
+ docs = loader.load()
126
+
127
+ return [
128
+ Document(
129
+ page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
130
+ )
131
+ for doc in docs
132
+ ]
133
+
134
+ def _get_loader(self, filename: str, file_content_type: str, file_path: str):
135
+ file_ext = filename.split(".")[-1].lower()
136
+
137
+ if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
138
+ if file_ext in known_source_ext or (
139
+ file_content_type and file_content_type.find("text/") >= 0
140
+ ):
141
+ loader = TextLoader(file_path, autodetect_encoding=True)
142
+ else:
143
+ loader = TikaLoader(
144
+ url=self.kwargs.get("TIKA_SERVER_URL"),
145
+ file_path=file_path,
146
+ mime_type=file_content_type,
147
+ )
148
+ else:
149
+ if file_ext == "pdf":
150
+ loader = PyPDFLoader(
151
+ file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES")
152
+ )
153
+ elif file_ext == "csv":
154
+ loader = CSVLoader(file_path)
155
+ elif file_ext == "rst":
156
+ loader = UnstructuredRSTLoader(file_path, mode="elements")
157
+ elif file_ext == "xml":
158
+ loader = UnstructuredXMLLoader(file_path)
159
+ elif file_ext in ["htm", "html"]:
160
+ loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
161
+ elif file_ext == "md":
162
+ loader = UnstructuredMarkdownLoader(file_path)
163
+ elif file_content_type == "application/epub+zip":
164
+ loader = UnstructuredEPubLoader(file_path)
165
+ elif (
166
+ file_content_type
167
+ == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
168
+ or file_ext == "docx"
169
+ ):
170
+ loader = Docx2txtLoader(file_path)
171
+ elif file_content_type in [
172
+ "application/vnd.ms-excel",
173
+ "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
174
+ ] or file_ext in ["xls", "xlsx"]:
175
+ loader = UnstructuredExcelLoader(file_path)
176
+ elif file_content_type in [
177
+ "application/vnd.ms-powerpoint",
178
+ "application/vnd.openxmlformats-officedocument.presentationml.presentation",
179
+ ] or file_ext in ["ppt", "pptx"]:
180
+ loader = UnstructuredPowerPointLoader(file_path)
181
+ elif file_ext == "msg":
182
+ loader = OutlookMessageLoader(file_path)
183
+ elif file_ext in known_source_ext or (
184
+ file_content_type and file_content_type.find("text/") >= 0
185
+ ):
186
+ loader = TextLoader(file_path, autodetect_encoding=True)
187
+ else:
188
+ loader = TextLoader(file_path, autodetect_encoding=True)
189
+
190
+ return loader
backend/open_webui/apps/retrieval/main.py CHANGED
@@ -1,1326 +1,1326 @@
1
- # TODO: Merge this with the webui_app and make it a single app
2
-
3
- import json
4
- import logging
5
- import mimetypes
6
- import os
7
- import shutil
8
-
9
- import uuid
10
- from datetime import datetime
11
- from pathlib import Path
12
- from typing import Iterator, Optional, Sequence, Union
13
-
14
- from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
15
- from fastapi.middleware.cors import CORSMiddleware
16
- from pydantic import BaseModel
17
-
18
-
19
- from open_webui.storage.provider import Storage
20
- from open_webui.apps.webui.models.knowledge import Knowledges
21
- from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
22
-
23
- # Document loaders
24
- from open_webui.apps.retrieval.loaders.main import Loader
25
-
26
- # Web search engines
27
- from open_webui.apps.retrieval.web.main import SearchResult
28
- from open_webui.apps.retrieval.web.utils import get_web_loader
29
- from open_webui.apps.retrieval.web.brave import search_brave
30
- from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo
31
- from open_webui.apps.retrieval.web.google_pse import search_google_pse
32
- from open_webui.apps.retrieval.web.jina_search import search_jina
33
- from open_webui.apps.retrieval.web.searchapi import search_searchapi
34
- from open_webui.apps.retrieval.web.searxng import search_searxng
35
- from open_webui.apps.retrieval.web.serper import search_serper
36
- from open_webui.apps.retrieval.web.serply import search_serply
37
- from open_webui.apps.retrieval.web.serpstack import search_serpstack
38
- from open_webui.apps.retrieval.web.tavily import search_tavily
39
-
40
-
41
- from open_webui.apps.retrieval.utils import (
42
- get_embedding_function,
43
- get_model_path,
44
- query_collection,
45
- query_collection_with_hybrid_search,
46
- query_doc,
47
- query_doc_with_hybrid_search,
48
- )
49
-
50
- from open_webui.apps.webui.models.files import Files
51
- from open_webui.config import (
52
- BRAVE_SEARCH_API_KEY,
53
- TIKTOKEN_ENCODING_NAME,
54
- RAG_TEXT_SPLITTER,
55
- CHUNK_OVERLAP,
56
- CHUNK_SIZE,
57
- CONTENT_EXTRACTION_ENGINE,
58
- CORS_ALLOW_ORIGIN,
59
- ENABLE_RAG_HYBRID_SEARCH,
60
- ENABLE_RAG_LOCAL_WEB_FETCH,
61
- ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
62
- ENABLE_RAG_WEB_SEARCH,
63
- ENV,
64
- GOOGLE_PSE_API_KEY,
65
- GOOGLE_PSE_ENGINE_ID,
66
- PDF_EXTRACT_IMAGES,
67
- RAG_EMBEDDING_ENGINE,
68
- RAG_EMBEDDING_MODEL,
69
- RAG_EMBEDDING_MODEL_AUTO_UPDATE,
70
- RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
71
- RAG_EMBEDDING_BATCH_SIZE,
72
- RAG_FILE_MAX_COUNT,
73
- RAG_FILE_MAX_SIZE,
74
- RAG_OPENAI_API_BASE_URL,
75
- RAG_OPENAI_API_KEY,
76
- RAG_RELEVANCE_THRESHOLD,
77
- RAG_RERANKING_MODEL,
78
- RAG_RERANKING_MODEL_AUTO_UPDATE,
79
- RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
80
- DEFAULT_RAG_TEMPLATE,
81
- RAG_TEMPLATE,
82
- RAG_TOP_K,
83
- RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
84
- RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
85
- RAG_WEB_SEARCH_ENGINE,
86
- RAG_WEB_SEARCH_RESULT_COUNT,
87
- SEARCHAPI_API_KEY,
88
- SEARCHAPI_ENGINE,
89
- SEARXNG_QUERY_URL,
90
- SERPER_API_KEY,
91
- SERPLY_API_KEY,
92
- SERPSTACK_API_KEY,
93
- SERPSTACK_HTTPS,
94
- TAVILY_API_KEY,
95
- TIKA_SERVER_URL,
96
- UPLOAD_DIR,
97
- YOUTUBE_LOADER_LANGUAGE,
98
- AppConfig,
99
- )
100
- from open_webui.constants import ERROR_MESSAGES
101
- from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER
102
- from open_webui.utils.misc import (
103
- calculate_sha256,
104
- calculate_sha256_string,
105
- extract_folders_after_data_docs,
106
- sanitize_filename,
107
- )
108
- from open_webui.utils.utils import get_admin_user, get_verified_user
109
-
110
- from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
111
- from langchain_community.document_loaders import (
112
- YoutubeLoader,
113
- )
114
- from langchain_core.documents import Document
115
-
116
-
117
- log = logging.getLogger(__name__)
118
- log.setLevel(SRC_LOG_LEVELS["RAG"])
119
-
120
- app = FastAPI()
121
-
122
- app.state.config = AppConfig()
123
-
124
- app.state.config.TOP_K = RAG_TOP_K
125
- app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
126
- app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
127
- app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
128
-
129
- app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
130
- app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
131
- ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
132
- )
133
-
134
- app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
135
- app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
136
-
137
- app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
138
- app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
139
-
140
- app.state.config.CHUNK_SIZE = CHUNK_SIZE
141
- app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
142
-
143
- app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
144
- app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
145
- app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
146
- app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
147
- app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
148
-
149
- app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
150
- app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
151
-
152
- app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
153
-
154
- app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
155
- app.state.YOUTUBE_LOADER_TRANSLATION = None
156
-
157
-
158
- app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
159
- app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
160
- app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
161
-
162
- app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
163
- app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
164
- app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
165
- app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
166
- app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
167
- app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
168
- app.state.config.SERPER_API_KEY = SERPER_API_KEY
169
- app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
170
- app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
171
- app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
172
- app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
173
- app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
174
- app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
175
-
176
-
177
- def update_embedding_model(
178
- embedding_model: str,
179
- auto_update: bool = False,
180
- ):
181
- if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
182
- from sentence_transformers import SentenceTransformer
183
-
184
- app.state.sentence_transformer_ef = SentenceTransformer(
185
- get_model_path(embedding_model, auto_update),
186
- device=DEVICE_TYPE,
187
- trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
188
- )
189
- else:
190
- app.state.sentence_transformer_ef = None
191
-
192
-
193
- def update_reranking_model(
194
- reranking_model: str,
195
- auto_update: bool = False,
196
- ):
197
- if reranking_model:
198
- if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
199
- try:
200
- from open_webui.apps.retrieval.models.colbert import ColBERT
201
-
202
- app.state.sentence_transformer_rf = ColBERT(
203
- get_model_path(reranking_model, auto_update),
204
- env="docker" if DOCKER else None,
205
- )
206
- except Exception as e:
207
- log.error(f"ColBERT: {e}")
208
- app.state.sentence_transformer_rf = None
209
- app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
210
- else:
211
- import sentence_transformers
212
-
213
- try:
214
- app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
215
- get_model_path(reranking_model, auto_update),
216
- device=DEVICE_TYPE,
217
- trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
218
- )
219
- except:
220
- log.error("CrossEncoder error")
221
- app.state.sentence_transformer_rf = None
222
- app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
223
- else:
224
- app.state.sentence_transformer_rf = None
225
-
226
-
227
- update_embedding_model(
228
- app.state.config.RAG_EMBEDDING_MODEL,
229
- RAG_EMBEDDING_MODEL_AUTO_UPDATE,
230
- )
231
-
232
- update_reranking_model(
233
- app.state.config.RAG_RERANKING_MODEL,
234
- RAG_RERANKING_MODEL_AUTO_UPDATE,
235
- )
236
-
237
-
238
- app.state.EMBEDDING_FUNCTION = get_embedding_function(
239
- app.state.config.RAG_EMBEDDING_ENGINE,
240
- app.state.config.RAG_EMBEDDING_MODEL,
241
- app.state.sentence_transformer_ef,
242
- app.state.config.OPENAI_API_KEY,
243
- app.state.config.OPENAI_API_BASE_URL,
244
- app.state.config.RAG_EMBEDDING_BATCH_SIZE,
245
- )
246
-
247
- app.add_middleware(
248
- CORSMiddleware,
249
- allow_origins=CORS_ALLOW_ORIGIN,
250
- allow_credentials=True,
251
- allow_methods=["*"],
252
- allow_headers=["*"],
253
- )
254
-
255
-
256
- class CollectionNameForm(BaseModel):
257
- collection_name: Optional[str] = None
258
-
259
-
260
- class ProcessUrlForm(CollectionNameForm):
261
- url: str
262
-
263
-
264
- class SearchForm(CollectionNameForm):
265
- query: str
266
-
267
-
268
- @app.get("/")
269
- async def get_status():
270
- return {
271
- "status": True,
272
- "chunk_size": app.state.config.CHUNK_SIZE,
273
- "chunk_overlap": app.state.config.CHUNK_OVERLAP,
274
- "template": app.state.config.RAG_TEMPLATE,
275
- "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
276
- "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
277
- "reranking_model": app.state.config.RAG_RERANKING_MODEL,
278
- "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
279
- }
280
-
281
-
282
- @app.get("/embedding")
283
- async def get_embedding_config(user=Depends(get_admin_user)):
284
- return {
285
- "status": True,
286
- "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
287
- "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
288
- "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
289
- "openai_config": {
290
- "url": app.state.config.OPENAI_API_BASE_URL,
291
- "key": app.state.config.OPENAI_API_KEY,
292
- },
293
- }
294
-
295
-
296
- @app.get("/reranking")
297
- async def get_reraanking_config(user=Depends(get_admin_user)):
298
- return {
299
- "status": True,
300
- "reranking_model": app.state.config.RAG_RERANKING_MODEL,
301
- }
302
-
303
-
304
- class OpenAIConfigForm(BaseModel):
305
- url: str
306
- key: str
307
-
308
-
309
- class EmbeddingModelUpdateForm(BaseModel):
310
- openai_config: Optional[OpenAIConfigForm] = None
311
- embedding_engine: str
312
- embedding_model: str
313
- embedding_batch_size: Optional[int] = 1
314
-
315
-
316
- @app.post("/embedding/update")
317
- async def update_embedding_config(
318
- form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
319
- ):
320
- log.info(
321
- f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
322
- )
323
- try:
324
- app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
325
- app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
326
-
327
- if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
328
- if form_data.openai_config is not None:
329
- app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
330
- app.state.config.OPENAI_API_KEY = form_data.openai_config.key
331
- app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
332
-
333
- update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
334
-
335
- app.state.EMBEDDING_FUNCTION = get_embedding_function(
336
- app.state.config.RAG_EMBEDDING_ENGINE,
337
- app.state.config.RAG_EMBEDDING_MODEL,
338
- app.state.sentence_transformer_ef,
339
- app.state.config.OPENAI_API_KEY,
340
- app.state.config.OPENAI_API_BASE_URL,
341
- app.state.config.RAG_EMBEDDING_BATCH_SIZE,
342
- )
343
-
344
- return {
345
- "status": True,
346
- "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
347
- "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
348
- "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
349
- "openai_config": {
350
- "url": app.state.config.OPENAI_API_BASE_URL,
351
- "key": app.state.config.OPENAI_API_KEY,
352
- },
353
- }
354
- except Exception as e:
355
- log.exception(f"Problem updating embedding model: {e}")
356
- raise HTTPException(
357
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
358
- detail=ERROR_MESSAGES.DEFAULT(e),
359
- )
360
-
361
-
362
- class RerankingModelUpdateForm(BaseModel):
363
- reranking_model: str
364
-
365
-
366
- @app.post("/reranking/update")
367
- async def update_reranking_config(
368
- form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
369
- ):
370
- log.info(
371
- f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
372
- )
373
- try:
374
- app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
375
-
376
- update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True)
377
-
378
- return {
379
- "status": True,
380
- "reranking_model": app.state.config.RAG_RERANKING_MODEL,
381
- }
382
- except Exception as e:
383
- log.exception(f"Problem updating reranking model: {e}")
384
- raise HTTPException(
385
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
386
- detail=ERROR_MESSAGES.DEFAULT(e),
387
- )
388
-
389
-
390
- @app.get("/config")
391
- async def get_rag_config(user=Depends(get_admin_user)):
392
- return {
393
- "status": True,
394
- "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
395
- "content_extraction": {
396
- "engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
397
- "tika_server_url": app.state.config.TIKA_SERVER_URL,
398
- },
399
- "chunk": {
400
- "text_splitter": app.state.config.TEXT_SPLITTER,
401
- "chunk_size": app.state.config.CHUNK_SIZE,
402
- "chunk_overlap": app.state.config.CHUNK_OVERLAP,
403
- },
404
- "file": {
405
- "max_size": app.state.config.FILE_MAX_SIZE,
406
- "max_count": app.state.config.FILE_MAX_COUNT,
407
- },
408
- "youtube": {
409
- "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
410
- "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
411
- },
412
- "web": {
413
- "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
414
- "search": {
415
- "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
416
- "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
417
- "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
418
- "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
419
- "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
420
- "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
421
- "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
422
- "serpstack_https": app.state.config.SERPSTACK_HTTPS,
423
- "serper_api_key": app.state.config.SERPER_API_KEY,
424
- "serply_api_key": app.state.config.SERPLY_API_KEY,
425
- "tavily_api_key": app.state.config.TAVILY_API_KEY,
426
- "searchapi_api_key": app.state.config.SEARCHAPI_API_KEY,
427
- "seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE,
428
- "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
429
- "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
430
- },
431
- },
432
- }
433
-
434
-
435
- class FileConfig(BaseModel):
436
- max_size: Optional[int] = None
437
- max_count: Optional[int] = None
438
-
439
-
440
- class ContentExtractionConfig(BaseModel):
441
- engine: str = ""
442
- tika_server_url: Optional[str] = None
443
-
444
-
445
- class ChunkParamUpdateForm(BaseModel):
446
- text_splitter: Optional[str] = None
447
- chunk_size: int
448
- chunk_overlap: int
449
-
450
-
451
- class YoutubeLoaderConfig(BaseModel):
452
- language: list[str]
453
- translation: Optional[str] = None
454
-
455
-
456
- class WebSearchConfig(BaseModel):
457
- enabled: bool
458
- engine: Optional[str] = None
459
- searxng_query_url: Optional[str] = None
460
- google_pse_api_key: Optional[str] = None
461
- google_pse_engine_id: Optional[str] = None
462
- brave_search_api_key: Optional[str] = None
463
- serpstack_api_key: Optional[str] = None
464
- serpstack_https: Optional[bool] = None
465
- serper_api_key: Optional[str] = None
466
- serply_api_key: Optional[str] = None
467
- tavily_api_key: Optional[str] = None
468
- searchapi_api_key: Optional[str] = None
469
- searchapi_engine: Optional[str] = None
470
- result_count: Optional[int] = None
471
- concurrent_requests: Optional[int] = None
472
-
473
-
474
- class WebConfig(BaseModel):
475
- search: WebSearchConfig
476
- web_loader_ssl_verification: Optional[bool] = None
477
-
478
-
479
- class ConfigUpdateForm(BaseModel):
480
- pdf_extract_images: Optional[bool] = None
481
- file: Optional[FileConfig] = None
482
- content_extraction: Optional[ContentExtractionConfig] = None
483
- chunk: Optional[ChunkParamUpdateForm] = None
484
- youtube: Optional[YoutubeLoaderConfig] = None
485
- web: Optional[WebConfig] = None
486
-
487
-
488
- @app.post("/config/update")
489
- async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
490
- app.state.config.PDF_EXTRACT_IMAGES = (
491
- form_data.pdf_extract_images
492
- if form_data.pdf_extract_images is not None
493
- else app.state.config.PDF_EXTRACT_IMAGES
494
- )
495
-
496
- if form_data.file is not None:
497
- app.state.config.FILE_MAX_SIZE = form_data.file.max_size
498
- app.state.config.FILE_MAX_COUNT = form_data.file.max_count
499
-
500
- if form_data.content_extraction is not None:
501
- log.info(f"Updating text settings: {form_data.content_extraction}")
502
- app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine
503
- app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
504
-
505
- if form_data.chunk is not None:
506
- app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
507
- app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
508
- app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
509
-
510
- if form_data.youtube is not None:
511
- app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
512
- app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
513
-
514
- if form_data.web is not None:
515
- app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
516
- form_data.web.web_loader_ssl_verification
517
- )
518
-
519
- app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
520
- app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
521
- app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
522
- app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
523
- app.state.config.GOOGLE_PSE_ENGINE_ID = (
524
- form_data.web.search.google_pse_engine_id
525
- )
526
- app.state.config.BRAVE_SEARCH_API_KEY = (
527
- form_data.web.search.brave_search_api_key
528
- )
529
- app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
530
- app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
531
- app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
532
- app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
533
- app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
534
- app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
535
- app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
536
- app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
537
- app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
538
- form_data.web.search.concurrent_requests
539
- )
540
-
541
- return {
542
- "status": True,
543
- "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
544
- "file": {
545
- "max_size": app.state.config.FILE_MAX_SIZE,
546
- "max_count": app.state.config.FILE_MAX_COUNT,
547
- },
548
- "content_extraction": {
549
- "engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
550
- "tika_server_url": app.state.config.TIKA_SERVER_URL,
551
- },
552
- "chunk": {
553
- "text_splitter": app.state.config.TEXT_SPLITTER,
554
- "chunk_size": app.state.config.CHUNK_SIZE,
555
- "chunk_overlap": app.state.config.CHUNK_OVERLAP,
556
- },
557
- "youtube": {
558
- "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
559
- "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
560
- },
561
- "web": {
562
- "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
563
- "search": {
564
- "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
565
- "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
566
- "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
567
- "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
568
- "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
569
- "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
570
- "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
571
- "serpstack_https": app.state.config.SERPSTACK_HTTPS,
572
- "serper_api_key": app.state.config.SERPER_API_KEY,
573
- "serply_api_key": app.state.config.SERPLY_API_KEY,
574
- "serachapi_api_key": app.state.config.SEARCHAPI_API_KEY,
575
- "searchapi_engine": app.state.config.SEARCHAPI_ENGINE,
576
- "tavily_api_key": app.state.config.TAVILY_API_KEY,
577
- "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
578
- "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
579
- },
580
- },
581
- }
582
-
583
-
584
- @app.get("/template")
585
- async def get_rag_template(user=Depends(get_verified_user)):
586
- return {
587
- "status": True,
588
- "template": app.state.config.RAG_TEMPLATE,
589
- }
590
-
591
-
592
- @app.get("/query/settings")
593
- async def get_query_settings(user=Depends(get_admin_user)):
594
- return {
595
- "status": True,
596
- "template": app.state.config.RAG_TEMPLATE,
597
- "k": app.state.config.TOP_K,
598
- "r": app.state.config.RELEVANCE_THRESHOLD,
599
- "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
600
- }
601
-
602
-
603
- class QuerySettingsForm(BaseModel):
604
- k: Optional[int] = None
605
- r: Optional[float] = None
606
- template: Optional[str] = None
607
- hybrid: Optional[bool] = None
608
-
609
-
610
- @app.post("/query/settings/update")
611
- async def update_query_settings(
612
- form_data: QuerySettingsForm, user=Depends(get_admin_user)
613
- ):
614
- app.state.config.RAG_TEMPLATE = form_data.template
615
- app.state.config.TOP_K = form_data.k if form_data.k else 4
616
- app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
617
-
618
- app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
619
- form_data.hybrid if form_data.hybrid else False
620
- )
621
-
622
- return {
623
- "status": True,
624
- "template": app.state.config.RAG_TEMPLATE,
625
- "k": app.state.config.TOP_K,
626
- "r": app.state.config.RELEVANCE_THRESHOLD,
627
- "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
628
- }
629
-
630
-
631
- ####################################
632
- #
633
- # Document process and retrieval
634
- #
635
- ####################################
636
-
637
-
638
- def save_docs_to_vector_db(
639
- docs,
640
- collection_name,
641
- metadata: Optional[dict] = None,
642
- overwrite: bool = False,
643
- split: bool = True,
644
- add: bool = False,
645
- ) -> bool:
646
- log.info(f"save_docs_to_vector_db {docs} {collection_name}")
647
-
648
- # Check if entries with the same hash (metadata.hash) already exist
649
- if metadata and "hash" in metadata:
650
- result = VECTOR_DB_CLIENT.query(
651
- collection_name=collection_name,
652
- filter={"hash": metadata["hash"]},
653
- )
654
-
655
- if result is not None:
656
- existing_doc_ids = result.ids[0]
657
- if existing_doc_ids:
658
- log.info(f"Document with hash {metadata['hash']} already exists")
659
- raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
660
-
661
- if split:
662
- if app.state.config.TEXT_SPLITTER in ["", "character"]:
663
- text_splitter = RecursiveCharacterTextSplitter(
664
- chunk_size=app.state.config.CHUNK_SIZE,
665
- chunk_overlap=app.state.config.CHUNK_OVERLAP,
666
- add_start_index=True,
667
- )
668
- elif app.state.config.TEXT_SPLITTER == "token":
669
- text_splitter = TokenTextSplitter(
670
- encoding_name=app.state.config.TIKTOKEN_ENCODING_NAME,
671
- chunk_size=app.state.config.CHUNK_SIZE,
672
- chunk_overlap=app.state.config.CHUNK_OVERLAP,
673
- add_start_index=True,
674
- )
675
- else:
676
- raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
677
-
678
- docs = text_splitter.split_documents(docs)
679
-
680
- if len(docs) == 0:
681
- raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
682
-
683
- texts = [doc.page_content for doc in docs]
684
- metadatas = [
685
- {
686
- **doc.metadata,
687
- **(metadata if metadata else {}),
688
- "embedding_config": json.dumps(
689
- {
690
- "engine": app.state.config.RAG_EMBEDDING_ENGINE,
691
- "model": app.state.config.RAG_EMBEDDING_MODEL,
692
- }
693
- ),
694
- }
695
- for doc in docs
696
- ]
697
-
698
- # ChromaDB does not like datetime formats
699
- # for meta-data so convert them to string.
700
- for metadata in metadatas:
701
- for key, value in metadata.items():
702
- if isinstance(value, datetime):
703
- metadata[key] = str(value)
704
-
705
- try:
706
- if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
707
- log.info(f"collection {collection_name} already exists")
708
-
709
- if overwrite:
710
- VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
711
- log.info(f"deleting existing collection {collection_name}")
712
- elif add is False:
713
- log.info(
714
- f"collection {collection_name} already exists, overwrite is False and add is False"
715
- )
716
- return True
717
-
718
- log.info(f"adding to collection {collection_name}")
719
- embedding_function = get_embedding_function(
720
- app.state.config.RAG_EMBEDDING_ENGINE,
721
- app.state.config.RAG_EMBEDDING_MODEL,
722
- app.state.sentence_transformer_ef,
723
- app.state.config.OPENAI_API_KEY,
724
- app.state.config.OPENAI_API_BASE_URL,
725
- app.state.config.RAG_EMBEDDING_BATCH_SIZE,
726
- )
727
-
728
- embeddings = embedding_function(
729
- list(map(lambda x: x.replace("\n", " "), texts))
730
- )
731
-
732
- items = [
733
- {
734
- "id": str(uuid.uuid4()),
735
- "text": text,
736
- "vector": embeddings[idx],
737
- "metadata": metadatas[idx],
738
- }
739
- for idx, text in enumerate(texts)
740
- ]
741
-
742
- VECTOR_DB_CLIENT.insert(
743
- collection_name=collection_name,
744
- items=items,
745
- )
746
-
747
- return True
748
- except Exception as e:
749
- log.exception(e)
750
- return False
751
-
752
-
753
- class ProcessFileForm(BaseModel):
754
- file_id: str
755
- content: Optional[str] = None
756
- collection_name: Optional[str] = None
757
-
758
-
759
- @app.post("/process/file")
760
- def process_file(
761
- form_data: ProcessFileForm,
762
- user=Depends(get_verified_user),
763
- ):
764
- try:
765
- file = Files.get_file_by_id(form_data.file_id)
766
-
767
- collection_name = form_data.collection_name
768
-
769
- if collection_name is None:
770
- collection_name = f"file-{file.id}"
771
-
772
- if form_data.content:
773
- # Update the content in the file
774
- # Usage: /files/{file_id}/data/content/update
775
-
776
- VECTOR_DB_CLIENT.delete(
777
- collection_name=f"file-{file.id}",
778
- filter={"file_id": file.id},
779
- )
780
-
781
- docs = [
782
- Document(
783
- page_content=form_data.content,
784
- metadata={
785
- "name": file.meta.get("name", file.filename),
786
- "created_by": file.user_id,
787
- "file_id": file.id,
788
- **file.meta,
789
- },
790
- )
791
- ]
792
-
793
- text_content = form_data.content
794
- elif form_data.collection_name:
795
- # Check if the file has already been processed and save the content
796
- # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
797
-
798
- result = VECTOR_DB_CLIENT.query(
799
- collection_name=f"file-{file.id}", filter={"file_id": file.id}
800
- )
801
-
802
- if result is not None and len(result.ids[0]) > 0:
803
- docs = [
804
- Document(
805
- page_content=result.documents[0][idx],
806
- metadata=result.metadatas[0][idx],
807
- )
808
- for idx, id in enumerate(result.ids[0])
809
- ]
810
- else:
811
- docs = [
812
- Document(
813
- page_content=file.data.get("content", ""),
814
- metadata={
815
- "name": file.meta.get("name", file.filename),
816
- "created_by": file.user_id,
817
- "file_id": file.id,
818
- **file.meta,
819
- },
820
- )
821
- ]
822
-
823
- text_content = file.data.get("content", "")
824
- else:
825
- # Process the file and save the content
826
- # Usage: /files/
827
- file_path = file.path
828
- if file_path:
829
- file_path = Storage.get_file(file_path)
830
- loader = Loader(
831
- engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
832
- TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
833
- PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
834
- )
835
- docs = loader.load(
836
- file.filename, file.meta.get("content_type"), file_path
837
- )
838
- else:
839
- docs = [
840
- Document(
841
- page_content=file.data.get("content", ""),
842
- metadata={
843
- "name": file.filename,
844
- "created_by": file.user_id,
845
- "file_id": file.id,
846
- **file.meta,
847
- },
848
- )
849
- ]
850
- text_content = " ".join([doc.page_content for doc in docs])
851
-
852
- log.debug(f"text_content: {text_content}")
853
- Files.update_file_data_by_id(
854
- file.id,
855
- {"content": text_content},
856
- )
857
-
858
- hash = calculate_sha256_string(text_content)
859
- Files.update_file_hash_by_id(file.id, hash)
860
-
861
- try:
862
- result = save_docs_to_vector_db(
863
- docs=docs,
864
- collection_name=collection_name,
865
- metadata={
866
- "file_id": file.id,
867
- "name": file.meta.get("name", file.filename),
868
- "hash": hash,
869
- },
870
- add=(True if form_data.collection_name else False),
871
- )
872
-
873
- if result:
874
- Files.update_file_metadata_by_id(
875
- file.id,
876
- {
877
- "collection_name": collection_name,
878
- },
879
- )
880
-
881
- return {
882
- "status": True,
883
- "collection_name": collection_name,
884
- "filename": file.meta.get("name", file.filename),
885
- "content": text_content,
886
- }
887
- except Exception as e:
888
- raise e
889
- except Exception as e:
890
- log.exception(e)
891
- if "No pandoc was found" in str(e):
892
- raise HTTPException(
893
- status_code=status.HTTP_400_BAD_REQUEST,
894
- detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
895
- )
896
- else:
897
- raise HTTPException(
898
- status_code=status.HTTP_400_BAD_REQUEST,
899
- detail=str(e),
900
- )
901
-
902
-
903
- class ProcessTextForm(BaseModel):
904
- name: str
905
- content: str
906
- collection_name: Optional[str] = None
907
-
908
-
909
- @app.post("/process/text")
910
- def process_text(
911
- form_data: ProcessTextForm,
912
- user=Depends(get_verified_user),
913
- ):
914
- collection_name = form_data.collection_name
915
- if collection_name is None:
916
- collection_name = calculate_sha256_string(form_data.content)
917
-
918
- docs = [
919
- Document(
920
- page_content=form_data.content,
921
- metadata={"name": form_data.name, "created_by": user.id},
922
- )
923
- ]
924
- text_content = form_data.content
925
- log.debug(f"text_content: {text_content}")
926
-
927
- result = save_docs_to_vector_db(docs, collection_name)
928
-
929
- if result:
930
- return {
931
- "status": True,
932
- "collection_name": collection_name,
933
- "content": text_content,
934
- }
935
- else:
936
- raise HTTPException(
937
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
938
- detail=ERROR_MESSAGES.DEFAULT(),
939
- )
940
-
941
-
942
- @app.post("/process/youtube")
943
- def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
944
- try:
945
- collection_name = form_data.collection_name
946
- if not collection_name:
947
- collection_name = calculate_sha256_string(form_data.url)[:63]
948
-
949
- loader = YoutubeLoader.from_youtube_url(
950
- form_data.url,
951
- add_video_info=True,
952
- language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
953
- translation=app.state.YOUTUBE_LOADER_TRANSLATION,
954
- )
955
- docs = loader.load()
956
- content = " ".join([doc.page_content for doc in docs])
957
- log.debug(f"text_content: {content}")
958
- save_docs_to_vector_db(docs, collection_name, overwrite=True)
959
-
960
- return {
961
- "status": True,
962
- "collection_name": collection_name,
963
- "filename": form_data.url,
964
- "file": {
965
- "data": {
966
- "content": content,
967
- },
968
- "meta": {
969
- "name": form_data.url,
970
- },
971
- },
972
- }
973
- except Exception as e:
974
- log.exception(e)
975
- raise HTTPException(
976
- status_code=status.HTTP_400_BAD_REQUEST,
977
- detail=ERROR_MESSAGES.DEFAULT(e),
978
- )
979
-
980
-
981
- @app.post("/process/web")
982
- def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
983
- try:
984
- collection_name = form_data.collection_name
985
- if not collection_name:
986
- collection_name = calculate_sha256_string(form_data.url)[:63]
987
-
988
- loader = get_web_loader(
989
- form_data.url,
990
- verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
991
- requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
992
- )
993
- docs = loader.load()
994
- content = " ".join([doc.page_content for doc in docs])
995
- log.debug(f"text_content: {content}")
996
- save_docs_to_vector_db(docs, collection_name, overwrite=True)
997
-
998
- return {
999
- "status": True,
1000
- "collection_name": collection_name,
1001
- "filename": form_data.url,
1002
- "file": {
1003
- "data": {
1004
- "content": content,
1005
- },
1006
- "meta": {
1007
- "name": form_data.url,
1008
- },
1009
- },
1010
- }
1011
- except Exception as e:
1012
- log.exception(e)
1013
- raise HTTPException(
1014
- status_code=status.HTTP_400_BAD_REQUEST,
1015
- detail=ERROR_MESSAGES.DEFAULT(e),
1016
- )
1017
-
1018
-
1019
- def search_web(engine: str, query: str) -> list[SearchResult]:
1020
- """Search the web using a search engine and return the results as a list of SearchResult objects.
1021
- Will look for a search engine API key in environment variables in the following order:
1022
- - SEARXNG_QUERY_URL
1023
- - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
1024
- - BRAVE_SEARCH_API_KEY
1025
- - SERPSTACK_API_KEY
1026
- - SERPER_API_KEY
1027
- - SERPLY_API_KEY
1028
- - TAVILY_API_KEY
1029
- - SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
1030
- Args:
1031
- query (str): The query to search for
1032
- """
1033
-
1034
- # TODO: add playwright to search the web
1035
- if engine == "searxng":
1036
- if app.state.config.SEARXNG_QUERY_URL:
1037
- return search_searxng(
1038
- app.state.config.SEARXNG_QUERY_URL,
1039
- query,
1040
- app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1041
- app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1042
- )
1043
- else:
1044
- raise Exception("No SEARXNG_QUERY_URL found in environment variables")
1045
- elif engine == "google_pse":
1046
- if (
1047
- app.state.config.GOOGLE_PSE_API_KEY
1048
- and app.state.config.GOOGLE_PSE_ENGINE_ID
1049
- ):
1050
- return search_google_pse(
1051
- app.state.config.GOOGLE_PSE_API_KEY,
1052
- app.state.config.GOOGLE_PSE_ENGINE_ID,
1053
- query,
1054
- app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1055
- app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1056
- )
1057
- else:
1058
- raise Exception(
1059
- "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
1060
- )
1061
- elif engine == "brave":
1062
- if app.state.config.BRAVE_SEARCH_API_KEY:
1063
- return search_brave(
1064
- app.state.config.BRAVE_SEARCH_API_KEY,
1065
- query,
1066
- app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1067
- app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1068
- )
1069
- else:
1070
- raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
1071
- elif engine == "serpstack":
1072
- if app.state.config.SERPSTACK_API_KEY:
1073
- return search_serpstack(
1074
- app.state.config.SERPSTACK_API_KEY,
1075
- query,
1076
- app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1077
- app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1078
- https_enabled=app.state.config.SERPSTACK_HTTPS,
1079
- )
1080
- else:
1081
- raise Exception("No SERPSTACK_API_KEY found in environment variables")
1082
- elif engine == "serper":
1083
- if app.state.config.SERPER_API_KEY:
1084
- return search_serper(
1085
- app.state.config.SERPER_API_KEY,
1086
- query,
1087
- app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1088
- app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1089
- )
1090
- else:
1091
- raise Exception("No SERPER_API_KEY found in environment variables")
1092
- elif engine == "serply":
1093
- if app.state.config.SERPLY_API_KEY:
1094
- return search_serply(
1095
- app.state.config.SERPLY_API_KEY,
1096
- query,
1097
- app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1098
- app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1099
- )
1100
- else:
1101
- raise Exception("No SERPLY_API_KEY found in environment variables")
1102
- elif engine == "duckduckgo":
1103
- return search_duckduckgo(
1104
- query,
1105
- app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1106
- app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1107
- )
1108
- elif engine == "tavily":
1109
- if app.state.config.TAVILY_API_KEY:
1110
- return search_tavily(
1111
- app.state.config.TAVILY_API_KEY,
1112
- query,
1113
- app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1114
- )
1115
- else:
1116
- raise Exception("No TAVILY_API_KEY found in environment variables")
1117
- elif engine == "searchapi":
1118
- if app.state.config.SEARCHAPI_API_KEY:
1119
- return search_searchapi(
1120
- app.state.config.SEARCHAPI_API_KEY,
1121
- app.state.config.SEARCHAPI_ENGINE,
1122
- query,
1123
- app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1124
- app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1125
- )
1126
- else:
1127
- raise Exception("No SEARCHAPI_API_KEY found in environment variables")
1128
- elif engine == "jina":
1129
- return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
1130
- else:
1131
- raise Exception("No search engine API key found in environment variables")
1132
-
1133
-
1134
- @app.post("/process/web/search")
1135
- def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
1136
- try:
1137
- logging.info(
1138
- f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
1139
- )
1140
- web_results = search_web(
1141
- app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
1142
- )
1143
- except Exception as e:
1144
- log.exception(e)
1145
-
1146
- print(e)
1147
- raise HTTPException(
1148
- status_code=status.HTTP_400_BAD_REQUEST,
1149
- detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
1150
- )
1151
-
1152
- try:
1153
- collection_name = form_data.collection_name
1154
- if collection_name == "":
1155
- collection_name = calculate_sha256_string(form_data.query)[:63]
1156
-
1157
- urls = [result.link for result in web_results]
1158
-
1159
- loader = get_web_loader(urls)
1160
- docs = loader.load()
1161
-
1162
- save_docs_to_vector_db(docs, collection_name, overwrite=True)
1163
-
1164
- return {
1165
- "status": True,
1166
- "collection_name": collection_name,
1167
- "filenames": urls,
1168
- }
1169
- except Exception as e:
1170
- log.exception(e)
1171
- raise HTTPException(
1172
- status_code=status.HTTP_400_BAD_REQUEST,
1173
- detail=ERROR_MESSAGES.DEFAULT(e),
1174
- )
1175
-
1176
-
1177
- class QueryDocForm(BaseModel):
1178
- collection_name: str
1179
- query: str
1180
- k: Optional[int] = None
1181
- r: Optional[float] = None
1182
- hybrid: Optional[bool] = None
1183
-
1184
-
1185
- @app.post("/query/doc")
1186
- def query_doc_handler(
1187
- form_data: QueryDocForm,
1188
- user=Depends(get_verified_user),
1189
- ):
1190
- try:
1191
- if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
1192
- return query_doc_with_hybrid_search(
1193
- collection_name=form_data.collection_name,
1194
- query=form_data.query,
1195
- embedding_function=app.state.EMBEDDING_FUNCTION,
1196
- k=form_data.k if form_data.k else app.state.config.TOP_K,
1197
- reranking_function=app.state.sentence_transformer_rf,
1198
- r=(
1199
- form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
1200
- ),
1201
- )
1202
- else:
1203
- return query_doc(
1204
- collection_name=form_data.collection_name,
1205
- query=form_data.query,
1206
- embedding_function=app.state.EMBEDDING_FUNCTION,
1207
- k=form_data.k if form_data.k else app.state.config.TOP_K,
1208
- )
1209
- except Exception as e:
1210
- log.exception(e)
1211
- raise HTTPException(
1212
- status_code=status.HTTP_400_BAD_REQUEST,
1213
- detail=ERROR_MESSAGES.DEFAULT(e),
1214
- )
1215
-
1216
-
1217
- class QueryCollectionsForm(BaseModel):
1218
- collection_names: list[str]
1219
- query: str
1220
- k: Optional[int] = None
1221
- r: Optional[float] = None
1222
- hybrid: Optional[bool] = None
1223
-
1224
-
1225
- @app.post("/query/collection")
1226
- def query_collection_handler(
1227
- form_data: QueryCollectionsForm,
1228
- user=Depends(get_verified_user),
1229
- ):
1230
- try:
1231
- if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
1232
- return query_collection_with_hybrid_search(
1233
- collection_names=form_data.collection_names,
1234
- query=form_data.query,
1235
- embedding_function=app.state.EMBEDDING_FUNCTION,
1236
- k=form_data.k if form_data.k else app.state.config.TOP_K,
1237
- reranking_function=app.state.sentence_transformer_rf,
1238
- r=(
1239
- form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
1240
- ),
1241
- )
1242
- else:
1243
- return query_collection(
1244
- collection_names=form_data.collection_names,
1245
- query=form_data.query,
1246
- embedding_function=app.state.EMBEDDING_FUNCTION,
1247
- k=form_data.k if form_data.k else app.state.config.TOP_K,
1248
- )
1249
-
1250
- except Exception as e:
1251
- log.exception(e)
1252
- raise HTTPException(
1253
- status_code=status.HTTP_400_BAD_REQUEST,
1254
- detail=ERROR_MESSAGES.DEFAULT(e),
1255
- )
1256
-
1257
-
1258
- ####################################
1259
- #
1260
- # Vector DB operations
1261
- #
1262
- ####################################
1263
-
1264
-
1265
- class DeleteForm(BaseModel):
1266
- collection_name: str
1267
- file_id: str
1268
-
1269
-
1270
- @app.post("/delete")
1271
- def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
1272
- try:
1273
- if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
1274
- file = Files.get_file_by_id(form_data.file_id)
1275
- hash = file.hash
1276
-
1277
- VECTOR_DB_CLIENT.delete(
1278
- collection_name=form_data.collection_name,
1279
- metadata={"hash": hash},
1280
- )
1281
- return {"status": True}
1282
- else:
1283
- return {"status": False}
1284
- except Exception as e:
1285
- log.exception(e)
1286
- return {"status": False}
1287
-
1288
-
1289
- @app.post("/reset/db")
1290
- def reset_vector_db(user=Depends(get_admin_user)):
1291
- VECTOR_DB_CLIENT.reset()
1292
- Knowledges.delete_all_knowledge()
1293
-
1294
-
1295
- @app.post("/reset/uploads")
1296
- def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
1297
- folder = f"{UPLOAD_DIR}"
1298
- try:
1299
- # Check if the directory exists
1300
- if os.path.exists(folder):
1301
- # Iterate over all the files and directories in the specified directory
1302
- for filename in os.listdir(folder):
1303
- file_path = os.path.join(folder, filename)
1304
- try:
1305
- if os.path.isfile(file_path) or os.path.islink(file_path):
1306
- os.unlink(file_path) # Remove the file or link
1307
- elif os.path.isdir(file_path):
1308
- shutil.rmtree(file_path) # Remove the directory
1309
- except Exception as e:
1310
- print(f"Failed to delete {file_path}. Reason: {e}")
1311
- else:
1312
- print(f"The directory {folder} does not exist")
1313
- except Exception as e:
1314
- print(f"Failed to process the directory {folder}. Reason: {e}")
1315
- return True
1316
-
1317
-
1318
- if ENV == "dev":
1319
-
1320
- @app.get("/ef")
1321
- async def get_embeddings():
1322
- return {"result": app.state.EMBEDDING_FUNCTION("hello world")}
1323
-
1324
- @app.get("/ef/{text}")
1325
- async def get_embeddings_text(text: str):
1326
- return {"result": app.state.EMBEDDING_FUNCTION(text)}
 
1
+ # TODO: Merge this with the webui_app and make it a single app
2
+
3
+ import json
4
+ import logging
5
+ import mimetypes
6
+ import os
7
+ import shutil
8
+
9
+ import uuid
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Iterator, Optional, Sequence, Union
13
+
14
+ from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from pydantic import BaseModel
17
+
18
+
19
+ from open_webui.storage.provider import Storage
20
+ from open_webui.apps.webui.models.knowledge import Knowledges
21
+ from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
22
+
23
+ # Document loaders
24
+ from open_webui.apps.retrieval.loaders.main import Loader
25
+
26
+ # Web search engines
27
+ from open_webui.apps.retrieval.web.main import SearchResult
28
+ from open_webui.apps.retrieval.web.utils import get_web_loader
29
+ from open_webui.apps.retrieval.web.brave import search_brave
30
+ from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo
31
+ from open_webui.apps.retrieval.web.google_pse import search_google_pse
32
+ from open_webui.apps.retrieval.web.jina_search import search_jina
33
+ from open_webui.apps.retrieval.web.searchapi import search_searchapi
34
+ from open_webui.apps.retrieval.web.searxng import search_searxng
35
+ from open_webui.apps.retrieval.web.serper import search_serper
36
+ from open_webui.apps.retrieval.web.serply import search_serply
37
+ from open_webui.apps.retrieval.web.serpstack import search_serpstack
38
+ from open_webui.apps.retrieval.web.tavily import search_tavily
39
+
40
+
41
+ from open_webui.apps.retrieval.utils import (
42
+ get_embedding_function,
43
+ get_model_path,
44
+ query_collection,
45
+ query_collection_with_hybrid_search,
46
+ query_doc,
47
+ query_doc_with_hybrid_search,
48
+ )
49
+
50
+ from open_webui.apps.webui.models.files import Files
51
+ from open_webui.config import (
52
+ BRAVE_SEARCH_API_KEY,
53
+ TIKTOKEN_ENCODING_NAME,
54
+ RAG_TEXT_SPLITTER,
55
+ CHUNK_OVERLAP,
56
+ CHUNK_SIZE,
57
+ CONTENT_EXTRACTION_ENGINE,
58
+ CORS_ALLOW_ORIGIN,
59
+ ENABLE_RAG_HYBRID_SEARCH,
60
+ ENABLE_RAG_LOCAL_WEB_FETCH,
61
+ ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
62
+ ENABLE_RAG_WEB_SEARCH,
63
+ ENV,
64
+ GOOGLE_PSE_API_KEY,
65
+ GOOGLE_PSE_ENGINE_ID,
66
+ PDF_EXTRACT_IMAGES,
67
+ RAG_EMBEDDING_ENGINE,
68
+ RAG_EMBEDDING_MODEL,
69
+ RAG_EMBEDDING_MODEL_AUTO_UPDATE,
70
+ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
71
+ RAG_EMBEDDING_BATCH_SIZE,
72
+ RAG_FILE_MAX_COUNT,
73
+ RAG_FILE_MAX_SIZE,
74
+ RAG_OPENAI_API_BASE_URL,
75
+ RAG_OPENAI_API_KEY,
76
+ RAG_RELEVANCE_THRESHOLD,
77
+ RAG_RERANKING_MODEL,
78
+ RAG_RERANKING_MODEL_AUTO_UPDATE,
79
+ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
80
+ DEFAULT_RAG_TEMPLATE,
81
+ RAG_TEMPLATE,
82
+ RAG_TOP_K,
83
+ RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
84
+ RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
85
+ RAG_WEB_SEARCH_ENGINE,
86
+ RAG_WEB_SEARCH_RESULT_COUNT,
87
+ SEARCHAPI_API_KEY,
88
+ SEARCHAPI_ENGINE,
89
+ SEARXNG_QUERY_URL,
90
+ SERPER_API_KEY,
91
+ SERPLY_API_KEY,
92
+ SERPSTACK_API_KEY,
93
+ SERPSTACK_HTTPS,
94
+ TAVILY_API_KEY,
95
+ TIKA_SERVER_URL,
96
+ UPLOAD_DIR,
97
+ YOUTUBE_LOADER_LANGUAGE,
98
+ AppConfig,
99
+ )
100
+ from open_webui.constants import ERROR_MESSAGES
101
+ from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER
102
+ from open_webui.utils.misc import (
103
+ calculate_sha256,
104
+ calculate_sha256_string,
105
+ extract_folders_after_data_docs,
106
+ sanitize_filename,
107
+ )
108
+ from open_webui.utils.utils import get_admin_user, get_verified_user
109
+
110
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
111
+ from langchain_community.document_loaders import (
112
+ YoutubeLoader,
113
+ )
114
+ from langchain_core.documents import Document
115
+
116
+
117
+ log = logging.getLogger(__name__)
118
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
119
+
120
+ app = FastAPI()
121
+
122
+ app.state.config = AppConfig()
123
+
124
+ app.state.config.TOP_K = RAG_TOP_K
125
+ app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
126
+ app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
127
+ app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
128
+
129
+ app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
130
+ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
131
+ ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
132
+ )
133
+
134
+ app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
135
+ app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
136
+
137
+ app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
138
+ app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
139
+
140
+ app.state.config.CHUNK_SIZE = CHUNK_SIZE
141
+ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
142
+
143
+ app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
144
+ app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
145
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
146
+ app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
147
+ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
148
+
149
+ app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
150
+ app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
151
+
152
+ app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
153
+
154
+ app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
155
+ app.state.YOUTUBE_LOADER_TRANSLATION = None
156
+
157
+
158
+ app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
159
+ app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
160
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
161
+
162
+ app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
163
+ app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
164
+ app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
165
+ app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
166
+ app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
167
+ app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
168
+ app.state.config.SERPER_API_KEY = SERPER_API_KEY
169
+ app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
170
+ app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
171
+ app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
172
+ app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
173
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
174
+ app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
175
+
176
+
177
+ def update_embedding_model(
178
+ embedding_model: str,
179
+ auto_update: bool = False,
180
+ ):
181
+ if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
182
+ from sentence_transformers import SentenceTransformer
183
+
184
+ app.state.sentence_transformer_ef = SentenceTransformer(
185
+ get_model_path(embedding_model, auto_update),
186
+ device=DEVICE_TYPE,
187
+ trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
188
+ )
189
+ else:
190
+ app.state.sentence_transformer_ef = None
191
+
192
+
193
+ def update_reranking_model(
194
+ reranking_model: str,
195
+ auto_update: bool = False,
196
+ ):
197
+ if reranking_model:
198
+ if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
199
+ try:
200
+ from open_webui.apps.retrieval.models.colbert import ColBERT
201
+
202
+ app.state.sentence_transformer_rf = ColBERT(
203
+ get_model_path(reranking_model, auto_update),
204
+ env="docker" if DOCKER else None,
205
+ )
206
+ except Exception as e:
207
+ log.error(f"ColBERT: {e}")
208
+ app.state.sentence_transformer_rf = None
209
+ app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
210
+ else:
211
+ import sentence_transformers
212
+
213
+ try:
214
+ app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
215
+ get_model_path(reranking_model, auto_update),
216
+ device=DEVICE_TYPE,
217
+ trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
218
+ )
219
+ except:
220
+ log.error("CrossEncoder error")
221
+ app.state.sentence_transformer_rf = None
222
+ app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
223
+ else:
224
+ app.state.sentence_transformer_rf = None
225
+
226
+
227
+ update_embedding_model(
228
+ app.state.config.RAG_EMBEDDING_MODEL,
229
+ RAG_EMBEDDING_MODEL_AUTO_UPDATE,
230
+ )
231
+
232
+ update_reranking_model(
233
+ app.state.config.RAG_RERANKING_MODEL,
234
+ RAG_RERANKING_MODEL_AUTO_UPDATE,
235
+ )
236
+
237
+
238
+ app.state.EMBEDDING_FUNCTION = get_embedding_function(
239
+ app.state.config.RAG_EMBEDDING_ENGINE,
240
+ app.state.config.RAG_EMBEDDING_MODEL,
241
+ app.state.sentence_transformer_ef,
242
+ app.state.config.OPENAI_API_KEY,
243
+ app.state.config.OPENAI_API_BASE_URL,
244
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE,
245
+ )
246
+
247
+ app.add_middleware(
248
+ CORSMiddleware,
249
+ allow_origins=CORS_ALLOW_ORIGIN,
250
+ allow_credentials=True,
251
+ allow_methods=["*"],
252
+ allow_headers=["*"],
253
+ )
254
+
255
+
256
+ class CollectionNameForm(BaseModel):
257
+ collection_name: Optional[str] = None
258
+
259
+
260
+ class ProcessUrlForm(CollectionNameForm):
261
+ url: str
262
+
263
+
264
+ class SearchForm(CollectionNameForm):
265
+ query: str
266
+
267
+
268
+ @app.get("/")
269
+ async def get_status():
270
+ return {
271
+ "status": True,
272
+ "chunk_size": app.state.config.CHUNK_SIZE,
273
+ "chunk_overlap": app.state.config.CHUNK_OVERLAP,
274
+ "template": app.state.config.RAG_TEMPLATE,
275
+ "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
276
+ "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
277
+ "reranking_model": app.state.config.RAG_RERANKING_MODEL,
278
+ "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
279
+ }
280
+
281
+
282
+ @app.get("/embedding")
283
+ async def get_embedding_config(user=Depends(get_admin_user)):
284
+ return {
285
+ "status": True,
286
+ "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
287
+ "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
288
+ "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
289
+ "openai_config": {
290
+ "url": app.state.config.OPENAI_API_BASE_URL,
291
+ "key": app.state.config.OPENAI_API_KEY,
292
+ },
293
+ }
294
+
295
+
296
+ @app.get("/reranking")
297
+ async def get_reraanking_config(user=Depends(get_admin_user)):
298
+ return {
299
+ "status": True,
300
+ "reranking_model": app.state.config.RAG_RERANKING_MODEL,
301
+ }
302
+
303
+
304
+ class OpenAIConfigForm(BaseModel):
305
+ url: str
306
+ key: str
307
+
308
+
309
+ class EmbeddingModelUpdateForm(BaseModel):
310
+ openai_config: Optional[OpenAIConfigForm] = None
311
+ embedding_engine: str
312
+ embedding_model: str
313
+ embedding_batch_size: Optional[int] = 1
314
+
315
+
316
+ @app.post("/embedding/update")
317
+ async def update_embedding_config(
318
+ form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
319
+ ):
320
+ log.info(
321
+ f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
322
+ )
323
+ try:
324
+ app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
325
+ app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
326
+
327
+ if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
328
+ if form_data.openai_config is not None:
329
+ app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
330
+ app.state.config.OPENAI_API_KEY = form_data.openai_config.key
331
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
332
+
333
+ update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
334
+
335
+ app.state.EMBEDDING_FUNCTION = get_embedding_function(
336
+ app.state.config.RAG_EMBEDDING_ENGINE,
337
+ app.state.config.RAG_EMBEDDING_MODEL,
338
+ app.state.sentence_transformer_ef,
339
+ app.state.config.OPENAI_API_KEY,
340
+ app.state.config.OPENAI_API_BASE_URL,
341
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE,
342
+ )
343
+
344
+ return {
345
+ "status": True,
346
+ "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
347
+ "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
348
+ "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
349
+ "openai_config": {
350
+ "url": app.state.config.OPENAI_API_BASE_URL,
351
+ "key": app.state.config.OPENAI_API_KEY,
352
+ },
353
+ }
354
+ except Exception as e:
355
+ log.exception(f"Problem updating embedding model: {e}")
356
+ raise HTTPException(
357
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
358
+ detail=ERROR_MESSAGES.DEFAULT(e),
359
+ )
360
+
361
+
362
+ class RerankingModelUpdateForm(BaseModel):
363
+ reranking_model: str
364
+
365
+
366
+ @app.post("/reranking/update")
367
+ async def update_reranking_config(
368
+ form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
369
+ ):
370
+ log.info(
371
+ f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
372
+ )
373
+ try:
374
+ app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
375
+
376
+ update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True)
377
+
378
+ return {
379
+ "status": True,
380
+ "reranking_model": app.state.config.RAG_RERANKING_MODEL,
381
+ }
382
+ except Exception as e:
383
+ log.exception(f"Problem updating reranking model: {e}")
384
+ raise HTTPException(
385
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
386
+ detail=ERROR_MESSAGES.DEFAULT(e),
387
+ )
388
+
389
+
390
+ @app.get("/config")
391
+ async def get_rag_config(user=Depends(get_admin_user)):
392
+ return {
393
+ "status": True,
394
+ "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
395
+ "content_extraction": {
396
+ "engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
397
+ "tika_server_url": app.state.config.TIKA_SERVER_URL,
398
+ },
399
+ "chunk": {
400
+ "text_splitter": app.state.config.TEXT_SPLITTER,
401
+ "chunk_size": app.state.config.CHUNK_SIZE,
402
+ "chunk_overlap": app.state.config.CHUNK_OVERLAP,
403
+ },
404
+ "file": {
405
+ "max_size": app.state.config.FILE_MAX_SIZE,
406
+ "max_count": app.state.config.FILE_MAX_COUNT,
407
+ },
408
+ "youtube": {
409
+ "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
410
+ "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
411
+ },
412
+ "web": {
413
+ "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
414
+ "search": {
415
+ "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
416
+ "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
417
+ "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
418
+ "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
419
+ "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
420
+ "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
421
+ "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
422
+ "serpstack_https": app.state.config.SERPSTACK_HTTPS,
423
+ "serper_api_key": app.state.config.SERPER_API_KEY,
424
+ "serply_api_key": app.state.config.SERPLY_API_KEY,
425
+ "tavily_api_key": app.state.config.TAVILY_API_KEY,
426
+ "searchapi_api_key": app.state.config.SEARCHAPI_API_KEY,
427
+ "seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE,
428
+ "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
429
+ "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
430
+ },
431
+ },
432
+ }
433
+
434
+
435
+ class FileConfig(BaseModel):
436
+ max_size: Optional[int] = None
437
+ max_count: Optional[int] = None
438
+
439
+
440
+ class ContentExtractionConfig(BaseModel):
441
+ engine: str = ""
442
+ tika_server_url: Optional[str] = None
443
+
444
+
445
+ class ChunkParamUpdateForm(BaseModel):
446
+ text_splitter: Optional[str] = None
447
+ chunk_size: int
448
+ chunk_overlap: int
449
+
450
+
451
+ class YoutubeLoaderConfig(BaseModel):
452
+ language: list[str]
453
+ translation: Optional[str] = None
454
+
455
+
456
+ class WebSearchConfig(BaseModel):
457
+ enabled: bool
458
+ engine: Optional[str] = None
459
+ searxng_query_url: Optional[str] = None
460
+ google_pse_api_key: Optional[str] = None
461
+ google_pse_engine_id: Optional[str] = None
462
+ brave_search_api_key: Optional[str] = None
463
+ serpstack_api_key: Optional[str] = None
464
+ serpstack_https: Optional[bool] = None
465
+ serper_api_key: Optional[str] = None
466
+ serply_api_key: Optional[str] = None
467
+ tavily_api_key: Optional[str] = None
468
+ searchapi_api_key: Optional[str] = None
469
+ searchapi_engine: Optional[str] = None
470
+ result_count: Optional[int] = None
471
+ concurrent_requests: Optional[int] = None
472
+
473
+
474
+ class WebConfig(BaseModel):
475
+ search: WebSearchConfig
476
+ web_loader_ssl_verification: Optional[bool] = None
477
+
478
+
479
+ class ConfigUpdateForm(BaseModel):
480
+ pdf_extract_images: Optional[bool] = None
481
+ file: Optional[FileConfig] = None
482
+ content_extraction: Optional[ContentExtractionConfig] = None
483
+ chunk: Optional[ChunkParamUpdateForm] = None
484
+ youtube: Optional[YoutubeLoaderConfig] = None
485
+ web: Optional[WebConfig] = None
486
+
487
+
488
+ @app.post("/config/update")
489
+ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
490
+ app.state.config.PDF_EXTRACT_IMAGES = (
491
+ form_data.pdf_extract_images
492
+ if form_data.pdf_extract_images is not None
493
+ else app.state.config.PDF_EXTRACT_IMAGES
494
+ )
495
+
496
+ if form_data.file is not None:
497
+ app.state.config.FILE_MAX_SIZE = form_data.file.max_size
498
+ app.state.config.FILE_MAX_COUNT = form_data.file.max_count
499
+
500
+ if form_data.content_extraction is not None:
501
+ log.info(f"Updating text settings: {form_data.content_extraction}")
502
+ app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine
503
+ app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
504
+
505
+ if form_data.chunk is not None:
506
+ app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
507
+ app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
508
+ app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
509
+
510
+ if form_data.youtube is not None:
511
+ app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
512
+ app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
513
+
514
+ if form_data.web is not None:
515
+ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
516
+ form_data.web.web_loader_ssl_verification
517
+ )
518
+
519
+ app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
520
+ app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
521
+ app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
522
+ app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
523
+ app.state.config.GOOGLE_PSE_ENGINE_ID = (
524
+ form_data.web.search.google_pse_engine_id
525
+ )
526
+ app.state.config.BRAVE_SEARCH_API_KEY = (
527
+ form_data.web.search.brave_search_api_key
528
+ )
529
+ app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
530
+ app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
531
+ app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
532
+ app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
533
+ app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
534
+ app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
535
+ app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
536
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
537
+ app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
538
+ form_data.web.search.concurrent_requests
539
+ )
540
+
541
+ return {
542
+ "status": True,
543
+ "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
544
+ "file": {
545
+ "max_size": app.state.config.FILE_MAX_SIZE,
546
+ "max_count": app.state.config.FILE_MAX_COUNT,
547
+ },
548
+ "content_extraction": {
549
+ "engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
550
+ "tika_server_url": app.state.config.TIKA_SERVER_URL,
551
+ },
552
+ "chunk": {
553
+ "text_splitter": app.state.config.TEXT_SPLITTER,
554
+ "chunk_size": app.state.config.CHUNK_SIZE,
555
+ "chunk_overlap": app.state.config.CHUNK_OVERLAP,
556
+ },
557
+ "youtube": {
558
+ "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
559
+ "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
560
+ },
561
+ "web": {
562
+ "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
563
+ "search": {
564
+ "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
565
+ "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
566
+ "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
567
+ "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
568
+ "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
569
+ "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
570
+ "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
571
+ "serpstack_https": app.state.config.SERPSTACK_HTTPS,
572
+ "serper_api_key": app.state.config.SERPER_API_KEY,
573
+ "serply_api_key": app.state.config.SERPLY_API_KEY,
574
+ "serachapi_api_key": app.state.config.SEARCHAPI_API_KEY,
575
+ "searchapi_engine": app.state.config.SEARCHAPI_ENGINE,
576
+ "tavily_api_key": app.state.config.TAVILY_API_KEY,
577
+ "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
578
+ "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
579
+ },
580
+ },
581
+ }
582
+
583
+
584
+ @app.get("/template")
585
+ async def get_rag_template(user=Depends(get_verified_user)):
586
+ return {
587
+ "status": True,
588
+ "template": app.state.config.RAG_TEMPLATE,
589
+ }
590
+
591
+
592
+ @app.get("/query/settings")
593
+ async def get_query_settings(user=Depends(get_admin_user)):
594
+ return {
595
+ "status": True,
596
+ "template": app.state.config.RAG_TEMPLATE,
597
+ "k": app.state.config.TOP_K,
598
+ "r": app.state.config.RELEVANCE_THRESHOLD,
599
+ "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
600
+ }
601
+
602
+
603
+ class QuerySettingsForm(BaseModel):
604
+ k: Optional[int] = None
605
+ r: Optional[float] = None
606
+ template: Optional[str] = None
607
+ hybrid: Optional[bool] = None
608
+
609
+
610
+ @app.post("/query/settings/update")
611
+ async def update_query_settings(
612
+ form_data: QuerySettingsForm, user=Depends(get_admin_user)
613
+ ):
614
+ app.state.config.RAG_TEMPLATE = form_data.template
615
+ app.state.config.TOP_K = form_data.k if form_data.k else 4
616
+ app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
617
+
618
+ app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
619
+ form_data.hybrid if form_data.hybrid else False
620
+ )
621
+
622
+ return {
623
+ "status": True,
624
+ "template": app.state.config.RAG_TEMPLATE,
625
+ "k": app.state.config.TOP_K,
626
+ "r": app.state.config.RELEVANCE_THRESHOLD,
627
+ "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
628
+ }
629
+
630
+
631
+ ####################################
632
+ #
633
+ # Document process and retrieval
634
+ #
635
+ ####################################
636
+
637
+
638
+ def save_docs_to_vector_db(
639
+ docs,
640
+ collection_name,
641
+ metadata: Optional[dict] = None,
642
+ overwrite: bool = False,
643
+ split: bool = True,
644
+ add: bool = False,
645
+ ) -> bool:
646
+ log.info(f"save_docs_to_vector_db {docs} {collection_name}")
647
+
648
+ # Check if entries with the same hash (metadata.hash) already exist
649
+ if metadata and "hash" in metadata:
650
+ result = VECTOR_DB_CLIENT.query(
651
+ collection_name=collection_name,
652
+ filter={"hash": metadata["hash"]},
653
+ )
654
+
655
+ if result is not None:
656
+ existing_doc_ids = result.ids[0]
657
+ if existing_doc_ids:
658
+ log.info(f"Document with hash {metadata['hash']} already exists")
659
+ raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
660
+
661
+ if split:
662
+ if app.state.config.TEXT_SPLITTER in ["", "character"]:
663
+ text_splitter = RecursiveCharacterTextSplitter(
664
+ chunk_size=app.state.config.CHUNK_SIZE,
665
+ chunk_overlap=app.state.config.CHUNK_OVERLAP,
666
+ add_start_index=True,
667
+ )
668
+ elif app.state.config.TEXT_SPLITTER == "token":
669
+ text_splitter = TokenTextSplitter(
670
+ encoding_name=app.state.config.TIKTOKEN_ENCODING_NAME,
671
+ chunk_size=app.state.config.CHUNK_SIZE,
672
+ chunk_overlap=app.state.config.CHUNK_OVERLAP,
673
+ add_start_index=True,
674
+ )
675
+ else:
676
+ raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
677
+
678
+ docs = text_splitter.split_documents(docs)
679
+
680
+ if len(docs) == 0:
681
+ raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
682
+
683
+ texts = [doc.page_content for doc in docs]
684
+ metadatas = [
685
+ {
686
+ **doc.metadata,
687
+ **(metadata if metadata else {}),
688
+ "embedding_config": json.dumps(
689
+ {
690
+ "engine": app.state.config.RAG_EMBEDDING_ENGINE,
691
+ "model": app.state.config.RAG_EMBEDDING_MODEL,
692
+ }
693
+ ),
694
+ }
695
+ for doc in docs
696
+ ]
697
+
698
+ # ChromaDB does not like datetime formats
699
+ # for meta-data so convert them to string.
700
+ for metadata in metadatas:
701
+ for key, value in metadata.items():
702
+ if isinstance(value, datetime):
703
+ metadata[key] = str(value)
704
+
705
+ try:
706
+ if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
707
+ log.info(f"collection {collection_name} already exists")
708
+
709
+ if overwrite:
710
+ VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
711
+ log.info(f"deleting existing collection {collection_name}")
712
+ elif add is False:
713
+ log.info(
714
+ f"collection {collection_name} already exists, overwrite is False and add is False"
715
+ )
716
+ return True
717
+
718
+ log.info(f"adding to collection {collection_name}")
719
+ embedding_function = get_embedding_function(
720
+ app.state.config.RAG_EMBEDDING_ENGINE,
721
+ app.state.config.RAG_EMBEDDING_MODEL,
722
+ app.state.sentence_transformer_ef,
723
+ app.state.config.OPENAI_API_KEY,
724
+ app.state.config.OPENAI_API_BASE_URL,
725
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE,
726
+ )
727
+
728
+ embeddings = embedding_function(
729
+ list(map(lambda x: x.replace("\n", " "), texts))
730
+ )
731
+
732
+ items = [
733
+ {
734
+ "id": str(uuid.uuid4()),
735
+ "text": text,
736
+ "vector": embeddings[idx],
737
+ "metadata": metadatas[idx],
738
+ }
739
+ for idx, text in enumerate(texts)
740
+ ]
741
+
742
+ VECTOR_DB_CLIENT.insert(
743
+ collection_name=collection_name,
744
+ items=items,
745
+ )
746
+
747
+ return True
748
+ except Exception as e:
749
+ log.exception(e)
750
+ return False
751
+
752
+
753
+ class ProcessFileForm(BaseModel):
754
+ file_id: str
755
+ content: Optional[str] = None
756
+ collection_name: Optional[str] = None
757
+
758
+
759
+ @app.post("/process/file")
760
+ def process_file(
761
+ form_data: ProcessFileForm,
762
+ user=Depends(get_verified_user),
763
+ ):
764
+ try:
765
+ file = Files.get_file_by_id(form_data.file_id)
766
+
767
+ collection_name = form_data.collection_name
768
+
769
+ if collection_name is None:
770
+ collection_name = f"file-{file.id}"
771
+
772
+ if form_data.content:
773
+ # Update the content in the file
774
+ # Usage: /files/{file_id}/data/content/update
775
+
776
+ VECTOR_DB_CLIENT.delete(
777
+ collection_name=f"file-{file.id}",
778
+ filter={"file_id": file.id},
779
+ )
780
+
781
+ docs = [
782
+ Document(
783
+ page_content=form_data.content,
784
+ metadata={
785
+ "name": file.meta.get("name", file.filename),
786
+ "created_by": file.user_id,
787
+ "file_id": file.id,
788
+ **file.meta,
789
+ },
790
+ )
791
+ ]
792
+
793
+ text_content = form_data.content
794
+ elif form_data.collection_name:
795
+ # Check if the file has already been processed and save the content
796
+ # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
797
+
798
+ result = VECTOR_DB_CLIENT.query(
799
+ collection_name=f"file-{file.id}", filter={"file_id": file.id}
800
+ )
801
+
802
+ if result is not None and len(result.ids[0]) > 0:
803
+ docs = [
804
+ Document(
805
+ page_content=result.documents[0][idx],
806
+ metadata=result.metadatas[0][idx],
807
+ )
808
+ for idx, id in enumerate(result.ids[0])
809
+ ]
810
+ else:
811
+ docs = [
812
+ Document(
813
+ page_content=file.data.get("content", ""),
814
+ metadata={
815
+ "name": file.meta.get("name", file.filename),
816
+ "created_by": file.user_id,
817
+ "file_id": file.id,
818
+ **file.meta,
819
+ },
820
+ )
821
+ ]
822
+
823
+ text_content = file.data.get("content", "")
824
+ else:
825
+ # Process the file and save the content
826
+ # Usage: /files/
827
+ file_path = file.path
828
+ if file_path:
829
+ file_path = Storage.get_file(file_path)
830
+ loader = Loader(
831
+ engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
832
+ TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
833
+ PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
834
+ )
835
+ docs = loader.load(
836
+ file.filename, file.meta.get("content_type"), file_path
837
+ )
838
+ else:
839
+ docs = [
840
+ Document(
841
+ page_content=file.data.get("content", ""),
842
+ metadata={
843
+ "name": file.filename,
844
+ "created_by": file.user_id,
845
+ "file_id": file.id,
846
+ **file.meta,
847
+ },
848
+ )
849
+ ]
850
+ text_content = " ".join([doc.page_content for doc in docs])
851
+
852
+ log.debug(f"text_content: {text_content}")
853
+ Files.update_file_data_by_id(
854
+ file.id,
855
+ {"content": text_content},
856
+ )
857
+
858
+ hash = calculate_sha256_string(text_content)
859
+ Files.update_file_hash_by_id(file.id, hash)
860
+
861
+ try:
862
+ result = save_docs_to_vector_db(
863
+ docs=docs,
864
+ collection_name=collection_name,
865
+ metadata={
866
+ "file_id": file.id,
867
+ "name": file.meta.get("name", file.filename),
868
+ "hash": hash,
869
+ },
870
+ add=(True if form_data.collection_name else False),
871
+ )
872
+
873
+ if result:
874
+ Files.update_file_metadata_by_id(
875
+ file.id,
876
+ {
877
+ "collection_name": collection_name,
878
+ },
879
+ )
880
+
881
+ return {
882
+ "status": True,
883
+ "collection_name": collection_name,
884
+ "filename": file.meta.get("name", file.filename),
885
+ "content": text_content,
886
+ }
887
+ except Exception as e:
888
+ raise e
889
+ except Exception as e:
890
+ log.exception(e)
891
+ if "No pandoc was found" in str(e):
892
+ raise HTTPException(
893
+ status_code=status.HTTP_400_BAD_REQUEST,
894
+ detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
895
+ )
896
+ else:
897
+ raise HTTPException(
898
+ status_code=status.HTTP_400_BAD_REQUEST,
899
+ detail=str(e),
900
+ )
901
+
902
+
903
+ class ProcessTextForm(BaseModel):
904
+ name: str
905
+ content: str
906
+ collection_name: Optional[str] = None
907
+
908
+
909
+ @app.post("/process/text")
910
+ def process_text(
911
+ form_data: ProcessTextForm,
912
+ user=Depends(get_verified_user),
913
+ ):
914
+ collection_name = form_data.collection_name
915
+ if collection_name is None:
916
+ collection_name = calculate_sha256_string(form_data.content)
917
+
918
+ docs = [
919
+ Document(
920
+ page_content=form_data.content,
921
+ metadata={"name": form_data.name, "created_by": user.id},
922
+ )
923
+ ]
924
+ text_content = form_data.content
925
+ log.debug(f"text_content: {text_content}")
926
+
927
+ result = save_docs_to_vector_db(docs, collection_name)
928
+
929
+ if result:
930
+ return {
931
+ "status": True,
932
+ "collection_name": collection_name,
933
+ "content": text_content,
934
+ }
935
+ else:
936
+ raise HTTPException(
937
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
938
+ detail=ERROR_MESSAGES.DEFAULT(),
939
+ )
940
+
941
+
942
+ @app.post("/process/youtube")
943
+ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
944
+ try:
945
+ collection_name = form_data.collection_name
946
+ if not collection_name:
947
+ collection_name = calculate_sha256_string(form_data.url)[:63]
948
+
949
+ loader = YoutubeLoader.from_youtube_url(
950
+ form_data.url,
951
+ add_video_info=True,
952
+ language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
953
+ translation=app.state.YOUTUBE_LOADER_TRANSLATION,
954
+ )
955
+ docs = loader.load()
956
+ content = " ".join([doc.page_content for doc in docs])
957
+ log.debug(f"text_content: {content}")
958
+ save_docs_to_vector_db(docs, collection_name, overwrite=True)
959
+
960
+ return {
961
+ "status": True,
962
+ "collection_name": collection_name,
963
+ "filename": form_data.url,
964
+ "file": {
965
+ "data": {
966
+ "content": content,
967
+ },
968
+ "meta": {
969
+ "name": form_data.url,
970
+ },
971
+ },
972
+ }
973
+ except Exception as e:
974
+ log.exception(e)
975
+ raise HTTPException(
976
+ status_code=status.HTTP_400_BAD_REQUEST,
977
+ detail=ERROR_MESSAGES.DEFAULT(e),
978
+ )
979
+
980
+
981
+ @app.post("/process/web")
982
+ def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
983
+ try:
984
+ collection_name = form_data.collection_name
985
+ if not collection_name:
986
+ collection_name = calculate_sha256_string(form_data.url)[:63]
987
+
988
+ loader = get_web_loader(
989
+ form_data.url,
990
+ verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
991
+ requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
992
+ )
993
+ docs = loader.load()
994
+ content = " ".join([doc.page_content for doc in docs])
995
+ log.debug(f"text_content: {content}")
996
+ save_docs_to_vector_db(docs, collection_name, overwrite=True)
997
+
998
+ return {
999
+ "status": True,
1000
+ "collection_name": collection_name,
1001
+ "filename": form_data.url,
1002
+ "file": {
1003
+ "data": {
1004
+ "content": content,
1005
+ },
1006
+ "meta": {
1007
+ "name": form_data.url,
1008
+ },
1009
+ },
1010
+ }
1011
+ except Exception as e:
1012
+ log.exception(e)
1013
+ raise HTTPException(
1014
+ status_code=status.HTTP_400_BAD_REQUEST,
1015
+ detail=ERROR_MESSAGES.DEFAULT(e),
1016
+ )
1017
+
1018
+
1019
+ def search_web(engine: str, query: str) -> list[SearchResult]:
1020
+ """Search the web using a search engine and return the results as a list of SearchResult objects.
1021
+ Will look for a search engine API key in environment variables in the following order:
1022
+ - SEARXNG_QUERY_URL
1023
+ - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
1024
+ - BRAVE_SEARCH_API_KEY
1025
+ - SERPSTACK_API_KEY
1026
+ - SERPER_API_KEY
1027
+ - SERPLY_API_KEY
1028
+ - TAVILY_API_KEY
1029
+ - SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
1030
+ Args:
1031
+ query (str): The query to search for
1032
+ """
1033
+
1034
+ # TODO: add playwright to search the web
1035
+ if engine == "searxng":
1036
+ if app.state.config.SEARXNG_QUERY_URL:
1037
+ return search_searxng(
1038
+ app.state.config.SEARXNG_QUERY_URL,
1039
+ query,
1040
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1041
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1042
+ )
1043
+ else:
1044
+ raise Exception("No SEARXNG_QUERY_URL found in environment variables")
1045
+ elif engine == "google_pse":
1046
+ if (
1047
+ app.state.config.GOOGLE_PSE_API_KEY
1048
+ and app.state.config.GOOGLE_PSE_ENGINE_ID
1049
+ ):
1050
+ return search_google_pse(
1051
+ app.state.config.GOOGLE_PSE_API_KEY,
1052
+ app.state.config.GOOGLE_PSE_ENGINE_ID,
1053
+ query,
1054
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1055
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1056
+ )
1057
+ else:
1058
+ raise Exception(
1059
+ "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
1060
+ )
1061
+ elif engine == "brave":
1062
+ if app.state.config.BRAVE_SEARCH_API_KEY:
1063
+ return search_brave(
1064
+ app.state.config.BRAVE_SEARCH_API_KEY,
1065
+ query,
1066
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1067
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1068
+ )
1069
+ else:
1070
+ raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
1071
+ elif engine == "serpstack":
1072
+ if app.state.config.SERPSTACK_API_KEY:
1073
+ return search_serpstack(
1074
+ app.state.config.SERPSTACK_API_KEY,
1075
+ query,
1076
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1077
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1078
+ https_enabled=app.state.config.SERPSTACK_HTTPS,
1079
+ )
1080
+ else:
1081
+ raise Exception("No SERPSTACK_API_KEY found in environment variables")
1082
+ elif engine == "serper":
1083
+ if app.state.config.SERPER_API_KEY:
1084
+ return search_serper(
1085
+ app.state.config.SERPER_API_KEY,
1086
+ query,
1087
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1088
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1089
+ )
1090
+ else:
1091
+ raise Exception("No SERPER_API_KEY found in environment variables")
1092
+ elif engine == "serply":
1093
+ if app.state.config.SERPLY_API_KEY:
1094
+ return search_serply(
1095
+ app.state.config.SERPLY_API_KEY,
1096
+ query,
1097
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1098
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1099
+ )
1100
+ else:
1101
+ raise Exception("No SERPLY_API_KEY found in environment variables")
1102
+ elif engine == "duckduckgo":
1103
+ return search_duckduckgo(
1104
+ query,
1105
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1106
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1107
+ )
1108
+ elif engine == "tavily":
1109
+ if app.state.config.TAVILY_API_KEY:
1110
+ return search_tavily(
1111
+ app.state.config.TAVILY_API_KEY,
1112
+ query,
1113
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1114
+ )
1115
+ else:
1116
+ raise Exception("No TAVILY_API_KEY found in environment variables")
1117
+ elif engine == "searchapi":
1118
+ if app.state.config.SEARCHAPI_API_KEY:
1119
+ return search_searchapi(
1120
+ app.state.config.SEARCHAPI_API_KEY,
1121
+ app.state.config.SEARCHAPI_ENGINE,
1122
+ query,
1123
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1124
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1125
+ )
1126
+ else:
1127
+ raise Exception("No SEARCHAPI_API_KEY found in environment variables")
1128
+ elif engine == "jina":
1129
+ return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
1130
+ else:
1131
+ raise Exception("No search engine API key found in environment variables")
1132
+
1133
+
1134
+ @app.post("/process/web/search")
1135
+ def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
1136
+ try:
1137
+ logging.info(
1138
+ f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
1139
+ )
1140
+ web_results = search_web(
1141
+ app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
1142
+ )
1143
+ except Exception as e:
1144
+ log.exception(e)
1145
+
1146
+ print(e)
1147
+ raise HTTPException(
1148
+ status_code=status.HTTP_400_BAD_REQUEST,
1149
+ detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
1150
+ )
1151
+
1152
+ try:
1153
+ collection_name = form_data.collection_name
1154
+ if collection_name == "":
1155
+ collection_name = calculate_sha256_string(form_data.query)[:63]
1156
+
1157
+ urls = [result.link for result in web_results]
1158
+
1159
+ loader = get_web_loader(urls)
1160
+ docs = loader.load()
1161
+
1162
+ save_docs_to_vector_db(docs, collection_name, overwrite=True)
1163
+
1164
+ return {
1165
+ "status": True,
1166
+ "collection_name": collection_name,
1167
+ "filenames": urls,
1168
+ }
1169
+ except Exception as e:
1170
+ log.exception(e)
1171
+ raise HTTPException(
1172
+ status_code=status.HTTP_400_BAD_REQUEST,
1173
+ detail=ERROR_MESSAGES.DEFAULT(e),
1174
+ )
1175
+
1176
+
1177
+ class QueryDocForm(BaseModel):
1178
+ collection_name: str
1179
+ query: str
1180
+ k: Optional[int] = None
1181
+ r: Optional[float] = None
1182
+ hybrid: Optional[bool] = None
1183
+
1184
+
1185
+ @app.post("/query/doc")
1186
+ def query_doc_handler(
1187
+ form_data: QueryDocForm,
1188
+ user=Depends(get_verified_user),
1189
+ ):
1190
+ try:
1191
+ if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
1192
+ return query_doc_with_hybrid_search(
1193
+ collection_name=form_data.collection_name,
1194
+ query=form_data.query,
1195
+ embedding_function=app.state.EMBEDDING_FUNCTION,
1196
+ k=form_data.k if form_data.k else app.state.config.TOP_K,
1197
+ reranking_function=app.state.sentence_transformer_rf,
1198
+ r=(
1199
+ form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
1200
+ ),
1201
+ )
1202
+ else:
1203
+ return query_doc(
1204
+ collection_name=form_data.collection_name,
1205
+ query=form_data.query,
1206
+ embedding_function=app.state.EMBEDDING_FUNCTION,
1207
+ k=form_data.k if form_data.k else app.state.config.TOP_K,
1208
+ )
1209
+ except Exception as e:
1210
+ log.exception(e)
1211
+ raise HTTPException(
1212
+ status_code=status.HTTP_400_BAD_REQUEST,
1213
+ detail=ERROR_MESSAGES.DEFAULT(e),
1214
+ )
1215
+
1216
+
1217
+ class QueryCollectionsForm(BaseModel):
1218
+ collection_names: list[str]
1219
+ query: str
1220
+ k: Optional[int] = None
1221
+ r: Optional[float] = None
1222
+ hybrid: Optional[bool] = None
1223
+
1224
+
1225
+ @app.post("/query/collection")
1226
+ def query_collection_handler(
1227
+ form_data: QueryCollectionsForm,
1228
+ user=Depends(get_verified_user),
1229
+ ):
1230
+ try:
1231
+ if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
1232
+ return query_collection_with_hybrid_search(
1233
+ collection_names=form_data.collection_names,
1234
+ query=form_data.query,
1235
+ embedding_function=app.state.EMBEDDING_FUNCTION,
1236
+ k=form_data.k if form_data.k else app.state.config.TOP_K,
1237
+ reranking_function=app.state.sentence_transformer_rf,
1238
+ r=(
1239
+ form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
1240
+ ),
1241
+ )
1242
+ else:
1243
+ return query_collection(
1244
+ collection_names=form_data.collection_names,
1245
+ query=form_data.query,
1246
+ embedding_function=app.state.EMBEDDING_FUNCTION,
1247
+ k=form_data.k if form_data.k else app.state.config.TOP_K,
1248
+ )
1249
+
1250
+ except Exception as e:
1251
+ log.exception(e)
1252
+ raise HTTPException(
1253
+ status_code=status.HTTP_400_BAD_REQUEST,
1254
+ detail=ERROR_MESSAGES.DEFAULT(e),
1255
+ )
1256
+
1257
+
1258
+ ####################################
1259
+ #
1260
+ # Vector DB operations
1261
+ #
1262
+ ####################################
1263
+
1264
+
1265
+ class DeleteForm(BaseModel):
1266
+ collection_name: str
1267
+ file_id: str
1268
+
1269
+
1270
+ @app.post("/delete")
1271
+ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
1272
+ try:
1273
+ if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
1274
+ file = Files.get_file_by_id(form_data.file_id)
1275
+ hash = file.hash
1276
+
1277
+ VECTOR_DB_CLIENT.delete(
1278
+ collection_name=form_data.collection_name,
1279
+ metadata={"hash": hash},
1280
+ )
1281
+ return {"status": True}
1282
+ else:
1283
+ return {"status": False}
1284
+ except Exception as e:
1285
+ log.exception(e)
1286
+ return {"status": False}
1287
+
1288
+
1289
+ @app.post("/reset/db")
1290
+ def reset_vector_db(user=Depends(get_admin_user)):
1291
+ VECTOR_DB_CLIENT.reset()
1292
+ Knowledges.delete_all_knowledge()
1293
+
1294
+
1295
+ @app.post("/reset/uploads")
1296
+ def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
1297
+ folder = f"{UPLOAD_DIR}"
1298
+ try:
1299
+ # Check if the directory exists
1300
+ if os.path.exists(folder):
1301
+ # Iterate over all the files and directories in the specified directory
1302
+ for filename in os.listdir(folder):
1303
+ file_path = os.path.join(folder, filename)
1304
+ try:
1305
+ if os.path.isfile(file_path) or os.path.islink(file_path):
1306
+ os.unlink(file_path) # Remove the file or link
1307
+ elif os.path.isdir(file_path):
1308
+ shutil.rmtree(file_path) # Remove the directory
1309
+ except Exception as e:
1310
+ print(f"Failed to delete {file_path}. Reason: {e}")
1311
+ else:
1312
+ print(f"The directory {folder} does not exist")
1313
+ except Exception as e:
1314
+ print(f"Failed to process the directory {folder}. Reason: {e}")
1315
+ return True
1316
+
1317
+
1318
+ if ENV == "dev":
1319
+
1320
+ @app.get("/ef")
1321
+ async def get_embeddings():
1322
+ return {"result": app.state.EMBEDDING_FUNCTION("hello world")}
1323
+
1324
+ @app.get("/ef/{text}")
1325
+ async def get_embeddings_text(text: str):
1326
+ return {"result": app.state.EMBEDDING_FUNCTION(text)}
backend/open_webui/apps/retrieval/models/colbert.py CHANGED
@@ -1,81 +1,81 @@
1
- import os
2
- import torch
3
- import numpy as np
4
- from colbert.infra import ColBERTConfig
5
- from colbert.modeling.checkpoint import Checkpoint
6
-
7
-
8
- class ColBERT:
9
- def __init__(self, name, **kwargs) -> None:
10
- print("ColBERT: Loading model", name)
11
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
-
13
- DOCKER = kwargs.get("env") == "docker"
14
- if DOCKER:
15
- # This is a workaround for the issue with the docker container
16
- # where the torch extension is not loaded properly
17
- # and the following error is thrown:
18
- # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
19
-
20
- lock_file = (
21
- "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
22
- )
23
- if os.path.exists(lock_file):
24
- os.remove(lock_file)
25
-
26
- self.ckpt = Checkpoint(
27
- name,
28
- colbert_config=ColBERTConfig(model_name=name),
29
- ).to(self.device)
30
- pass
31
-
32
- def calculate_similarity_scores(self, query_embeddings, document_embeddings):
33
-
34
- query_embeddings = query_embeddings.to(self.device)
35
- document_embeddings = document_embeddings.to(self.device)
36
-
37
- # Validate dimensions to ensure compatibility
38
- if query_embeddings.dim() != 3:
39
- raise ValueError(
40
- f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
41
- )
42
- if document_embeddings.dim() != 3:
43
- raise ValueError(
44
- f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
45
- )
46
- if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
47
- raise ValueError(
48
- "There should be either one query or queries equal to the number of documents."
49
- )
50
-
51
- # Transpose the query embeddings to align for matrix multiplication
52
- transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
53
- # Compute similarity scores using batch matrix multiplication
54
- computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings)
55
- # Apply max pooling to extract the highest semantic similarity across each document's sequence
56
- maximum_scores = torch.max(computed_scores, dim=1).values
57
-
58
- # Sum up the maximum scores across features to get the overall document relevance scores
59
- final_scores = maximum_scores.sum(dim=1)
60
-
61
- normalized_scores = torch.softmax(final_scores, dim=0)
62
-
63
- return normalized_scores.detach().cpu().numpy().astype(np.float32)
64
-
65
- def predict(self, sentences):
66
-
67
- query = sentences[0][0]
68
- docs = [i[1] for i in sentences]
69
-
70
- # Embedding the documents
71
- embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
72
- # Embedding the queries
73
- embedded_queries = self.ckpt.queryFromText([query], bsize=32)
74
- embedded_query = embedded_queries[0]
75
-
76
- # Calculate retrieval scores for the query against all documents
77
- scores = self.calculate_similarity_scores(
78
- embedded_query.unsqueeze(0), embedded_docs
79
- )
80
-
81
- return scores
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from colbert.infra import ColBERTConfig
5
+ from colbert.modeling.checkpoint import Checkpoint
6
+
7
+
8
+ class ColBERT:
9
+ def __init__(self, name, **kwargs) -> None:
10
+ print("ColBERT: Loading model", name)
11
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ DOCKER = kwargs.get("env") == "docker"
14
+ if DOCKER:
15
+ # This is a workaround for the issue with the docker container
16
+ # where the torch extension is not loaded properly
17
+ # and the following error is thrown:
18
+ # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
19
+
20
+ lock_file = (
21
+ "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
22
+ )
23
+ if os.path.exists(lock_file):
24
+ os.remove(lock_file)
25
+
26
+ self.ckpt = Checkpoint(
27
+ name,
28
+ colbert_config=ColBERTConfig(model_name=name),
29
+ ).to(self.device)
30
+ pass
31
+
32
+ def calculate_similarity_scores(self, query_embeddings, document_embeddings):
33
+
34
+ query_embeddings = query_embeddings.to(self.device)
35
+ document_embeddings = document_embeddings.to(self.device)
36
+
37
+ # Validate dimensions to ensure compatibility
38
+ if query_embeddings.dim() != 3:
39
+ raise ValueError(
40
+ f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
41
+ )
42
+ if document_embeddings.dim() != 3:
43
+ raise ValueError(
44
+ f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
45
+ )
46
+ if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
47
+ raise ValueError(
48
+ "There should be either one query or queries equal to the number of documents."
49
+ )
50
+
51
+ # Transpose the query embeddings to align for matrix multiplication
52
+ transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
53
+ # Compute similarity scores using batch matrix multiplication
54
+ computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings)
55
+ # Apply max pooling to extract the highest semantic similarity across each document's sequence
56
+ maximum_scores = torch.max(computed_scores, dim=1).values
57
+
58
+ # Sum up the maximum scores across features to get the overall document relevance scores
59
+ final_scores = maximum_scores.sum(dim=1)
60
+
61
+ normalized_scores = torch.softmax(final_scores, dim=0)
62
+
63
+ return normalized_scores.detach().cpu().numpy().astype(np.float32)
64
+
65
+ def predict(self, sentences):
66
+
67
+ query = sentences[0][0]
68
+ docs = [i[1] for i in sentences]
69
+
70
+ # Embedding the documents
71
+ embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
72
+ # Embedding the queries
73
+ embedded_queries = self.ckpt.queryFromText([query], bsize=32)
74
+ embedded_query = embedded_queries[0]
75
+
76
+ # Calculate retrieval scores for the query against all documents
77
+ scores = self.calculate_similarity_scores(
78
+ embedded_query.unsqueeze(0), embedded_docs
79
+ )
80
+
81
+ return scores
backend/open_webui/apps/retrieval/utils.py CHANGED
@@ -1,573 +1,573 @@
1
- import logging
2
- import os
3
- import uuid
4
- from typing import Optional, Union
5
-
6
- import requests
7
-
8
- from huggingface_hub import snapshot_download
9
- from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
10
- from langchain_community.retrievers import BM25Retriever
11
- from langchain_core.documents import Document
12
-
13
-
14
- from open_webui.apps.ollama.main import (
15
- GenerateEmbedForm,
16
- generate_ollama_batch_embeddings,
17
- )
18
- from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
19
- from open_webui.utils.misc import get_last_user_message
20
-
21
- from open_webui.env import SRC_LOG_LEVELS
22
- from open_webui.config import DEFAULT_RAG_TEMPLATE
23
-
24
-
25
- log = logging.getLogger(__name__)
26
- log.setLevel(SRC_LOG_LEVELS["RAG"])
27
-
28
-
29
- from typing import Any
30
-
31
- from langchain_core.callbacks import CallbackManagerForRetrieverRun
32
- from langchain_core.retrievers import BaseRetriever
33
-
34
-
35
- class VectorSearchRetriever(BaseRetriever):
36
- collection_name: Any
37
- embedding_function: Any
38
- top_k: int
39
-
40
- def _get_relevant_documents(
41
- self,
42
- query: str,
43
- *,
44
- run_manager: CallbackManagerForRetrieverRun,
45
- ) -> list[Document]:
46
- result = VECTOR_DB_CLIENT.search(
47
- collection_name=self.collection_name,
48
- vectors=[self.embedding_function(query)],
49
- limit=self.top_k,
50
- )
51
-
52
- ids = result.ids[0]
53
- metadatas = result.metadatas[0]
54
- documents = result.documents[0]
55
-
56
- results = []
57
- for idx in range(len(ids)):
58
- results.append(
59
- Document(
60
- metadata=metadatas[idx],
61
- page_content=documents[idx],
62
- )
63
- )
64
- return results
65
-
66
-
67
- def query_doc(
68
- collection_name: str,
69
- query_embedding: list[float],
70
- k: int,
71
- ):
72
- try:
73
- result = VECTOR_DB_CLIENT.search(
74
- collection_name=collection_name,
75
- vectors=[query_embedding],
76
- limit=k,
77
- )
78
-
79
- log.info(f"query_doc:result {result}")
80
- return result
81
- except Exception as e:
82
- print(e)
83
- raise e
84
-
85
-
86
- def query_doc_with_hybrid_search(
87
- collection_name: str,
88
- query: str,
89
- embedding_function,
90
- k: int,
91
- reranking_function,
92
- r: float,
93
- ) -> dict:
94
- try:
95
- result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
96
-
97
- bm25_retriever = BM25Retriever.from_texts(
98
- texts=result.documents[0],
99
- metadatas=result.metadatas[0],
100
- )
101
- bm25_retriever.k = k
102
-
103
- vector_search_retriever = VectorSearchRetriever(
104
- collection_name=collection_name,
105
- embedding_function=embedding_function,
106
- top_k=k,
107
- )
108
-
109
- ensemble_retriever = EnsembleRetriever(
110
- retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
111
- )
112
- compressor = RerankCompressor(
113
- embedding_function=embedding_function,
114
- top_n=k,
115
- reranking_function=reranking_function,
116
- r_score=r,
117
- )
118
-
119
- compression_retriever = ContextualCompressionRetriever(
120
- base_compressor=compressor, base_retriever=ensemble_retriever
121
- )
122
-
123
- result = compression_retriever.invoke(query)
124
- result = {
125
- "distances": [[d.metadata.get("score") for d in result]],
126
- "documents": [[d.page_content for d in result]],
127
- "metadatas": [[d.metadata for d in result]],
128
- }
129
-
130
- log.info(f"query_doc_with_hybrid_search:result {result}")
131
- return result
132
- except Exception as e:
133
- raise e
134
-
135
-
136
- def merge_and_sort_query_results(
137
- query_results: list[dict], k: int, reverse: bool = False
138
- ) -> list[dict]:
139
- # Initialize lists to store combined data
140
- combined_distances = []
141
- combined_documents = []
142
- combined_metadatas = []
143
-
144
- for data in query_results:
145
- combined_distances.extend(data["distances"][0])
146
- combined_documents.extend(data["documents"][0])
147
- combined_metadatas.extend(data["metadatas"][0])
148
-
149
- # Create a list of tuples (distance, document, metadata)
150
- combined = list(zip(combined_distances, combined_documents, combined_metadatas))
151
-
152
- # Sort the list based on distances
153
- combined.sort(key=lambda x: x[0], reverse=reverse)
154
-
155
- # We don't have anything :-(
156
- if not combined:
157
- sorted_distances = []
158
- sorted_documents = []
159
- sorted_metadatas = []
160
- else:
161
- # Unzip the sorted list
162
- sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
163
-
164
- # Slicing the lists to include only k elements
165
- sorted_distances = list(sorted_distances)[:k]
166
- sorted_documents = list(sorted_documents)[:k]
167
- sorted_metadatas = list(sorted_metadatas)[:k]
168
-
169
- # Create the output dictionary
170
- result = {
171
- "distances": [sorted_distances],
172
- "documents": [sorted_documents],
173
- "metadatas": [sorted_metadatas],
174
- }
175
-
176
- return result
177
-
178
-
179
- def query_collection(
180
- collection_names: list[str],
181
- query: str,
182
- embedding_function,
183
- k: int,
184
- ) -> dict:
185
-
186
- results = []
187
- query_embedding = embedding_function(query)
188
-
189
- for collection_name in collection_names:
190
- if collection_name:
191
- try:
192
- result = query_doc(
193
- collection_name=collection_name,
194
- k=k,
195
- query_embedding=query_embedding,
196
- )
197
- if result is not None:
198
- results.append(result.model_dump())
199
- except Exception as e:
200
- log.exception(f"Error when querying the collection: {e}")
201
- else:
202
- pass
203
-
204
- return merge_and_sort_query_results(results, k=k)
205
-
206
-
207
- def query_collection_with_hybrid_search(
208
- collection_names: list[str],
209
- query: str,
210
- embedding_function,
211
- k: int,
212
- reranking_function,
213
- r: float,
214
- ) -> dict:
215
- results = []
216
- error = False
217
- for collection_name in collection_names:
218
- try:
219
- result = query_doc_with_hybrid_search(
220
- collection_name=collection_name,
221
- query=query,
222
- embedding_function=embedding_function,
223
- k=k,
224
- reranking_function=reranking_function,
225
- r=r,
226
- )
227
- results.append(result)
228
- except Exception as e:
229
- log.exception(
230
- "Error when querying the collection with " f"hybrid_search: {e}"
231
- )
232
- error = True
233
-
234
- if error:
235
- raise Exception(
236
- "Hybrid search failed for all collections. Using Non hybrid search as fallback."
237
- )
238
-
239
- return merge_and_sort_query_results(results, k=k, reverse=True)
240
-
241
-
242
- def rag_template(template: str, context: str, query: str):
243
- if template == "":
244
- template = DEFAULT_RAG_TEMPLATE
245
-
246
- if "[context]" not in template and "{{CONTEXT}}" not in template:
247
- log.debug(
248
- "WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
249
- )
250
-
251
- if "<context>" in context and "</context>" in context:
252
- log.debug(
253
- "WARNING: Potential prompt injection attack: the RAG "
254
- "context contains '<context>' and '</context>'. This might be "
255
- "nothing, or the user might be trying to hack something."
256
- )
257
-
258
- query_placeholders = []
259
- if "[query]" in context:
260
- query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
261
- template = template.replace("[query]", query_placeholder)
262
- query_placeholders.append(query_placeholder)
263
-
264
- if "{{QUERY}}" in context:
265
- query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
266
- template = template.replace("{{QUERY}}", query_placeholder)
267
- query_placeholders.append(query_placeholder)
268
-
269
- template = template.replace("[context]", context)
270
- template = template.replace("{{CONTEXT}}", context)
271
- template = template.replace("[query]", query)
272
- template = template.replace("{{QUERY}}", query)
273
-
274
- for query_placeholder in query_placeholders:
275
- template = template.replace(query_placeholder, query)
276
-
277
- return template
278
-
279
-
280
- def get_embedding_function(
281
- embedding_engine,
282
- embedding_model,
283
- embedding_function,
284
- openai_key,
285
- openai_url,
286
- embedding_batch_size,
287
- ):
288
- if embedding_engine == "":
289
- return lambda query: embedding_function.encode(query).tolist()
290
- elif embedding_engine in ["ollama", "openai"]:
291
- func = lambda query: generate_embeddings(
292
- engine=embedding_engine,
293
- model=embedding_model,
294
- text=query,
295
- key=openai_key if embedding_engine == "openai" else "",
296
- url=openai_url if embedding_engine == "openai" else "",
297
- )
298
-
299
- def generate_multiple(query, func):
300
- if isinstance(query, list):
301
- embeddings = []
302
- for i in range(0, len(query), embedding_batch_size):
303
- embeddings.extend(func(query[i : i + embedding_batch_size]))
304
- return embeddings
305
- else:
306
- return func(query)
307
-
308
- return lambda query: generate_multiple(query, func)
309
-
310
-
311
- def get_rag_context(
312
- files,
313
- messages,
314
- embedding_function,
315
- k,
316
- reranking_function,
317
- r,
318
- hybrid_search,
319
- ):
320
- log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
321
- query = get_last_user_message(messages)
322
-
323
- extracted_collections = []
324
- relevant_contexts = []
325
-
326
- for file in files:
327
- if file.get("context") == "full":
328
- context = {
329
- "documents": [[file.get("file").get("data", {}).get("content")]],
330
- "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
331
- }
332
- else:
333
- context = None
334
-
335
- collection_names = []
336
- if file.get("type") == "collection":
337
- if file.get("legacy"):
338
- collection_names = file.get("collection_names", [])
339
- else:
340
- collection_names.append(file["id"])
341
- elif file.get("collection_name"):
342
- collection_names.append(file["collection_name"])
343
- elif file.get("id"):
344
- if file.get("legacy"):
345
- collection_names.append(f"{file['id']}")
346
- else:
347
- collection_names.append(f"file-{file['id']}")
348
-
349
- collection_names = set(collection_names).difference(extracted_collections)
350
- if not collection_names:
351
- log.debug(f"skipping {file} as it has already been extracted")
352
- continue
353
-
354
- try:
355
- context = None
356
- if file.get("type") == "text":
357
- context = file["content"]
358
- else:
359
- if hybrid_search:
360
- try:
361
- context = query_collection_with_hybrid_search(
362
- collection_names=collection_names,
363
- query=query,
364
- embedding_function=embedding_function,
365
- k=k,
366
- reranking_function=reranking_function,
367
- r=r,
368
- )
369
- except Exception as e:
370
- log.debug(
371
- "Error when using hybrid search, using"
372
- " non hybrid search as fallback."
373
- )
374
-
375
- if (not hybrid_search) or (context is None):
376
- context = query_collection(
377
- collection_names=collection_names,
378
- query=query,
379
- embedding_function=embedding_function,
380
- k=k,
381
- )
382
- except Exception as e:
383
- log.exception(e)
384
-
385
- extracted_collections.extend(collection_names)
386
-
387
- if context:
388
- if "data" in file:
389
- del file["data"]
390
- relevant_contexts.append({**context, "file": file})
391
-
392
- contexts = []
393
- citations = []
394
- for context in relevant_contexts:
395
- try:
396
- if "documents" in context:
397
- file_names = list(
398
- set(
399
- [
400
- metadata["name"]
401
- for metadata in context["metadatas"][0]
402
- if metadata is not None and "name" in metadata
403
- ]
404
- )
405
- )
406
- contexts.append(
407
- ((", ".join(file_names) + ":\n\n") if file_names else "")
408
- + "\n\n".join(
409
- [text for text in context["documents"][0] if text is not None]
410
- )
411
- )
412
-
413
- if "metadatas" in context:
414
- citation = {
415
- "source": context["file"],
416
- "document": context["documents"][0],
417
- "metadata": context["metadatas"][0],
418
- }
419
- if "distances" in context and context["distances"]:
420
- citation["distances"] = context["distances"][0]
421
- citations.append(citation)
422
- except Exception as e:
423
- log.exception(e)
424
-
425
- print("contexts", contexts)
426
- print("citations", citations)
427
-
428
- return contexts, citations
429
-
430
-
431
- def get_model_path(model: str, update_model: bool = False):
432
- # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
433
- cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
434
-
435
- local_files_only = not update_model
436
-
437
- snapshot_kwargs = {
438
- "cache_dir": cache_dir,
439
- "local_files_only": local_files_only,
440
- }
441
-
442
- log.debug(f"model: {model}")
443
- log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
444
-
445
- # Inspiration from upstream sentence_transformers
446
- if (
447
- os.path.exists(model)
448
- or ("\\" in model or model.count("/") > 1)
449
- and local_files_only
450
- ):
451
- # If fully qualified path exists, return input, else set repo_id
452
- return model
453
- elif "/" not in model:
454
- # Set valid repo_id for model short-name
455
- model = "sentence-transformers" + "/" + model
456
-
457
- snapshot_kwargs["repo_id"] = model
458
-
459
- # Attempt to query the huggingface_hub library to determine the local path and/or to update
460
- try:
461
- model_repo_path = snapshot_download(**snapshot_kwargs)
462
- log.debug(f"model_repo_path: {model_repo_path}")
463
- return model_repo_path
464
- except Exception as e:
465
- log.exception(f"Cannot determine model snapshot path: {e}")
466
- return model
467
-
468
-
469
- def generate_openai_batch_embeddings(
470
- model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
471
- ) -> Optional[list[list[float]]]:
472
- try:
473
- r = requests.post(
474
- f"{url}/embeddings",
475
- headers={
476
- "Content-Type": "application/json",
477
- "Authorization": f"Bearer {key}",
478
- },
479
- json={"input": texts, "model": model},
480
- )
481
- r.raise_for_status()
482
- data = r.json()
483
- if "data" in data:
484
- return [elem["embedding"] for elem in data["data"]]
485
- else:
486
- raise "Something went wrong :/"
487
- except Exception as e:
488
- print(e)
489
- return None
490
-
491
-
492
- def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
493
- if engine == "ollama":
494
- if isinstance(text, list):
495
- embeddings = generate_ollama_batch_embeddings(
496
- GenerateEmbedForm(**{"model": model, "input": text})
497
- )
498
- else:
499
- embeddings = generate_ollama_batch_embeddings(
500
- GenerateEmbedForm(**{"model": model, "input": [text]})
501
- )
502
- return (
503
- embeddings["embeddings"][0]
504
- if isinstance(text, str)
505
- else embeddings["embeddings"]
506
- )
507
- elif engine == "openai":
508
- key = kwargs.get("key", "")
509
- url = kwargs.get("url", "https://api.openai.com/v1")
510
-
511
- if isinstance(text, list):
512
- embeddings = generate_openai_batch_embeddings(model, text, key, url)
513
- else:
514
- embeddings = generate_openai_batch_embeddings(model, [text], key, url)
515
-
516
- return embeddings[0] if isinstance(text, str) else embeddings
517
-
518
-
519
- import operator
520
- from typing import Optional, Sequence
521
-
522
- from langchain_core.callbacks import Callbacks
523
- from langchain_core.documents import BaseDocumentCompressor, Document
524
-
525
-
526
- class RerankCompressor(BaseDocumentCompressor):
527
- embedding_function: Any
528
- top_n: int
529
- reranking_function: Any
530
- r_score: float
531
-
532
- class Config:
533
- extra = "forbid"
534
- arbitrary_types_allowed = True
535
-
536
- def compress_documents(
537
- self,
538
- documents: Sequence[Document],
539
- query: str,
540
- callbacks: Optional[Callbacks] = None,
541
- ) -> Sequence[Document]:
542
- reranking = self.reranking_function is not None
543
-
544
- if reranking:
545
- scores = self.reranking_function.predict(
546
- [(query, doc.page_content) for doc in documents]
547
- )
548
- else:
549
- from sentence_transformers import util
550
-
551
- query_embedding = self.embedding_function(query)
552
- document_embedding = self.embedding_function(
553
- [doc.page_content for doc in documents]
554
- )
555
- scores = util.cos_sim(query_embedding, document_embedding)[0]
556
-
557
- docs_with_scores = list(zip(documents, scores.tolist()))
558
- if self.r_score:
559
- docs_with_scores = [
560
- (d, s) for d, s in docs_with_scores if s >= self.r_score
561
- ]
562
-
563
- result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
564
- final_results = []
565
- for doc, doc_score in result[: self.top_n]:
566
- metadata = doc.metadata
567
- metadata["score"] = doc_score
568
- doc = Document(
569
- page_content=doc.page_content,
570
- metadata=metadata,
571
- )
572
- final_results.append(doc)
573
- return final_results
 
1
+ import logging
2
+ import os
3
+ import uuid
4
+ from typing import Optional, Union
5
+
6
+ import requests
7
+
8
+ from huggingface_hub import snapshot_download
9
+ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
10
+ from langchain_community.retrievers import BM25Retriever
11
+ from langchain_core.documents import Document
12
+
13
+
14
+ from open_webui.apps.ollama.main import (
15
+ GenerateEmbedForm,
16
+ generate_ollama_batch_embeddings,
17
+ )
18
+ from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
19
+ from open_webui.utils.misc import get_last_user_message
20
+
21
+ from open_webui.env import SRC_LOG_LEVELS
22
+ from open_webui.config import DEFAULT_RAG_TEMPLATE
23
+
24
+
25
+ log = logging.getLogger(__name__)
26
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
27
+
28
+
29
+ from typing import Any
30
+
31
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
32
+ from langchain_core.retrievers import BaseRetriever
33
+
34
+
35
+ class VectorSearchRetriever(BaseRetriever):
36
+ collection_name: Any
37
+ embedding_function: Any
38
+ top_k: int
39
+
40
+ def _get_relevant_documents(
41
+ self,
42
+ query: str,
43
+ *,
44
+ run_manager: CallbackManagerForRetrieverRun,
45
+ ) -> list[Document]:
46
+ result = VECTOR_DB_CLIENT.search(
47
+ collection_name=self.collection_name,
48
+ vectors=[self.embedding_function(query)],
49
+ limit=self.top_k,
50
+ )
51
+
52
+ ids = result.ids[0]
53
+ metadatas = result.metadatas[0]
54
+ documents = result.documents[0]
55
+
56
+ results = []
57
+ for idx in range(len(ids)):
58
+ results.append(
59
+ Document(
60
+ metadata=metadatas[idx],
61
+ page_content=documents[idx],
62
+ )
63
+ )
64
+ return results
65
+
66
+
67
+ def query_doc(
68
+ collection_name: str,
69
+ query_embedding: list[float],
70
+ k: int,
71
+ ):
72
+ try:
73
+ result = VECTOR_DB_CLIENT.search(
74
+ collection_name=collection_name,
75
+ vectors=[query_embedding],
76
+ limit=k,
77
+ )
78
+
79
+ log.info(f"query_doc:result {result}")
80
+ return result
81
+ except Exception as e:
82
+ print(e)
83
+ raise e
84
+
85
+
86
+ def query_doc_with_hybrid_search(
87
+ collection_name: str,
88
+ query: str,
89
+ embedding_function,
90
+ k: int,
91
+ reranking_function,
92
+ r: float,
93
+ ) -> dict:
94
+ try:
95
+ result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
96
+
97
+ bm25_retriever = BM25Retriever.from_texts(
98
+ texts=result.documents[0],
99
+ metadatas=result.metadatas[0],
100
+ )
101
+ bm25_retriever.k = k
102
+
103
+ vector_search_retriever = VectorSearchRetriever(
104
+ collection_name=collection_name,
105
+ embedding_function=embedding_function,
106
+ top_k=k,
107
+ )
108
+
109
+ ensemble_retriever = EnsembleRetriever(
110
+ retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
111
+ )
112
+ compressor = RerankCompressor(
113
+ embedding_function=embedding_function,
114
+ top_n=k,
115
+ reranking_function=reranking_function,
116
+ r_score=r,
117
+ )
118
+
119
+ compression_retriever = ContextualCompressionRetriever(
120
+ base_compressor=compressor, base_retriever=ensemble_retriever
121
+ )
122
+
123
+ result = compression_retriever.invoke(query)
124
+ result = {
125
+ "distances": [[d.metadata.get("score") for d in result]],
126
+ "documents": [[d.page_content for d in result]],
127
+ "metadatas": [[d.metadata for d in result]],
128
+ }
129
+
130
+ log.info(f"query_doc_with_hybrid_search:result {result}")
131
+ return result
132
+ except Exception as e:
133
+ raise e
134
+
135
+
136
+ def merge_and_sort_query_results(
137
+ query_results: list[dict], k: int, reverse: bool = False
138
+ ) -> list[dict]:
139
+ # Initialize lists to store combined data
140
+ combined_distances = []
141
+ combined_documents = []
142
+ combined_metadatas = []
143
+
144
+ for data in query_results:
145
+ combined_distances.extend(data["distances"][0])
146
+ combined_documents.extend(data["documents"][0])
147
+ combined_metadatas.extend(data["metadatas"][0])
148
+
149
+ # Create a list of tuples (distance, document, metadata)
150
+ combined = list(zip(combined_distances, combined_documents, combined_metadatas))
151
+
152
+ # Sort the list based on distances
153
+ combined.sort(key=lambda x: x[0], reverse=reverse)
154
+
155
+ # We don't have anything :-(
156
+ if not combined:
157
+ sorted_distances = []
158
+ sorted_documents = []
159
+ sorted_metadatas = []
160
+ else:
161
+ # Unzip the sorted list
162
+ sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
163
+
164
+ # Slicing the lists to include only k elements
165
+ sorted_distances = list(sorted_distances)[:k]
166
+ sorted_documents = list(sorted_documents)[:k]
167
+ sorted_metadatas = list(sorted_metadatas)[:k]
168
+
169
+ # Create the output dictionary
170
+ result = {
171
+ "distances": [sorted_distances],
172
+ "documents": [sorted_documents],
173
+ "metadatas": [sorted_metadatas],
174
+ }
175
+
176
+ return result
177
+
178
+
179
+ def query_collection(
180
+ collection_names: list[str],
181
+ query: str,
182
+ embedding_function,
183
+ k: int,
184
+ ) -> dict:
185
+
186
+ results = []
187
+ query_embedding = embedding_function(query)
188
+
189
+ for collection_name in collection_names:
190
+ if collection_name:
191
+ try:
192
+ result = query_doc(
193
+ collection_name=collection_name,
194
+ k=k,
195
+ query_embedding=query_embedding,
196
+ )
197
+ if result is not None:
198
+ results.append(result.model_dump())
199
+ except Exception as e:
200
+ log.exception(f"Error when querying the collection: {e}")
201
+ else:
202
+ pass
203
+
204
+ return merge_and_sort_query_results(results, k=k)
205
+
206
+
207
+ def query_collection_with_hybrid_search(
208
+ collection_names: list[str],
209
+ query: str,
210
+ embedding_function,
211
+ k: int,
212
+ reranking_function,
213
+ r: float,
214
+ ) -> dict:
215
+ results = []
216
+ error = False
217
+ for collection_name in collection_names:
218
+ try:
219
+ result = query_doc_with_hybrid_search(
220
+ collection_name=collection_name,
221
+ query=query,
222
+ embedding_function=embedding_function,
223
+ k=k,
224
+ reranking_function=reranking_function,
225
+ r=r,
226
+ )
227
+ results.append(result)
228
+ except Exception as e:
229
+ log.exception(
230
+ "Error when querying the collection with " f"hybrid_search: {e}"
231
+ )
232
+ error = True
233
+
234
+ if error:
235
+ raise Exception(
236
+ "Hybrid search failed for all collections. Using Non hybrid search as fallback."
237
+ )
238
+
239
+ return merge_and_sort_query_results(results, k=k, reverse=True)
240
+
241
+
242
+ def rag_template(template: str, context: str, query: str):
243
+ if template == "":
244
+ template = DEFAULT_RAG_TEMPLATE
245
+
246
+ if "[context]" not in template and "{{CONTEXT}}" not in template:
247
+ log.debug(
248
+ "WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
249
+ )
250
+
251
+ if "<context>" in context and "</context>" in context:
252
+ log.debug(
253
+ "WARNING: Potential prompt injection attack: the RAG "
254
+ "context contains '<context>' and '</context>'. This might be "
255
+ "nothing, or the user might be trying to hack something."
256
+ )
257
+
258
+ query_placeholders = []
259
+ if "[query]" in context:
260
+ query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
261
+ template = template.replace("[query]", query_placeholder)
262
+ query_placeholders.append(query_placeholder)
263
+
264
+ if "{{QUERY}}" in context:
265
+ query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
266
+ template = template.replace("{{QUERY}}", query_placeholder)
267
+ query_placeholders.append(query_placeholder)
268
+
269
+ template = template.replace("[context]", context)
270
+ template = template.replace("{{CONTEXT}}", context)
271
+ template = template.replace("[query]", query)
272
+ template = template.replace("{{QUERY}}", query)
273
+
274
+ for query_placeholder in query_placeholders:
275
+ template = template.replace(query_placeholder, query)
276
+
277
+ return template
278
+
279
+
280
+ def get_embedding_function(
281
+ embedding_engine,
282
+ embedding_model,
283
+ embedding_function,
284
+ openai_key,
285
+ openai_url,
286
+ embedding_batch_size,
287
+ ):
288
+ if embedding_engine == "":
289
+ return lambda query: embedding_function.encode(query).tolist()
290
+ elif embedding_engine in ["ollama", "openai"]:
291
+ func = lambda query: generate_embeddings(
292
+ engine=embedding_engine,
293
+ model=embedding_model,
294
+ text=query,
295
+ key=openai_key if embedding_engine == "openai" else "",
296
+ url=openai_url if embedding_engine == "openai" else "",
297
+ )
298
+
299
+ def generate_multiple(query, func):
300
+ if isinstance(query, list):
301
+ embeddings = []
302
+ for i in range(0, len(query), embedding_batch_size):
303
+ embeddings.extend(func(query[i : i + embedding_batch_size]))
304
+ return embeddings
305
+ else:
306
+ return func(query)
307
+
308
+ return lambda query: generate_multiple(query, func)
309
+
310
+
311
+ def get_rag_context(
312
+ files,
313
+ messages,
314
+ embedding_function,
315
+ k,
316
+ reranking_function,
317
+ r,
318
+ hybrid_search,
319
+ ):
320
+ log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
321
+ query = get_last_user_message(messages)
322
+
323
+ extracted_collections = []
324
+ relevant_contexts = []
325
+
326
+ for file in files:
327
+ if file.get("context") == "full":
328
+ context = {
329
+ "documents": [[file.get("file").get("data", {}).get("content")]],
330
+ "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
331
+ }
332
+ else:
333
+ context = None
334
+
335
+ collection_names = []
336
+ if file.get("type") == "collection":
337
+ if file.get("legacy"):
338
+ collection_names = file.get("collection_names", [])
339
+ else:
340
+ collection_names.append(file["id"])
341
+ elif file.get("collection_name"):
342
+ collection_names.append(file["collection_name"])
343
+ elif file.get("id"):
344
+ if file.get("legacy"):
345
+ collection_names.append(f"{file['id']}")
346
+ else:
347
+ collection_names.append(f"file-{file['id']}")
348
+
349
+ collection_names = set(collection_names).difference(extracted_collections)
350
+ if not collection_names:
351
+ log.debug(f"skipping {file} as it has already been extracted")
352
+ continue
353
+
354
+ try:
355
+ context = None
356
+ if file.get("type") == "text":
357
+ context = file["content"]
358
+ else:
359
+ if hybrid_search:
360
+ try:
361
+ context = query_collection_with_hybrid_search(
362
+ collection_names=collection_names,
363
+ query=query,
364
+ embedding_function=embedding_function,
365
+ k=k,
366
+ reranking_function=reranking_function,
367
+ r=r,
368
+ )
369
+ except Exception as e:
370
+ log.debug(
371
+ "Error when using hybrid search, using"
372
+ " non hybrid search as fallback."
373
+ )
374
+
375
+ if (not hybrid_search) or (context is None):
376
+ context = query_collection(
377
+ collection_names=collection_names,
378
+ query=query,
379
+ embedding_function=embedding_function,
380
+ k=k,
381
+ )
382
+ except Exception as e:
383
+ log.exception(e)
384
+
385
+ extracted_collections.extend(collection_names)
386
+
387
+ if context:
388
+ if "data" in file:
389
+ del file["data"]
390
+ relevant_contexts.append({**context, "file": file})
391
+
392
+ contexts = []
393
+ citations = []
394
+ for context in relevant_contexts:
395
+ try:
396
+ if "documents" in context:
397
+ file_names = list(
398
+ set(
399
+ [
400
+ metadata["name"]
401
+ for metadata in context["metadatas"][0]
402
+ if metadata is not None and "name" in metadata
403
+ ]
404
+ )
405
+ )
406
+ contexts.append(
407
+ ((", ".join(file_names) + ":\n\n") if file_names else "")
408
+ + "\n\n".join(
409
+ [text for text in context["documents"][0] if text is not None]
410
+ )
411
+ )
412
+
413
+ if "metadatas" in context:
414
+ citation = {
415
+ "source": context["file"],
416
+ "document": context["documents"][0],
417
+ "metadata": context["metadatas"][0],
418
+ }
419
+ if "distances" in context and context["distances"]:
420
+ citation["distances"] = context["distances"][0]
421
+ citations.append(citation)
422
+ except Exception as e:
423
+ log.exception(e)
424
+
425
+ print("contexts", contexts)
426
+ print("citations", citations)
427
+
428
+ return contexts, citations
429
+
430
+
431
+ def get_model_path(model: str, update_model: bool = False):
432
+ # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
433
+ cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
434
+
435
+ local_files_only = not update_model
436
+
437
+ snapshot_kwargs = {
438
+ "cache_dir": cache_dir,
439
+ "local_files_only": local_files_only,
440
+ }
441
+
442
+ log.debug(f"model: {model}")
443
+ log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
444
+
445
+ # Inspiration from upstream sentence_transformers
446
+ if (
447
+ os.path.exists(model)
448
+ or ("\\" in model or model.count("/") > 1)
449
+ and local_files_only
450
+ ):
451
+ # If fully qualified path exists, return input, else set repo_id
452
+ return model
453
+ elif "/" not in model:
454
+ # Set valid repo_id for model short-name
455
+ model = "sentence-transformers" + "/" + model
456
+
457
+ snapshot_kwargs["repo_id"] = model
458
+
459
+ # Attempt to query the huggingface_hub library to determine the local path and/or to update
460
+ try:
461
+ model_repo_path = snapshot_download(**snapshot_kwargs)
462
+ log.debug(f"model_repo_path: {model_repo_path}")
463
+ return model_repo_path
464
+ except Exception as e:
465
+ log.exception(f"Cannot determine model snapshot path: {e}")
466
+ return model
467
+
468
+
469
+ def generate_openai_batch_embeddings(
470
+ model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
471
+ ) -> Optional[list[list[float]]]:
472
+ try:
473
+ r = requests.post(
474
+ f"{url}/embeddings",
475
+ headers={
476
+ "Content-Type": "application/json",
477
+ "Authorization": f"Bearer {key}",
478
+ },
479
+ json={"input": texts, "model": model},
480
+ )
481
+ r.raise_for_status()
482
+ data = r.json()
483
+ if "data" in data:
484
+ return [elem["embedding"] for elem in data["data"]]
485
+ else:
486
+ raise "Something went wrong :/"
487
+ except Exception as e:
488
+ print(e)
489
+ return None
490
+
491
+
492
+ def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
493
+ if engine == "ollama":
494
+ if isinstance(text, list):
495
+ embeddings = generate_ollama_batch_embeddings(
496
+ GenerateEmbedForm(**{"model": model, "input": text})
497
+ )
498
+ else:
499
+ embeddings = generate_ollama_batch_embeddings(
500
+ GenerateEmbedForm(**{"model": model, "input": [text]})
501
+ )
502
+ return (
503
+ embeddings["embeddings"][0]
504
+ if isinstance(text, str)
505
+ else embeddings["embeddings"]
506
+ )
507
+ elif engine == "openai":
508
+ key = kwargs.get("key", "")
509
+ url = kwargs.get("url", "https://api.openai.com/v1")
510
+
511
+ if isinstance(text, list):
512
+ embeddings = generate_openai_batch_embeddings(model, text, key, url)
513
+ else:
514
+ embeddings = generate_openai_batch_embeddings(model, [text], key, url)
515
+
516
+ return embeddings[0] if isinstance(text, str) else embeddings
517
+
518
+
519
+ import operator
520
+ from typing import Optional, Sequence
521
+
522
+ from langchain_core.callbacks import Callbacks
523
+ from langchain_core.documents import BaseDocumentCompressor, Document
524
+
525
+
526
+ class RerankCompressor(BaseDocumentCompressor):
527
+ embedding_function: Any
528
+ top_n: int
529
+ reranking_function: Any
530
+ r_score: float
531
+
532
+ class Config:
533
+ extra = "forbid"
534
+ arbitrary_types_allowed = True
535
+
536
+ def compress_documents(
537
+ self,
538
+ documents: Sequence[Document],
539
+ query: str,
540
+ callbacks: Optional[Callbacks] = None,
541
+ ) -> Sequence[Document]:
542
+ reranking = self.reranking_function is not None
543
+
544
+ if reranking:
545
+ scores = self.reranking_function.predict(
546
+ [(query, doc.page_content) for doc in documents]
547
+ )
548
+ else:
549
+ from sentence_transformers import util
550
+
551
+ query_embedding = self.embedding_function(query)
552
+ document_embedding = self.embedding_function(
553
+ [doc.page_content for doc in documents]
554
+ )
555
+ scores = util.cos_sim(query_embedding, document_embedding)[0]
556
+
557
+ docs_with_scores = list(zip(documents, scores.tolist()))
558
+ if self.r_score:
559
+ docs_with_scores = [
560
+ (d, s) for d, s in docs_with_scores if s >= self.r_score
561
+ ]
562
+
563
+ result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
564
+ final_results = []
565
+ for doc, doc_score in result[: self.top_n]:
566
+ metadata = doc.metadata
567
+ metadata["score"] = doc_score
568
+ doc = Document(
569
+ page_content=doc.page_content,
570
+ metadata=metadata,
571
+ )
572
+ final_results.append(doc)
573
+ return final_results
backend/open_webui/apps/retrieval/vector/connector.py CHANGED
@@ -1,14 +1,14 @@
1
- from open_webui.config import VECTOR_DB
2
-
3
- if VECTOR_DB == "milvus":
4
- from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
5
-
6
- VECTOR_DB_CLIENT = MilvusClient()
7
- elif VECTOR_DB == "qdrant":
8
- from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
9
-
10
- VECTOR_DB_CLIENT = QdrantClient()
11
- else:
12
- from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
13
-
14
- VECTOR_DB_CLIENT = ChromaClient()
 
1
+ from open_webui.config import VECTOR_DB
2
+
3
+ if VECTOR_DB == "milvus":
4
+ from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
5
+
6
+ VECTOR_DB_CLIENT = MilvusClient()
7
+ elif VECTOR_DB == "qdrant":
8
+ from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
9
+
10
+ VECTOR_DB_CLIENT = QdrantClient()
11
+ else:
12
+ from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
13
+
14
+ VECTOR_DB_CLIENT = ChromaClient()
backend/open_webui/apps/retrieval/vector/dbs/chroma.py CHANGED
@@ -1,161 +1,161 @@
1
- import chromadb
2
- from chromadb import Settings
3
- from chromadb.utils.batch_utils import create_batches
4
-
5
- from typing import Optional
6
-
7
- from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
8
- from open_webui.config import (
9
- CHROMA_DATA_PATH,
10
- CHROMA_HTTP_HOST,
11
- CHROMA_HTTP_PORT,
12
- CHROMA_HTTP_HEADERS,
13
- CHROMA_HTTP_SSL,
14
- CHROMA_TENANT,
15
- CHROMA_DATABASE,
16
- )
17
-
18
-
19
- class ChromaClient:
20
- def __init__(self):
21
- if CHROMA_HTTP_HOST != "":
22
- self.client = chromadb.HttpClient(
23
- host=CHROMA_HTTP_HOST,
24
- port=CHROMA_HTTP_PORT,
25
- headers=CHROMA_HTTP_HEADERS,
26
- ssl=CHROMA_HTTP_SSL,
27
- tenant=CHROMA_TENANT,
28
- database=CHROMA_DATABASE,
29
- settings=Settings(allow_reset=True, anonymized_telemetry=False),
30
- )
31
- else:
32
- self.client = chromadb.PersistentClient(
33
- path=CHROMA_DATA_PATH,
34
- settings=Settings(allow_reset=True, anonymized_telemetry=False),
35
- tenant=CHROMA_TENANT,
36
- database=CHROMA_DATABASE,
37
- )
38
-
39
- def has_collection(self, collection_name: str) -> bool:
40
- # Check if the collection exists based on the collection name.
41
- collections = self.client.list_collections()
42
- return collection_name in [collection.name for collection in collections]
43
-
44
- def delete_collection(self, collection_name: str):
45
- # Delete the collection based on the collection name.
46
- return self.client.delete_collection(name=collection_name)
47
-
48
- def search(
49
- self, collection_name: str, vectors: list[list[float | int]], limit: int
50
- ) -> Optional[SearchResult]:
51
- # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
52
- try:
53
- collection = self.client.get_collection(name=collection_name)
54
- if collection:
55
- result = collection.query(
56
- query_embeddings=vectors,
57
- n_results=limit,
58
- )
59
-
60
- return SearchResult(
61
- **{
62
- "ids": result["ids"],
63
- "distances": result["distances"],
64
- "documents": result["documents"],
65
- "metadatas": result["metadatas"],
66
- }
67
- )
68
- return None
69
- except Exception as e:
70
- return None
71
-
72
- def query(
73
- self, collection_name: str, filter: dict, limit: Optional[int] = None
74
- ) -> Optional[GetResult]:
75
- # Query the items from the collection based on the filter.
76
- try:
77
- collection = self.client.get_collection(name=collection_name)
78
- if collection:
79
- result = collection.get(
80
- where=filter,
81
- limit=limit,
82
- )
83
-
84
- return GetResult(
85
- **{
86
- "ids": [result["ids"]],
87
- "documents": [result["documents"]],
88
- "metadatas": [result["metadatas"]],
89
- }
90
- )
91
- return None
92
- except Exception as e:
93
- print(e)
94
- return None
95
-
96
- def get(self, collection_name: str) -> Optional[GetResult]:
97
- # Get all the items in the collection.
98
- collection = self.client.get_collection(name=collection_name)
99
- if collection:
100
- result = collection.get()
101
- return GetResult(
102
- **{
103
- "ids": [result["ids"]],
104
- "documents": [result["documents"]],
105
- "metadatas": [result["metadatas"]],
106
- }
107
- )
108
- return None
109
-
110
- def insert(self, collection_name: str, items: list[VectorItem]):
111
- # Insert the items into the collection, if the collection does not exist, it will be created.
112
- collection = self.client.get_or_create_collection(
113
- name=collection_name, metadata={"hnsw:space": "cosine"}
114
- )
115
-
116
- ids = [item["id"] for item in items]
117
- documents = [item["text"] for item in items]
118
- embeddings = [item["vector"] for item in items]
119
- metadatas = [item["metadata"] for item in items]
120
-
121
- for batch in create_batches(
122
- api=self.client,
123
- documents=documents,
124
- embeddings=embeddings,
125
- ids=ids,
126
- metadatas=metadatas,
127
- ):
128
- collection.add(*batch)
129
-
130
- def upsert(self, collection_name: str, items: list[VectorItem]):
131
- # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
132
- collection = self.client.get_or_create_collection(
133
- name=collection_name, metadata={"hnsw:space": "cosine"}
134
- )
135
-
136
- ids = [item["id"] for item in items]
137
- documents = [item["text"] for item in items]
138
- embeddings = [item["vector"] for item in items]
139
- metadatas = [item["metadata"] for item in items]
140
-
141
- collection.upsert(
142
- ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
143
- )
144
-
145
- def delete(
146
- self,
147
- collection_name: str,
148
- ids: Optional[list[str]] = None,
149
- filter: Optional[dict] = None,
150
- ):
151
- # Delete the items from the collection based on the ids.
152
- collection = self.client.get_collection(name=collection_name)
153
- if collection:
154
- if ids:
155
- collection.delete(ids=ids)
156
- elif filter:
157
- collection.delete(where=filter)
158
-
159
- def reset(self):
160
- # Resets the database. This will delete all collections and item entries.
161
- return self.client.reset()
 
1
+ import chromadb
2
+ from chromadb import Settings
3
+ from chromadb.utils.batch_utils import create_batches
4
+
5
+ from typing import Optional
6
+
7
+ from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
8
+ from open_webui.config import (
9
+ CHROMA_DATA_PATH,
10
+ CHROMA_HTTP_HOST,
11
+ CHROMA_HTTP_PORT,
12
+ CHROMA_HTTP_HEADERS,
13
+ CHROMA_HTTP_SSL,
14
+ CHROMA_TENANT,
15
+ CHROMA_DATABASE,
16
+ )
17
+
18
+
19
+ class ChromaClient:
20
+ def __init__(self):
21
+ if CHROMA_HTTP_HOST != "":
22
+ self.client = chromadb.HttpClient(
23
+ host=CHROMA_HTTP_HOST,
24
+ port=CHROMA_HTTP_PORT,
25
+ headers=CHROMA_HTTP_HEADERS,
26
+ ssl=CHROMA_HTTP_SSL,
27
+ tenant=CHROMA_TENANT,
28
+ database=CHROMA_DATABASE,
29
+ settings=Settings(allow_reset=True, anonymized_telemetry=False),
30
+ )
31
+ else:
32
+ self.client = chromadb.PersistentClient(
33
+ path=CHROMA_DATA_PATH,
34
+ settings=Settings(allow_reset=True, anonymized_telemetry=False),
35
+ tenant=CHROMA_TENANT,
36
+ database=CHROMA_DATABASE,
37
+ )
38
+
39
+ def has_collection(self, collection_name: str) -> bool:
40
+ # Check if the collection exists based on the collection name.
41
+ collections = self.client.list_collections()
42
+ return collection_name in [collection.name for collection in collections]
43
+
44
+ def delete_collection(self, collection_name: str):
45
+ # Delete the collection based on the collection name.
46
+ return self.client.delete_collection(name=collection_name)
47
+
48
+ def search(
49
+ self, collection_name: str, vectors: list[list[float | int]], limit: int
50
+ ) -> Optional[SearchResult]:
51
+ # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
52
+ try:
53
+ collection = self.client.get_collection(name=collection_name)
54
+ if collection:
55
+ result = collection.query(
56
+ query_embeddings=vectors,
57
+ n_results=limit,
58
+ )
59
+
60
+ return SearchResult(
61
+ **{
62
+ "ids": result["ids"],
63
+ "distances": result["distances"],
64
+ "documents": result["documents"],
65
+ "metadatas": result["metadatas"],
66
+ }
67
+ )
68
+ return None
69
+ except Exception as e:
70
+ return None
71
+
72
+ def query(
73
+ self, collection_name: str, filter: dict, limit: Optional[int] = None
74
+ ) -> Optional[GetResult]:
75
+ # Query the items from the collection based on the filter.
76
+ try:
77
+ collection = self.client.get_collection(name=collection_name)
78
+ if collection:
79
+ result = collection.get(
80
+ where=filter,
81
+ limit=limit,
82
+ )
83
+
84
+ return GetResult(
85
+ **{
86
+ "ids": [result["ids"]],
87
+ "documents": [result["documents"]],
88
+ "metadatas": [result["metadatas"]],
89
+ }
90
+ )
91
+ return None
92
+ except Exception as e:
93
+ print(e)
94
+ return None
95
+
96
+ def get(self, collection_name: str) -> Optional[GetResult]:
97
+ # Get all the items in the collection.
98
+ collection = self.client.get_collection(name=collection_name)
99
+ if collection:
100
+ result = collection.get()
101
+ return GetResult(
102
+ **{
103
+ "ids": [result["ids"]],
104
+ "documents": [result["documents"]],
105
+ "metadatas": [result["metadatas"]],
106
+ }
107
+ )
108
+ return None
109
+
110
+ def insert(self, collection_name: str, items: list[VectorItem]):
111
+ # Insert the items into the collection, if the collection does not exist, it will be created.
112
+ collection = self.client.get_or_create_collection(
113
+ name=collection_name, metadata={"hnsw:space": "cosine"}
114
+ )
115
+
116
+ ids = [item["id"] for item in items]
117
+ documents = [item["text"] for item in items]
118
+ embeddings = [item["vector"] for item in items]
119
+ metadatas = [item["metadata"] for item in items]
120
+
121
+ for batch in create_batches(
122
+ api=self.client,
123
+ documents=documents,
124
+ embeddings=embeddings,
125
+ ids=ids,
126
+ metadatas=metadatas,
127
+ ):
128
+ collection.add(*batch)
129
+
130
+ def upsert(self, collection_name: str, items: list[VectorItem]):
131
+ # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
132
+ collection = self.client.get_or_create_collection(
133
+ name=collection_name, metadata={"hnsw:space": "cosine"}
134
+ )
135
+
136
+ ids = [item["id"] for item in items]
137
+ documents = [item["text"] for item in items]
138
+ embeddings = [item["vector"] for item in items]
139
+ metadatas = [item["metadata"] for item in items]
140
+
141
+ collection.upsert(
142
+ ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
143
+ )
144
+
145
+ def delete(
146
+ self,
147
+ collection_name: str,
148
+ ids: Optional[list[str]] = None,
149
+ filter: Optional[dict] = None,
150
+ ):
151
+ # Delete the items from the collection based on the ids.
152
+ collection = self.client.get_collection(name=collection_name)
153
+ if collection:
154
+ if ids:
155
+ collection.delete(ids=ids)
156
+ elif filter:
157
+ collection.delete(where=filter)
158
+
159
+ def reset(self):
160
+ # Resets the database. This will delete all collections and item entries.
161
+ return self.client.reset()
backend/open_webui/apps/retrieval/vector/dbs/milvus.py CHANGED
@@ -1,286 +1,286 @@
1
- from pymilvus import MilvusClient as Client
2
- from pymilvus import FieldSchema, DataType
3
- import json
4
-
5
- from typing import Optional
6
-
7
- from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
8
- from open_webui.config import (
9
- MILVUS_URI,
10
- )
11
-
12
-
13
- class MilvusClient:
14
- def __init__(self):
15
- self.collection_prefix = "open_webui"
16
- self.client = Client(uri=MILVUS_URI)
17
-
18
- def _result_to_get_result(self, result) -> GetResult:
19
- ids = []
20
- documents = []
21
- metadatas = []
22
-
23
- for match in result:
24
- _ids = []
25
- _documents = []
26
- _metadatas = []
27
- for item in match:
28
- _ids.append(item.get("id"))
29
- _documents.append(item.get("data", {}).get("text"))
30
- _metadatas.append(item.get("metadata"))
31
-
32
- ids.append(_ids)
33
- documents.append(_documents)
34
- metadatas.append(_metadatas)
35
-
36
- return GetResult(
37
- **{
38
- "ids": ids,
39
- "documents": documents,
40
- "metadatas": metadatas,
41
- }
42
- )
43
-
44
- def _result_to_search_result(self, result) -> SearchResult:
45
- ids = []
46
- distances = []
47
- documents = []
48
- metadatas = []
49
-
50
- for match in result:
51
- _ids = []
52
- _distances = []
53
- _documents = []
54
- _metadatas = []
55
-
56
- for item in match:
57
- _ids.append(item.get("id"))
58
- _distances.append(item.get("distance"))
59
- _documents.append(item.get("entity", {}).get("data", {}).get("text"))
60
- _metadatas.append(item.get("entity", {}).get("metadata"))
61
-
62
- ids.append(_ids)
63
- distances.append(_distances)
64
- documents.append(_documents)
65
- metadatas.append(_metadatas)
66
-
67
- return SearchResult(
68
- **{
69
- "ids": ids,
70
- "distances": distances,
71
- "documents": documents,
72
- "metadatas": metadatas,
73
- }
74
- )
75
-
76
- def _create_collection(self, collection_name: str, dimension: int):
77
- schema = self.client.create_schema(
78
- auto_id=False,
79
- enable_dynamic_field=True,
80
- )
81
- schema.add_field(
82
- field_name="id",
83
- datatype=DataType.VARCHAR,
84
- is_primary=True,
85
- max_length=65535,
86
- )
87
- schema.add_field(
88
- field_name="vector",
89
- datatype=DataType.FLOAT_VECTOR,
90
- dim=dimension,
91
- description="vector",
92
- )
93
- schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
94
- schema.add_field(
95
- field_name="metadata", datatype=DataType.JSON, description="metadata"
96
- )
97
-
98
- index_params = self.client.prepare_index_params()
99
- index_params.add_index(
100
- field_name="vector",
101
- index_type="HNSW",
102
- metric_type="COSINE",
103
- params={"M": 16, "efConstruction": 100},
104
- )
105
-
106
- self.client.create_collection(
107
- collection_name=f"{self.collection_prefix}_{collection_name}",
108
- schema=schema,
109
- index_params=index_params,
110
- )
111
-
112
- def has_collection(self, collection_name: str) -> bool:
113
- # Check if the collection exists based on the collection name.
114
- collection_name = collection_name.replace("-", "_")
115
- return self.client.has_collection(
116
- collection_name=f"{self.collection_prefix}_{collection_name}"
117
- )
118
-
119
- def delete_collection(self, collection_name: str):
120
- # Delete the collection based on the collection name.
121
- collection_name = collection_name.replace("-", "_")
122
- return self.client.drop_collection(
123
- collection_name=f"{self.collection_prefix}_{collection_name}"
124
- )
125
-
126
- def search(
127
- self, collection_name: str, vectors: list[list[float | int]], limit: int
128
- ) -> Optional[SearchResult]:
129
- # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
130
- collection_name = collection_name.replace("-", "_")
131
- result = self.client.search(
132
- collection_name=f"{self.collection_prefix}_{collection_name}",
133
- data=vectors,
134
- limit=limit,
135
- output_fields=["data", "metadata"],
136
- )
137
-
138
- return self._result_to_search_result(result)
139
-
140
- def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
141
- # Construct the filter string for querying
142
- collection_name = collection_name.replace("-", "_")
143
- if not self.has_collection(collection_name):
144
- return None
145
-
146
- filter_string = " && ".join(
147
- [
148
- f'metadata["{key}"] == {json.dumps(value)}'
149
- for key, value in filter.items()
150
- ]
151
- )
152
-
153
- max_limit = 16383 # The maximum number of records per request
154
- all_results = []
155
-
156
- if limit is None:
157
- limit = float("inf") # Use infinity as a placeholder for no limit
158
-
159
- # Initialize offset and remaining to handle pagination
160
- offset = 0
161
- remaining = limit
162
-
163
- try:
164
- # Loop until there are no more items to fetch or the desired limit is reached
165
- while remaining > 0:
166
- print("remaining", remaining)
167
- current_fetch = min(
168
- max_limit, remaining
169
- ) # Determine how many items to fetch in this iteration
170
-
171
- results = self.client.query(
172
- collection_name=f"{self.collection_prefix}_{collection_name}",
173
- filter=filter_string,
174
- output_fields=["*"],
175
- limit=current_fetch,
176
- offset=offset,
177
- )
178
-
179
- if not results:
180
- break
181
-
182
- all_results.extend(results)
183
- results_count = len(results)
184
- remaining -= (
185
- results_count # Decrease remaining by the number of items fetched
186
- )
187
- offset += results_count
188
-
189
- # Break the loop if the results returned are less than the requested fetch count
190
- if results_count < current_fetch:
191
- break
192
-
193
- print(all_results)
194
- return self._result_to_get_result([all_results])
195
- except Exception as e:
196
- print(e)
197
- return None
198
-
199
- def get(self, collection_name: str) -> Optional[GetResult]:
200
- # Get all the items in the collection.
201
- collection_name = collection_name.replace("-", "_")
202
- result = self.client.query(
203
- collection_name=f"{self.collection_prefix}_{collection_name}",
204
- filter='id != ""',
205
- )
206
- return self._result_to_get_result([result])
207
-
208
- def insert(self, collection_name: str, items: list[VectorItem]):
209
- # Insert the items into the collection, if the collection does not exist, it will be created.
210
- collection_name = collection_name.replace("-", "_")
211
- if not self.client.has_collection(
212
- collection_name=f"{self.collection_prefix}_{collection_name}"
213
- ):
214
- self._create_collection(
215
- collection_name=collection_name, dimension=len(items[0]["vector"])
216
- )
217
-
218
- return self.client.insert(
219
- collection_name=f"{self.collection_prefix}_{collection_name}",
220
- data=[
221
- {
222
- "id": item["id"],
223
- "vector": item["vector"],
224
- "data": {"text": item["text"]},
225
- "metadata": item["metadata"],
226
- }
227
- for item in items
228
- ],
229
- )
230
-
231
- def upsert(self, collection_name: str, items: list[VectorItem]):
232
- # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
233
- collection_name = collection_name.replace("-", "_")
234
- if not self.client.has_collection(
235
- collection_name=f"{self.collection_prefix}_{collection_name}"
236
- ):
237
- self._create_collection(
238
- collection_name=collection_name, dimension=len(items[0]["vector"])
239
- )
240
-
241
- return self.client.upsert(
242
- collection_name=f"{self.collection_prefix}_{collection_name}",
243
- data=[
244
- {
245
- "id": item["id"],
246
- "vector": item["vector"],
247
- "data": {"text": item["text"]},
248
- "metadata": item["metadata"],
249
- }
250
- for item in items
251
- ],
252
- )
253
-
254
- def delete(
255
- self,
256
- collection_name: str,
257
- ids: Optional[list[str]] = None,
258
- filter: Optional[dict] = None,
259
- ):
260
- # Delete the items from the collection based on the ids.
261
- collection_name = collection_name.replace("-", "_")
262
- if ids:
263
- return self.client.delete(
264
- collection_name=f"{self.collection_prefix}_{collection_name}",
265
- ids=ids,
266
- )
267
- elif filter:
268
- # Convert the filter dictionary to a string using JSON_CONTAINS.
269
- filter_string = " && ".join(
270
- [
271
- f'metadata["{key}"] == {json.dumps(value)}'
272
- for key, value in filter.items()
273
- ]
274
- )
275
-
276
- return self.client.delete(
277
- collection_name=f"{self.collection_prefix}_{collection_name}",
278
- filter=filter_string,
279
- )
280
-
281
- def reset(self):
282
- # Resets the database. This will delete all collections and item entries.
283
- collection_names = self.client.list_collections()
284
- for collection_name in collection_names:
285
- if collection_name.startswith(self.collection_prefix):
286
- self.client.drop_collection(collection_name=collection_name)
 
1
+ from pymilvus import MilvusClient as Client
2
+ from pymilvus import FieldSchema, DataType
3
+ import json
4
+
5
+ from typing import Optional
6
+
7
+ from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
8
+ from open_webui.config import (
9
+ MILVUS_URI,
10
+ )
11
+
12
+
13
+ class MilvusClient:
14
+ def __init__(self):
15
+ self.collection_prefix = "open_webui"
16
+ self.client = Client(uri=MILVUS_URI)
17
+
18
+ def _result_to_get_result(self, result) -> GetResult:
19
+ ids = []
20
+ documents = []
21
+ metadatas = []
22
+
23
+ for match in result:
24
+ _ids = []
25
+ _documents = []
26
+ _metadatas = []
27
+ for item in match:
28
+ _ids.append(item.get("id"))
29
+ _documents.append(item.get("data", {}).get("text"))
30
+ _metadatas.append(item.get("metadata"))
31
+
32
+ ids.append(_ids)
33
+ documents.append(_documents)
34
+ metadatas.append(_metadatas)
35
+
36
+ return GetResult(
37
+ **{
38
+ "ids": ids,
39
+ "documents": documents,
40
+ "metadatas": metadatas,
41
+ }
42
+ )
43
+
44
+ def _result_to_search_result(self, result) -> SearchResult:
45
+ ids = []
46
+ distances = []
47
+ documents = []
48
+ metadatas = []
49
+
50
+ for match in result:
51
+ _ids = []
52
+ _distances = []
53
+ _documents = []
54
+ _metadatas = []
55
+
56
+ for item in match:
57
+ _ids.append(item.get("id"))
58
+ _distances.append(item.get("distance"))
59
+ _documents.append(item.get("entity", {}).get("data", {}).get("text"))
60
+ _metadatas.append(item.get("entity", {}).get("metadata"))
61
+
62
+ ids.append(_ids)
63
+ distances.append(_distances)
64
+ documents.append(_documents)
65
+ metadatas.append(_metadatas)
66
+
67
+ return SearchResult(
68
+ **{
69
+ "ids": ids,
70
+ "distances": distances,
71
+ "documents": documents,
72
+ "metadatas": metadatas,
73
+ }
74
+ )
75
+
76
+ def _create_collection(self, collection_name: str, dimension: int):
77
+ schema = self.client.create_schema(
78
+ auto_id=False,
79
+ enable_dynamic_field=True,
80
+ )
81
+ schema.add_field(
82
+ field_name="id",
83
+ datatype=DataType.VARCHAR,
84
+ is_primary=True,
85
+ max_length=65535,
86
+ )
87
+ schema.add_field(
88
+ field_name="vector",
89
+ datatype=DataType.FLOAT_VECTOR,
90
+ dim=dimension,
91
+ description="vector",
92
+ )
93
+ schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
94
+ schema.add_field(
95
+ field_name="metadata", datatype=DataType.JSON, description="metadata"
96
+ )
97
+
98
+ index_params = self.client.prepare_index_params()
99
+ index_params.add_index(
100
+ field_name="vector",
101
+ index_type="HNSW",
102
+ metric_type="COSINE",
103
+ params={"M": 16, "efConstruction": 100},
104
+ )
105
+
106
+ self.client.create_collection(
107
+ collection_name=f"{self.collection_prefix}_{collection_name}",
108
+ schema=schema,
109
+ index_params=index_params,
110
+ )
111
+
112
+ def has_collection(self, collection_name: str) -> bool:
113
+ # Check if the collection exists based on the collection name.
114
+ collection_name = collection_name.replace("-", "_")
115
+ return self.client.has_collection(
116
+ collection_name=f"{self.collection_prefix}_{collection_name}"
117
+ )
118
+
119
+ def delete_collection(self, collection_name: str):
120
+ # Delete the collection based on the collection name.
121
+ collection_name = collection_name.replace("-", "_")
122
+ return self.client.drop_collection(
123
+ collection_name=f"{self.collection_prefix}_{collection_name}"
124
+ )
125
+
126
+ def search(
127
+ self, collection_name: str, vectors: list[list[float | int]], limit: int
128
+ ) -> Optional[SearchResult]:
129
+ # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
130
+ collection_name = collection_name.replace("-", "_")
131
+ result = self.client.search(
132
+ collection_name=f"{self.collection_prefix}_{collection_name}",
133
+ data=vectors,
134
+ limit=limit,
135
+ output_fields=["data", "metadata"],
136
+ )
137
+
138
+ return self._result_to_search_result(result)
139
+
140
+ def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
141
+ # Construct the filter string for querying
142
+ collection_name = collection_name.replace("-", "_")
143
+ if not self.has_collection(collection_name):
144
+ return None
145
+
146
+ filter_string = " && ".join(
147
+ [
148
+ f'metadata["{key}"] == {json.dumps(value)}'
149
+ for key, value in filter.items()
150
+ ]
151
+ )
152
+
153
+ max_limit = 16383 # The maximum number of records per request
154
+ all_results = []
155
+
156
+ if limit is None:
157
+ limit = float("inf") # Use infinity as a placeholder for no limit
158
+
159
+ # Initialize offset and remaining to handle pagination
160
+ offset = 0
161
+ remaining = limit
162
+
163
+ try:
164
+ # Loop until there are no more items to fetch or the desired limit is reached
165
+ while remaining > 0:
166
+ print("remaining", remaining)
167
+ current_fetch = min(
168
+ max_limit, remaining
169
+ ) # Determine how many items to fetch in this iteration
170
+
171
+ results = self.client.query(
172
+ collection_name=f"{self.collection_prefix}_{collection_name}",
173
+ filter=filter_string,
174
+ output_fields=["*"],
175
+ limit=current_fetch,
176
+ offset=offset,
177
+ )
178
+
179
+ if not results:
180
+ break
181
+
182
+ all_results.extend(results)
183
+ results_count = len(results)
184
+ remaining -= (
185
+ results_count # Decrease remaining by the number of items fetched
186
+ )
187
+ offset += results_count
188
+
189
+ # Break the loop if the results returned are less than the requested fetch count
190
+ if results_count < current_fetch:
191
+ break
192
+
193
+ print(all_results)
194
+ return self._result_to_get_result([all_results])
195
+ except Exception as e:
196
+ print(e)
197
+ return None
198
+
199
+ def get(self, collection_name: str) -> Optional[GetResult]:
200
+ # Get all the items in the collection.
201
+ collection_name = collection_name.replace("-", "_")
202
+ result = self.client.query(
203
+ collection_name=f"{self.collection_prefix}_{collection_name}",
204
+ filter='id != ""',
205
+ )
206
+ return self._result_to_get_result([result])
207
+
208
+ def insert(self, collection_name: str, items: list[VectorItem]):
209
+ # Insert the items into the collection, if the collection does not exist, it will be created.
210
+ collection_name = collection_name.replace("-", "_")
211
+ if not self.client.has_collection(
212
+ collection_name=f"{self.collection_prefix}_{collection_name}"
213
+ ):
214
+ self._create_collection(
215
+ collection_name=collection_name, dimension=len(items[0]["vector"])
216
+ )
217
+
218
+ return self.client.insert(
219
+ collection_name=f"{self.collection_prefix}_{collection_name}",
220
+ data=[
221
+ {
222
+ "id": item["id"],
223
+ "vector": item["vector"],
224
+ "data": {"text": item["text"]},
225
+ "metadata": item["metadata"],
226
+ }
227
+ for item in items
228
+ ],
229
+ )
230
+
231
+ def upsert(self, collection_name: str, items: list[VectorItem]):
232
+ # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
233
+ collection_name = collection_name.replace("-", "_")
234
+ if not self.client.has_collection(
235
+ collection_name=f"{self.collection_prefix}_{collection_name}"
236
+ ):
237
+ self._create_collection(
238
+ collection_name=collection_name, dimension=len(items[0]["vector"])
239
+ )
240
+
241
+ return self.client.upsert(
242
+ collection_name=f"{self.collection_prefix}_{collection_name}",
243
+ data=[
244
+ {
245
+ "id": item["id"],
246
+ "vector": item["vector"],
247
+ "data": {"text": item["text"]},
248
+ "metadata": item["metadata"],
249
+ }
250
+ for item in items
251
+ ],
252
+ )
253
+
254
+ def delete(
255
+ self,
256
+ collection_name: str,
257
+ ids: Optional[list[str]] = None,
258
+ filter: Optional[dict] = None,
259
+ ):
260
+ # Delete the items from the collection based on the ids.
261
+ collection_name = collection_name.replace("-", "_")
262
+ if ids:
263
+ return self.client.delete(
264
+ collection_name=f"{self.collection_prefix}_{collection_name}",
265
+ ids=ids,
266
+ )
267
+ elif filter:
268
+ # Convert the filter dictionary to a string using JSON_CONTAINS.
269
+ filter_string = " && ".join(
270
+ [
271
+ f'metadata["{key}"] == {json.dumps(value)}'
272
+ for key, value in filter.items()
273
+ ]
274
+ )
275
+
276
+ return self.client.delete(
277
+ collection_name=f"{self.collection_prefix}_{collection_name}",
278
+ filter=filter_string,
279
+ )
280
+
281
+ def reset(self):
282
+ # Resets the database. This will delete all collections and item entries.
283
+ collection_names = self.client.list_collections()
284
+ for collection_name in collection_names:
285
+ if collection_name.startswith(self.collection_prefix):
286
+ self.client.drop_collection(collection_name=collection_name)
backend/open_webui/apps/retrieval/vector/dbs/qdrant.py CHANGED
@@ -1,179 +1,179 @@
1
- from typing import Optional
2
-
3
- from qdrant_client import QdrantClient as Qclient
4
- from qdrant_client.http.models import PointStruct
5
- from qdrant_client.models import models
6
-
7
- from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
8
- from open_webui.config import QDRANT_URI
9
-
10
- NO_LIMIT = 999999999
11
-
12
-
13
- class QdrantClient:
14
- def __init__(self):
15
- self.collection_prefix = "open-webui"
16
- self.QDRANT_URI = QDRANT_URI
17
- self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
18
-
19
- def _result_to_get_result(self, points) -> GetResult:
20
- ids = []
21
- documents = []
22
- metadatas = []
23
-
24
- for point in points:
25
- payload = point.payload
26
- ids.append(point.id)
27
- documents.append(payload["text"])
28
- metadatas.append(payload["metadata"])
29
-
30
- return GetResult(
31
- **{
32
- "ids": [ids],
33
- "documents": [documents],
34
- "metadatas": [metadatas],
35
- }
36
- )
37
-
38
- def _create_collection(self, collection_name: str, dimension: int):
39
- collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
40
- self.client.create_collection(
41
- collection_name=collection_name_with_prefix,
42
- vectors_config=models.VectorParams(
43
- size=dimension, distance=models.Distance.COSINE
44
- ),
45
- )
46
-
47
- print(f"collection {collection_name_with_prefix} successfully created!")
48
-
49
- def _create_collection_if_not_exists(self, collection_name, dimension):
50
- if not self.has_collection(collection_name=collection_name):
51
- self._create_collection(
52
- collection_name=collection_name, dimension=dimension
53
- )
54
-
55
- def _create_points(self, items: list[VectorItem]):
56
- return [
57
- PointStruct(
58
- id=item["id"],
59
- vector=item["vector"],
60
- payload={"text": item["text"], "metadata": item["metadata"]},
61
- )
62
- for item in items
63
- ]
64
-
65
- def has_collection(self, collection_name: str) -> bool:
66
- return self.client.collection_exists(
67
- f"{self.collection_prefix}_{collection_name}"
68
- )
69
-
70
- def delete_collection(self, collection_name: str):
71
- return self.client.delete_collection(
72
- collection_name=f"{self.collection_prefix}_{collection_name}"
73
- )
74
-
75
- def search(
76
- self, collection_name: str, vectors: list[list[float | int]], limit: int
77
- ) -> Optional[SearchResult]:
78
- # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
79
- if limit is None:
80
- limit = NO_LIMIT # otherwise qdrant would set limit to 10!
81
-
82
- query_response = self.client.query_points(
83
- collection_name=f"{self.collection_prefix}_{collection_name}",
84
- query=vectors[0],
85
- limit=limit,
86
- )
87
- get_result = self._result_to_get_result(query_response.points)
88
- return SearchResult(
89
- ids=get_result.ids,
90
- documents=get_result.documents,
91
- metadatas=get_result.metadatas,
92
- distances=[[point.score for point in query_response.points]],
93
- )
94
-
95
- def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
96
- # Construct the filter string for querying
97
- if not self.has_collection(collection_name):
98
- return None
99
- try:
100
- if limit is None:
101
- limit = NO_LIMIT # otherwise qdrant would set limit to 10!
102
-
103
- field_conditions = []
104
- for key, value in filter.items():
105
- field_conditions.append(
106
- models.FieldCondition(
107
- key=f"metadata.{key}", match=models.MatchValue(value=value)
108
- )
109
- )
110
-
111
- points = self.client.query_points(
112
- collection_name=f"{self.collection_prefix}_{collection_name}",
113
- query_filter=models.Filter(should=field_conditions),
114
- limit=limit,
115
- )
116
- return self._result_to_get_result(points.points)
117
- except Exception as e:
118
- print(e)
119
- return None
120
-
121
- def get(self, collection_name: str) -> Optional[GetResult]:
122
- # Get all the items in the collection.
123
- points = self.client.query_points(
124
- collection_name=f"{self.collection_prefix}_{collection_name}",
125
- limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
126
- )
127
- return self._result_to_get_result(points.points)
128
-
129
- def insert(self, collection_name: str, items: list[VectorItem]):
130
- # Insert the items into the collection, if the collection does not exist, it will be created.
131
- self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
132
- points = self._create_points(items)
133
- self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
134
-
135
- def upsert(self, collection_name: str, items: list[VectorItem]):
136
- # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
137
- self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
138
- points = self._create_points(items)
139
- return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
140
-
141
- def delete(
142
- self,
143
- collection_name: str,
144
- ids: Optional[list[str]] = None,
145
- filter: Optional[dict] = None,
146
- ):
147
- # Delete the items from the collection based on the ids.
148
- field_conditions = []
149
-
150
- if ids:
151
- for id_value in ids:
152
- field_conditions.append(
153
- models.FieldCondition(
154
- key="metadata.id",
155
- match=models.MatchValue(value=id_value),
156
- ),
157
- ),
158
- elif filter:
159
- for key, value in filter.items():
160
- field_conditions.append(
161
- models.FieldCondition(
162
- key=f"metadata.{key}",
163
- match=models.MatchValue(value=value),
164
- ),
165
- ),
166
-
167
- return self.client.delete(
168
- collection_name=f"{self.collection_prefix}_{collection_name}",
169
- points_selector=models.FilterSelector(
170
- filter=models.Filter(must=field_conditions)
171
- ),
172
- )
173
-
174
- def reset(self):
175
- # Resets the database. This will delete all collections and item entries.
176
- collection_names = self.client.get_collections().collections
177
- for collection_name in collection_names:
178
- if collection_name.name.startswith(self.collection_prefix):
179
- self.client.delete_collection(collection_name=collection_name.name)
 
1
+ from typing import Optional
2
+
3
+ from qdrant_client import QdrantClient as Qclient
4
+ from qdrant_client.http.models import PointStruct
5
+ from qdrant_client.models import models
6
+
7
+ from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
8
+ from open_webui.config import QDRANT_URI
9
+
10
+ NO_LIMIT = 999999999
11
+
12
+
13
+ class QdrantClient:
14
+ def __init__(self):
15
+ self.collection_prefix = "open-webui"
16
+ self.QDRANT_URI = QDRANT_URI
17
+ self.client = Qclient(url=self.QDRANT_URI) if self.QDRANT_URI else None
18
+
19
+ def _result_to_get_result(self, points) -> GetResult:
20
+ ids = []
21
+ documents = []
22
+ metadatas = []
23
+
24
+ for point in points:
25
+ payload = point.payload
26
+ ids.append(point.id)
27
+ documents.append(payload["text"])
28
+ metadatas.append(payload["metadata"])
29
+
30
+ return GetResult(
31
+ **{
32
+ "ids": [ids],
33
+ "documents": [documents],
34
+ "metadatas": [metadatas],
35
+ }
36
+ )
37
+
38
+ def _create_collection(self, collection_name: str, dimension: int):
39
+ collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
40
+ self.client.create_collection(
41
+ collection_name=collection_name_with_prefix,
42
+ vectors_config=models.VectorParams(
43
+ size=dimension, distance=models.Distance.COSINE
44
+ ),
45
+ )
46
+
47
+ print(f"collection {collection_name_with_prefix} successfully created!")
48
+
49
+ def _create_collection_if_not_exists(self, collection_name, dimension):
50
+ if not self.has_collection(collection_name=collection_name):
51
+ self._create_collection(
52
+ collection_name=collection_name, dimension=dimension
53
+ )
54
+
55
+ def _create_points(self, items: list[VectorItem]):
56
+ return [
57
+ PointStruct(
58
+ id=item["id"],
59
+ vector=item["vector"],
60
+ payload={"text": item["text"], "metadata": item["metadata"]},
61
+ )
62
+ for item in items
63
+ ]
64
+
65
+ def has_collection(self, collection_name: str) -> bool:
66
+ return self.client.collection_exists(
67
+ f"{self.collection_prefix}_{collection_name}"
68
+ )
69
+
70
+ def delete_collection(self, collection_name: str):
71
+ return self.client.delete_collection(
72
+ collection_name=f"{self.collection_prefix}_{collection_name}"
73
+ )
74
+
75
+ def search(
76
+ self, collection_name: str, vectors: list[list[float | int]], limit: int
77
+ ) -> Optional[SearchResult]:
78
+ # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
79
+ if limit is None:
80
+ limit = NO_LIMIT # otherwise qdrant would set limit to 10!
81
+
82
+ query_response = self.client.query_points(
83
+ collection_name=f"{self.collection_prefix}_{collection_name}",
84
+ query=vectors[0],
85
+ limit=limit,
86
+ )
87
+ get_result = self._result_to_get_result(query_response.points)
88
+ return SearchResult(
89
+ ids=get_result.ids,
90
+ documents=get_result.documents,
91
+ metadatas=get_result.metadatas,
92
+ distances=[[point.score for point in query_response.points]],
93
+ )
94
+
95
+ def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
96
+ # Construct the filter string for querying
97
+ if not self.has_collection(collection_name):
98
+ return None
99
+ try:
100
+ if limit is None:
101
+ limit = NO_LIMIT # otherwise qdrant would set limit to 10!
102
+
103
+ field_conditions = []
104
+ for key, value in filter.items():
105
+ field_conditions.append(
106
+ models.FieldCondition(
107
+ key=f"metadata.{key}", match=models.MatchValue(value=value)
108
+ )
109
+ )
110
+
111
+ points = self.client.query_points(
112
+ collection_name=f"{self.collection_prefix}_{collection_name}",
113
+ query_filter=models.Filter(should=field_conditions),
114
+ limit=limit,
115
+ )
116
+ return self._result_to_get_result(points.points)
117
+ except Exception as e:
118
+ print(e)
119
+ return None
120
+
121
+ def get(self, collection_name: str) -> Optional[GetResult]:
122
+ # Get all the items in the collection.
123
+ points = self.client.query_points(
124
+ collection_name=f"{self.collection_prefix}_{collection_name}",
125
+ limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
126
+ )
127
+ return self._result_to_get_result(points.points)
128
+
129
+ def insert(self, collection_name: str, items: list[VectorItem]):
130
+ # Insert the items into the collection, if the collection does not exist, it will be created.
131
+ self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
132
+ points = self._create_points(items)
133
+ self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
134
+
135
+ def upsert(self, collection_name: str, items: list[VectorItem]):
136
+ # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
137
+ self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
138
+ points = self._create_points(items)
139
+ return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
140
+
141
+ def delete(
142
+ self,
143
+ collection_name: str,
144
+ ids: Optional[list[str]] = None,
145
+ filter: Optional[dict] = None,
146
+ ):
147
+ # Delete the items from the collection based on the ids.
148
+ field_conditions = []
149
+
150
+ if ids:
151
+ for id_value in ids:
152
+ field_conditions.append(
153
+ models.FieldCondition(
154
+ key="metadata.id",
155
+ match=models.MatchValue(value=id_value),
156
+ ),
157
+ ),
158
+ elif filter:
159
+ for key, value in filter.items():
160
+ field_conditions.append(
161
+ models.FieldCondition(
162
+ key=f"metadata.{key}",
163
+ match=models.MatchValue(value=value),
164
+ ),
165
+ ),
166
+
167
+ return self.client.delete(
168
+ collection_name=f"{self.collection_prefix}_{collection_name}",
169
+ points_selector=models.FilterSelector(
170
+ filter=models.Filter(must=field_conditions)
171
+ ),
172
+ )
173
+
174
+ def reset(self):
175
+ # Resets the database. This will delete all collections and item entries.
176
+ collection_names = self.client.get_collections().collections
177
+ for collection_name in collection_names:
178
+ if collection_name.name.startswith(self.collection_prefix):
179
+ self.client.delete_collection(collection_name=collection_name.name)
backend/open_webui/apps/retrieval/vector/main.py CHANGED
@@ -1,19 +1,19 @@
1
- from pydantic import BaseModel
2
- from typing import Optional, List, Any
3
-
4
-
5
- class VectorItem(BaseModel):
6
- id: str
7
- text: str
8
- vector: List[float | int]
9
- metadata: Any
10
-
11
-
12
- class GetResult(BaseModel):
13
- ids: Optional[List[List[str]]]
14
- documents: Optional[List[List[str]]]
15
- metadatas: Optional[List[List[Any]]]
16
-
17
-
18
- class SearchResult(GetResult):
19
- distances: Optional[List[List[float | int]]]
 
1
+ from pydantic import BaseModel
2
+ from typing import Optional, List, Any
3
+
4
+
5
+ class VectorItem(BaseModel):
6
+ id: str
7
+ text: str
8
+ vector: List[float | int]
9
+ metadata: Any
10
+
11
+
12
+ class GetResult(BaseModel):
13
+ ids: Optional[List[List[str]]]
14
+ documents: Optional[List[List[str]]]
15
+ metadatas: Optional[List[List[Any]]]
16
+
17
+
18
+ class SearchResult(GetResult):
19
+ distances: Optional[List[List[float | int]]]
backend/open_webui/apps/retrieval/web/brave.py CHANGED
@@ -1,42 +1,42 @@
1
- import logging
2
- from typing import Optional
3
-
4
- import requests
5
- from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
6
- from open_webui.env import SRC_LOG_LEVELS
7
-
8
- log = logging.getLogger(__name__)
9
- log.setLevel(SRC_LOG_LEVELS["RAG"])
10
-
11
-
12
- def search_brave(
13
- api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
14
- ) -> list[SearchResult]:
15
- """Search using Brave's Search API and return the results as a list of SearchResult objects.
16
-
17
- Args:
18
- api_key (str): A Brave Search API key
19
- query (str): The query to search for
20
- """
21
- url = "https://api.search.brave.com/res/v1/web/search"
22
- headers = {
23
- "Accept": "application/json",
24
- "Accept-Encoding": "gzip",
25
- "X-Subscription-Token": api_key,
26
- }
27
- params = {"q": query, "count": count}
28
-
29
- response = requests.get(url, headers=headers, params=params)
30
- response.raise_for_status()
31
-
32
- json_response = response.json()
33
- results = json_response.get("web", {}).get("results", [])
34
- if filter_list:
35
- results = get_filtered_results(results, filter_list)
36
-
37
- return [
38
- SearchResult(
39
- link=result["url"], title=result.get("title"), snippet=result.get("snippet")
40
- )
41
- for result in results[:count]
42
- ]
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import requests
5
+ from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
6
+ from open_webui.env import SRC_LOG_LEVELS
7
+
8
+ log = logging.getLogger(__name__)
9
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
10
+
11
+
12
+ def search_brave(
13
+ api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
14
+ ) -> list[SearchResult]:
15
+ """Search using Brave's Search API and return the results as a list of SearchResult objects.
16
+
17
+ Args:
18
+ api_key (str): A Brave Search API key
19
+ query (str): The query to search for
20
+ """
21
+ url = "https://api.search.brave.com/res/v1/web/search"
22
+ headers = {
23
+ "Accept": "application/json",
24
+ "Accept-Encoding": "gzip",
25
+ "X-Subscription-Token": api_key,
26
+ }
27
+ params = {"q": query, "count": count}
28
+
29
+ response = requests.get(url, headers=headers, params=params)
30
+ response.raise_for_status()
31
+
32
+ json_response = response.json()
33
+ results = json_response.get("web", {}).get("results", [])
34
+ if filter_list:
35
+ results = get_filtered_results(results, filter_list)
36
+
37
+ return [
38
+ SearchResult(
39
+ link=result["url"], title=result.get("title"), snippet=result.get("snippet")
40
+ )
41
+ for result in results[:count]
42
+ ]
backend/open_webui/apps/retrieval/web/duckduckgo.py CHANGED
@@ -1,50 +1,50 @@
1
- import logging
2
- from typing import Optional
3
-
4
- from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
5
- from duckduckgo_search import DDGS
6
- from open_webui.env import SRC_LOG_LEVELS
7
-
8
- log = logging.getLogger(__name__)
9
- log.setLevel(SRC_LOG_LEVELS["RAG"])
10
-
11
-
12
- def search_duckduckgo(
13
- query: str, count: int, filter_list: Optional[list[str]] = None
14
- ) -> list[SearchResult]:
15
- """
16
- Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
17
- Args:
18
- query (str): The query to search for
19
- count (int): The number of results to return
20
-
21
- Returns:
22
- list[SearchResult]: A list of search results
23
- """
24
- # Use the DDGS context manager to create a DDGS object
25
- with DDGS() as ddgs:
26
- # Use the ddgs.text() method to perform the search
27
- ddgs_gen = ddgs.text(
28
- query, safesearch="moderate", max_results=count, backend="api"
29
- )
30
- # Check if there are search results
31
- if ddgs_gen:
32
- # Convert the search results into a list
33
- search_results = [r for r in ddgs_gen]
34
-
35
- # Create an empty list to store the SearchResult objects
36
- results = []
37
- # Iterate over each search result
38
- for result in search_results:
39
- # Create a SearchResult object and append it to the results list
40
- results.append(
41
- SearchResult(
42
- link=result["href"],
43
- title=result.get("title"),
44
- snippet=result.get("body"),
45
- )
46
- )
47
- if filter_list:
48
- results = get_filtered_results(results, filter_list)
49
- # Return the list of search results
50
- return results
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
5
+ from duckduckgo_search import DDGS
6
+ from open_webui.env import SRC_LOG_LEVELS
7
+
8
+ log = logging.getLogger(__name__)
9
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
10
+
11
+
12
+ def search_duckduckgo(
13
+ query: str, count: int, filter_list: Optional[list[str]] = None
14
+ ) -> list[SearchResult]:
15
+ """
16
+ Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
17
+ Args:
18
+ query (str): The query to search for
19
+ count (int): The number of results to return
20
+
21
+ Returns:
22
+ list[SearchResult]: A list of search results
23
+ """
24
+ # Use the DDGS context manager to create a DDGS object
25
+ with DDGS() as ddgs:
26
+ # Use the ddgs.text() method to perform the search
27
+ ddgs_gen = ddgs.text(
28
+ query, safesearch="moderate", max_results=count, backend="api"
29
+ )
30
+ # Check if there are search results
31
+ if ddgs_gen:
32
+ # Convert the search results into a list
33
+ search_results = [r for r in ddgs_gen]
34
+
35
+ # Create an empty list to store the SearchResult objects
36
+ results = []
37
+ # Iterate over each search result
38
+ for result in search_results:
39
+ # Create a SearchResult object and append it to the results list
40
+ results.append(
41
+ SearchResult(
42
+ link=result["href"],
43
+ title=result.get("title"),
44
+ snippet=result.get("body"),
45
+ )
46
+ )
47
+ if filter_list:
48
+ results = get_filtered_results(results, filter_list)
49
+ # Return the list of search results
50
+ return results
backend/open_webui/apps/retrieval/web/google_pse.py CHANGED
@@ -1,50 +1,50 @@
1
- import logging
2
- from typing import Optional
3
-
4
- import requests
5
- from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
6
- from open_webui.env import SRC_LOG_LEVELS
7
-
8
- log = logging.getLogger(__name__)
9
- log.setLevel(SRC_LOG_LEVELS["RAG"])
10
-
11
-
12
- def search_google_pse(
13
- api_key: str,
14
- search_engine_id: str,
15
- query: str,
16
- count: int,
17
- filter_list: Optional[list[str]] = None,
18
- ) -> list[SearchResult]:
19
- """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
20
-
21
- Args:
22
- api_key (str): A Programmable Search Engine API key
23
- search_engine_id (str): A Programmable Search Engine ID
24
- query (str): The query to search for
25
- """
26
- url = "https://www.googleapis.com/customsearch/v1"
27
-
28
- headers = {"Content-Type": "application/json"}
29
- params = {
30
- "cx": search_engine_id,
31
- "q": query,
32
- "key": api_key,
33
- "num": count,
34
- }
35
-
36
- response = requests.request("GET", url, headers=headers, params=params)
37
- response.raise_for_status()
38
-
39
- json_response = response.json()
40
- results = json_response.get("items", [])
41
- if filter_list:
42
- results = get_filtered_results(results, filter_list)
43
- return [
44
- SearchResult(
45
- link=result["link"],
46
- title=result.get("title"),
47
- snippet=result.get("snippet"),
48
- )
49
- for result in results
50
- ]
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import requests
5
+ from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
6
+ from open_webui.env import SRC_LOG_LEVELS
7
+
8
+ log = logging.getLogger(__name__)
9
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
10
+
11
+
12
+ def search_google_pse(
13
+ api_key: str,
14
+ search_engine_id: str,
15
+ query: str,
16
+ count: int,
17
+ filter_list: Optional[list[str]] = None,
18
+ ) -> list[SearchResult]:
19
+ """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
20
+
21
+ Args:
22
+ api_key (str): A Programmable Search Engine API key
23
+ search_engine_id (str): A Programmable Search Engine ID
24
+ query (str): The query to search for
25
+ """
26
+ url = "https://www.googleapis.com/customsearch/v1"
27
+
28
+ headers = {"Content-Type": "application/json"}
29
+ params = {
30
+ "cx": search_engine_id,
31
+ "q": query,
32
+ "key": api_key,
33
+ "num": count,
34
+ }
35
+
36
+ response = requests.request("GET", url, headers=headers, params=params)
37
+ response.raise_for_status()
38
+
39
+ json_response = response.json()
40
+ results = json_response.get("items", [])
41
+ if filter_list:
42
+ results = get_filtered_results(results, filter_list)
43
+ return [
44
+ SearchResult(
45
+ link=result["link"],
46
+ title=result.get("title"),
47
+ snippet=result.get("snippet"),
48
+ )
49
+ for result in results
50
+ ]
backend/open_webui/apps/retrieval/web/jina_search.py CHANGED
@@ -1,41 +1,41 @@
1
- import logging
2
-
3
- import requests
4
- from open_webui.apps.retrieval.web.main import SearchResult
5
- from open_webui.env import SRC_LOG_LEVELS
6
- from yarl import URL
7
-
8
- log = logging.getLogger(__name__)
9
- log.setLevel(SRC_LOG_LEVELS["RAG"])
10
-
11
-
12
- def search_jina(query: str, count: int) -> list[SearchResult]:
13
- """
14
- Search using Jina's Search API and return the results as a list of SearchResult objects.
15
- Args:
16
- query (str): The query to search for
17
- count (int): The number of results to return
18
-
19
- Returns:
20
- list[SearchResult]: A list of search results
21
- """
22
- jina_search_endpoint = "https://s.jina.ai/"
23
- headers = {
24
- "Accept": "application/json",
25
- }
26
- url = str(URL(jina_search_endpoint + query))
27
- response = requests.get(url, headers=headers)
28
- response.raise_for_status()
29
- data = response.json()
30
-
31
- results = []
32
- for result in data["data"][:count]:
33
- results.append(
34
- SearchResult(
35
- link=result["url"],
36
- title=result.get("title"),
37
- snippet=result.get("content"),
38
- )
39
- )
40
-
41
- return results
 
1
+ import logging
2
+
3
+ import requests
4
+ from open_webui.apps.retrieval.web.main import SearchResult
5
+ from open_webui.env import SRC_LOG_LEVELS
6
+ from yarl import URL
7
+
8
+ log = logging.getLogger(__name__)
9
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
10
+
11
+
12
+ def search_jina(query: str, count: int) -> list[SearchResult]:
13
+ """
14
+ Search using Jina's Search API and return the results as a list of SearchResult objects.
15
+ Args:
16
+ query (str): The query to search for
17
+ count (int): The number of results to return
18
+
19
+ Returns:
20
+ list[SearchResult]: A list of search results
21
+ """
22
+ jina_search_endpoint = "https://s.jina.ai/"
23
+ headers = {
24
+ "Accept": "application/json",
25
+ }
26
+ url = str(URL(jina_search_endpoint + query))
27
+ response = requests.get(url, headers=headers)
28
+ response.raise_for_status()
29
+ data = response.json()
30
+
31
+ results = []
32
+ for result in data["data"][:count]:
33
+ results.append(
34
+ SearchResult(
35
+ link=result["url"],
36
+ title=result.get("title"),
37
+ snippet=result.get("content"),
38
+ )
39
+ )
40
+
41
+ return results
backend/open_webui/apps/retrieval/web/main.py CHANGED
@@ -1,22 +1,22 @@
1
- from typing import Optional
2
- from urllib.parse import urlparse
3
-
4
- from pydantic import BaseModel
5
-
6
-
7
- def get_filtered_results(results, filter_list):
8
- if not filter_list:
9
- return results
10
- filtered_results = []
11
- for result in results:
12
- url = result.get("url") or result.get("link", "")
13
- domain = urlparse(url).netloc
14
- if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
15
- filtered_results.append(result)
16
- return filtered_results
17
-
18
-
19
- class SearchResult(BaseModel):
20
- link: str
21
- title: Optional[str]
22
- snippet: Optional[str]
 
1
+ from typing import Optional
2
+ from urllib.parse import urlparse
3
+
4
+ from pydantic import BaseModel
5
+
6
+
7
+ def get_filtered_results(results, filter_list):
8
+ if not filter_list:
9
+ return results
10
+ filtered_results = []
11
+ for result in results:
12
+ url = result.get("url") or result.get("link", "")
13
+ domain = urlparse(url).netloc
14
+ if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
15
+ filtered_results.append(result)
16
+ return filtered_results
17
+
18
+
19
+ class SearchResult(BaseModel):
20
+ link: str
21
+ title: Optional[str]
22
+ snippet: Optional[str]
backend/open_webui/apps/retrieval/web/searchapi.py CHANGED
@@ -1,48 +1,48 @@
1
- import logging
2
- from typing import Optional
3
- from urllib.parse import urlencode
4
-
5
- import requests
6
- from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
7
- from open_webui.env import SRC_LOG_LEVELS
8
-
9
- log = logging.getLogger(__name__)
10
- log.setLevel(SRC_LOG_LEVELS["RAG"])
11
-
12
-
13
- def search_searchapi(
14
- api_key: str,
15
- engine: str,
16
- query: str,
17
- count: int,
18
- filter_list: Optional[list[str]] = None,
19
- ) -> list[SearchResult]:
20
- """Search using searchapi.io's API and return the results as a list of SearchResult objects.
21
-
22
- Args:
23
- api_key (str): A searchapi.io API key
24
- query (str): The query to search for
25
- """
26
- url = "https://www.searchapi.io/api/v1/search"
27
-
28
- engine = engine or "google"
29
-
30
- payload = {"engine": engine, "q": query, "api_key": api_key}
31
-
32
- url = f"{url}?{urlencode(payload)}"
33
- response = requests.request("GET", url)
34
-
35
- json_response = response.json()
36
- log.info(f"results from searchapi search: {json_response}")
37
-
38
- results = sorted(
39
- json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
40
- )
41
- if filter_list:
42
- results = get_filtered_results(results, filter_list)
43
- return [
44
- SearchResult(
45
- link=result["link"], title=result["title"], snippet=result["snippet"]
46
- )
47
- for result in results[:count]
48
- ]
 
1
+ import logging
2
+ from typing import Optional
3
+ from urllib.parse import urlencode
4
+
5
+ import requests
6
+ from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
7
+ from open_webui.env import SRC_LOG_LEVELS
8
+
9
+ log = logging.getLogger(__name__)
10
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
11
+
12
+
13
+ def search_searchapi(
14
+ api_key: str,
15
+ engine: str,
16
+ query: str,
17
+ count: int,
18
+ filter_list: Optional[list[str]] = None,
19
+ ) -> list[SearchResult]:
20
+ """Search using searchapi.io's API and return the results as a list of SearchResult objects.
21
+
22
+ Args:
23
+ api_key (str): A searchapi.io API key
24
+ query (str): The query to search for
25
+ """
26
+ url = "https://www.searchapi.io/api/v1/search"
27
+
28
+ engine = engine or "google"
29
+
30
+ payload = {"engine": engine, "q": query, "api_key": api_key}
31
+
32
+ url = f"{url}?{urlencode(payload)}"
33
+ response = requests.request("GET", url)
34
+
35
+ json_response = response.json()
36
+ log.info(f"results from searchapi search: {json_response}")
37
+
38
+ results = sorted(
39
+ json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
40
+ )
41
+ if filter_list:
42
+ results = get_filtered_results(results, filter_list)
43
+ return [
44
+ SearchResult(
45
+ link=result["link"], title=result["title"], snippet=result["snippet"]
46
+ )
47
+ for result in results[:count]
48
+ ]
backend/open_webui/apps/retrieval/web/searxng.py CHANGED
@@ -1,91 +1,91 @@
1
- import logging
2
- from typing import Optional
3
-
4
- import requests
5
- from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
6
- from open_webui.env import SRC_LOG_LEVELS
7
-
8
- log = logging.getLogger(__name__)
9
- log.setLevel(SRC_LOG_LEVELS["RAG"])
10
-
11
-
12
- def search_searxng(
13
- query_url: str,
14
- query: str,
15
- count: int,
16
- filter_list: Optional[list[str]] = None,
17
- **kwargs,
18
- ) -> list[SearchResult]:
19
- """
20
- Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
21
-
22
- The function allows passing additional parameters such as language or time_range to tailor the search result.
23
-
24
- Args:
25
- query_url (str): The base URL of the SearXNG server.
26
- query (str): The search term or question to find in the SearXNG database.
27
- count (int): The maximum number of results to retrieve from the search.
28
-
29
- Keyword Args:
30
- language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
31
- safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
32
- time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
33
- categories: (Optional[list[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
34
-
35
- Returns:
36
- list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
37
-
38
- Raise:
39
- requests.exceptions.RequestException: If a request error occurs during the search process.
40
- """
41
-
42
- # Default values for optional parameters are provided as empty strings or None when not specified.
43
- language = kwargs.get("language", "en-US")
44
- safesearch = kwargs.get("safesearch", "1")
45
- time_range = kwargs.get("time_range", "")
46
- categories = "".join(kwargs.get("categories", []))
47
-
48
- params = {
49
- "q": query,
50
- "format": "json",
51
- "pageno": 1,
52
- "safesearch": safesearch,
53
- "language": language,
54
- "time_range": time_range,
55
- "categories": categories,
56
- "theme": "simple",
57
- "image_proxy": 0,
58
- }
59
-
60
- # Legacy query format
61
- if "<query>" in query_url:
62
- # Strip all query parameters from the URL
63
- query_url = query_url.split("?")[0]
64
-
65
- log.debug(f"searching {query_url}")
66
-
67
- response = requests.get(
68
- query_url,
69
- headers={
70
- "User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
71
- "Accept": "text/html",
72
- "Accept-Encoding": "gzip, deflate",
73
- "Accept-Language": "en-US,en;q=0.5",
74
- "Connection": "keep-alive",
75
- },
76
- params=params,
77
- )
78
-
79
- response.raise_for_status() # Raise an exception for HTTP errors.
80
-
81
- json_response = response.json()
82
- results = json_response.get("results", [])
83
- sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
84
- if filter_list:
85
- sorted_results = get_filtered_results(sorted_results, filter_list)
86
- return [
87
- SearchResult(
88
- link=result["url"], title=result.get("title"), snippet=result.get("content")
89
- )
90
- for result in sorted_results[:count]
91
- ]
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import requests
5
+ from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
6
+ from open_webui.env import SRC_LOG_LEVELS
7
+
8
+ log = logging.getLogger(__name__)
9
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
10
+
11
+
12
+ def search_searxng(
13
+ query_url: str,
14
+ query: str,
15
+ count: int,
16
+ filter_list: Optional[list[str]] = None,
17
+ **kwargs,
18
+ ) -> list[SearchResult]:
19
+ """
20
+ Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
21
+
22
+ The function allows passing additional parameters such as language or time_range to tailor the search result.
23
+
24
+ Args:
25
+ query_url (str): The base URL of the SearXNG server.
26
+ query (str): The search term or question to find in the SearXNG database.
27
+ count (int): The maximum number of results to retrieve from the search.
28
+
29
+ Keyword Args:
30
+ language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
31
+ safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
32
+ time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
33
+ categories: (Optional[list[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
34
+
35
+ Returns:
36
+ list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
37
+
38
+ Raise:
39
+ requests.exceptions.RequestException: If a request error occurs during the search process.
40
+ """
41
+
42
+ # Default values for optional parameters are provided as empty strings or None when not specified.
43
+ language = kwargs.get("language", "en-US")
44
+ safesearch = kwargs.get("safesearch", "1")
45
+ time_range = kwargs.get("time_range", "")
46
+ categories = "".join(kwargs.get("categories", []))
47
+
48
+ params = {
49
+ "q": query,
50
+ "format": "json",
51
+ "pageno": 1,
52
+ "safesearch": safesearch,
53
+ "language": language,
54
+ "time_range": time_range,
55
+ "categories": categories,
56
+ "theme": "simple",
57
+ "image_proxy": 0,
58
+ }
59
+
60
+ # Legacy query format
61
+ if "<query>" in query_url:
62
+ # Strip all query parameters from the URL
63
+ query_url = query_url.split("?")[0]
64
+
65
+ log.debug(f"searching {query_url}")
66
+
67
+ response = requests.get(
68
+ query_url,
69
+ headers={
70
+ "User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
71
+ "Accept": "text/html",
72
+ "Accept-Encoding": "gzip, deflate",
73
+ "Accept-Language": "en-US,en;q=0.5",
74
+ "Connection": "keep-alive",
75
+ },
76
+ params=params,
77
+ )
78
+
79
+ response.raise_for_status() # Raise an exception for HTTP errors.
80
+
81
+ json_response = response.json()
82
+ results = json_response.get("results", [])
83
+ sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
84
+ if filter_list:
85
+ sorted_results = get_filtered_results(sorted_results, filter_list)
86
+ return [
87
+ SearchResult(
88
+ link=result["url"], title=result.get("title"), snippet=result.get("content")
89
+ )
90
+ for result in sorted_results[:count]
91
+ ]
backend/open_webui/apps/retrieval/web/serper.py CHANGED
@@ -1,43 +1,43 @@
1
- import json
2
- import logging
3
- from typing import Optional
4
-
5
- import requests
6
- from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
7
- from open_webui.env import SRC_LOG_LEVELS
8
-
9
- log = logging.getLogger(__name__)
10
- log.setLevel(SRC_LOG_LEVELS["RAG"])
11
-
12
-
13
- def search_serper(
14
- api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
15
- ) -> list[SearchResult]:
16
- """Search using serper.dev's API and return the results as a list of SearchResult objects.
17
-
18
- Args:
19
- api_key (str): A serper.dev API key
20
- query (str): The query to search for
21
- """
22
- url = "https://google.serper.dev/search"
23
-
24
- payload = json.dumps({"q": query})
25
- headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}
26
-
27
- response = requests.request("POST", url, headers=headers, data=payload)
28
- response.raise_for_status()
29
-
30
- json_response = response.json()
31
- results = sorted(
32
- json_response.get("organic", []), key=lambda x: x.get("position", 0)
33
- )
34
- if filter_list:
35
- results = get_filtered_results(results, filter_list)
36
- return [
37
- SearchResult(
38
- link=result["link"],
39
- title=result.get("title"),
40
- snippet=result.get("description"),
41
- )
42
- for result in results[:count]
43
- ]
 
1
+ import json
2
+ import logging
3
+ from typing import Optional
4
+
5
+ import requests
6
+ from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
7
+ from open_webui.env import SRC_LOG_LEVELS
8
+
9
+ log = logging.getLogger(__name__)
10
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
11
+
12
+
13
+ def search_serper(
14
+ api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
15
+ ) -> list[SearchResult]:
16
+ """Search using serper.dev's API and return the results as a list of SearchResult objects.
17
+
18
+ Args:
19
+ api_key (str): A serper.dev API key
20
+ query (str): The query to search for
21
+ """
22
+ url = "https://google.serper.dev/search"
23
+
24
+ payload = json.dumps({"q": query})
25
+ headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}
26
+
27
+ response = requests.request("POST", url, headers=headers, data=payload)
28
+ response.raise_for_status()
29
+
30
+ json_response = response.json()
31
+ results = sorted(
32
+ json_response.get("organic", []), key=lambda x: x.get("position", 0)
33
+ )
34
+ if filter_list:
35
+ results = get_filtered_results(results, filter_list)
36
+ return [
37
+ SearchResult(
38
+ link=result["link"],
39
+ title=result.get("title"),
40
+ snippet=result.get("description"),
41
+ )
42
+ for result in results[:count]
43
+ ]
backend/open_webui/apps/retrieval/web/serply.py CHANGED
@@ -1,69 +1,69 @@
1
- import logging
2
- from typing import Optional
3
- from urllib.parse import urlencode
4
-
5
- import requests
6
- from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
7
- from open_webui.env import SRC_LOG_LEVELS
8
-
9
- log = logging.getLogger(__name__)
10
- log.setLevel(SRC_LOG_LEVELS["RAG"])
11
-
12
-
13
- def search_serply(
14
- api_key: str,
15
- query: str,
16
- count: int,
17
- hl: str = "us",
18
- limit: int = 10,
19
- device_type: str = "desktop",
20
- proxy_location: str = "US",
21
- filter_list: Optional[list[str]] = None,
22
- ) -> list[SearchResult]:
23
- """Search using serper.dev's API and return the results as a list of SearchResult objects.
24
-
25
- Args:
26
- api_key (str): A serply.io API key
27
- query (str): The query to search for
28
- hl (str): Host Language code to display results in (reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
29
- limit (int): The maximum number of results to return [10-100, defaults to 10]
30
- """
31
- log.info("Searching with Serply")
32
-
33
- url = "https://api.serply.io/v1/search/"
34
-
35
- query_payload = {
36
- "q": query,
37
- "language": "en",
38
- "num": limit,
39
- "gl": proxy_location.upper(),
40
- "hl": hl.lower(),
41
- }
42
-
43
- url = f"{url}{urlencode(query_payload)}"
44
- headers = {
45
- "X-API-KEY": api_key,
46
- "X-User-Agent": device_type,
47
- "User-Agent": "open-webui",
48
- "X-Proxy-Location": proxy_location,
49
- }
50
-
51
- response = requests.request("GET", url, headers=headers)
52
- response.raise_for_status()
53
-
54
- json_response = response.json()
55
- log.info(f"results from serply search: {json_response}")
56
-
57
- results = sorted(
58
- json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
59
- )
60
- if filter_list:
61
- results = get_filtered_results(results, filter_list)
62
- return [
63
- SearchResult(
64
- link=result["link"],
65
- title=result.get("title"),
66
- snippet=result.get("description"),
67
- )
68
- for result in results[:count]
69
- ]
 
1
+ import logging
2
+ from typing import Optional
3
+ from urllib.parse import urlencode
4
+
5
+ import requests
6
+ from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
7
+ from open_webui.env import SRC_LOG_LEVELS
8
+
9
+ log = logging.getLogger(__name__)
10
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
11
+
12
+
13
+ def search_serply(
14
+ api_key: str,
15
+ query: str,
16
+ count: int,
17
+ hl: str = "us",
18
+ limit: int = 10,
19
+ device_type: str = "desktop",
20
+ proxy_location: str = "US",
21
+ filter_list: Optional[list[str]] = None,
22
+ ) -> list[SearchResult]:
23
+ """Search using serper.dev's API and return the results as a list of SearchResult objects.
24
+
25
+ Args:
26
+ api_key (str): A serply.io API key
27
+ query (str): The query to search for
28
+ hl (str): Host Language code to display results in (reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
29
+ limit (int): The maximum number of results to return [10-100, defaults to 10]
30
+ """
31
+ log.info("Searching with Serply")
32
+
33
+ url = "https://api.serply.io/v1/search/"
34
+
35
+ query_payload = {
36
+ "q": query,
37
+ "language": "en",
38
+ "num": limit,
39
+ "gl": proxy_location.upper(),
40
+ "hl": hl.lower(),
41
+ }
42
+
43
+ url = f"{url}{urlencode(query_payload)}"
44
+ headers = {
45
+ "X-API-KEY": api_key,
46
+ "X-User-Agent": device_type,
47
+ "User-Agent": "open-webui",
48
+ "X-Proxy-Location": proxy_location,
49
+ }
50
+
51
+ response = requests.request("GET", url, headers=headers)
52
+ response.raise_for_status()
53
+
54
+ json_response = response.json()
55
+ log.info(f"results from serply search: {json_response}")
56
+
57
+ results = sorted(
58
+ json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
59
+ )
60
+ if filter_list:
61
+ results = get_filtered_results(results, filter_list)
62
+ return [
63
+ SearchResult(
64
+ link=result["link"],
65
+ title=result.get("title"),
66
+ snippet=result.get("description"),
67
+ )
68
+ for result in results[:count]
69
+ ]
backend/open_webui/apps/retrieval/web/serpstack.py CHANGED
@@ -1,48 +1,48 @@
1
- import logging
2
- from typing import Optional
3
-
4
- import requests
5
- from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
6
- from open_webui.env import SRC_LOG_LEVELS
7
-
8
- log = logging.getLogger(__name__)
9
- log.setLevel(SRC_LOG_LEVELS["RAG"])
10
-
11
-
12
- def search_serpstack(
13
- api_key: str,
14
- query: str,
15
- count: int,
16
- filter_list: Optional[list[str]] = None,
17
- https_enabled: bool = True,
18
- ) -> list[SearchResult]:
19
- """Search using serpstack.com's and return the results as a list of SearchResult objects.
20
-
21
- Args:
22
- api_key (str): A serpstack.com API key
23
- query (str): The query to search for
24
- https_enabled (bool): Whether to use HTTPS or HTTP for the API request
25
- """
26
- url = f"{'https' if https_enabled else 'http'}://api.serpstack.com/search"
27
-
28
- headers = {"Content-Type": "application/json"}
29
- params = {
30
- "access_key": api_key,
31
- "query": query,
32
- }
33
-
34
- response = requests.request("POST", url, headers=headers, params=params)
35
- response.raise_for_status()
36
-
37
- json_response = response.json()
38
- results = sorted(
39
- json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
40
- )
41
- if filter_list:
42
- results = get_filtered_results(results, filter_list)
43
- return [
44
- SearchResult(
45
- link=result["url"], title=result.get("title"), snippet=result.get("snippet")
46
- )
47
- for result in results[:count]
48
- ]
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import requests
5
+ from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
6
+ from open_webui.env import SRC_LOG_LEVELS
7
+
8
+ log = logging.getLogger(__name__)
9
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
10
+
11
+
12
+ def search_serpstack(
13
+ api_key: str,
14
+ query: str,
15
+ count: int,
16
+ filter_list: Optional[list[str]] = None,
17
+ https_enabled: bool = True,
18
+ ) -> list[SearchResult]:
19
+ """Search using serpstack.com's and return the results as a list of SearchResult objects.
20
+
21
+ Args:
22
+ api_key (str): A serpstack.com API key
23
+ query (str): The query to search for
24
+ https_enabled (bool): Whether to use HTTPS or HTTP for the API request
25
+ """
26
+ url = f"{'https' if https_enabled else 'http'}://api.serpstack.com/search"
27
+
28
+ headers = {"Content-Type": "application/json"}
29
+ params = {
30
+ "access_key": api_key,
31
+ "query": query,
32
+ }
33
+
34
+ response = requests.request("POST", url, headers=headers, params=params)
35
+ response.raise_for_status()
36
+
37
+ json_response = response.json()
38
+ results = sorted(
39
+ json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
40
+ )
41
+ if filter_list:
42
+ results = get_filtered_results(results, filter_list)
43
+ return [
44
+ SearchResult(
45
+ link=result["url"], title=result.get("title"), snippet=result.get("snippet")
46
+ )
47
+ for result in results[:count]
48
+ ]
backend/open_webui/apps/retrieval/web/tavily.py CHANGED
@@ -1,38 +1,38 @@
1
- import logging
2
-
3
- import requests
4
- from open_webui.apps.retrieval.web.main import SearchResult
5
- from open_webui.env import SRC_LOG_LEVELS
6
-
7
- log = logging.getLogger(__name__)
8
- log.setLevel(SRC_LOG_LEVELS["RAG"])
9
-
10
-
11
- def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
12
- """Search using Tavily's Search API and return the results as a list of SearchResult objects.
13
-
14
- Args:
15
- api_key (str): A Tavily Search API key
16
- query (str): The query to search for
17
-
18
- Returns:
19
- list[SearchResult]: A list of search results
20
- """
21
- url = "https://api.tavily.com/search"
22
- data = {"query": query, "api_key": api_key}
23
-
24
- response = requests.post(url, json=data)
25
- response.raise_for_status()
26
-
27
- json_response = response.json()
28
-
29
- raw_search_results = json_response.get("results", [])
30
-
31
- return [
32
- SearchResult(
33
- link=result["url"],
34
- title=result.get("title", ""),
35
- snippet=result.get("content"),
36
- )
37
- for result in raw_search_results[:count]
38
- ]
 
1
+ import logging
2
+
3
+ import requests
4
+ from open_webui.apps.retrieval.web.main import SearchResult
5
+ from open_webui.env import SRC_LOG_LEVELS
6
+
7
+ log = logging.getLogger(__name__)
8
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
9
+
10
+
11
+ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
12
+ """Search using Tavily's Search API and return the results as a list of SearchResult objects.
13
+
14
+ Args:
15
+ api_key (str): A Tavily Search API key
16
+ query (str): The query to search for
17
+
18
+ Returns:
19
+ list[SearchResult]: A list of search results
20
+ """
21
+ url = "https://api.tavily.com/search"
22
+ data = {"query": query, "api_key": api_key}
23
+
24
+ response = requests.post(url, json=data)
25
+ response.raise_for_status()
26
+
27
+ json_response = response.json()
28
+
29
+ raw_search_results = json_response.get("results", [])
30
+
31
+ return [
32
+ SearchResult(
33
+ link=result["url"],
34
+ title=result.get("title", ""),
35
+ snippet=result.get("content"),
36
+ )
37
+ for result in raw_search_results[:count]
38
+ ]
backend/open_webui/apps/retrieval/web/testdata/brave.json CHANGED
The diff for this file is too large to render. See raw diff
 
backend/open_webui/apps/retrieval/web/testdata/google_pse.json CHANGED
@@ -1,442 +1,442 @@
1
- {
2
- "kind": "customsearch#search",
3
- "url": {
4
- "type": "application/json",
5
- "template": "https://www.googleapis.com/customsearch/v1?q={searchTerms}&num={count?}&start={startIndex?}&lr={language?}&safe={safe?}&cx={cx?}&sort={sort?}&filter={filter?}&gl={gl?}&cr={cr?}&googlehost={googleHost?}&c2coff={disableCnTwTranslation?}&hq={hq?}&hl={hl?}&siteSearch={siteSearch?}&siteSearchFilter={siteSearchFilter?}&exactTerms={exactTerms?}&excludeTerms={excludeTerms?}&linkSite={linkSite?}&orTerms={orTerms?}&dateRestrict={dateRestrict?}&lowRange={lowRange?}&highRange={highRange?}&searchType={searchType}&fileType={fileType?}&rights={rights?}&imgSize={imgSize?}&imgType={imgType?}&imgColorType={imgColorType?}&imgDominantColor={imgDominantColor?}&alt=json"
6
- },
7
- "queries": {
8
- "request": [
9
- {
10
- "title": "Google Custom Search - lectures",
11
- "totalResults": "2450000000",
12
- "searchTerms": "lectures",
13
- "count": 10,
14
- "startIndex": 1,
15
- "inputEncoding": "utf8",
16
- "outputEncoding": "utf8",
17
- "safe": "off",
18
- "cx": "0473ef98502d44e18"
19
- }
20
- ],
21
- "nextPage": [
22
- {
23
- "title": "Google Custom Search - lectures",
24
- "totalResults": "2450000000",
25
- "searchTerms": "lectures",
26
- "count": 10,
27
- "startIndex": 11,
28
- "inputEncoding": "utf8",
29
- "outputEncoding": "utf8",
30
- "safe": "off",
31
- "cx": "0473ef98502d44e18"
32
- }
33
- ]
34
- },
35
- "context": {
36
- "title": "LLM Search"
37
- },
38
- "searchInformation": {
39
- "searchTime": 0.445959,
40
- "formattedSearchTime": "0.45",
41
- "totalResults": "2450000000",
42
- "formattedTotalResults": "2,450,000,000"
43
- },
44
- "items": [
45
- {
46
- "kind": "customsearch#result",
47
- "title": "The Feynman Lectures on Physics",
48
- "htmlTitle": "The Feynman \u003cb\u003eLectures\u003c/b\u003e on Physics",
49
- "link": "https://www.feynmanlectures.caltech.edu/",
50
- "displayLink": "www.feynmanlectures.caltech.edu",
51
- "snippet": "This edition has been designed for ease of reading on devices of any size or shape; text, figures and equations can all be zoomed without degradation.",
52
- "htmlSnippet": "This edition has been designed for ease of reading on devices of any size or shape; text, figures and equations can all be zoomed without degradation.",
53
- "cacheId": "CyXMWYWs9UEJ",
54
- "formattedUrl": "https://www.feynmanlectures.caltech.edu/",
55
- "htmlFormattedUrl": "https://www.feynman\u003cb\u003electures\u003c/b\u003e.caltech.edu/",
56
- "pagemap": {
57
- "metatags": [
58
- {
59
- "viewport": "width=device-width, initial-scale=1.0"
60
- }
61
- ]
62
- }
63
- },
64
- {
65
- "kind": "customsearch#result",
66
- "title": "Video Lectures",
67
- "htmlTitle": "Video \u003cb\u003eLectures\u003c/b\u003e",
68
- "link": "https://www.reddit.com/r/lectures/",
69
- "displayLink": "www.reddit.com",
70
- "snippet": "r/lectures: This subreddit is all about video lectures, talks and interesting public speeches. The topics include mathematics, physics, computer…",
71
- "htmlSnippet": "r/\u003cb\u003electures\u003c/b\u003e: This subreddit is all about video \u003cb\u003electures\u003c/b\u003e, talks and interesting public speeches. The topics include mathematics, physics, computer…",
72
- "formattedUrl": "https://www.reddit.com/r/lectures/",
73
- "htmlFormattedUrl": "https://www.reddit.com/r/\u003cb\u003electures\u003c/b\u003e/",
74
- "pagemap": {
75
- "cse_thumbnail": [
76
- {
77
- "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTZtOjhfkgUKQbL3DZxe5F6OVsgeDNffleObjJ7n9RllKQTSsimax7VIaY&s",
78
- "width": "192",
79
- "height": "192"
80
- }
81
- ],
82
- "metatags": [
83
- {
84
- "og:image": "https://www.redditstatic.com/shreddit/assets/favicon/192x192.png",
85
- "theme-color": "#000000",
86
- "og:image:width": "256",
87
- "og:type": "website",
88
- "twitter:card": "summary",
89
- "twitter:title": "r/lectures",
90
- "og:site_name": "Reddit",
91
- "og:title": "r/lectures",
92
- "og:image:height": "256",
93
- "bingbot": "noarchive",
94
- "msapplication-navbutton-color": "#000000",
95
- "og:description": "This subreddit is all about video lectures, talks and interesting public speeches.\n\nThe topics include mathematics, physics, computer science, programming, engineering, biology, medicine, economics, politics, social sciences, and any other subjects!",
96
- "twitter:image": "https://www.redditstatic.com/shreddit/assets/favicon/192x192.png",
97
- "apple-mobile-web-app-status-bar-style": "black",
98
- "twitter:site": "@reddit",
99
- "viewport": "width=device-width, initial-scale=1, viewport-fit=cover",
100
- "apple-mobile-web-app-capable": "yes",
101
- "og:ttl": "600",
102
- "og:url": "https://www.reddit.com/r/lectures/"
103
- }
104
- ],
105
- "cse_image": [
106
- {
107
- "src": "https://www.redditstatic.com/shreddit/assets/favicon/192x192.png"
108
- }
109
- ]
110
- }
111
- },
112
- {
113
- "kind": "customsearch#result",
114
- "title": "Lectures & Discussions | Flint Institute of Arts",
115
- "htmlTitle": "\u003cb\u003eLectures\u003c/b\u003e &amp; Discussions | Flint Institute of Arts",
116
- "link": "https://flintarts.org/events/lectures",
117
- "displayLink": "flintarts.org",
118
- "snippet": "It will trace the intricate relationship between jewelry, attire, and the expression of personal identity, social hierarchy, and spiritual belief systems that ...",
119
- "htmlSnippet": "It will trace the intricate relationship between jewelry, attire, and the expression of personal identity, social hierarchy, and spiritual belief systems that&nbsp;...",
120
- "cacheId": "jvpb9DxrfxoJ",
121
- "formattedUrl": "https://flintarts.org/events/lectures",
122
- "htmlFormattedUrl": "https://flintarts.org/events/\u003cb\u003electures\u003c/b\u003e",
123
- "pagemap": {
124
- "cse_thumbnail": [
125
- {
126
- "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcS23tMtAeNhJbOWdGxShYsmnyzFdzOC9Hb7lRykA9Pw72z1IlKTkjTdZw&s",
127
- "width": "447",
128
- "height": "113"
129
- }
130
- ],
131
- "metatags": [
132
- {
133
- "og:image": "https://flintarts.org/uploads/images/page-headers/_headerImage/nightshot.jpg",
134
- "og:type": "website",
135
- "viewport": "width=device-width, initial-scale=1",
136
- "og:title": "Lectures & Discussions | Flint Institute of Arts",
137
- "og:description": "The Flint Institute of Arts is the second largest art museum in Michigan and one of the largest museum art schools in the nation."
138
- }
139
- ],
140
- "cse_image": [
141
- {
142
- "src": "https://flintarts.org/uploads/images/page-headers/_headerImage/nightshot.jpg"
143
- }
144
- ]
145
- }
146
- },
147
- {
148
- "kind": "customsearch#result",
149
- "title": "Mandel Lectures | Mandel Center for the Humanities ... - Waltham",
150
- "htmlTitle": "Mandel \u003cb\u003eLectures\u003c/b\u003e | Mandel Center for the Humanities ... - Waltham",
151
- "link": "https://www.brandeis.edu/mandel-center-humanities/mandel-lectures.html",
152
- "displayLink": "www.brandeis.edu",
153
- "snippet": "Past Lectures · Lecture 1: \"Invisible Music: The Sonic Idea of Black Revolution From Captivity to Reconstruction\" · Lecture 2: \"Solidarity in Sound: Grassroots ...",
154
- "htmlSnippet": "Past \u003cb\u003eLectures\u003c/b\u003e &middot; \u003cb\u003eLecture\u003c/b\u003e 1: &quot;Invisible Music: The Sonic Idea of Black Revolution From Captivity to Reconstruction&quot; &middot; \u003cb\u003eLecture\u003c/b\u003e 2: &quot;Solidarity in Sound: Grassroots&nbsp;...",
155
- "cacheId": "cQLOZr0kgEEJ",
156
- "formattedUrl": "https://www.brandeis.edu/mandel-center-humanities/mandel-lectures.html",
157
- "htmlFormattedUrl": "https://www.brandeis.edu/mandel-center-humanities/mandel-\u003cb\u003electures\u003c/b\u003e.html",
158
- "pagemap": {
159
- "cse_thumbnail": [
160
- {
161
- "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQWlU7bcJ5pIHk7RBCk2QKE-48ejF7hyPV0pr-20_cBt2BGdfKtiYXBuyw&s",
162
- "width": "275",
163
- "height": "183"
164
- }
165
- ],
166
- "metatags": [
167
- {
168
- "og:image": "https://www.brandeis.edu/mandel-center-humanities/events/events-images/mlhzumba",
169
- "twitter:card": "summary_large_image",
170
- "viewport": "width=device-width,initial-scale=1,minimum-scale=1",
171
- "og:title": "Mandel Lectures in the Humanities",
172
- "og:url": "https://www.brandeis.edu/mandel-center-humanities/mandel-lectures.html",
173
- "og:description": "Annual Lecture Series",
174
- "twitter:image": "https://www.brandeis.edu/mandel-center-humanities/events/events-images/mlhzumba"
175
- }
176
- ],
177
- "cse_image": [
178
- {
179
- "src": "https://www.brandeis.edu/mandel-center-humanities/events/events-images/mlhzumba"
180
- }
181
- ]
182
- }
183
- },
184
- {
185
- "kind": "customsearch#result",
186
- "title": "Brian Douglas - YouTube",
187
- "htmlTitle": "Brian Douglas - YouTube",
188
- "link": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
189
- "displayLink": "www.youtube.com",
190
- "snippet": "Welcome to Control Systems Lectures! This collection of videos is intended to supplement a first year controls class, not replace it.",
191
- "htmlSnippet": "Welcome to Control Systems \u003cb\u003eLectures\u003c/b\u003e! This collection of videos is intended to supplement a first year controls class, not replace it.",
192
- "cacheId": "NEROyBHolL0J",
193
- "formattedUrl": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
194
- "htmlFormattedUrl": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
195
- "pagemap": {
196
- "hcard": [
197
- {
198
- "fn": "Brian Douglas",
199
- "url": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg"
200
- }
201
- ],
202
- "cse_thumbnail": [
203
- {
204
- "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcR7G0CeCBz_wVTZgjnhEr2QbiKP7f3uYzKitZYn74Mi32cDmVxvsegJoLI&s",
205
- "width": "225",
206
- "height": "225"
207
- }
208
- ],
209
- "imageobject": [
210
- {
211
- "width": "900",
212
- "url": "https://yt3.googleusercontent.com/ytc/AIdro_nLo68wetImbwGUYP3stve_iKmAEccjhqB-q4o79xdInN4=s900-c-k-c0x00ffffff-no-rj",
213
- "height": "900"
214
- }
215
- ],
216
- "person": [
217
- {
218
- "name": "Brian Douglas",
219
- "url": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg"
220
- }
221
- ],
222
- "metatags": [
223
- {
224
- "apple-itunes-app": "app-id=544007664, app-argument=https://m.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg?referring_app=com.apple.mobilesafari-smartbanner, affiliate-data=ct=smart_app_banner_polymer&pt=9008",
225
- "og:image": "https://yt3.googleusercontent.com/ytc/AIdro_nLo68wetImbwGUYP3stve_iKmAEccjhqB-q4o79xdInN4=s900-c-k-c0x00ffffff-no-rj",
226
- "twitter:app:url:iphone": "vnd.youtube://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
227
- "twitter:app:id:googleplay": "com.google.android.youtube",
228
- "theme-color": "rgb(255, 255, 255)",
229
- "og:image:width": "900",
230
- "twitter:card": "summary",
231
- "og:site_name": "YouTube",
232
- "twitter:url": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
233
- "twitter:app:url:ipad": "vnd.youtube://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
234
- "al:android:package": "com.google.android.youtube",
235
- "twitter:app:name:googleplay": "YouTube",
236
- "al:ios:url": "vnd.youtube://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
237
- "twitter:app:id:iphone": "544007664",
238
- "og:description": "Welcome to Control Systems Lectures! This collection of videos is intended to supplement a first year controls class, not replace it. My goal is to take specific concepts in controls and expand on them in order to provide an intuitive understanding which will ultimately make you a better controls engineer. \n\nI'm glad you made it to my channel and I hope you find it useful.\n\nShoot me a message at controlsystemlectures@gmail.com, leave a comment or question and I'll get back to you if I can. Don't forget to subscribe!\n \nTwitter: @BrianBDouglas for engineering tweets and announcement of new videos.\nWebpage: http://engineeringmedia.com\n\nHere is the hardware/software I use: http://www.youtube.com/watch?v=m-M5_mIyHe4\n\nHere's a list of my favorite references: http://bit.ly/2skvmWd\n\n--Brian",
239
- "al:ios:app_store_id": "544007664",
240
- "twitter:image": "https://yt3.googleusercontent.com/ytc/AIdro_nLo68wetImbwGUYP3stve_iKmAEccjhqB-q4o79xdInN4=s900-c-k-c0x00ffffff-no-rj",
241
- "twitter:site": "@youtube",
242
- "og:type": "profile",
243
- "twitter:title": "Brian Douglas",
244
- "al:ios:app_name": "YouTube",
245
- "og:title": "Brian Douglas",
246
- "og:image:height": "900",
247
- "twitter:app:id:ipad": "544007664",
248
- "al:web:url": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg?feature=applinks",
249
- "al:android:url": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg?feature=applinks",
250
- "fb:app_id": "87741124305",
251
- "twitter:app:url:googleplay": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
252
- "twitter:app:name:ipad": "YouTube",
253
- "viewport": "width=device-width, initial-scale=1.0, minimum-scale=1.0, maximum-scale=1.0, user-scalable=no,",
254
- "twitter:description": "Welcome to Control Systems Lectures! This collection of videos is intended to supplement a first year controls class, not replace it. My goal is to take specific concepts in controls and expand on them in order to provide an intuitive understanding which will ultimately make you a better controls engineer. \n\nI'm glad you made it to my channel and I hope you find it useful.\n\nShoot me a message at controlsystemlectures@gmail.com, leave a comment or question and I'll get back to you if I can. Don't forget to subscribe!\n \nTwitter: @BrianBDouglas for engineering tweets and announcement of new videos.\nWebpage: http://engineeringmedia.com\n\nHere is the hardware/software I use: http://www.youtube.com/watch?v=m-M5_mIyHe4\n\nHere's a list of my favorite references: http://bit.ly/2skvmWd\n\n--Brian",
255
- "og:url": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
256
- "al:android:app_name": "YouTube",
257
- "twitter:app:name:iphone": "YouTube"
258
- }
259
- ],
260
- "cse_image": [
261
- {
262
- "src": "https://yt3.googleusercontent.com/ytc/AIdro_nLo68wetImbwGUYP3stve_iKmAEccjhqB-q4o79xdInN4=s900-c-k-c0x00ffffff-no-rj"
263
- }
264
- ]
265
- }
266
- },
267
- {
268
- "kind": "customsearch#result",
269
- "title": "Lecture - Wikipedia",
270
- "htmlTitle": "\u003cb\u003eLecture\u003c/b\u003e - Wikipedia",
271
- "link": "https://en.wikipedia.org/wiki/Lecture",
272
- "displayLink": "en.wikipedia.org",
273
- "snippet": "Lecture ... For the academic rank, see Lecturer. A lecture (from Latin: lēctūra 'reading') is an oral presentation intended to present information or teach people ...",
274
- "htmlSnippet": "\u003cb\u003eLecture\u003c/b\u003e ... For the academic rank, see \u003cb\u003eLecturer\u003c/b\u003e. A \u003cb\u003electure\u003c/b\u003e (from Latin: lēctūra &#39;reading&#39;) is an oral presentation intended to present information or teach people&nbsp;...",
275
- "cacheId": "d9Pjta02fmgJ",
276
- "formattedUrl": "https://en.wikipedia.org/wiki/Lecture",
277
- "htmlFormattedUrl": "https://en.wikipedia.org/wiki/Lecture",
278
- "pagemap": {
279
- "metatags": [
280
- {
281
- "referrer": "origin",
282
- "og:image": "https://upload.wikimedia.org/wikipedia/commons/thumb/2/26/ADFA_Lecture_Theatres.jpg/1200px-ADFA_Lecture_Theatres.jpg",
283
- "theme-color": "#eaecf0",
284
- "og:image:width": "1200",
285
- "og:type": "website",
286
- "viewport": "width=device-width, initial-scale=1.0, user-scalable=yes, minimum-scale=0.25, maximum-scale=5.0",
287
- "og:title": "Lecture - Wikipedia",
288
- "og:image:height": "799",
289
- "format-detection": "telephone=no"
290
- }
291
- ]
292
- }
293
- },
294
- {
295
- "kind": "customsearch#result",
296
- "title": "Mount Wilson Observatory | Lectures",
297
- "htmlTitle": "Mount Wilson Observatory | \u003cb\u003eLectures\u003c/b\u003e",
298
- "link": "https://www.mtwilson.edu/lectures/",
299
- "displayLink": "www.mtwilson.edu",
300
- "snippet": "Talks & Telescopes: August 24, 2024 – Panel: The Triumph of Hubble ... Compelling talks followed by picnicking and convivial stargazing through both the big ...",
301
- "htmlSnippet": "Talks &amp; Telescopes: August 24, 2024 – Panel: The Triumph of Hubble ... Compelling talks followed by picnicking and convivial stargazing through both the big&nbsp;...",
302
- "cacheId": "wdXI0azqx5UJ",
303
- "formattedUrl": "https://www.mtwilson.edu/lectures/",
304
- "htmlFormattedUrl": "https://www.mtwilson.edu/\u003cb\u003electures\u003c/b\u003e/",
305
- "pagemap": {
306
- "metatags": [
307
- {
308
- "viewport": "width=device-width,initial-scale=1,user-scalable=no"
309
- }
310
- ],
311
- "webpage": [
312
- {
313
- "image": "http://www.mtwilson.edu/wp-content/uploads/2016/09/Logo.jpg",
314
- "url": "https://www.facebook.com/WilsonObs"
315
- }
316
- ]
317
- }
318
- },
319
- {
320
- "kind": "customsearch#result",
321
- "title": "Lectures | NBER",
322
- "htmlTitle": "\u003cb\u003eLectures\u003c/b\u003e | NBER",
323
- "link": "https://www.nber.org/research/lectures",
324
- "displayLink": "www.nber.org",
325
- "snippet": "Results 1 - 50 of 354 ... Among featured events at the NBER Summer Institute are the Martin Feldstein Lecture, which examines a current issue involving economic ...",
326
- "htmlSnippet": "Results 1 - 50 of 354 \u003cb\u003e...\u003c/b\u003e Among featured events at the NBER Summer Institute are the Martin Feldstein \u003cb\u003eLecture\u003c/b\u003e, which examines a current issue involving economic&nbsp;...",
327
- "cacheId": "CvvP3U3nb44J",
328
- "formattedUrl": "https://www.nber.org/research/lectures",
329
- "htmlFormattedUrl": "https://www.nber.org/research/\u003cb\u003electures\u003c/b\u003e",
330
- "pagemap": {
331
- "cse_thumbnail": [
332
- {
333
- "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTmeViEZyV1YmFEFLhcA6WdgAG3v3RV6tB93ncyxSJ5JPst_p2aWrL7D1k&s",
334
- "width": "310",
335
- "height": "163"
336
- }
337
- ],
338
- "metatags": [
339
- {
340
- "og:image": "https://www.nber.org/sites/default/files/2022-06/NBER-FB-Share-Tile-1200.jpg",
341
- "og:site_name": "NBER",
342
- "handheldfriendly": "true",
343
- "viewport": "width=device-width, initial-scale=1.0",
344
- "og:title": "Lectures",
345
- "mobileoptimized": "width",
346
- "og:url": "https://www.nber.org/research/lectures"
347
- }
348
- ],
349
- "cse_image": [
350
- {
351
- "src": "https://www.nber.org/sites/default/files/2022-06/NBER-FB-Share-Tile-1200.jpg"
352
- }
353
- ]
354
- }
355
- },
356
- {
357
- "kind": "customsearch#result",
358
- "title": "STUDENTS CANNOT ACCESS RECORDED LECTURES ... - Solved",
359
- "htmlTitle": "STUDENTS CANNOT ACCESS RECORDED LECTURES ... - Solved",
360
- "link": "https://community.canvaslms.com/t5/Canvas-Question-Forum/STUDENTS-CANNOT-ACCESS-RECORDED-LECTURES/td-p/190358",
361
- "displayLink": "community.canvaslms.com",
362
- "snippet": "Mar 19, 2020 ... I believe the issue is that students were not invited. Are you trying to capture your screen? If not, there is an option to just record your web ...",
363
- "htmlSnippet": "Mar 19, 2020 \u003cb\u003e...\u003c/b\u003e I believe the issue is that students were not invited. Are you trying to capture your screen? If not, there is an option to just record your web&nbsp;...",
364
- "cacheId": "wqrynQXX61sJ",
365
- "formattedUrl": "https://community.canvaslms.com/t5/Canvas...LECTURES/td-p/190358",
366
- "htmlFormattedUrl": "https://community.canvaslms.com/t5/Canvas...\u003cb\u003eLECTURES\u003c/b\u003e/td-p/190358",
367
- "pagemap": {
368
- "cse_thumbnail": [
369
- {
370
- "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRUqXau3N8LfKgSD7OJOvV7xzGarLKRU-ckWXy1ZQ1p4CLPsedvLKmLMhk&s",
371
- "width": "310",
372
- "height": "163"
373
- }
374
- ],
375
- "metatags": [
376
- {
377
- "og:image": "https://community.canvaslms.com/html/@6A1FDD4D5FF35E4BBB4083A1022FA0DB/assets/CommunityPreview23.png",
378
- "og:type": "article",
379
- "article:section": "Canvas Question Forum",
380
- "article:published_time": "2020-03-19T15:50:03.409Z",
381
- "og:site_name": "Instructure Community",
382
- "article:modified_time": "2020-03-19T13:55:53-07:00",
383
- "viewport": "width=device-width, initial-scale=1.0, user-scalable=yes",
384
- "og:title": "STUDENTS CANNOT ACCESS RECORDED LECTURES",
385
- "og:url": "https://community.canvaslms.com/t5/Canvas-Question-Forum/STUDENTS-CANNOT-ACCESS-RECORDED-LECTURES/m-p/190358#M93667",
386
- "og:description": "I can access and see my recorded lectures but my students can't. They have an error message when they try to open the recorded presentation or notes.",
387
- "article:author": "https://community.canvaslms.com/t5/user/viewprofilepage/user-id/794287",
388
- "twitter:image": "https://community.canvaslms.com/html/@6A1FDD4D5FF35E4BBB4083A1022FA0DB/assets/CommunityPreview23.png"
389
- }
390
- ],
391
- "cse_image": [
392
- {
393
- "src": "https://community.canvaslms.com/html/@6A1FDD4D5FF35E4BBB4083A1022FA0DB/assets/CommunityPreview23.png"
394
- }
395
- ]
396
- }
397
- },
398
- {
399
- "kind": "customsearch#result",
400
- "title": "Public Lecture Series - Sam Fox School of Design & Visual Arts",
401
- "htmlTitle": "Public \u003cb\u003eLecture\u003c/b\u003e Series - Sam Fox School of Design &amp; Visual Arts",
402
- "link": "https://samfoxschool.wustl.edu/calendar/series/2-public-lecture-series",
403
- "displayLink": "samfoxschool.wustl.edu",
404
- "snippet": "The Sam Fox School's Spring 2024 Public Lecture Series highlights design and art as catalysts for change. Renowned speakers will delve into themes like ...",
405
- "htmlSnippet": "The Sam Fox School&#39;s Spring 2024 Public \u003cb\u003eLecture\u003c/b\u003e Series highlights design and art as catalysts for change. Renowned speakers will delve into themes like&nbsp;...",
406
- "cacheId": "B-cgQG0j6tUJ",
407
- "formattedUrl": "https://samfoxschool.wustl.edu/calendar/series/2-public-lecture-series",
408
- "htmlFormattedUrl": "https://samfoxschool.wustl.edu/calendar/series/2-public-lecture-series",
409
- "pagemap": {
410
- "cse_thumbnail": [
411
- {
412
- "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQSmHaGianm-64m-qauYjkPK_Q0JKWe-7yom4m1ogFYTmpWArA7k6dmk0sR&s",
413
- "width": "307",
414
- "height": "164"
415
- }
416
- ],
417
- "website": [
418
- {
419
- "name": "Public Lecture Series - Sam Fox School of Design & Visual Arts — Washington University in St. Louis"
420
- }
421
- ],
422
- "metatags": [
423
- {
424
- "og:image": "https://dvsp0hlm0xrn3.cloudfront.net/assets/default_og_image-44e73dee4b9d1e2c6a6295901371270c8ec5899eaed48ee8167a9b12f1b0f8b3.jpg",
425
- "og:type": "website",
426
- "og:site_name": "Sam Fox School of Design & Visual Arts — Washington University in St. Louis",
427
- "viewport": "width=device-width, initial-scale=1.0",
428
- "og:title": "Public Lecture Series - Sam Fox School of Design & Visual Arts — Washington University in St. Louis",
429
- "csrf-token": "jBQsfZGY3RH8NVs0-KVDBYB-2N2kib4UYZHYdrShfTdLkvzfSvGeOaMrRKTRdYBPRKzdcGIuP7zwm9etqX_uvg",
430
- "csrf-param": "authenticity_token",
431
- "og:description": "The Sam Fox School's Spring 2024 Public Lecture Series highlights design and art as catalysts for change. Renowned speakers will delve into themes like social equity, resilient cities, and the impact of emerging technologies on contemporary life. Speakers include artists, architects, designers, and critics of the highest caliber, widely recognized for their research-based practices and multidisciplinary approaches to their fields."
432
- }
433
- ],
434
- "cse_image": [
435
- {
436
- "src": "https://dvsp0hlm0xrn3.cloudfront.net/assets/default_og_image-44e73dee4b9d1e2c6a6295901371270c8ec5899eaed48ee8167a9b12f1b0f8b3.jpg"
437
- }
438
- ]
439
- }
440
- }
441
- ]
442
- }
 
1
+ {
2
+ "kind": "customsearch#search",
3
+ "url": {
4
+ "type": "application/json",
5
+ "template": "https://www.googleapis.com/customsearch/v1?q={searchTerms}&num={count?}&start={startIndex?}&lr={language?}&safe={safe?}&cx={cx?}&sort={sort?}&filter={filter?}&gl={gl?}&cr={cr?}&googlehost={googleHost?}&c2coff={disableCnTwTranslation?}&hq={hq?}&hl={hl?}&siteSearch={siteSearch?}&siteSearchFilter={siteSearchFilter?}&exactTerms={exactTerms?}&excludeTerms={excludeTerms?}&linkSite={linkSite?}&orTerms={orTerms?}&dateRestrict={dateRestrict?}&lowRange={lowRange?}&highRange={highRange?}&searchType={searchType}&fileType={fileType?}&rights={rights?}&imgSize={imgSize?}&imgType={imgType?}&imgColorType={imgColorType?}&imgDominantColor={imgDominantColor?}&alt=json"
6
+ },
7
+ "queries": {
8
+ "request": [
9
+ {
10
+ "title": "Google Custom Search - lectures",
11
+ "totalResults": "2450000000",
12
+ "searchTerms": "lectures",
13
+ "count": 10,
14
+ "startIndex": 1,
15
+ "inputEncoding": "utf8",
16
+ "outputEncoding": "utf8",
17
+ "safe": "off",
18
+ "cx": "0473ef98502d44e18"
19
+ }
20
+ ],
21
+ "nextPage": [
22
+ {
23
+ "title": "Google Custom Search - lectures",
24
+ "totalResults": "2450000000",
25
+ "searchTerms": "lectures",
26
+ "count": 10,
27
+ "startIndex": 11,
28
+ "inputEncoding": "utf8",
29
+ "outputEncoding": "utf8",
30
+ "safe": "off",
31
+ "cx": "0473ef98502d44e18"
32
+ }
33
+ ]
34
+ },
35
+ "context": {
36
+ "title": "LLM Search"
37
+ },
38
+ "searchInformation": {
39
+ "searchTime": 0.445959,
40
+ "formattedSearchTime": "0.45",
41
+ "totalResults": "2450000000",
42
+ "formattedTotalResults": "2,450,000,000"
43
+ },
44
+ "items": [
45
+ {
46
+ "kind": "customsearch#result",
47
+ "title": "The Feynman Lectures on Physics",
48
+ "htmlTitle": "The Feynman \u003cb\u003eLectures\u003c/b\u003e on Physics",
49
+ "link": "https://www.feynmanlectures.caltech.edu/",
50
+ "displayLink": "www.feynmanlectures.caltech.edu",
51
+ "snippet": "This edition has been designed for ease of reading on devices of any size or shape; text, figures and equations can all be zoomed without degradation.",
52
+ "htmlSnippet": "This edition has been designed for ease of reading on devices of any size or shape; text, figures and equations can all be zoomed without degradation.",
53
+ "cacheId": "CyXMWYWs9UEJ",
54
+ "formattedUrl": "https://www.feynmanlectures.caltech.edu/",
55
+ "htmlFormattedUrl": "https://www.feynman\u003cb\u003electures\u003c/b\u003e.caltech.edu/",
56
+ "pagemap": {
57
+ "metatags": [
58
+ {
59
+ "viewport": "width=device-width, initial-scale=1.0"
60
+ }
61
+ ]
62
+ }
63
+ },
64
+ {
65
+ "kind": "customsearch#result",
66
+ "title": "Video Lectures",
67
+ "htmlTitle": "Video \u003cb\u003eLectures\u003c/b\u003e",
68
+ "link": "https://www.reddit.com/r/lectures/",
69
+ "displayLink": "www.reddit.com",
70
+ "snippet": "r/lectures: This subreddit is all about video lectures, talks and interesting public speeches. The topics include mathematics, physics, computer…",
71
+ "htmlSnippet": "r/\u003cb\u003electures\u003c/b\u003e: This subreddit is all about video \u003cb\u003electures\u003c/b\u003e, talks and interesting public speeches. The topics include mathematics, physics, computer…",
72
+ "formattedUrl": "https://www.reddit.com/r/lectures/",
73
+ "htmlFormattedUrl": "https://www.reddit.com/r/\u003cb\u003electures\u003c/b\u003e/",
74
+ "pagemap": {
75
+ "cse_thumbnail": [
76
+ {
77
+ "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTZtOjhfkgUKQbL3DZxe5F6OVsgeDNffleObjJ7n9RllKQTSsimax7VIaY&s",
78
+ "width": "192",
79
+ "height": "192"
80
+ }
81
+ ],
82
+ "metatags": [
83
+ {
84
+ "og:image": "https://www.redditstatic.com/shreddit/assets/favicon/192x192.png",
85
+ "theme-color": "#000000",
86
+ "og:image:width": "256",
87
+ "og:type": "website",
88
+ "twitter:card": "summary",
89
+ "twitter:title": "r/lectures",
90
+ "og:site_name": "Reddit",
91
+ "og:title": "r/lectures",
92
+ "og:image:height": "256",
93
+ "bingbot": "noarchive",
94
+ "msapplication-navbutton-color": "#000000",
95
+ "og:description": "This subreddit is all about video lectures, talks and interesting public speeches.\n\nThe topics include mathematics, physics, computer science, programming, engineering, biology, medicine, economics, politics, social sciences, and any other subjects!",
96
+ "twitter:image": "https://www.redditstatic.com/shreddit/assets/favicon/192x192.png",
97
+ "apple-mobile-web-app-status-bar-style": "black",
98
+ "twitter:site": "@reddit",
99
+ "viewport": "width=device-width, initial-scale=1, viewport-fit=cover",
100
+ "apple-mobile-web-app-capable": "yes",
101
+ "og:ttl": "600",
102
+ "og:url": "https://www.reddit.com/r/lectures/"
103
+ }
104
+ ],
105
+ "cse_image": [
106
+ {
107
+ "src": "https://www.redditstatic.com/shreddit/assets/favicon/192x192.png"
108
+ }
109
+ ]
110
+ }
111
+ },
112
+ {
113
+ "kind": "customsearch#result",
114
+ "title": "Lectures & Discussions | Flint Institute of Arts",
115
+ "htmlTitle": "\u003cb\u003eLectures\u003c/b\u003e &amp; Discussions | Flint Institute of Arts",
116
+ "link": "https://flintarts.org/events/lectures",
117
+ "displayLink": "flintarts.org",
118
+ "snippet": "It will trace the intricate relationship between jewelry, attire, and the expression of personal identity, social hierarchy, and spiritual belief systems that ...",
119
+ "htmlSnippet": "It will trace the intricate relationship between jewelry, attire, and the expression of personal identity, social hierarchy, and spiritual belief systems that&nbsp;...",
120
+ "cacheId": "jvpb9DxrfxoJ",
121
+ "formattedUrl": "https://flintarts.org/events/lectures",
122
+ "htmlFormattedUrl": "https://flintarts.org/events/\u003cb\u003electures\u003c/b\u003e",
123
+ "pagemap": {
124
+ "cse_thumbnail": [
125
+ {
126
+ "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcS23tMtAeNhJbOWdGxShYsmnyzFdzOC9Hb7lRykA9Pw72z1IlKTkjTdZw&s",
127
+ "width": "447",
128
+ "height": "113"
129
+ }
130
+ ],
131
+ "metatags": [
132
+ {
133
+ "og:image": "https://flintarts.org/uploads/images/page-headers/_headerImage/nightshot.jpg",
134
+ "og:type": "website",
135
+ "viewport": "width=device-width, initial-scale=1",
136
+ "og:title": "Lectures & Discussions | Flint Institute of Arts",
137
+ "og:description": "The Flint Institute of Arts is the second largest art museum in Michigan and one of the largest museum art schools in the nation."
138
+ }
139
+ ],
140
+ "cse_image": [
141
+ {
142
+ "src": "https://flintarts.org/uploads/images/page-headers/_headerImage/nightshot.jpg"
143
+ }
144
+ ]
145
+ }
146
+ },
147
+ {
148
+ "kind": "customsearch#result",
149
+ "title": "Mandel Lectures | Mandel Center for the Humanities ... - Waltham",
150
+ "htmlTitle": "Mandel \u003cb\u003eLectures\u003c/b\u003e | Mandel Center for the Humanities ... - Waltham",
151
+ "link": "https://www.brandeis.edu/mandel-center-humanities/mandel-lectures.html",
152
+ "displayLink": "www.brandeis.edu",
153
+ "snippet": "Past Lectures · Lecture 1: \"Invisible Music: The Sonic Idea of Black Revolution From Captivity to Reconstruction\" · Lecture 2: \"Solidarity in Sound: Grassroots ...",
154
+ "htmlSnippet": "Past \u003cb\u003eLectures\u003c/b\u003e &middot; \u003cb\u003eLecture\u003c/b\u003e 1: &quot;Invisible Music: The Sonic Idea of Black Revolution From Captivity to Reconstruction&quot; &middot; \u003cb\u003eLecture\u003c/b\u003e 2: &quot;Solidarity in Sound: Grassroots&nbsp;...",
155
+ "cacheId": "cQLOZr0kgEEJ",
156
+ "formattedUrl": "https://www.brandeis.edu/mandel-center-humanities/mandel-lectures.html",
157
+ "htmlFormattedUrl": "https://www.brandeis.edu/mandel-center-humanities/mandel-\u003cb\u003electures\u003c/b\u003e.html",
158
+ "pagemap": {
159
+ "cse_thumbnail": [
160
+ {
161
+ "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQWlU7bcJ5pIHk7RBCk2QKE-48ejF7hyPV0pr-20_cBt2BGdfKtiYXBuyw&s",
162
+ "width": "275",
163
+ "height": "183"
164
+ }
165
+ ],
166
+ "metatags": [
167
+ {
168
+ "og:image": "https://www.brandeis.edu/mandel-center-humanities/events/events-images/mlhzumba",
169
+ "twitter:card": "summary_large_image",
170
+ "viewport": "width=device-width,initial-scale=1,minimum-scale=1",
171
+ "og:title": "Mandel Lectures in the Humanities",
172
+ "og:url": "https://www.brandeis.edu/mandel-center-humanities/mandel-lectures.html",
173
+ "og:description": "Annual Lecture Series",
174
+ "twitter:image": "https://www.brandeis.edu/mandel-center-humanities/events/events-images/mlhzumba"
175
+ }
176
+ ],
177
+ "cse_image": [
178
+ {
179
+ "src": "https://www.brandeis.edu/mandel-center-humanities/events/events-images/mlhzumba"
180
+ }
181
+ ]
182
+ }
183
+ },
184
+ {
185
+ "kind": "customsearch#result",
186
+ "title": "Brian Douglas - YouTube",
187
+ "htmlTitle": "Brian Douglas - YouTube",
188
+ "link": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
189
+ "displayLink": "www.youtube.com",
190
+ "snippet": "Welcome to Control Systems Lectures! This collection of videos is intended to supplement a first year controls class, not replace it.",
191
+ "htmlSnippet": "Welcome to Control Systems \u003cb\u003eLectures\u003c/b\u003e! This collection of videos is intended to supplement a first year controls class, not replace it.",
192
+ "cacheId": "NEROyBHolL0J",
193
+ "formattedUrl": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
194
+ "htmlFormattedUrl": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
195
+ "pagemap": {
196
+ "hcard": [
197
+ {
198
+ "fn": "Brian Douglas",
199
+ "url": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg"
200
+ }
201
+ ],
202
+ "cse_thumbnail": [
203
+ {
204
+ "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcR7G0CeCBz_wVTZgjnhEr2QbiKP7f3uYzKitZYn74Mi32cDmVxvsegJoLI&s",
205
+ "width": "225",
206
+ "height": "225"
207
+ }
208
+ ],
209
+ "imageobject": [
210
+ {
211
+ "width": "900",
212
+ "url": "https://yt3.googleusercontent.com/ytc/AIdro_nLo68wetImbwGUYP3stve_iKmAEccjhqB-q4o79xdInN4=s900-c-k-c0x00ffffff-no-rj",
213
+ "height": "900"
214
+ }
215
+ ],
216
+ "person": [
217
+ {
218
+ "name": "Brian Douglas",
219
+ "url": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg"
220
+ }
221
+ ],
222
+ "metatags": [
223
+ {
224
+ "apple-itunes-app": "app-id=544007664, app-argument=https://m.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg?referring_app=com.apple.mobilesafari-smartbanner, affiliate-data=ct=smart_app_banner_polymer&pt=9008",
225
+ "og:image": "https://yt3.googleusercontent.com/ytc/AIdro_nLo68wetImbwGUYP3stve_iKmAEccjhqB-q4o79xdInN4=s900-c-k-c0x00ffffff-no-rj",
226
+ "twitter:app:url:iphone": "vnd.youtube://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
227
+ "twitter:app:id:googleplay": "com.google.android.youtube",
228
+ "theme-color": "rgb(255, 255, 255)",
229
+ "og:image:width": "900",
230
+ "twitter:card": "summary",
231
+ "og:site_name": "YouTube",
232
+ "twitter:url": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
233
+ "twitter:app:url:ipad": "vnd.youtube://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
234
+ "al:android:package": "com.google.android.youtube",
235
+ "twitter:app:name:googleplay": "YouTube",
236
+ "al:ios:url": "vnd.youtube://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
237
+ "twitter:app:id:iphone": "544007664",
238
+ "og:description": "Welcome to Control Systems Lectures! This collection of videos is intended to supplement a first year controls class, not replace it. My goal is to take specific concepts in controls and expand on them in order to provide an intuitive understanding which will ultimately make you a better controls engineer. \n\nI'm glad you made it to my channel and I hope you find it useful.\n\nShoot me a message at controlsystemlectures@gmail.com, leave a comment or question and I'll get back to you if I can. Don't forget to subscribe!\n \nTwitter: @BrianBDouglas for engineering tweets and announcement of new videos.\nWebpage: http://engineeringmedia.com\n\nHere is the hardware/software I use: http://www.youtube.com/watch?v=m-M5_mIyHe4\n\nHere's a list of my favorite references: http://bit.ly/2skvmWd\n\n--Brian",
239
+ "al:ios:app_store_id": "544007664",
240
+ "twitter:image": "https://yt3.googleusercontent.com/ytc/AIdro_nLo68wetImbwGUYP3stve_iKmAEccjhqB-q4o79xdInN4=s900-c-k-c0x00ffffff-no-rj",
241
+ "twitter:site": "@youtube",
242
+ "og:type": "profile",
243
+ "twitter:title": "Brian Douglas",
244
+ "al:ios:app_name": "YouTube",
245
+ "og:title": "Brian Douglas",
246
+ "og:image:height": "900",
247
+ "twitter:app:id:ipad": "544007664",
248
+ "al:web:url": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg?feature=applinks",
249
+ "al:android:url": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg?feature=applinks",
250
+ "fb:app_id": "87741124305",
251
+ "twitter:app:url:googleplay": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
252
+ "twitter:app:name:ipad": "YouTube",
253
+ "viewport": "width=device-width, initial-scale=1.0, minimum-scale=1.0, maximum-scale=1.0, user-scalable=no,",
254
+ "twitter:description": "Welcome to Control Systems Lectures! This collection of videos is intended to supplement a first year controls class, not replace it. My goal is to take specific concepts in controls and expand on them in order to provide an intuitive understanding which will ultimately make you a better controls engineer. \n\nI'm glad you made it to my channel and I hope you find it useful.\n\nShoot me a message at controlsystemlectures@gmail.com, leave a comment or question and I'll get back to you if I can. Don't forget to subscribe!\n \nTwitter: @BrianBDouglas for engineering tweets and announcement of new videos.\nWebpage: http://engineeringmedia.com\n\nHere is the hardware/software I use: http://www.youtube.com/watch?v=m-M5_mIyHe4\n\nHere's a list of my favorite references: http://bit.ly/2skvmWd\n\n--Brian",
255
+ "og:url": "https://www.youtube.com/channel/UCq0imsn84ShAe9PBOFnoIrg",
256
+ "al:android:app_name": "YouTube",
257
+ "twitter:app:name:iphone": "YouTube"
258
+ }
259
+ ],
260
+ "cse_image": [
261
+ {
262
+ "src": "https://yt3.googleusercontent.com/ytc/AIdro_nLo68wetImbwGUYP3stve_iKmAEccjhqB-q4o79xdInN4=s900-c-k-c0x00ffffff-no-rj"
263
+ }
264
+ ]
265
+ }
266
+ },
267
+ {
268
+ "kind": "customsearch#result",
269
+ "title": "Lecture - Wikipedia",
270
+ "htmlTitle": "\u003cb\u003eLecture\u003c/b\u003e - Wikipedia",
271
+ "link": "https://en.wikipedia.org/wiki/Lecture",
272
+ "displayLink": "en.wikipedia.org",
273
+ "snippet": "Lecture ... For the academic rank, see Lecturer. A lecture (from Latin: lēctūra 'reading') is an oral presentation intended to present information or teach people ...",
274
+ "htmlSnippet": "\u003cb\u003eLecture\u003c/b\u003e ... For the academic rank, see \u003cb\u003eLecturer\u003c/b\u003e. A \u003cb\u003electure\u003c/b\u003e (from Latin: lēctūra &#39;reading&#39;) is an oral presentation intended to present information or teach people&nbsp;...",
275
+ "cacheId": "d9Pjta02fmgJ",
276
+ "formattedUrl": "https://en.wikipedia.org/wiki/Lecture",
277
+ "htmlFormattedUrl": "https://en.wikipedia.org/wiki/Lecture",
278
+ "pagemap": {
279
+ "metatags": [
280
+ {
281
+ "referrer": "origin",
282
+ "og:image": "https://upload.wikimedia.org/wikipedia/commons/thumb/2/26/ADFA_Lecture_Theatres.jpg/1200px-ADFA_Lecture_Theatres.jpg",
283
+ "theme-color": "#eaecf0",
284
+ "og:image:width": "1200",
285
+ "og:type": "website",
286
+ "viewport": "width=device-width, initial-scale=1.0, user-scalable=yes, minimum-scale=0.25, maximum-scale=5.0",
287
+ "og:title": "Lecture - Wikipedia",
288
+ "og:image:height": "799",
289
+ "format-detection": "telephone=no"
290
+ }
291
+ ]
292
+ }
293
+ },
294
+ {
295
+ "kind": "customsearch#result",
296
+ "title": "Mount Wilson Observatory | Lectures",
297
+ "htmlTitle": "Mount Wilson Observatory | \u003cb\u003eLectures\u003c/b\u003e",
298
+ "link": "https://www.mtwilson.edu/lectures/",
299
+ "displayLink": "www.mtwilson.edu",
300
+ "snippet": "Talks & Telescopes: August 24, 2024 – Panel: The Triumph of Hubble ... Compelling talks followed by picnicking and convivial stargazing through both the big ...",
301
+ "htmlSnippet": "Talks &amp; Telescopes: August 24, 2024 – Panel: The Triumph of Hubble ... Compelling talks followed by picnicking and convivial stargazing through both the big&nbsp;...",
302
+ "cacheId": "wdXI0azqx5UJ",
303
+ "formattedUrl": "https://www.mtwilson.edu/lectures/",
304
+ "htmlFormattedUrl": "https://www.mtwilson.edu/\u003cb\u003electures\u003c/b\u003e/",
305
+ "pagemap": {
306
+ "metatags": [
307
+ {
308
+ "viewport": "width=device-width,initial-scale=1,user-scalable=no"
309
+ }
310
+ ],
311
+ "webpage": [
312
+ {
313
+ "image": "http://www.mtwilson.edu/wp-content/uploads/2016/09/Logo.jpg",
314
+ "url": "https://www.facebook.com/WilsonObs"
315
+ }
316
+ ]
317
+ }
318
+ },
319
+ {
320
+ "kind": "customsearch#result",
321
+ "title": "Lectures | NBER",
322
+ "htmlTitle": "\u003cb\u003eLectures\u003c/b\u003e | NBER",
323
+ "link": "https://www.nber.org/research/lectures",
324
+ "displayLink": "www.nber.org",
325
+ "snippet": "Results 1 - 50 of 354 ... Among featured events at the NBER Summer Institute are the Martin Feldstein Lecture, which examines a current issue involving economic ...",
326
+ "htmlSnippet": "Results 1 - 50 of 354 \u003cb\u003e...\u003c/b\u003e Among featured events at the NBER Summer Institute are the Martin Feldstein \u003cb\u003eLecture\u003c/b\u003e, which examines a current issue involving economic&nbsp;...",
327
+ "cacheId": "CvvP3U3nb44J",
328
+ "formattedUrl": "https://www.nber.org/research/lectures",
329
+ "htmlFormattedUrl": "https://www.nber.org/research/\u003cb\u003electures\u003c/b\u003e",
330
+ "pagemap": {
331
+ "cse_thumbnail": [
332
+ {
333
+ "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTmeViEZyV1YmFEFLhcA6WdgAG3v3RV6tB93ncyxSJ5JPst_p2aWrL7D1k&s",
334
+ "width": "310",
335
+ "height": "163"
336
+ }
337
+ ],
338
+ "metatags": [
339
+ {
340
+ "og:image": "https://www.nber.org/sites/default/files/2022-06/NBER-FB-Share-Tile-1200.jpg",
341
+ "og:site_name": "NBER",
342
+ "handheldfriendly": "true",
343
+ "viewport": "width=device-width, initial-scale=1.0",
344
+ "og:title": "Lectures",
345
+ "mobileoptimized": "width",
346
+ "og:url": "https://www.nber.org/research/lectures"
347
+ }
348
+ ],
349
+ "cse_image": [
350
+ {
351
+ "src": "https://www.nber.org/sites/default/files/2022-06/NBER-FB-Share-Tile-1200.jpg"
352
+ }
353
+ ]
354
+ }
355
+ },
356
+ {
357
+ "kind": "customsearch#result",
358
+ "title": "STUDENTS CANNOT ACCESS RECORDED LECTURES ... - Solved",
359
+ "htmlTitle": "STUDENTS CANNOT ACCESS RECORDED LECTURES ... - Solved",
360
+ "link": "https://community.canvaslms.com/t5/Canvas-Question-Forum/STUDENTS-CANNOT-ACCESS-RECORDED-LECTURES/td-p/190358",
361
+ "displayLink": "community.canvaslms.com",
362
+ "snippet": "Mar 19, 2020 ... I believe the issue is that students were not invited. Are you trying to capture your screen? If not, there is an option to just record your web ...",
363
+ "htmlSnippet": "Mar 19, 2020 \u003cb\u003e...\u003c/b\u003e I believe the issue is that students were not invited. Are you trying to capture your screen? If not, there is an option to just record your web&nbsp;...",
364
+ "cacheId": "wqrynQXX61sJ",
365
+ "formattedUrl": "https://community.canvaslms.com/t5/Canvas...LECTURES/td-p/190358",
366
+ "htmlFormattedUrl": "https://community.canvaslms.com/t5/Canvas...\u003cb\u003eLECTURES\u003c/b\u003e/td-p/190358",
367
+ "pagemap": {
368
+ "cse_thumbnail": [
369
+ {
370
+ "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRUqXau3N8LfKgSD7OJOvV7xzGarLKRU-ckWXy1ZQ1p4CLPsedvLKmLMhk&s",
371
+ "width": "310",
372
+ "height": "163"
373
+ }
374
+ ],
375
+ "metatags": [
376
+ {
377
+ "og:image": "https://community.canvaslms.com/html/@6A1FDD4D5FF35E4BBB4083A1022FA0DB/assets/CommunityPreview23.png",
378
+ "og:type": "article",
379
+ "article:section": "Canvas Question Forum",
380
+ "article:published_time": "2020-03-19T15:50:03.409Z",
381
+ "og:site_name": "Instructure Community",
382
+ "article:modified_time": "2020-03-19T13:55:53-07:00",
383
+ "viewport": "width=device-width, initial-scale=1.0, user-scalable=yes",
384
+ "og:title": "STUDENTS CANNOT ACCESS RECORDED LECTURES",
385
+ "og:url": "https://community.canvaslms.com/t5/Canvas-Question-Forum/STUDENTS-CANNOT-ACCESS-RECORDED-LECTURES/m-p/190358#M93667",
386
+ "og:description": "I can access and see my recorded lectures but my students can't. They have an error message when they try to open the recorded presentation or notes.",
387
+ "article:author": "https://community.canvaslms.com/t5/user/viewprofilepage/user-id/794287",
388
+ "twitter:image": "https://community.canvaslms.com/html/@6A1FDD4D5FF35E4BBB4083A1022FA0DB/assets/CommunityPreview23.png"
389
+ }
390
+ ],
391
+ "cse_image": [
392
+ {
393
+ "src": "https://community.canvaslms.com/html/@6A1FDD4D5FF35E4BBB4083A1022FA0DB/assets/CommunityPreview23.png"
394
+ }
395
+ ]
396
+ }
397
+ },
398
+ {
399
+ "kind": "customsearch#result",
400
+ "title": "Public Lecture Series - Sam Fox School of Design & Visual Arts",
401
+ "htmlTitle": "Public \u003cb\u003eLecture\u003c/b\u003e Series - Sam Fox School of Design &amp; Visual Arts",
402
+ "link": "https://samfoxschool.wustl.edu/calendar/series/2-public-lecture-series",
403
+ "displayLink": "samfoxschool.wustl.edu",
404
+ "snippet": "The Sam Fox School's Spring 2024 Public Lecture Series highlights design and art as catalysts for change. Renowned speakers will delve into themes like ...",
405
+ "htmlSnippet": "The Sam Fox School&#39;s Spring 2024 Public \u003cb\u003eLecture\u003c/b\u003e Series highlights design and art as catalysts for change. Renowned speakers will delve into themes like&nbsp;...",
406
+ "cacheId": "B-cgQG0j6tUJ",
407
+ "formattedUrl": "https://samfoxschool.wustl.edu/calendar/series/2-public-lecture-series",
408
+ "htmlFormattedUrl": "https://samfoxschool.wustl.edu/calendar/series/2-public-lecture-series",
409
+ "pagemap": {
410
+ "cse_thumbnail": [
411
+ {
412
+ "src": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQSmHaGianm-64m-qauYjkPK_Q0JKWe-7yom4m1ogFYTmpWArA7k6dmk0sR&s",
413
+ "width": "307",
414
+ "height": "164"
415
+ }
416
+ ],
417
+ "website": [
418
+ {
419
+ "name": "Public Lecture Series - Sam Fox School of Design & Visual Arts — Washington University in St. Louis"
420
+ }
421
+ ],
422
+ "metatags": [
423
+ {
424
+ "og:image": "https://dvsp0hlm0xrn3.cloudfront.net/assets/default_og_image-44e73dee4b9d1e2c6a6295901371270c8ec5899eaed48ee8167a9b12f1b0f8b3.jpg",
425
+ "og:type": "website",
426
+ "og:site_name": "Sam Fox School of Design & Visual Arts — Washington University in St. Louis",
427
+ "viewport": "width=device-width, initial-scale=1.0",
428
+ "og:title": "Public Lecture Series - Sam Fox School of Design & Visual Arts — Washington University in St. Louis",
429
+ "csrf-token": "jBQsfZGY3RH8NVs0-KVDBYB-2N2kib4UYZHYdrShfTdLkvzfSvGeOaMrRKTRdYBPRKzdcGIuP7zwm9etqX_uvg",
430
+ "csrf-param": "authenticity_token",
431
+ "og:description": "The Sam Fox School's Spring 2024 Public Lecture Series highlights design and art as catalysts for change. Renowned speakers will delve into themes like social equity, resilient cities, and the impact of emerging technologies on contemporary life. Speakers include artists, architects, designers, and critics of the highest caliber, widely recognized for their research-based practices and multidisciplinary approaches to their fields."
432
+ }
433
+ ],
434
+ "cse_image": [
435
+ {
436
+ "src": "https://dvsp0hlm0xrn3.cloudfront.net/assets/default_og_image-44e73dee4b9d1e2c6a6295901371270c8ec5899eaed48ee8167a9b12f1b0f8b3.jpg"
437
+ }
438
+ ]
439
+ }
440
+ }
441
+ ]
442
+ }
backend/open_webui/apps/retrieval/web/testdata/searchapi.json CHANGED
The diff for this file is too large to render. See raw diff