github-actions[bot] commited on
Commit
3b623f5
·
0 Parent(s):

GitHub deploy: 9b6076f726b2bbd0d7d13a3601fe27cb7e4a5db0

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +20 -0
  2. .env.example +13 -0
  3. .eslintignore +13 -0
  4. .eslintrc.cjs +31 -0
  5. .gitattributes +3 -0
  6. .github/FUNDING.yml +1 -0
  7. .github/ISSUE_TEMPLATE/bug_report.md +80 -0
  8. .github/ISSUE_TEMPLATE/feature_request.md +35 -0
  9. .github/dependabot.yml +12 -0
  10. .github/pull_request_template.md +72 -0
  11. .github/workflows/build-release.yml +72 -0
  12. .github/workflows/deploy-to-hf-spaces.yml +63 -0
  13. .github/workflows/docker-build.yaml +477 -0
  14. .github/workflows/format-backend.yaml +39 -0
  15. .github/workflows/format-build-frontend.yaml +57 -0
  16. .github/workflows/integration-test.yml +253 -0
  17. .github/workflows/lint-backend.disabled +27 -0
  18. .github/workflows/lint-frontend.disabled +21 -0
  19. .github/workflows/release-pypi.yml +32 -0
  20. .gitignore +309 -0
  21. .npmrc +1 -0
  22. .prettierignore +316 -0
  23. .prettierrc +9 -0
  24. CHANGELOG.md +0 -0
  25. CODE_OF_CONDUCT.md +99 -0
  26. Caddyfile.localhost +64 -0
  27. Dockerfile +176 -0
  28. INSTALLATION.md +35 -0
  29. LICENSE +21 -0
  30. Makefile +33 -0
  31. README.md +221 -0
  32. TROUBLESHOOTING.md +36 -0
  33. backend/.dockerignore +14 -0
  34. backend/.gitignore +12 -0
  35. backend/dev.sh +2 -0
  36. backend/open_webui/__init__.py +77 -0
  37. backend/open_webui/alembic.ini +114 -0
  38. backend/open_webui/apps/audio/main.py +703 -0
  39. backend/open_webui/apps/images/main.py +609 -0
  40. backend/open_webui/apps/images/utils/comfyui.py +186 -0
  41. backend/open_webui/apps/ollama/main.py +1351 -0
  42. backend/open_webui/apps/openai/main.py +719 -0
  43. backend/open_webui/apps/retrieval/loaders/main.py +190 -0
  44. backend/open_webui/apps/retrieval/loaders/youtube.py +117 -0
  45. backend/open_webui/apps/retrieval/main.py +1494 -0
  46. backend/open_webui/apps/retrieval/models/colbert.py +81 -0
  47. backend/open_webui/apps/retrieval/utils.py +532 -0
  48. backend/open_webui/apps/retrieval/vector/connector.py +22 -0
  49. backend/open_webui/apps/retrieval/vector/dbs/chroma.py +174 -0
  50. backend/open_webui/apps/retrieval/vector/dbs/milvus.py +286 -0
.dockerignore ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .github
2
+ .DS_Store
3
+ docs
4
+ kubernetes
5
+ node_modules
6
+ /.svelte-kit
7
+ /package
8
+ .env
9
+ .env.*
10
+ vite.config.js.timestamp-*
11
+ vite.config.ts.timestamp-*
12
+ __pycache__
13
+ .idea
14
+ venv
15
+ _old
16
+ uploads
17
+ .ipynb_checkpoints
18
+ **/*.db
19
+ _test
20
+ backend/data/*
.env.example ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ollama URL for the backend to connect
2
+ # The path '/ollama' will be redirected to the specified backend URL
3
+ OLLAMA_BASE_URL='http://localhost:11434'
4
+
5
+ OPENAI_API_BASE_URL=''
6
+ OPENAI_API_KEY=''
7
+
8
+ # AUTOMATIC1111_BASE_URL="http://localhost:7860"
9
+
10
+ # DO NOT TRACK
11
+ SCARF_NO_ANALYTICS=true
12
+ DO_NOT_TRACK=true
13
+ ANONYMIZED_TELEMETRY=false
.eslintignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ node_modules
3
+ /build
4
+ /.svelte-kit
5
+ /package
6
+ .env
7
+ .env.*
8
+ !.env.example
9
+
10
+ # Ignore files for PNPM, NPM and YARN
11
+ pnpm-lock.yaml
12
+ package-lock.json
13
+ yarn.lock
.eslintrc.cjs ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ module.exports = {
2
+ root: true,
3
+ extends: [
4
+ 'eslint:recommended',
5
+ 'plugin:@typescript-eslint/recommended',
6
+ 'plugin:svelte/recommended',
7
+ 'plugin:cypress/recommended',
8
+ 'prettier'
9
+ ],
10
+ parser: '@typescript-eslint/parser',
11
+ plugins: ['@typescript-eslint'],
12
+ parserOptions: {
13
+ sourceType: 'module',
14
+ ecmaVersion: 2020,
15
+ extraFileExtensions: ['.svelte']
16
+ },
17
+ env: {
18
+ browser: true,
19
+ es2017: true,
20
+ node: true
21
+ },
22
+ overrides: [
23
+ {
24
+ files: ['*.svelte'],
25
+ parser: 'svelte-eslint-parser',
26
+ parserOptions: {
27
+ parser: '@typescript-eslint/parser'
28
+ }
29
+ }
30
+ ]
31
+ };
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.sh text eol=lf
2
+ *.ttf filter=lfs diff=lfs merge=lfs -text
3
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.github/FUNDING.yml ADDED
@@ -0,0 +1 @@
 
 
1
+ github: tjbck
.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+ ---
8
+
9
+ # Bug Report
10
+
11
+ ## Important Notes
12
+
13
+ - **Before submitting a bug report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If you’re unsure, start a discussion post first. This will help us efficiently focus on improving the project.
14
+
15
+ - **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We’re here to help if you’re open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
16
+
17
+ - **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
18
+
19
+ - **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
20
+
21
+ Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
22
+
23
+ ---
24
+
25
+ ## Installation Method
26
+
27
+ [Describe the method you used to install the project, e.g., git clone, Docker, pip, etc.]
28
+
29
+ ## Environment
30
+
31
+ - **Open WebUI Version:** [e.g., v0.3.11]
32
+ - **Ollama (if applicable):** [e.g., v0.2.0, v0.1.32-rc1]
33
+
34
+ - **Operating System:** [e.g., Windows 10, macOS Big Sur, Ubuntu 20.04]
35
+ - **Browser (if applicable):** [e.g., Chrome 100.0, Firefox 98.0]
36
+
37
+ **Confirmation:**
38
+
39
+ - [ ] I have read and followed all the instructions provided in the README.md.
40
+ - [ ] I am on the latest version of both Open WebUI and Ollama.
41
+ - [ ] I have included the browser console logs.
42
+ - [ ] I have included the Docker container logs.
43
+ - [ ] I have provided the exact steps to reproduce the bug in the "Steps to Reproduce" section below.
44
+
45
+ ## Expected Behavior:
46
+
47
+ [Describe what you expected to happen.]
48
+
49
+ ## Actual Behavior:
50
+
51
+ [Describe what actually happened.]
52
+
53
+ ## Description
54
+
55
+ **Bug Summary:**
56
+ [Provide a brief but clear summary of the bug]
57
+
58
+ ## Reproduction Details
59
+
60
+ **Steps to Reproduce:**
61
+ [Outline the steps to reproduce the bug. Be as detailed as possible.]
62
+
63
+ ## Logs and Screenshots
64
+
65
+ **Browser Console Logs:**
66
+ [Include relevant browser console logs, if applicable]
67
+
68
+ **Docker Container Logs:**
69
+ [Include relevant Docker container logs, if applicable]
70
+
71
+ **Screenshots/Screen Recordings (if applicable):**
72
+ [Attach any relevant screenshots to help illustrate the issue]
73
+
74
+ ## Additional Information
75
+
76
+ [Include any additional details that may help in understanding and reproducing the issue. This could include specific configurations, error messages, or anything else relevant to the bug.]
77
+
78
+ ## Note
79
+
80
+ If the bug report is incomplete or does not follow the provided instructions, it may not be addressed. Please ensure that you have followed the steps outlined in the README.md and troubleshooting.md documents, and provide all necessary information for us to reproduce and address the issue. Thank you!
.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Feature request
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+ ---
8
+
9
+ # Feature Request
10
+
11
+ ## Important Notes
12
+
13
+ - **Before submitting a report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If you’re unsure, start a discussion post first. This will help us efficiently focus on improving the project.
14
+
15
+ - **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We’re here to help if you’re open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
16
+
17
+ - **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
18
+
19
+ - **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
20
+
21
+ Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
22
+
23
+ ---
24
+
25
+ **Is your feature request related to a problem? Please describe.**
26
+ A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
27
+
28
+ **Describe the solution you'd like**
29
+ A clear and concise description of what you want to happen.
30
+
31
+ **Describe alternatives you've considered**
32
+ A clear and concise description of any alternative solutions or features you've considered.
33
+
34
+ **Additional context**
35
+ Add any other context or screenshots about the feature request here.
.github/dependabot.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 2
2
+ updates:
3
+ - package-ecosystem: pip
4
+ directory: '/backend'
5
+ schedule:
6
+ interval: monthly
7
+ target-branch: 'dev'
8
+ - package-ecosystem: 'github-actions'
9
+ directory: '/'
10
+ schedule:
11
+ # Check for updates to GitHub Actions every week
12
+ interval: monthly
.github/pull_request_template.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pull Request Checklist
2
+
3
+ ### Note to first-time contributors: Please open a discussion post in [Discussions](https://github.com/open-webui/open-webui/discussions) and describe your changes before submitting a pull request.
4
+
5
+ **Before submitting, make sure you've checked the following:**
6
+
7
+ - [ ] **Target branch:** Please verify that the pull request targets the `dev` branch.
8
+ - [ ] **Description:** Provide a concise description of the changes made in this pull request.
9
+ - [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
10
+ - [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
11
+ - [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
12
+ - [ ] **Testing:** Have you written and run sufficient tests for validating the changes?
13
+ - [ ] **Code review:** Have you performed a self-review of your code, addressing any coding standard issues and ensuring adherence to the project's coding standards?
14
+ - [ ] **Prefix:** To cleary categorize this pull request, prefix the pull request title, using one of the following:
15
+ - **BREAKING CHANGE**: Significant changes that may affect compatibility
16
+ - **build**: Changes that affect the build system or external dependencies
17
+ - **ci**: Changes to our continuous integration processes or workflows
18
+ - **chore**: Refactor, cleanup, or other non-functional code changes
19
+ - **docs**: Documentation update or addition
20
+ - **feat**: Introduces a new feature or enhancement to the codebase
21
+ - **fix**: Bug fix or error correction
22
+ - **i18n**: Internationalization or localization changes
23
+ - **perf**: Performance improvement
24
+ - **refactor**: Code restructuring for better maintainability, readability, or scalability
25
+ - **style**: Changes that do not affect the meaning of the code (white-space, formatting, missing semi-colons, etc.)
26
+ - **test**: Adding missing tests or correcting existing tests
27
+ - **WIP**: Work in progress, a temporary label for incomplete or ongoing work
28
+
29
+ # Changelog Entry
30
+
31
+ ### Description
32
+
33
+ - [Concisely describe the changes made in this pull request, including any relevant motivation and impact (e.g., fixing a bug, adding a feature, or improving performance)]
34
+
35
+ ### Added
36
+
37
+ - [List any new features, functionalities, or additions]
38
+
39
+ ### Changed
40
+
41
+ - [List any changes, updates, refactorings, or optimizations]
42
+
43
+ ### Deprecated
44
+
45
+ - [List any deprecated functionality or features that have been removed]
46
+
47
+ ### Removed
48
+
49
+ - [List any removed features, files, or functionalities]
50
+
51
+ ### Fixed
52
+
53
+ - [List any fixes, corrections, or bug fixes]
54
+
55
+ ### Security
56
+
57
+ - [List any new or updated security-related changes, including vulnerability fixes]
58
+
59
+ ### Breaking Changes
60
+
61
+ - **BREAKING CHANGE**: [List any breaking changes affecting compatibility or functionality]
62
+
63
+ ---
64
+
65
+ ### Additional Information
66
+
67
+ - [Insert any additional context, notes, or explanations for the changes]
68
+ - [Reference any related issues, commits, or other relevant information]
69
+
70
+ ### Screenshots or Videos
71
+
72
+ - [Attach any relevant screenshots or videos demonstrating the changes]
.github/workflows/build-release.yml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Release
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main # or whatever branch you want to use
7
+
8
+ jobs:
9
+ release:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout repository
14
+ uses: actions/checkout@v4
15
+
16
+ - name: Check for changes in package.json
17
+ run: |
18
+ git diff --cached --diff-filter=d package.json || {
19
+ echo "No changes to package.json"
20
+ exit 1
21
+ }
22
+
23
+ - name: Get version number from package.json
24
+ id: get_version
25
+ run: |
26
+ VERSION=$(jq -r '.version' package.json)
27
+ echo "::set-output name=version::$VERSION"
28
+
29
+ - name: Extract latest CHANGELOG entry
30
+ id: changelog
31
+ run: |
32
+ CHANGELOG_CONTENT=$(awk 'BEGIN {print_section=0;} /^## \[/ {if (print_section == 0) {print_section=1;} else {exit;}} print_section {print;}' CHANGELOG.md)
33
+ CHANGELOG_ESCAPED=$(echo "$CHANGELOG_CONTENT" | sed ':a;N;$!ba;s/\n/%0A/g')
34
+ echo "Extracted latest release notes from CHANGELOG.md:"
35
+ echo -e "$CHANGELOG_CONTENT"
36
+ echo "::set-output name=content::$CHANGELOG_ESCAPED"
37
+
38
+ - name: Create GitHub release
39
+ uses: actions/github-script@v7
40
+ with:
41
+ github-token: ${{ secrets.GITHUB_TOKEN }}
42
+ script: |
43
+ const changelog = `${{ steps.changelog.outputs.content }}`;
44
+ const release = await github.rest.repos.createRelease({
45
+ owner: context.repo.owner,
46
+ repo: context.repo.repo,
47
+ tag_name: `v${{ steps.get_version.outputs.version }}`,
48
+ name: `v${{ steps.get_version.outputs.version }}`,
49
+ body: changelog,
50
+ })
51
+ console.log(`Created release ${release.data.html_url}`)
52
+
53
+ - name: Upload package to GitHub release
54
+ uses: actions/upload-artifact@v4
55
+ with:
56
+ name: package
57
+ path: |
58
+ .
59
+ !.git
60
+ env:
61
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
62
+
63
+ - name: Trigger Docker build workflow
64
+ uses: actions/github-script@v7
65
+ with:
66
+ script: |
67
+ github.rest.actions.createWorkflowDispatch({
68
+ owner: context.repo.owner,
69
+ repo: context.repo.repo,
70
+ workflow_id: 'docker-build.yaml',
71
+ ref: 'v${{ steps.get_version.outputs.version }}',
72
+ })
.github/workflows/deploy-to-hf-spaces.yml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Deploy to HuggingFace Spaces
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - dev
7
+ - main
8
+ workflow_dispatch:
9
+
10
+ jobs:
11
+ check-secret:
12
+ runs-on: ubuntu-latest
13
+ outputs:
14
+ token-set: ${{ steps.check-key.outputs.defined }}
15
+ steps:
16
+ - id: check-key
17
+ env:
18
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
19
+ if: "${{ env.HF_TOKEN != '' }}"
20
+ run: echo "defined=true" >> $GITHUB_OUTPUT
21
+
22
+ deploy:
23
+ runs-on: ubuntu-latest
24
+ needs: [check-secret]
25
+ if: needs.check-secret.outputs.token-set == 'true'
26
+ env:
27
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
28
+ steps:
29
+ - name: Checkout repository
30
+ uses: actions/checkout@v4
31
+ with:
32
+ lfs: true
33
+
34
+ - name: Remove git history
35
+ run: rm -rf .git
36
+
37
+ - name: Prepend YAML front matter to README.md
38
+ run: |
39
+ echo "---" > temp_readme.md
40
+ echo "title: Open WebUI" >> temp_readme.md
41
+ echo "emoji: 🐳" >> temp_readme.md
42
+ echo "colorFrom: purple" >> temp_readme.md
43
+ echo "colorTo: gray" >> temp_readme.md
44
+ echo "sdk: docker" >> temp_readme.md
45
+ echo "app_port: 8080" >> temp_readme.md
46
+ echo "---" >> temp_readme.md
47
+ cat README.md >> temp_readme.md
48
+ mv temp_readme.md README.md
49
+
50
+ - name: Configure git
51
+ run: |
52
+ git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com"
53
+ git config --global user.name "github-actions[bot]"
54
+ - name: Set up Git and push to Space
55
+ run: |
56
+ git init --initial-branch=main
57
+ git lfs install
58
+ git lfs track "*.ttf"
59
+ git lfs track "*.jpg"
60
+ rm demo.gif
61
+ git add .
62
+ git commit -m "GitHub deploy: ${{ github.sha }}"
63
+ git push --force https://open-webui:${HF_TOKEN}@huggingface.co/spaces/open-webui/open-webui main
.github/workflows/docker-build.yaml ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Create and publish Docker images with specific build args
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ push:
6
+ branches:
7
+ - main
8
+ - dev
9
+ tags:
10
+ - v*
11
+
12
+ env:
13
+ REGISTRY: ghcr.io
14
+
15
+ jobs:
16
+ build-main-image:
17
+ runs-on: ubuntu-latest
18
+ permissions:
19
+ contents: read
20
+ packages: write
21
+ strategy:
22
+ fail-fast: false
23
+ matrix:
24
+ platform:
25
+ - linux/amd64
26
+ - linux/arm64
27
+
28
+ steps:
29
+ # GitHub Packages requires the entire repository name to be in lowercase
30
+ # although the repository owner has a lowercase username, this prevents some people from running actions after forking
31
+ - name: Set repository and image name to lowercase
32
+ run: |
33
+ echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
34
+ echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
35
+ env:
36
+ IMAGE_NAME: '${{ github.repository }}'
37
+
38
+ - name: Prepare
39
+ run: |
40
+ platform=${{ matrix.platform }}
41
+ echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
42
+
43
+ - name: Checkout repository
44
+ uses: actions/checkout@v4
45
+
46
+ - name: Set up QEMU
47
+ uses: docker/setup-qemu-action@v3
48
+
49
+ - name: Set up Docker Buildx
50
+ uses: docker/setup-buildx-action@v3
51
+
52
+ - name: Log in to the Container registry
53
+ uses: docker/login-action@v3
54
+ with:
55
+ registry: ${{ env.REGISTRY }}
56
+ username: ${{ github.actor }}
57
+ password: ${{ secrets.GITHUB_TOKEN }}
58
+
59
+ - name: Extract metadata for Docker images (default latest tag)
60
+ id: meta
61
+ uses: docker/metadata-action@v5
62
+ with:
63
+ images: ${{ env.FULL_IMAGE_NAME }}
64
+ tags: |
65
+ type=ref,event=branch
66
+ type=ref,event=tag
67
+ type=sha,prefix=git-
68
+ type=semver,pattern={{version}}
69
+ type=semver,pattern={{major}}.{{minor}}
70
+ flavor: |
71
+ latest=${{ github.ref == 'refs/heads/main' }}
72
+
73
+ - name: Extract metadata for Docker cache
74
+ id: cache-meta
75
+ uses: docker/metadata-action@v5
76
+ with:
77
+ images: ${{ env.FULL_IMAGE_NAME }}
78
+ tags: |
79
+ type=ref,event=branch
80
+ ${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
81
+ flavor: |
82
+ prefix=cache-${{ matrix.platform }}-
83
+ latest=false
84
+
85
+ - name: Build Docker image (latest)
86
+ uses: docker/build-push-action@v5
87
+ id: build
88
+ with:
89
+ context: .
90
+ push: true
91
+ platforms: ${{ matrix.platform }}
92
+ labels: ${{ steps.meta.outputs.labels }}
93
+ outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
94
+ cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
95
+ cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
96
+ build-args: |
97
+ BUILD_HASH=${{ github.sha }}
98
+
99
+ - name: Export digest
100
+ run: |
101
+ mkdir -p /tmp/digests
102
+ digest="${{ steps.build.outputs.digest }}"
103
+ touch "/tmp/digests/${digest#sha256:}"
104
+
105
+ - name: Upload digest
106
+ uses: actions/upload-artifact@v4
107
+ with:
108
+ name: digests-main-${{ env.PLATFORM_PAIR }}
109
+ path: /tmp/digests/*
110
+ if-no-files-found: error
111
+ retention-days: 1
112
+
113
+ build-cuda-image:
114
+ runs-on: ubuntu-latest
115
+ permissions:
116
+ contents: read
117
+ packages: write
118
+ strategy:
119
+ fail-fast: false
120
+ matrix:
121
+ platform:
122
+ - linux/amd64
123
+ - linux/arm64
124
+
125
+ steps:
126
+ # GitHub Packages requires the entire repository name to be in lowercase
127
+ # although the repository owner has a lowercase username, this prevents some people from running actions after forking
128
+ - name: Set repository and image name to lowercase
129
+ run: |
130
+ echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
131
+ echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
132
+ env:
133
+ IMAGE_NAME: '${{ github.repository }}'
134
+
135
+ - name: Prepare
136
+ run: |
137
+ platform=${{ matrix.platform }}
138
+ echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
139
+
140
+ - name: Checkout repository
141
+ uses: actions/checkout@v4
142
+
143
+ - name: Set up QEMU
144
+ uses: docker/setup-qemu-action@v3
145
+
146
+ - name: Set up Docker Buildx
147
+ uses: docker/setup-buildx-action@v3
148
+
149
+ - name: Log in to the Container registry
150
+ uses: docker/login-action@v3
151
+ with:
152
+ registry: ${{ env.REGISTRY }}
153
+ username: ${{ github.actor }}
154
+ password: ${{ secrets.GITHUB_TOKEN }}
155
+
156
+ - name: Extract metadata for Docker images (cuda tag)
157
+ id: meta
158
+ uses: docker/metadata-action@v5
159
+ with:
160
+ images: ${{ env.FULL_IMAGE_NAME }}
161
+ tags: |
162
+ type=ref,event=branch
163
+ type=ref,event=tag
164
+ type=sha,prefix=git-
165
+ type=semver,pattern={{version}}
166
+ type=semver,pattern={{major}}.{{minor}}
167
+ type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda
168
+ flavor: |
169
+ latest=${{ github.ref == 'refs/heads/main' }}
170
+ suffix=-cuda,onlatest=true
171
+
172
+ - name: Extract metadata for Docker cache
173
+ id: cache-meta
174
+ uses: docker/metadata-action@v5
175
+ with:
176
+ images: ${{ env.FULL_IMAGE_NAME }}
177
+ tags: |
178
+ type=ref,event=branch
179
+ ${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
180
+ flavor: |
181
+ prefix=cache-cuda-${{ matrix.platform }}-
182
+ latest=false
183
+
184
+ - name: Build Docker image (cuda)
185
+ uses: docker/build-push-action@v5
186
+ id: build
187
+ with:
188
+ context: .
189
+ push: true
190
+ platforms: ${{ matrix.platform }}
191
+ labels: ${{ steps.meta.outputs.labels }}
192
+ outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
193
+ cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
194
+ cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
195
+ build-args: |
196
+ BUILD_HASH=${{ github.sha }}
197
+ USE_CUDA=true
198
+
199
+ - name: Export digest
200
+ run: |
201
+ mkdir -p /tmp/digests
202
+ digest="${{ steps.build.outputs.digest }}"
203
+ touch "/tmp/digests/${digest#sha256:}"
204
+
205
+ - name: Upload digest
206
+ uses: actions/upload-artifact@v4
207
+ with:
208
+ name: digests-cuda-${{ env.PLATFORM_PAIR }}
209
+ path: /tmp/digests/*
210
+ if-no-files-found: error
211
+ retention-days: 1
212
+
213
+ build-ollama-image:
214
+ runs-on: ubuntu-latest
215
+ permissions:
216
+ contents: read
217
+ packages: write
218
+ strategy:
219
+ fail-fast: false
220
+ matrix:
221
+ platform:
222
+ - linux/amd64
223
+ - linux/arm64
224
+
225
+ steps:
226
+ # GitHub Packages requires the entire repository name to be in lowercase
227
+ # although the repository owner has a lowercase username, this prevents some people from running actions after forking
228
+ - name: Set repository and image name to lowercase
229
+ run: |
230
+ echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
231
+ echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
232
+ env:
233
+ IMAGE_NAME: '${{ github.repository }}'
234
+
235
+ - name: Prepare
236
+ run: |
237
+ platform=${{ matrix.platform }}
238
+ echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
239
+
240
+ - name: Checkout repository
241
+ uses: actions/checkout@v4
242
+
243
+ - name: Set up QEMU
244
+ uses: docker/setup-qemu-action@v3
245
+
246
+ - name: Set up Docker Buildx
247
+ uses: docker/setup-buildx-action@v3
248
+
249
+ - name: Log in to the Container registry
250
+ uses: docker/login-action@v3
251
+ with:
252
+ registry: ${{ env.REGISTRY }}
253
+ username: ${{ github.actor }}
254
+ password: ${{ secrets.GITHUB_TOKEN }}
255
+
256
+ - name: Extract metadata for Docker images (ollama tag)
257
+ id: meta
258
+ uses: docker/metadata-action@v5
259
+ with:
260
+ images: ${{ env.FULL_IMAGE_NAME }}
261
+ tags: |
262
+ type=ref,event=branch
263
+ type=ref,event=tag
264
+ type=sha,prefix=git-
265
+ type=semver,pattern={{version}}
266
+ type=semver,pattern={{major}}.{{minor}}
267
+ type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=ollama
268
+ flavor: |
269
+ latest=${{ github.ref == 'refs/heads/main' }}
270
+ suffix=-ollama,onlatest=true
271
+
272
+ - name: Extract metadata for Docker cache
273
+ id: cache-meta
274
+ uses: docker/metadata-action@v5
275
+ with:
276
+ images: ${{ env.FULL_IMAGE_NAME }}
277
+ tags: |
278
+ type=ref,event=branch
279
+ ${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
280
+ flavor: |
281
+ prefix=cache-ollama-${{ matrix.platform }}-
282
+ latest=false
283
+
284
+ - name: Build Docker image (ollama)
285
+ uses: docker/build-push-action@v5
286
+ id: build
287
+ with:
288
+ context: .
289
+ push: true
290
+ platforms: ${{ matrix.platform }}
291
+ labels: ${{ steps.meta.outputs.labels }}
292
+ outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
293
+ cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
294
+ cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
295
+ build-args: |
296
+ BUILD_HASH=${{ github.sha }}
297
+ USE_OLLAMA=true
298
+
299
+ - name: Export digest
300
+ run: |
301
+ mkdir -p /tmp/digests
302
+ digest="${{ steps.build.outputs.digest }}"
303
+ touch "/tmp/digests/${digest#sha256:}"
304
+
305
+ - name: Upload digest
306
+ uses: actions/upload-artifact@v4
307
+ with:
308
+ name: digests-ollama-${{ env.PLATFORM_PAIR }}
309
+ path: /tmp/digests/*
310
+ if-no-files-found: error
311
+ retention-days: 1
312
+
313
+ merge-main-images:
314
+ runs-on: ubuntu-latest
315
+ needs: [build-main-image]
316
+ steps:
317
+ # GitHub Packages requires the entire repository name to be in lowercase
318
+ # although the repository owner has a lowercase username, this prevents some people from running actions after forking
319
+ - name: Set repository and image name to lowercase
320
+ run: |
321
+ echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
322
+ echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
323
+ env:
324
+ IMAGE_NAME: '${{ github.repository }}'
325
+
326
+ - name: Download digests
327
+ uses: actions/download-artifact@v4
328
+ with:
329
+ pattern: digests-main-*
330
+ path: /tmp/digests
331
+ merge-multiple: true
332
+
333
+ - name: Set up Docker Buildx
334
+ uses: docker/setup-buildx-action@v3
335
+
336
+ - name: Log in to the Container registry
337
+ uses: docker/login-action@v3
338
+ with:
339
+ registry: ${{ env.REGISTRY }}
340
+ username: ${{ github.actor }}
341
+ password: ${{ secrets.GITHUB_TOKEN }}
342
+
343
+ - name: Extract metadata for Docker images (default latest tag)
344
+ id: meta
345
+ uses: docker/metadata-action@v5
346
+ with:
347
+ images: ${{ env.FULL_IMAGE_NAME }}
348
+ tags: |
349
+ type=ref,event=branch
350
+ type=ref,event=tag
351
+ type=sha,prefix=git-
352
+ type=semver,pattern={{version}}
353
+ type=semver,pattern={{major}}.{{minor}}
354
+ flavor: |
355
+ latest=${{ github.ref == 'refs/heads/main' }}
356
+
357
+ - name: Create manifest list and push
358
+ working-directory: /tmp/digests
359
+ run: |
360
+ docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
361
+ $(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
362
+
363
+ - name: Inspect image
364
+ run: |
365
+ docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
366
+
367
+ merge-cuda-images:
368
+ runs-on: ubuntu-latest
369
+ needs: [build-cuda-image]
370
+ steps:
371
+ # GitHub Packages requires the entire repository name to be in lowercase
372
+ # although the repository owner has a lowercase username, this prevents some people from running actions after forking
373
+ - name: Set repository and image name to lowercase
374
+ run: |
375
+ echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
376
+ echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
377
+ env:
378
+ IMAGE_NAME: '${{ github.repository }}'
379
+
380
+ - name: Download digests
381
+ uses: actions/download-artifact@v4
382
+ with:
383
+ pattern: digests-cuda-*
384
+ path: /tmp/digests
385
+ merge-multiple: true
386
+
387
+ - name: Set up Docker Buildx
388
+ uses: docker/setup-buildx-action@v3
389
+
390
+ - name: Log in to the Container registry
391
+ uses: docker/login-action@v3
392
+ with:
393
+ registry: ${{ env.REGISTRY }}
394
+ username: ${{ github.actor }}
395
+ password: ${{ secrets.GITHUB_TOKEN }}
396
+
397
+ - name: Extract metadata for Docker images (default latest tag)
398
+ id: meta
399
+ uses: docker/metadata-action@v5
400
+ with:
401
+ images: ${{ env.FULL_IMAGE_NAME }}
402
+ tags: |
403
+ type=ref,event=branch
404
+ type=ref,event=tag
405
+ type=sha,prefix=git-
406
+ type=semver,pattern={{version}}
407
+ type=semver,pattern={{major}}.{{minor}}
408
+ type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda
409
+ flavor: |
410
+ latest=${{ github.ref == 'refs/heads/main' }}
411
+ suffix=-cuda,onlatest=true
412
+
413
+ - name: Create manifest list and push
414
+ working-directory: /tmp/digests
415
+ run: |
416
+ docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
417
+ $(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
418
+
419
+ - name: Inspect image
420
+ run: |
421
+ docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
422
+
423
+ merge-ollama-images:
424
+ runs-on: ubuntu-latest
425
+ needs: [build-ollama-image]
426
+ steps:
427
+ # GitHub Packages requires the entire repository name to be in lowercase
428
+ # although the repository owner has a lowercase username, this prevents some people from running actions after forking
429
+ - name: Set repository and image name to lowercase
430
+ run: |
431
+ echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
432
+ echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
433
+ env:
434
+ IMAGE_NAME: '${{ github.repository }}'
435
+
436
+ - name: Download digests
437
+ uses: actions/download-artifact@v4
438
+ with:
439
+ pattern: digests-ollama-*
440
+ path: /tmp/digests
441
+ merge-multiple: true
442
+
443
+ - name: Set up Docker Buildx
444
+ uses: docker/setup-buildx-action@v3
445
+
446
+ - name: Log in to the Container registry
447
+ uses: docker/login-action@v3
448
+ with:
449
+ registry: ${{ env.REGISTRY }}
450
+ username: ${{ github.actor }}
451
+ password: ${{ secrets.GITHUB_TOKEN }}
452
+
453
+ - name: Extract metadata for Docker images (default ollama tag)
454
+ id: meta
455
+ uses: docker/metadata-action@v5
456
+ with:
457
+ images: ${{ env.FULL_IMAGE_NAME }}
458
+ tags: |
459
+ type=ref,event=branch
460
+ type=ref,event=tag
461
+ type=sha,prefix=git-
462
+ type=semver,pattern={{version}}
463
+ type=semver,pattern={{major}}.{{minor}}
464
+ type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=ollama
465
+ flavor: |
466
+ latest=${{ github.ref == 'refs/heads/main' }}
467
+ suffix=-ollama,onlatest=true
468
+
469
+ - name: Create manifest list and push
470
+ working-directory: /tmp/digests
471
+ run: |
472
+ docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
473
+ $(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
474
+
475
+ - name: Inspect image
476
+ run: |
477
+ docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
.github/workflows/format-backend.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Python CI
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - dev
8
+ pull_request:
9
+ branches:
10
+ - main
11
+ - dev
12
+
13
+ jobs:
14
+ build:
15
+ name: 'Format Backend'
16
+ runs-on: ubuntu-latest
17
+
18
+ strategy:
19
+ matrix:
20
+ python-version: [3.11]
21
+
22
+ steps:
23
+ - uses: actions/checkout@v4
24
+
25
+ - name: Set up Python
26
+ uses: actions/setup-python@v5
27
+ with:
28
+ python-version: ${{ matrix.python-version }}
29
+
30
+ - name: Install dependencies
31
+ run: |
32
+ python -m pip install --upgrade pip
33
+ pip install black
34
+
35
+ - name: Format backend
36
+ run: npm run format:backend
37
+
38
+ - name: Check for changes after format
39
+ run: git diff --exit-code
.github/workflows/format-build-frontend.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Frontend Build
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - dev
8
+ pull_request:
9
+ branches:
10
+ - main
11
+ - dev
12
+
13
+ jobs:
14
+ build:
15
+ name: 'Format & Build Frontend'
16
+ runs-on: ubuntu-latest
17
+ steps:
18
+ - name: Checkout Repository
19
+ uses: actions/checkout@v4
20
+
21
+ - name: Setup Node.js
22
+ uses: actions/setup-node@v4
23
+ with:
24
+ node-version: '22' # Or specify any other version you want to use
25
+
26
+ - name: Install Dependencies
27
+ run: npm install
28
+
29
+ - name: Format Frontend
30
+ run: npm run format
31
+
32
+ - name: Run i18next
33
+ run: npm run i18n:parse
34
+
35
+ - name: Check for Changes After Format
36
+ run: git diff --exit-code
37
+
38
+ - name: Build Frontend
39
+ run: npm run build
40
+
41
+ test-frontend:
42
+ name: 'Frontend Unit Tests'
43
+ runs-on: ubuntu-latest
44
+ steps:
45
+ - name: Checkout Repository
46
+ uses: actions/checkout@v4
47
+
48
+ - name: Setup Node.js
49
+ uses: actions/setup-node@v4
50
+ with:
51
+ node-version: '22'
52
+
53
+ - name: Install Dependencies
54
+ run: npm ci
55
+
56
+ - name: Run vitest
57
+ run: npm run test:frontend
.github/workflows/integration-test.yml ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Integration Test
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - dev
8
+ pull_request:
9
+ branches:
10
+ - main
11
+ - dev
12
+
13
+ jobs:
14
+ cypress-run:
15
+ name: Run Cypress Integration Tests
16
+ runs-on: ubuntu-latest
17
+ steps:
18
+ - name: Maximize build space
19
+ uses: AdityaGarg8/remove-unwanted-software@v4.1
20
+ with:
21
+ remove-android: 'true'
22
+ remove-haskell: 'true'
23
+ remove-codeql: 'true'
24
+
25
+ - name: Checkout Repository
26
+ uses: actions/checkout@v4
27
+
28
+ - name: Build and run Compose Stack
29
+ run: |
30
+ docker compose \
31
+ --file docker-compose.yaml \
32
+ --file docker-compose.api.yaml \
33
+ --file docker-compose.a1111-test.yaml \
34
+ up --detach --build
35
+
36
+ - name: Delete Docker build cache
37
+ run: |
38
+ docker builder prune --all --force
39
+
40
+ - name: Wait for Ollama to be up
41
+ timeout-minutes: 5
42
+ run: |
43
+ until curl --output /dev/null --silent --fail http://localhost:11434; do
44
+ printf '.'
45
+ sleep 1
46
+ done
47
+ echo "Service is up!"
48
+
49
+ - name: Preload Ollama model
50
+ run: |
51
+ docker exec ollama ollama pull qwen:0.5b-chat-v1.5-q2_K
52
+
53
+ - name: Cypress run
54
+ uses: cypress-io/github-action@v6
55
+ with:
56
+ browser: chrome
57
+ wait-on: 'http://localhost:3000'
58
+ config: baseUrl=http://localhost:3000
59
+
60
+ - uses: actions/upload-artifact@v4
61
+ if: always()
62
+ name: Upload Cypress videos
63
+ with:
64
+ name: cypress-videos
65
+ path: cypress/videos
66
+ if-no-files-found: ignore
67
+
68
+ - name: Extract Compose logs
69
+ if: always()
70
+ run: |
71
+ docker compose logs > compose-logs.txt
72
+
73
+ - uses: actions/upload-artifact@v4
74
+ if: always()
75
+ name: Upload Compose logs
76
+ with:
77
+ name: compose-logs
78
+ path: compose-logs.txt
79
+ if-no-files-found: ignore
80
+
81
+ # pytest:
82
+ # name: Run Backend Tests
83
+ # runs-on: ubuntu-latest
84
+ # steps:
85
+ # - uses: actions/checkout@v4
86
+
87
+ # - name: Set up Python
88
+ # uses: actions/setup-python@v5
89
+ # with:
90
+ # python-version: ${{ matrix.python-version }}
91
+
92
+ # - name: Install dependencies
93
+ # run: |
94
+ # python -m pip install --upgrade pip
95
+ # pip install -r backend/requirements.txt
96
+
97
+ # - name: pytest run
98
+ # run: |
99
+ # ls -al
100
+ # cd backend
101
+ # PYTHONPATH=. pytest . -o log_cli=true -o log_cli_level=INFO
102
+
103
+ migration_test:
104
+ name: Run Migration Tests
105
+ runs-on: ubuntu-latest
106
+ services:
107
+ postgres:
108
+ image: postgres
109
+ env:
110
+ POSTGRES_PASSWORD: postgres
111
+ options: >-
112
+ --health-cmd pg_isready
113
+ --health-interval 10s
114
+ --health-timeout 5s
115
+ --health-retries 5
116
+ ports:
117
+ - 5432:5432
118
+ # mysql:
119
+ # image: mysql
120
+ # env:
121
+ # MYSQL_ROOT_PASSWORD: mysql
122
+ # MYSQL_DATABASE: mysql
123
+ # options: >-
124
+ # --health-cmd "mysqladmin ping -h localhost"
125
+ # --health-interval 10s
126
+ # --health-timeout 5s
127
+ # --health-retries 5
128
+ # ports:
129
+ # - 3306:3306
130
+ steps:
131
+ - name: Checkout Repository
132
+ uses: actions/checkout@v4
133
+
134
+ - name: Set up Python
135
+ uses: actions/setup-python@v5
136
+ with:
137
+ python-version: ${{ matrix.python-version }}
138
+
139
+ - name: Set up uv
140
+ uses: yezz123/setup-uv@v4
141
+ with:
142
+ uv-venv: venv
143
+
144
+ - name: Activate virtualenv
145
+ run: |
146
+ . venv/bin/activate
147
+ echo PATH=$PATH >> $GITHUB_ENV
148
+
149
+ - name: Install dependencies
150
+ run: |
151
+ uv pip install -r backend/requirements.txt
152
+
153
+ - name: Test backend with SQLite
154
+ id: sqlite
155
+ env:
156
+ WEBUI_SECRET_KEY: secret-key
157
+ GLOBAL_LOG_LEVEL: debug
158
+ run: |
159
+ cd backend
160
+ uvicorn open_webui.main:app --port "8080" --forwarded-allow-ips '*' &
161
+ UVICORN_PID=$!
162
+ # Wait up to 40 seconds for the server to start
163
+ for i in {1..40}; do
164
+ curl -s http://localhost:8080/api/config > /dev/null && break
165
+ sleep 1
166
+ if [ $i -eq 40 ]; then
167
+ echo "Server failed to start"
168
+ kill -9 $UVICORN_PID
169
+ exit 1
170
+ fi
171
+ done
172
+ # Check that the server is still running after 5 seconds
173
+ sleep 5
174
+ if ! kill -0 $UVICORN_PID; then
175
+ echo "Server has stopped"
176
+ exit 1
177
+ fi
178
+
179
+ - name: Test backend with Postgres
180
+ if: success() || steps.sqlite.conclusion == 'failure'
181
+ env:
182
+ WEBUI_SECRET_KEY: secret-key
183
+ GLOBAL_LOG_LEVEL: debug
184
+ DATABASE_URL: postgresql://postgres:postgres@localhost:5432/postgres
185
+ DATABASE_POOL_SIZE: 10
186
+ DATABASE_POOL_MAX_OVERFLOW: 10
187
+ DATABASE_POOL_TIMEOUT: 30
188
+ run: |
189
+ cd backend
190
+ uvicorn open_webui.main:app --port "8081" --forwarded-allow-ips '*' &
191
+ UVICORN_PID=$!
192
+ # Wait up to 20 seconds for the server to start
193
+ for i in {1..20}; do
194
+ curl -s http://localhost:8081/api/config > /dev/null && break
195
+ sleep 1
196
+ if [ $i -eq 20 ]; then
197
+ echo "Server failed to start"
198
+ kill -9 $UVICORN_PID
199
+ exit 1
200
+ fi
201
+ done
202
+ # Check that the server is still running after 5 seconds
203
+ sleep 5
204
+ if ! kill -0 $UVICORN_PID; then
205
+ echo "Server has stopped"
206
+ exit 1
207
+ fi
208
+
209
+ # Check that service will reconnect to postgres when connection will be closed
210
+ status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
211
+ if [[ "$status_code" -ne 200 ]] ; then
212
+ echo "Server has failed before postgres reconnect check"
213
+ exit 1
214
+ fi
215
+
216
+ echo "Terminating all connections to postgres..."
217
+ python -c "import os, psycopg2 as pg2; \
218
+ conn = pg2.connect(dsn=os.environ['DATABASE_URL'].replace('+pool', '')); \
219
+ cur = conn.cursor(); \
220
+ cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
221
+
222
+ status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
223
+ if [[ "$status_code" -ne 200 ]] ; then
224
+ echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
225
+ exit 1
226
+ fi
227
+
228
+ # - name: Test backend with MySQL
229
+ # if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'
230
+ # env:
231
+ # WEBUI_SECRET_KEY: secret-key
232
+ # GLOBAL_LOG_LEVEL: debug
233
+ # DATABASE_URL: mysql://root:mysql@localhost:3306/mysql
234
+ # run: |
235
+ # cd backend
236
+ # uvicorn open_webui.main:app --port "8083" --forwarded-allow-ips '*' &
237
+ # UVICORN_PID=$!
238
+ # # Wait up to 20 seconds for the server to start
239
+ # for i in {1..20}; do
240
+ # curl -s http://localhost:8083/api/config > /dev/null && break
241
+ # sleep 1
242
+ # if [ $i -eq 20 ]; then
243
+ # echo "Server failed to start"
244
+ # kill -9 $UVICORN_PID
245
+ # exit 1
246
+ # fi
247
+ # done
248
+ # # Check that the server is still running after 5 seconds
249
+ # sleep 5
250
+ # if ! kill -0 $UVICORN_PID; then
251
+ # echo "Server has stopped"
252
+ # exit 1
253
+ # fi
.github/workflows/lint-backend.disabled ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Python CI
2
+ on:
3
+ push:
4
+ branches: ['main']
5
+ pull_request:
6
+ jobs:
7
+ build:
8
+ name: 'Lint Backend'
9
+ env:
10
+ PUBLIC_API_BASE_URL: ''
11
+ runs-on: ubuntu-latest
12
+ strategy:
13
+ matrix:
14
+ node-version:
15
+ - latest
16
+ steps:
17
+ - uses: actions/checkout@v4
18
+ - name: Use Python
19
+ uses: actions/setup-python@v5
20
+ - name: Use Bun
21
+ uses: oven-sh/setup-bun@v1
22
+ - name: Install dependencies
23
+ run: |
24
+ python -m pip install --upgrade pip
25
+ pip install pylint
26
+ - name: Lint backend
27
+ run: bun run lint:backend
.github/workflows/lint-frontend.disabled ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Bun CI
2
+ on:
3
+ push:
4
+ branches: ['main']
5
+ pull_request:
6
+ jobs:
7
+ build:
8
+ name: 'Lint Frontend'
9
+ env:
10
+ PUBLIC_API_BASE_URL: ''
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v4
14
+ - name: Use Bun
15
+ uses: oven-sh/setup-bun@v1
16
+ - run: bun --version
17
+ - name: Install frontend dependencies
18
+ run: bun install --frozen-lockfile
19
+ - run: bun run lint:frontend
20
+ - run: bun run lint:types
21
+ if: success() || failure()
.github/workflows/release-pypi.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Release to PyPI
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main # or whatever branch you want to use
7
+ - pypi-release
8
+
9
+ jobs:
10
+ release:
11
+ runs-on: ubuntu-latest
12
+ environment:
13
+ name: pypi
14
+ url: https://pypi.org/p/open-webui
15
+ permissions:
16
+ id-token: write
17
+ steps:
18
+ - name: Checkout repository
19
+ uses: actions/checkout@v4
20
+ - uses: actions/setup-node@v4
21
+ with:
22
+ node-version: 18
23
+ - uses: actions/setup-python@v5
24
+ with:
25
+ python-version: 3.11
26
+ - name: Build
27
+ run: |
28
+ python -m pip install --upgrade pip
29
+ pip install build
30
+ python -m build .
31
+ - name: Publish package distributions to PyPI
32
+ uses: pypa/gh-action-pypi-publish@release/v1
.gitignore ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ node_modules
3
+ /build
4
+ /.svelte-kit
5
+ /package
6
+ .env
7
+ .env.*
8
+ !.env.example
9
+ vite.config.js.timestamp-*
10
+ vite.config.ts.timestamp-*
11
+ # Byte-compiled / optimized / DLL files
12
+ __pycache__/
13
+ *.py[cod]
14
+ *$py.class
15
+
16
+ # C extensions
17
+ *.so
18
+
19
+ # Pyodide distribution
20
+ static/pyodide/*
21
+ !static/pyodide/pyodide-lock.json
22
+
23
+ # Distribution / packaging
24
+ .Python
25
+ build/
26
+ develop-eggs/
27
+ dist/
28
+ downloads/
29
+ eggs/
30
+ .eggs/
31
+ lib64/
32
+ parts/
33
+ sdist/
34
+ var/
35
+ wheels/
36
+ share/python-wheels/
37
+ *.egg-info/
38
+ .installed.cfg
39
+ *.egg
40
+ MANIFEST
41
+
42
+ # PyInstaller
43
+ # Usually these files are written by a python script from a template
44
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
45
+ *.manifest
46
+ *.spec
47
+
48
+ # Installer logs
49
+ pip-log.txt
50
+ pip-delete-this-directory.txt
51
+
52
+ # Unit test / coverage reports
53
+ htmlcov/
54
+ .tox/
55
+ .nox/
56
+ .coverage
57
+ .coverage.*
58
+ .cache
59
+ nosetests.xml
60
+ coverage.xml
61
+ *.cover
62
+ *.py,cover
63
+ .hypothesis/
64
+ .pytest_cache/
65
+ cover/
66
+
67
+ # Translations
68
+ *.mo
69
+ *.pot
70
+
71
+ # Django stuff:
72
+ *.log
73
+ local_settings.py
74
+ db.sqlite3
75
+ db.sqlite3-journal
76
+
77
+ # Flask stuff:
78
+ instance/
79
+ .webassets-cache
80
+
81
+ # Scrapy stuff:
82
+ .scrapy
83
+
84
+ # Sphinx documentation
85
+ docs/_build/
86
+
87
+ # PyBuilder
88
+ .pybuilder/
89
+ target/
90
+
91
+ # Jupyter Notebook
92
+ .ipynb_checkpoints
93
+
94
+ # IPython
95
+ profile_default/
96
+ ipython_config.py
97
+
98
+ # pyenv
99
+ # For a library or package, you might want to ignore these files since the code is
100
+ # intended to run in multiple environments; otherwise, check them in:
101
+ # .python-version
102
+
103
+ # pipenv
104
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
105
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
106
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
107
+ # install all needed dependencies.
108
+ #Pipfile.lock
109
+
110
+ # poetry
111
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
112
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
113
+ # commonly ignored for libraries.
114
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
115
+ #poetry.lock
116
+
117
+ # pdm
118
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
119
+ #pdm.lock
120
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
121
+ # in version control.
122
+ # https://pdm.fming.dev/#use-with-ide
123
+ .pdm.toml
124
+
125
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
126
+ __pypackages__/
127
+
128
+ # Celery stuff
129
+ celerybeat-schedule
130
+ celerybeat.pid
131
+
132
+ # SageMath parsed files
133
+ *.sage.py
134
+
135
+ # Environments
136
+ .env
137
+ .venv
138
+ env/
139
+ venv/
140
+ ENV/
141
+ env.bak/
142
+ venv.bak/
143
+
144
+ # Spyder project settings
145
+ .spyderproject
146
+ .spyproject
147
+
148
+ # Rope project settings
149
+ .ropeproject
150
+
151
+ # mkdocs documentation
152
+ /site
153
+
154
+ # mypy
155
+ .mypy_cache/
156
+ .dmypy.json
157
+ dmypy.json
158
+
159
+ # Pyre type checker
160
+ .pyre/
161
+
162
+ # pytype static type analyzer
163
+ .pytype/
164
+
165
+ # Cython debug symbols
166
+ cython_debug/
167
+
168
+ # PyCharm
169
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
170
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
171
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
172
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
173
+ .idea/
174
+
175
+ # Logs
176
+ logs
177
+ *.log
178
+ npm-debug.log*
179
+ yarn-debug.log*
180
+ yarn-error.log*
181
+ lerna-debug.log*
182
+ .pnpm-debug.log*
183
+
184
+ # Diagnostic reports (https://nodejs.org/api/report.html)
185
+ report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
186
+
187
+ # Runtime data
188
+ pids
189
+ *.pid
190
+ *.seed
191
+ *.pid.lock
192
+
193
+ # Directory for instrumented libs generated by jscoverage/JSCover
194
+ lib-cov
195
+
196
+ # Coverage directory used by tools like istanbul
197
+ coverage
198
+ *.lcov
199
+
200
+ # nyc test coverage
201
+ .nyc_output
202
+
203
+ # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
204
+ .grunt
205
+
206
+ # Bower dependency directory (https://bower.io/)
207
+ bower_components
208
+
209
+ # node-waf configuration
210
+ .lock-wscript
211
+
212
+ # Compiled binary addons (https://nodejs.org/api/addons.html)
213
+ build/Release
214
+
215
+ # Dependency directories
216
+ node_modules/
217
+ jspm_packages/
218
+
219
+ # Snowpack dependency directory (https://snowpack.dev/)
220
+ web_modules/
221
+
222
+ # TypeScript cache
223
+ *.tsbuildinfo
224
+
225
+ # Optional npm cache directory
226
+ .npm
227
+
228
+ # Optional eslint cache
229
+ .eslintcache
230
+
231
+ # Optional stylelint cache
232
+ .stylelintcache
233
+
234
+ # Microbundle cache
235
+ .rpt2_cache/
236
+ .rts2_cache_cjs/
237
+ .rts2_cache_es/
238
+ .rts2_cache_umd/
239
+
240
+ # Optional REPL history
241
+ .node_repl_history
242
+
243
+ # Output of 'npm pack'
244
+ *.tgz
245
+
246
+ # Yarn Integrity file
247
+ .yarn-integrity
248
+
249
+ # dotenv environment variable files
250
+ .env
251
+ .env.development.local
252
+ .env.test.local
253
+ .env.production.local
254
+ .env.local
255
+
256
+ # parcel-bundler cache (https://parceljs.org/)
257
+ .cache
258
+ .parcel-cache
259
+
260
+ # Next.js build output
261
+ .next
262
+ out
263
+
264
+ # Nuxt.js build / generate output
265
+ .nuxt
266
+ dist
267
+
268
+ # Gatsby files
269
+ .cache/
270
+ # Comment in the public line in if your project uses Gatsby and not Next.js
271
+ # https://nextjs.org/blog/next-9-1#public-directory-support
272
+ # public
273
+
274
+ # vuepress build output
275
+ .vuepress/dist
276
+
277
+ # vuepress v2.x temp and cache directory
278
+ .temp
279
+ .cache
280
+
281
+ # Docusaurus cache and generated files
282
+ .docusaurus
283
+
284
+ # Serverless directories
285
+ .serverless/
286
+
287
+ # FuseBox cache
288
+ .fusebox/
289
+
290
+ # DynamoDB Local files
291
+ .dynamodb/
292
+
293
+ # TernJS port file
294
+ .tern-port
295
+
296
+ # Stores VSCode versions used for testing VSCode extensions
297
+ .vscode-test
298
+
299
+ # yarn v2
300
+ .yarn/cache
301
+ .yarn/unplugged
302
+ .yarn/build-state.yml
303
+ .yarn/install-state.gz
304
+ .pnp.*
305
+
306
+ # cypress artifacts
307
+ cypress/videos
308
+ cypress/screenshots
309
+ .vscode/settings.json
.npmrc ADDED
@@ -0,0 +1 @@
 
 
1
+ engine-strict=true
.prettierignore ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore files for PNPM, NPM and YARN
2
+ pnpm-lock.yaml
3
+ package-lock.json
4
+ yarn.lock
5
+
6
+ kubernetes/
7
+
8
+ # Copy of .gitignore
9
+ .DS_Store
10
+ node_modules
11
+ /build
12
+ /.svelte-kit
13
+ /package
14
+ .env
15
+ .env.*
16
+ !.env.example
17
+ vite.config.js.timestamp-*
18
+ vite.config.ts.timestamp-*
19
+ # Byte-compiled / optimized / DLL files
20
+ __pycache__/
21
+ *.py[cod]
22
+ *$py.class
23
+
24
+ # C extensions
25
+ *.so
26
+
27
+ # Distribution / packaging
28
+ .Python
29
+ build/
30
+ develop-eggs/
31
+ dist/
32
+ downloads/
33
+ eggs/
34
+ .eggs/
35
+ lib64/
36
+ parts/
37
+ sdist/
38
+ var/
39
+ wheels/
40
+ share/python-wheels/
41
+ *.egg-info/
42
+ .installed.cfg
43
+ *.egg
44
+ MANIFEST
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ # Installer logs
53
+ pip-log.txt
54
+ pip-delete-this-directory.txt
55
+
56
+ # Unit test / coverage reports
57
+ htmlcov/
58
+ .tox/
59
+ .nox/
60
+ .coverage
61
+ .coverage.*
62
+ .cache
63
+ nosetests.xml
64
+ coverage.xml
65
+ *.cover
66
+ *.py,cover
67
+ .hypothesis/
68
+ .pytest_cache/
69
+ cover/
70
+
71
+ # Translations
72
+ *.mo
73
+ *.pot
74
+
75
+ # Django stuff:
76
+ *.log
77
+ local_settings.py
78
+ db.sqlite3
79
+ db.sqlite3-journal
80
+
81
+ # Flask stuff:
82
+ instance/
83
+ .webassets-cache
84
+
85
+ # Scrapy stuff:
86
+ .scrapy
87
+
88
+ # Sphinx documentation
89
+ docs/_build/
90
+
91
+ # PyBuilder
92
+ .pybuilder/
93
+ target/
94
+
95
+ # Jupyter Notebook
96
+ .ipynb_checkpoints
97
+
98
+ # IPython
99
+ profile_default/
100
+ ipython_config.py
101
+
102
+ # pyenv
103
+ # For a library or package, you might want to ignore these files since the code is
104
+ # intended to run in multiple environments; otherwise, check them in:
105
+ # .python-version
106
+
107
+ # pipenv
108
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
110
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
111
+ # install all needed dependencies.
112
+ #Pipfile.lock
113
+
114
+ # poetry
115
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
116
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
117
+ # commonly ignored for libraries.
118
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
119
+ #poetry.lock
120
+
121
+ # pdm
122
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
123
+ #pdm.lock
124
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
125
+ # in version control.
126
+ # https://pdm.fming.dev/#use-with-ide
127
+ .pdm.toml
128
+
129
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
130
+ __pypackages__/
131
+
132
+ # Celery stuff
133
+ celerybeat-schedule
134
+ celerybeat.pid
135
+
136
+ # SageMath parsed files
137
+ *.sage.py
138
+
139
+ # Environments
140
+ .env
141
+ .venv
142
+ env/
143
+ venv/
144
+ ENV/
145
+ env.bak/
146
+ venv.bak/
147
+
148
+ # Spyder project settings
149
+ .spyderproject
150
+ .spyproject
151
+
152
+ # Rope project settings
153
+ .ropeproject
154
+
155
+ # mkdocs documentation
156
+ /site
157
+
158
+ # mypy
159
+ .mypy_cache/
160
+ .dmypy.json
161
+ dmypy.json
162
+
163
+ # Pyre type checker
164
+ .pyre/
165
+
166
+ # pytype static type analyzer
167
+ .pytype/
168
+
169
+ # Cython debug symbols
170
+ cython_debug/
171
+
172
+ # PyCharm
173
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
174
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
175
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
176
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
177
+ .idea/
178
+
179
+ # Logs
180
+ logs
181
+ *.log
182
+ npm-debug.log*
183
+ yarn-debug.log*
184
+ yarn-error.log*
185
+ lerna-debug.log*
186
+ .pnpm-debug.log*
187
+
188
+ # Diagnostic reports (https://nodejs.org/api/report.html)
189
+ report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
190
+
191
+ # Runtime data
192
+ pids
193
+ *.pid
194
+ *.seed
195
+ *.pid.lock
196
+
197
+ # Directory for instrumented libs generated by jscoverage/JSCover
198
+ lib-cov
199
+
200
+ # Coverage directory used by tools like istanbul
201
+ coverage
202
+ *.lcov
203
+
204
+ # nyc test coverage
205
+ .nyc_output
206
+
207
+ # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
208
+ .grunt
209
+
210
+ # Bower dependency directory (https://bower.io/)
211
+ bower_components
212
+
213
+ # node-waf configuration
214
+ .lock-wscript
215
+
216
+ # Compiled binary addons (https://nodejs.org/api/addons.html)
217
+ build/Release
218
+
219
+ # Dependency directories
220
+ node_modules/
221
+ jspm_packages/
222
+
223
+ # Snowpack dependency directory (https://snowpack.dev/)
224
+ web_modules/
225
+
226
+ # TypeScript cache
227
+ *.tsbuildinfo
228
+
229
+ # Optional npm cache directory
230
+ .npm
231
+
232
+ # Optional eslint cache
233
+ .eslintcache
234
+
235
+ # Optional stylelint cache
236
+ .stylelintcache
237
+
238
+ # Microbundle cache
239
+ .rpt2_cache/
240
+ .rts2_cache_cjs/
241
+ .rts2_cache_es/
242
+ .rts2_cache_umd/
243
+
244
+ # Optional REPL history
245
+ .node_repl_history
246
+
247
+ # Output of 'npm pack'
248
+ *.tgz
249
+
250
+ # Yarn Integrity file
251
+ .yarn-integrity
252
+
253
+ # dotenv environment variable files
254
+ .env
255
+ .env.development.local
256
+ .env.test.local
257
+ .env.production.local
258
+ .env.local
259
+
260
+ # parcel-bundler cache (https://parceljs.org/)
261
+ .cache
262
+ .parcel-cache
263
+
264
+ # Next.js build output
265
+ .next
266
+ out
267
+
268
+ # Nuxt.js build / generate output
269
+ .nuxt
270
+ dist
271
+
272
+ # Gatsby files
273
+ .cache/
274
+ # Comment in the public line in if your project uses Gatsby and not Next.js
275
+ # https://nextjs.org/blog/next-9-1#public-directory-support
276
+ # public
277
+
278
+ # vuepress build output
279
+ .vuepress/dist
280
+
281
+ # vuepress v2.x temp and cache directory
282
+ .temp
283
+ .cache
284
+
285
+ # Docusaurus cache and generated files
286
+ .docusaurus
287
+
288
+ # Serverless directories
289
+ .serverless/
290
+
291
+ # FuseBox cache
292
+ .fusebox/
293
+
294
+ # DynamoDB Local files
295
+ .dynamodb/
296
+
297
+ # TernJS port file
298
+ .tern-port
299
+
300
+ # Stores VSCode versions used for testing VSCode extensions
301
+ .vscode-test
302
+
303
+ # yarn v2
304
+ .yarn/cache
305
+ .yarn/unplugged
306
+ .yarn/build-state.yml
307
+ .yarn/install-state.gz
308
+ .pnp.*
309
+
310
+ # cypress artifacts
311
+ cypress/videos
312
+ cypress/screenshots
313
+
314
+
315
+
316
+ /static/*
.prettierrc ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "useTabs": true,
3
+ "singleQuote": true,
4
+ "trailingComma": "none",
5
+ "printWidth": 100,
6
+ "plugins": ["prettier-plugin-svelte"],
7
+ "pluginSearchDirs": ["."],
8
+ "overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }]
9
+ }
CHANGELOG.md ADDED
The diff for this file is too large to render. See raw diff
 
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ As members, contributors, and leaders of this community, we pledge to make participation in our open-source project a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
6
+
7
+ We are committed to creating and maintaining an open, respectful, and professional environment where positive contributions and meaningful discussions can flourish. By participating in this project, you agree to uphold these values and align your behavior to the standards outlined in this Code of Conduct.
8
+
9
+ ## Why These Standards Are Important
10
+
11
+ Open-source projects rely on a community of volunteers dedicating their time, expertise, and effort toward a shared goal. These projects are inherently collaborative but also fragile, as the success of the project depends on the goodwill, energy, and productivity of those involved.
12
+
13
+ Maintaining a positive and respectful environment is essential to safeguarding the integrity of this project and protecting contributors' efforts. Behavior that disrupts this atmosphere—whether through hostility, entitlement, or unprofessional conduct—can severely harm the morale and productivity of the community. **Strict enforcement of these standards ensures a safe and supportive space for meaningful collaboration.**
14
+
15
+ This is a community where **respect and professionalism are mandatory.** Violations of these standards will result in **zero tolerance** and immediate enforcement to prevent disruption and ensure the well-being of all participants.
16
+
17
+ ## Our Standards
18
+
19
+ Examples of behavior that contribute to a positive and professional community include:
20
+
21
+ - **Respecting others.** Be considerate, listen actively, and engage with empathy toward others' viewpoints and experiences.
22
+ - **Constructive feedback.** Provide actionable, thoughtful, and respectful feedback that helps improve the project and encourages collaboration. Avoid unproductive negativity or hypercriticism.
23
+ - **Recognizing volunteer contributions.** Appreciate that contributors dedicate their free time and resources selflessly. Approach them with gratitude and patience.
24
+ - **Focusing on shared goals.** Collaborate in ways that prioritize the health, success, and sustainability of the community over individual agendas.
25
+
26
+ Examples of unacceptable behavior include:
27
+
28
+ - The use of discriminatory, demeaning, or sexualized language or behavior.
29
+ - Personal attacks, derogatory comments, trolling, or inflammatory political or ideological arguments.
30
+ - Harassment, intimidation, or any behavior intended to create a hostile, uncomfortable, or unsafe environment.
31
+ - Publishing others' private information (e.g., physical or email addresses) without explicit permission.
32
+ - **Entitlement, demand, or aggression toward contributors.** Volunteers are under no obligation to provide immediate or personalized support. Rude or dismissive behavior will not be tolerated.
33
+ - **Unproductive or destructive behavior.** This includes venting frustration as hostility ("tantrums"), hypercriticism, attention-seeking negativity, or anything that distracts from the project's goals.
34
+ - **Spamming and promotional exploitation.** Sharing irrelevant product promotions or self-promotion in the community is not allowed unless it directly contributes value to the discussion.
35
+
36
+ ### Feedback and Community Engagement
37
+
38
+ - **Constructive feedback is encouraged, but hostile or entitled behavior will result in immediate action.** If you disagree with elements of the project, we encourage you to offer meaningful improvements or fork the project if necessary. Healthy discussions and technical disagreements are welcome only when handled with professionalism.
39
+ - **Respect contributors' time and efforts.** No one is entitled to personalized or on-demand assistance. This is a community built on collaboration and shared effort; demanding or demeaning behavior undermines that trust and will not be allowed.
40
+
41
+ ### Zero Tolerance: No Warnings, Immediate Action
42
+
43
+ This community operates under a **zero-tolerance policy.** Any behavior deemed unacceptable under this Code of Conduct will result in **immediate enforcement, without prior warning.**
44
+
45
+ We employ this approach to ensure that unproductive or disruptive behavior does not escalate further or cause unnecessary harm to other contributors. The standards are clear, and violations of any kind—whether mild or severe—will be addressed decisively to protect the community.
46
+
47
+ ## Enforcement Responsibilities
48
+
49
+ Community leaders are responsible for upholding and enforcing these standards. They are empowered to take **immediate and appropriate action** to address any behaviors they deem unacceptable under this Code of Conduct. These actions are taken with the goal of protecting the community and preserving its safe, positive, and productive environment.
50
+
51
+ ## Scope
52
+
53
+ This Code of Conduct applies to all community spaces, including forums, repositories, social media accounts, and in-person events. It also applies when an individual represents the community in public settings, such as conferences or official communications.
54
+
55
+ Additionally, any behavior outside of these defined spaces that negatively impacts the community or its members may fall within the scope of this Code of Conduct.
56
+
57
+ ## Reporting Violations
58
+
59
+ Instances of unacceptable behavior can be reported to the leadership team at **hello@openwebui.com**. Reports will be handled promptly, confidentially, and with consideration for the safety and well-being of the reporter.
60
+
61
+ All community leaders are required to uphold confidentiality and impartiality when addressing reports of violations.
62
+
63
+ ## Enforcement Guidelines
64
+
65
+ ### Ban
66
+
67
+ **Community Impact**: Community leaders will issue a ban to any participant whose behavior is deemed unacceptable according to this Code of Conduct. Bans are enforced immediately and without prior notice.
68
+
69
+ A ban may be temporary or permanent, depending on the severity of the violation. This includes—but is not limited to—behavior such as:
70
+
71
+ - Harassment or abusive behavior toward contributors.
72
+ - Persistent negativity or hostility that disrupts the collaborative environment.
73
+ - Disrespectful, demanding, or aggressive interactions with others.
74
+ - Attempts to cause harm or sabotage the community.
75
+
76
+ **Consequence**: A banned individual is immediately removed from access to all community spaces, communication channels, and events. Community leaders reserve the right to enforce either a time-limited suspension or a permanent ban based on the specific circumstances of the violation.
77
+
78
+ This approach ensures that disruptive behaviors are addressed swiftly and decisively in order to maintain the integrity and productivity of the community.
79
+
80
+ ## Why Zero Tolerance Is Necessary
81
+
82
+ Open-source projects thrive on collaboration, goodwill, and mutual respect. Toxic behaviors—such as entitlement, hostility, or persistent negativity—threaten not just individual contributors but the health of the project as a whole. Allowing such behaviors to persist robs contributors of their time, energy, and enthusiasm for the work they do.
83
+
84
+ By enforcing a zero-tolerance policy, we ensure that the community remains a safe, welcoming space for all participants. These measures are not about harshness—they are about protecting contributors and fostering a productive environment where innovation can thrive.
85
+
86
+ Our expectations are clear, and our enforcement reflects our commitment to this project's long-term success.
87
+
88
+ ## Attribution
89
+
90
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at
91
+ https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
92
+
93
+ Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
94
+
95
+ [homepage]: https://www.contributor-covenant.org
96
+
97
+ For answers to common questions about this code of conduct, see the FAQ at
98
+ https://www.contributor-covenant.org/faq. Translations are available at
99
+ https://www.contributor-covenant.org/translations.
Caddyfile.localhost ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run with
2
+ # caddy run --envfile ./example.env --config ./Caddyfile.localhost
3
+ #
4
+ # This is configured for
5
+ # - Automatic HTTPS (even for localhost)
6
+ # - Reverse Proxying to Ollama API Base URL (http://localhost:11434/api)
7
+ # - CORS
8
+ # - HTTP Basic Auth API Tokens (uncomment basicauth section)
9
+
10
+
11
+ # CORS Preflight (OPTIONS) + Request (GET, POST, PATCH, PUT, DELETE)
12
+ (cors-api) {
13
+ @match-cors-api-preflight method OPTIONS
14
+ handle @match-cors-api-preflight {
15
+ header {
16
+ Access-Control-Allow-Origin "{http.request.header.origin}"
17
+ Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
18
+ Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
19
+ Access-Control-Allow-Credentials "true"
20
+ Access-Control-Max-Age "3600"
21
+ defer
22
+ }
23
+ respond "" 204
24
+ }
25
+
26
+ @match-cors-api-request {
27
+ not {
28
+ header Origin "{http.request.scheme}://{http.request.host}"
29
+ }
30
+ header Origin "{http.request.header.origin}"
31
+ }
32
+ handle @match-cors-api-request {
33
+ header {
34
+ Access-Control-Allow-Origin "{http.request.header.origin}"
35
+ Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
36
+ Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
37
+ Access-Control-Allow-Credentials "true"
38
+ Access-Control-Max-Age "3600"
39
+ defer
40
+ }
41
+ }
42
+ }
43
+
44
+ # replace localhost with example.com or whatever
45
+ localhost {
46
+ ## HTTP Basic Auth
47
+ ## (uncomment to enable)
48
+ # basicauth {
49
+ # # see .example.env for how to generate tokens
50
+ # {env.OLLAMA_API_ID} {env.OLLAMA_API_TOKEN_DIGEST}
51
+ # }
52
+
53
+ handle /api/* {
54
+ # Comment to disable CORS
55
+ import cors-api
56
+
57
+ reverse_proxy localhost:11434
58
+ }
59
+
60
+ # Same-Origin Static Web Server
61
+ file_server {
62
+ root ./build/
63
+ }
64
+ }
Dockerfile ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # syntax=docker/dockerfile:1
2
+ # Initialize device type args
3
+ # use build args in the docker build command with --build-arg="BUILDARG=true"
4
+ ARG USE_CUDA=false
5
+ ARG USE_OLLAMA=false
6
+ # Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
7
+ ARG USE_CUDA_VER=cu121
8
+ # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
9
+ # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
10
+ # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
11
+ # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
12
+ ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
13
+ ARG USE_RERANKING_MODEL=""
14
+
15
+ # Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
16
+ ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
17
+
18
+ ARG BUILD_HASH=dev-build
19
+ # Override at your own risk - non-root configurations are untested
20
+ ARG UID=0
21
+ ARG GID=0
22
+
23
+ ######## WebUI frontend ########
24
+ FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build
25
+ ARG BUILD_HASH
26
+
27
+ WORKDIR /app
28
+
29
+ COPY package.json package-lock.json ./
30
+ RUN npm ci
31
+
32
+ COPY . .
33
+ ENV APP_BUILD_HASH=${BUILD_HASH}
34
+ RUN npm run build
35
+
36
+ ######## WebUI backend ########
37
+ FROM python:3.11-slim-bookworm AS base
38
+
39
+ # Use args
40
+ ARG USE_CUDA
41
+ ARG USE_OLLAMA
42
+ ARG USE_CUDA_VER
43
+ ARG USE_EMBEDDING_MODEL
44
+ ARG USE_RERANKING_MODEL
45
+ ARG UID
46
+ ARG GID
47
+
48
+ ## Basis ##
49
+ ENV ENV=prod \
50
+ PORT=8080 \
51
+ # pass build args to the build
52
+ USE_OLLAMA_DOCKER=${USE_OLLAMA} \
53
+ USE_CUDA_DOCKER=${USE_CUDA} \
54
+ USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
55
+ USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
56
+ USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
57
+
58
+ ## Basis URL Config ##
59
+ ENV OLLAMA_BASE_URL="/ollama" \
60
+ OPENAI_API_BASE_URL=""
61
+
62
+ ## API Key and Security Config ##
63
+ ENV OPENAI_API_KEY="" \
64
+ WEBUI_SECRET_KEY="" \
65
+ SCARF_NO_ANALYTICS=true \
66
+ DO_NOT_TRACK=true \
67
+ ANONYMIZED_TELEMETRY=false
68
+
69
+ #### Other models #########################################################
70
+ ## whisper TTS model settings ##
71
+ ENV WHISPER_MODEL="base" \
72
+ WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
73
+
74
+ ## RAG Embedding model settings ##
75
+ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
76
+ RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
77
+ SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
78
+
79
+ ## Tiktoken model settings ##
80
+ ENV TIKTOKEN_ENCODING_NAME="cl100k_base" \
81
+ TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
82
+
83
+ ## Hugging Face download cache ##
84
+ ENV HF_HOME="/app/backend/data/cache/embedding/models"
85
+
86
+ ## Torch Extensions ##
87
+ # ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions"
88
+
89
+ #### Other models ##########################################################
90
+
91
+ WORKDIR /app/backend
92
+
93
+ ENV HOME=/root
94
+ # Create user and group if not root
95
+ RUN if [ $UID -ne 0 ]; then \
96
+ if [ $GID -ne 0 ]; then \
97
+ addgroup --gid $GID app; \
98
+ fi; \
99
+ adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
100
+ fi
101
+
102
+ RUN mkdir -p $HOME/.cache/chroma
103
+ RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
104
+
105
+ # Make sure the user has access to the app and root directory
106
+ RUN chown -R $UID:$GID /app $HOME
107
+
108
+ RUN if [ "$USE_OLLAMA" = "true" ]; then \
109
+ apt-get update && \
110
+ # Install pandoc and netcat
111
+ apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
112
+ apt-get install -y --no-install-recommends gcc python3-dev && \
113
+ # for RAG OCR
114
+ apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
115
+ # install helper tools
116
+ apt-get install -y --no-install-recommends curl jq && \
117
+ # install ollama
118
+ curl -fsSL https://ollama.com/install.sh | sh && \
119
+ # cleanup
120
+ rm -rf /var/lib/apt/lists/*; \
121
+ else \
122
+ apt-get update && \
123
+ # Install pandoc, netcat and gcc
124
+ apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
125
+ apt-get install -y --no-install-recommends gcc python3-dev && \
126
+ # for RAG OCR
127
+ apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
128
+ # cleanup
129
+ rm -rf /var/lib/apt/lists/*; \
130
+ fi
131
+
132
+ # install python dependencies
133
+ COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
134
+
135
+ RUN pip3 install uv && \
136
+ if [ "$USE_CUDA" = "true" ]; then \
137
+ # If you use CUDA the whisper and embedding model will be downloaded on first use
138
+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
139
+ uv pip install --system -r requirements.txt --no-cache-dir && \
140
+ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
141
+ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
142
+ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
143
+ else \
144
+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
145
+ uv pip install --system -r requirements.txt --no-cache-dir && \
146
+ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
147
+ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
148
+ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
149
+ fi; \
150
+ chown -R $UID:$GID /app/backend/data/
151
+
152
+
153
+
154
+ # copy embedding weight from build
155
+ # RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
156
+ # COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
157
+
158
+ # copy built frontend files
159
+ COPY --chown=$UID:$GID --from=build /app/build /app/build
160
+ COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md
161
+ COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json
162
+
163
+ # copy backend files
164
+ COPY --chown=$UID:$GID ./backend .
165
+
166
+ EXPOSE 8080
167
+
168
+ HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
169
+
170
+ USER $UID:$GID
171
+
172
+ ARG BUILD_HASH
173
+ ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
174
+ ENV DOCKER=true
175
+
176
+ CMD [ "bash", "start.sh"]
INSTALLATION.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Installing Both Ollama and Open WebUI Using Kustomize
2
+
3
+ For cpu-only pod
4
+
5
+ ```bash
6
+ kubectl apply -f ./kubernetes/manifest/base
7
+ ```
8
+
9
+ For gpu-enabled pod
10
+
11
+ ```bash
12
+ kubectl apply -k ./kubernetes/manifest
13
+ ```
14
+
15
+ ### Installing Both Ollama and Open WebUI Using Helm
16
+
17
+ Package Helm file first
18
+
19
+ ```bash
20
+ helm package ./kubernetes/helm/
21
+ ```
22
+
23
+ For cpu-only pod
24
+
25
+ ```bash
26
+ helm install ollama-webui ./ollama-webui-*.tgz
27
+ ```
28
+
29
+ For gpu-enabled pod
30
+
31
+ ```bash
32
+ helm install ollama-webui ./ollama-webui-*.tgz --set ollama.resources.limits.nvidia.com/gpu="1"
33
+ ```
34
+
35
+ Check the `kubernetes/helm/values.yaml` file to know which parameters are available for customization
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Timothy Jaeryang Baek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Makefile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ifneq ($(shell which docker-compose 2>/dev/null),)
3
+ DOCKER_COMPOSE := docker-compose
4
+ else
5
+ DOCKER_COMPOSE := docker compose
6
+ endif
7
+
8
+ install:
9
+ $(DOCKER_COMPOSE) up -d
10
+
11
+ remove:
12
+ @chmod +x confirm_remove.sh
13
+ @./confirm_remove.sh
14
+
15
+ start:
16
+ $(DOCKER_COMPOSE) start
17
+ startAndBuild:
18
+ $(DOCKER_COMPOSE) up -d --build
19
+
20
+ stop:
21
+ $(DOCKER_COMPOSE) stop
22
+
23
+ update:
24
+ # Calls the LLM update script
25
+ chmod +x update_ollama_models.sh
26
+ @./update_ollama_models.sh
27
+ @git pull
28
+ $(DOCKER_COMPOSE) down
29
+ # Make sure the ollama-webui container is stopped before rebuilding
30
+ @docker stop open-webui || true
31
+ $(DOCKER_COMPOSE) up --build -d
32
+ $(DOCKER_COMPOSE) start
33
+
README.md ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Open WebUI
3
+ emoji: 🐳
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: docker
7
+ app_port: 8080
8
+ ---
9
+ # Open WebUI 👋
10
+
11
+ ![GitHub stars](https://img.shields.io/github/stars/open-webui/open-webui?style=social)
12
+ ![GitHub forks](https://img.shields.io/github/forks/open-webui/open-webui?style=social)
13
+ ![GitHub watchers](https://img.shields.io/github/watchers/open-webui/open-webui?style=social)
14
+ ![GitHub repo size](https://img.shields.io/github/repo-size/open-webui/open-webui)
15
+ ![GitHub language count](https://img.shields.io/github/languages/count/open-webui/open-webui)
16
+ ![GitHub top language](https://img.shields.io/github/languages/top/open-webui/open-webui)
17
+ ![GitHub last commit](https://img.shields.io/github/last-commit/open-webui/open-webui?color=red)
18
+ ![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Follama-webui%2Follama-wbui&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)
19
+ [![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s)
20
+ [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck)
21
+
22
+ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-rich, and user-friendly self-hosted WebUI designed to operate entirely offline. It supports various LLM runners, including Ollama and OpenAI-compatible APIs. For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
23
+
24
+ ![Open WebUI Demo](./demo.gif)
25
+
26
+ ## Key Features of Open WebUI ⭐
27
+
28
+ - 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images.
29
+
30
+ - 🤝 **Ollama/OpenAI API Integration**: Effortlessly integrate OpenAI-compatible APIs for versatile conversations alongside Ollama models. Customize the OpenAI API URL to link with **LMStudio, GroqCloud, Mistral, OpenRouter, and more**.
31
+
32
+ - 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users.
33
+
34
+ - 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
35
+
36
+ - 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
37
+
38
+ - ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
39
+
40
+ - 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment.
41
+
42
+ - 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
43
+
44
+ - 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
45
+
46
+ - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
47
+
48
+ - 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience.
49
+
50
+ - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
51
+
52
+ - 🎨 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API or ComfyUI (local), and OpenAI's DALL-E (external), enriching your chat experience with dynamic visual content.
53
+
54
+ - ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel.
55
+
56
+ - 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
57
+
58
+ - 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
59
+
60
+ - 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
61
+
62
+ - 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates, fixes, and new features.
63
+
64
+ Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!
65
+
66
+ ## 🔗 Also Check Out Open WebUI Community!
67
+
68
+ Don't forget to explore our sibling project, [Open WebUI Community](https://openwebui.com/), where you can discover, download, and explore customized Modelfiles. Open WebUI Community offers a wide range of exciting possibilities for enhancing your chat interactions with Open WebUI! 🚀
69
+
70
+ ## How to Install 🚀
71
+
72
+ ### Installation via Python pip 🐍
73
+
74
+ Open WebUI can be installed using pip, the Python package installer. Before proceeding, ensure you're using **Python 3.11** to avoid compatibility issues.
75
+
76
+ 1. **Install Open WebUI**:
77
+ Open your terminal and run the following command to install Open WebUI:
78
+
79
+ ```bash
80
+ pip install open-webui
81
+ ```
82
+
83
+ 2. **Running Open WebUI**:
84
+ After installation, you can start Open WebUI by executing:
85
+
86
+ ```bash
87
+ open-webui serve
88
+ ```
89
+
90
+ This will start the Open WebUI server, which you can access at [http://localhost:8080](http://localhost:8080)
91
+
92
+ ### Quick Start with Docker 🐳
93
+
94
+ > [!NOTE]
95
+ > Please note that for certain Docker environments, additional configurations might be needed. If you encounter any connection issues, our detailed guide on [Open WebUI Documentation](https://docs.openwebui.com/) is ready to assist you.
96
+
97
+ > [!WARNING]
98
+ > When using Docker to install Open WebUI, make sure to include the `-v open-webui:/app/backend/data` in your Docker command. This step is crucial as it ensures your database is properly mounted and prevents any loss of data.
99
+
100
+ > [!TIP]
101
+ > If you wish to utilize Open WebUI with Ollama included or CUDA acceleration, we recommend utilizing our official images tagged with either `:cuda` or `:ollama`. To enable CUDA, you must install the [Nvidia CUDA container toolkit](https://docs.nvidia.com/dgx/nvidia-container-runtime-upgrade/) on your Linux/WSL system.
102
+
103
+ ### Installation with Default Configuration
104
+
105
+ - **If Ollama is on your computer**, use this command:
106
+
107
+ ```bash
108
+ docker run -d -p 3000:8080 --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
109
+ ```
110
+
111
+ - **If Ollama is on a Different Server**, use this command:
112
+
113
+ To connect to Ollama on another server, change the `OLLAMA_BASE_URL` to the server's URL:
114
+
115
+ ```bash
116
+ docker run -d -p 3000:8080 -e OLLAMA_BASE_URL=https://example.com -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
117
+ ```
118
+
119
+ - **To run Open WebUI with Nvidia GPU support**, use this command:
120
+
121
+ ```bash
122
+ docker run -d -p 3000:8080 --gpus all --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:cuda
123
+ ```
124
+
125
+ ### Installation for OpenAI API Usage Only
126
+
127
+ - **If you're only using OpenAI API**, use this command:
128
+
129
+ ```bash
130
+ docker run -d -p 3000:8080 -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
131
+ ```
132
+
133
+ ### Installing Open WebUI with Bundled Ollama Support
134
+
135
+ This installation method uses a single container image that bundles Open WebUI with Ollama, allowing for a streamlined setup via a single command. Choose the appropriate command based on your hardware setup:
136
+
137
+ - **With GPU Support**:
138
+ Utilize GPU resources by running the following command:
139
+
140
+ ```bash
141
+ docker run -d -p 3000:8080 --gpus=all -v ollama:/root/.ollama -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:ollama
142
+ ```
143
+
144
+ - **For CPU Only**:
145
+ If you're not using a GPU, use this command instead:
146
+
147
+ ```bash
148
+ docker run -d -p 3000:8080 -v ollama:/root/.ollama -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:ollama
149
+ ```
150
+
151
+ Both commands facilitate a built-in, hassle-free installation of both Open WebUI and Ollama, ensuring that you can get everything up and running swiftly.
152
+
153
+ After installation, you can access Open WebUI at [http://localhost:3000](http://localhost:3000). Enjoy! 😄
154
+
155
+ ### Other Installation Methods
156
+
157
+ We offer various installation alternatives, including non-Docker native installation methods, Docker Compose, Kustomize, and Helm. Visit our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/) or join our [Discord community](https://discord.gg/5rJgQTnV4s) for comprehensive guidance.
158
+
159
+ ### Troubleshooting
160
+
161
+ Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s).
162
+
163
+ #### Open WebUI: Server Connection Error
164
+
165
+ If you're experiencing connection issues, it’s often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`.
166
+
167
+ **Example Docker Command**:
168
+
169
+ ```bash
170
+ docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
171
+ ```
172
+
173
+ ### Keeping Your Docker Installation Up-to-Date
174
+
175
+ In case you want to update your local Docker installation to the latest version, you can do it with [Watchtower](https://containrrr.dev/watchtower/):
176
+
177
+ ```bash
178
+ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower --run-once open-webui
179
+ ```
180
+
181
+ In the last part of the command, replace `open-webui` with your container name if it is different.
182
+
183
+ Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/tutorials/migration/).
184
+
185
+ ### Using the Dev Branch 🌙
186
+
187
+ > [!WARNING]
188
+ > The `:dev` branch contains the latest unstable features and changes. Use it at your own risk as it may have bugs or incomplete features.
189
+
190
+ If you want to try out the latest bleeding-edge features and are okay with occasional instability, you can use the `:dev` tag like this:
191
+
192
+ ```bash
193
+ docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --add-host=host.docker.internal:host-gateway --restart always ghcr.io/open-webui/open-webui:dev
194
+ ```
195
+
196
+ ## What's Next? 🌟
197
+
198
+ Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/).
199
+
200
+ ## License 📜
201
+
202
+ This project is licensed under the [MIT License](LICENSE) - see the [LICENSE](LICENSE) file for details. 📄
203
+
204
+ ## Support 💬
205
+
206
+ If you have any questions, suggestions, or need assistance, please open an issue or join our
207
+ [Open WebUI Discord community](https://discord.gg/5rJgQTnV4s) to connect with us! 🤝
208
+
209
+ ## Star History
210
+
211
+ <a href="https://star-history.com/#open-webui/open-webui&Date">
212
+ <picture>
213
+ <source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date&theme=dark" />
214
+ <source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date" />
215
+ <img alt="Star History Chart" src="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date" />
216
+ </picture>
217
+ </a>
218
+
219
+ ---
220
+
221
+ Created by [Timothy Jaeryang Baek](https://github.com/tjbck) - Let's make Open WebUI even more amazing together! 💪
TROUBLESHOOTING.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open WebUI Troubleshooting Guide
2
+
3
+ ## Understanding the Open WebUI Architecture
4
+
5
+ The Open WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues.
6
+
7
+ - **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_BASE_URL` environment variable. Therefore, a request made to `/ollama` in the WebUI is effectively the same as making a request to `OLLAMA_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_BASE_URL/api/tags` in the backend.
8
+
9
+ - **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer.
10
+
11
+ ## Open WebUI: Server Connection Error
12
+
13
+ If you're experiencing connection issues, it’s often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`.
14
+
15
+ **Example Docker Command**:
16
+
17
+ ```bash
18
+ docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
19
+ ```
20
+
21
+ ### Error on Slow Responses for Ollama
22
+
23
+ Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds.
24
+
25
+ ### General Connection Errors
26
+
27
+ **Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.com/) for the latest updates.
28
+
29
+ **Troubleshooting Steps**:
30
+
31
+ 1. **Verify Ollama URL Format**:
32
+ - When running the Web UI container, ensure the `OLLAMA_BASE_URL` is correctly set. (e.g., `http://192.168.1.1:11434` for different host setups).
33
+ - In the Open WebUI, navigate to "Settings" > "General".
34
+ - Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]` (e.g., `http://localhost:11434`).
35
+
36
+ By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord.
backend/.dockerignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .env
3
+ _old
4
+ uploads
5
+ .ipynb_checkpoints
6
+ *.db
7
+ _test
8
+ !/data
9
+ /data/*
10
+ !/data/litellm
11
+ /data/litellm/*
12
+ !data/litellm/config.yaml
13
+
14
+ !data/config.json
backend/.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .env
3
+ _old
4
+ uploads
5
+ .ipynb_checkpoints
6
+ *.db
7
+ _test
8
+ Pipfile
9
+ !/data
10
+ /data/*
11
+ /open_webui/data/*
12
+ .webui_secret_key
backend/dev.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ PORT="${PORT:-8080}"
2
+ uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload
backend/open_webui/__init__.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import random
4
+ from pathlib import Path
5
+
6
+ import typer
7
+ import uvicorn
8
+
9
+ app = typer.Typer()
10
+
11
+ KEY_FILE = Path.cwd() / ".webui_secret_key"
12
+
13
+
14
+ @app.command()
15
+ def serve(
16
+ host: str = "0.0.0.0",
17
+ port: int = 8080,
18
+ ):
19
+ os.environ["FROM_INIT_PY"] = "true"
20
+ if os.getenv("WEBUI_SECRET_KEY") is None:
21
+ typer.echo(
22
+ "Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
23
+ )
24
+ if not KEY_FILE.exists():
25
+ typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}")
26
+ KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12)))
27
+ typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}")
28
+ os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text()
29
+
30
+ if os.getenv("USE_CUDA_DOCKER", "false") == "true":
31
+ typer.echo(
32
+ "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
33
+ )
34
+ LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":")
35
+ os.environ["LD_LIBRARY_PATH"] = ":".join(
36
+ LD_LIBRARY_PATH
37
+ + [
38
+ "/usr/local/lib/python3.11/site-packages/torch/lib",
39
+ "/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
40
+ ]
41
+ )
42
+ try:
43
+ import torch
44
+
45
+ assert torch.cuda.is_available(), "CUDA not available"
46
+ typer.echo("CUDA seems to be working")
47
+ except Exception as e:
48
+ typer.echo(
49
+ "Error when testing CUDA but USE_CUDA_DOCKER is true. "
50
+ "Resetting USE_CUDA_DOCKER to false and removing "
51
+ f"LD_LIBRARY_PATH modifications: {e}"
52
+ )
53
+ os.environ["USE_CUDA_DOCKER"] = "false"
54
+ os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
55
+
56
+ import open_webui.main # we need set environment variables before importing main
57
+
58
+ uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*")
59
+
60
+
61
+ @app.command()
62
+ def dev(
63
+ host: str = "0.0.0.0",
64
+ port: int = 8080,
65
+ reload: bool = True,
66
+ ):
67
+ uvicorn.run(
68
+ "open_webui.main:app",
69
+ host=host,
70
+ port=port,
71
+ reload=reload,
72
+ forwarded_allow_ips="*",
73
+ )
74
+
75
+
76
+ if __name__ == "__main__":
77
+ app()
backend/open_webui/alembic.ini ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A generic, single database configuration.
2
+
3
+ [alembic]
4
+ # path to migration scripts
5
+ script_location = migrations
6
+
7
+ # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
8
+ # Uncomment the line below if you want the files to be prepended with date and time
9
+ # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
10
+
11
+ # sys.path path, will be prepended to sys.path if present.
12
+ # defaults to the current working directory.
13
+ prepend_sys_path = .
14
+
15
+ # timezone to use when rendering the date within the migration file
16
+ # as well as the filename.
17
+ # If specified, requires the python>=3.9 or backports.zoneinfo library.
18
+ # Any required deps can installed by adding `alembic[tz]` to the pip requirements
19
+ # string value is passed to ZoneInfo()
20
+ # leave blank for localtime
21
+ # timezone =
22
+
23
+ # max length of characters to apply to the
24
+ # "slug" field
25
+ # truncate_slug_length = 40
26
+
27
+ # set to 'true' to run the environment during
28
+ # the 'revision' command, regardless of autogenerate
29
+ # revision_environment = false
30
+
31
+ # set to 'true' to allow .pyc and .pyo files without
32
+ # a source .py file to be detected as revisions in the
33
+ # versions/ directory
34
+ # sourceless = false
35
+
36
+ # version location specification; This defaults
37
+ # to migrations/versions. When using multiple version
38
+ # directories, initial revisions must be specified with --version-path.
39
+ # The path separator used here should be the separator specified by "version_path_separator" below.
40
+ # version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
41
+
42
+ # version path separator; As mentioned above, this is the character used to split
43
+ # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
44
+ # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
45
+ # Valid values for version_path_separator are:
46
+ #
47
+ # version_path_separator = :
48
+ # version_path_separator = ;
49
+ # version_path_separator = space
50
+ version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
51
+
52
+ # set to 'true' to search source files recursively
53
+ # in each "version_locations" directory
54
+ # new in Alembic version 1.10
55
+ # recursive_version_locations = false
56
+
57
+ # the output encoding used when revision files
58
+ # are written from script.py.mako
59
+ # output_encoding = utf-8
60
+
61
+ # sqlalchemy.url = REPLACE_WITH_DATABASE_URL
62
+
63
+
64
+ [post_write_hooks]
65
+ # post_write_hooks defines scripts or Python functions that are run
66
+ # on newly generated revision scripts. See the documentation for further
67
+ # detail and examples
68
+
69
+ # format using "black" - use the console_scripts runner, against the "black" entrypoint
70
+ # hooks = black
71
+ # black.type = console_scripts
72
+ # black.entrypoint = black
73
+ # black.options = -l 79 REVISION_SCRIPT_FILENAME
74
+
75
+ # lint with attempts to fix using "ruff" - use the exec runner, execute a binary
76
+ # hooks = ruff
77
+ # ruff.type = exec
78
+ # ruff.executable = %(here)s/.venv/bin/ruff
79
+ # ruff.options = --fix REVISION_SCRIPT_FILENAME
80
+
81
+ # Logging configuration
82
+ [loggers]
83
+ keys = root,sqlalchemy,alembic
84
+
85
+ [handlers]
86
+ keys = console
87
+
88
+ [formatters]
89
+ keys = generic
90
+
91
+ [logger_root]
92
+ level = WARN
93
+ handlers = console
94
+ qualname =
95
+
96
+ [logger_sqlalchemy]
97
+ level = WARN
98
+ handlers =
99
+ qualname = sqlalchemy.engine
100
+
101
+ [logger_alembic]
102
+ level = INFO
103
+ handlers =
104
+ qualname = alembic
105
+
106
+ [handler_console]
107
+ class = StreamHandler
108
+ args = (sys.stderr,)
109
+ level = NOTSET
110
+ formatter = generic
111
+
112
+ [formatter_generic]
113
+ format = %(levelname)-5.5s [%(name)s] %(message)s
114
+ datefmt = %H:%M:%S
backend/open_webui/apps/audio/main.py ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import logging
4
+ import os
5
+ import uuid
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+ from pydub import AudioSegment
9
+ from pydub.silence import split_on_silence
10
+
11
+ import aiohttp
12
+ import aiofiles
13
+ import requests
14
+ from open_webui.config import (
15
+ AUDIO_STT_ENGINE,
16
+ AUDIO_STT_MODEL,
17
+ AUDIO_STT_OPENAI_API_BASE_URL,
18
+ AUDIO_STT_OPENAI_API_KEY,
19
+ AUDIO_TTS_API_KEY,
20
+ AUDIO_TTS_ENGINE,
21
+ AUDIO_TTS_MODEL,
22
+ AUDIO_TTS_OPENAI_API_BASE_URL,
23
+ AUDIO_TTS_OPENAI_API_KEY,
24
+ AUDIO_TTS_SPLIT_ON,
25
+ AUDIO_TTS_VOICE,
26
+ AUDIO_TTS_AZURE_SPEECH_REGION,
27
+ AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
28
+ CACHE_DIR,
29
+ CORS_ALLOW_ORIGIN,
30
+ WHISPER_MODEL,
31
+ WHISPER_MODEL_AUTO_UPDATE,
32
+ WHISPER_MODEL_DIR,
33
+ AppConfig,
34
+ )
35
+
36
+ from open_webui.constants import ERROR_MESSAGES
37
+ from open_webui.env import (
38
+ ENV,
39
+ SRC_LOG_LEVELS,
40
+ DEVICE_TYPE,
41
+ ENABLE_FORWARD_USER_INFO_HEADERS,
42
+ )
43
+
44
+ from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
45
+ from fastapi.middleware.cors import CORSMiddleware
46
+ from fastapi.responses import FileResponse
47
+ from pydantic import BaseModel
48
+ from open_webui.utils.utils import get_admin_user, get_verified_user
49
+
50
+ # Constants
51
+ MAX_FILE_SIZE_MB = 25
52
+ MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
53
+
54
+
55
+ log = logging.getLogger(__name__)
56
+ log.setLevel(SRC_LOG_LEVELS["AUDIO"])
57
+
58
+ app = FastAPI(
59
+ docs_url="/docs" if ENV == "dev" else None,
60
+ openapi_url="/openapi.json" if ENV == "dev" else None,
61
+ redoc_url=None,
62
+ )
63
+
64
+ app.add_middleware(
65
+ CORSMiddleware,
66
+ allow_origins=CORS_ALLOW_ORIGIN,
67
+ allow_credentials=True,
68
+ allow_methods=["*"],
69
+ allow_headers=["*"],
70
+ )
71
+
72
+ app.state.config = AppConfig()
73
+
74
+ app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
75
+ app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
76
+ app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
77
+ app.state.config.STT_MODEL = AUDIO_STT_MODEL
78
+
79
+ app.state.config.WHISPER_MODEL = WHISPER_MODEL
80
+ app.state.faster_whisper_model = None
81
+
82
+ app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
83
+ app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
84
+ app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
85
+ app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
86
+ app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
87
+ app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
88
+ app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
89
+
90
+
91
+ app.state.speech_synthesiser = None
92
+ app.state.speech_speaker_embeddings_dataset = None
93
+
94
+ app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
95
+ app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
96
+
97
+ # setting device type for whisper model
98
+ whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
99
+ log.info(f"whisper_device_type: {whisper_device_type}")
100
+
101
+ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
102
+ SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
103
+
104
+
105
+ def set_faster_whisper_model(model: str, auto_update: bool = False):
106
+ if model and app.state.config.STT_ENGINE == "":
107
+ from faster_whisper import WhisperModel
108
+
109
+ faster_whisper_kwargs = {
110
+ "model_size_or_path": model,
111
+ "device": whisper_device_type,
112
+ "compute_type": "int8",
113
+ "download_root": WHISPER_MODEL_DIR,
114
+ "local_files_only": not auto_update,
115
+ }
116
+
117
+ try:
118
+ app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
119
+ except Exception:
120
+ log.warning(
121
+ "WhisperModel initialization failed, attempting download with local_files_only=False"
122
+ )
123
+ faster_whisper_kwargs["local_files_only"] = False
124
+ app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
125
+
126
+ else:
127
+ app.state.faster_whisper_model = None
128
+
129
+
130
+ class TTSConfigForm(BaseModel):
131
+ OPENAI_API_BASE_URL: str
132
+ OPENAI_API_KEY: str
133
+ API_KEY: str
134
+ ENGINE: str
135
+ MODEL: str
136
+ VOICE: str
137
+ SPLIT_ON: str
138
+ AZURE_SPEECH_REGION: str
139
+ AZURE_SPEECH_OUTPUT_FORMAT: str
140
+
141
+
142
+ class STTConfigForm(BaseModel):
143
+ OPENAI_API_BASE_URL: str
144
+ OPENAI_API_KEY: str
145
+ ENGINE: str
146
+ MODEL: str
147
+ WHISPER_MODEL: str
148
+
149
+
150
+ class AudioConfigUpdateForm(BaseModel):
151
+ tts: TTSConfigForm
152
+ stt: STTConfigForm
153
+
154
+
155
+ from pydub import AudioSegment
156
+ from pydub.utils import mediainfo
157
+
158
+
159
+ def is_mp4_audio(file_path):
160
+ """Check if the given file is an MP4 audio file."""
161
+ if not os.path.isfile(file_path):
162
+ print(f"File not found: {file_path}")
163
+ return False
164
+
165
+ info = mediainfo(file_path)
166
+ if (
167
+ info.get("codec_name") == "aac"
168
+ and info.get("codec_type") == "audio"
169
+ and info.get("codec_tag_string") == "mp4a"
170
+ ):
171
+ return True
172
+ return False
173
+
174
+
175
+ def convert_mp4_to_wav(file_path, output_path):
176
+ """Convert MP4 audio file to WAV format."""
177
+ audio = AudioSegment.from_file(file_path, format="mp4")
178
+ audio.export(output_path, format="wav")
179
+ print(f"Converted {file_path} to {output_path}")
180
+
181
+
182
+ @app.get("/config")
183
+ async def get_audio_config(user=Depends(get_admin_user)):
184
+ return {
185
+ "tts": {
186
+ "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
187
+ "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
188
+ "API_KEY": app.state.config.TTS_API_KEY,
189
+ "ENGINE": app.state.config.TTS_ENGINE,
190
+ "MODEL": app.state.config.TTS_MODEL,
191
+ "VOICE": app.state.config.TTS_VOICE,
192
+ "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
193
+ "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
194
+ "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
195
+ },
196
+ "stt": {
197
+ "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
198
+ "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
199
+ "ENGINE": app.state.config.STT_ENGINE,
200
+ "MODEL": app.state.config.STT_MODEL,
201
+ "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
202
+ },
203
+ }
204
+
205
+
206
+ @app.post("/config/update")
207
+ async def update_audio_config(
208
+ form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
209
+ ):
210
+ app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
211
+ app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
212
+ app.state.config.TTS_API_KEY = form_data.tts.API_KEY
213
+ app.state.config.TTS_ENGINE = form_data.tts.ENGINE
214
+ app.state.config.TTS_MODEL = form_data.tts.MODEL
215
+ app.state.config.TTS_VOICE = form_data.tts.VOICE
216
+ app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
217
+ app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
218
+ app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
219
+ form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
220
+ )
221
+
222
+ app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
223
+ app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
224
+ app.state.config.STT_ENGINE = form_data.stt.ENGINE
225
+ app.state.config.STT_MODEL = form_data.stt.MODEL
226
+ app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
227
+ set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
228
+
229
+ return {
230
+ "tts": {
231
+ "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
232
+ "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
233
+ "API_KEY": app.state.config.TTS_API_KEY,
234
+ "ENGINE": app.state.config.TTS_ENGINE,
235
+ "MODEL": app.state.config.TTS_MODEL,
236
+ "VOICE": app.state.config.TTS_VOICE,
237
+ "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
238
+ "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
239
+ "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
240
+ },
241
+ "stt": {
242
+ "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
243
+ "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
244
+ "ENGINE": app.state.config.STT_ENGINE,
245
+ "MODEL": app.state.config.STT_MODEL,
246
+ "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
247
+ },
248
+ }
249
+
250
+
251
+ def load_speech_pipeline():
252
+ from transformers import pipeline
253
+ from datasets import load_dataset
254
+
255
+ if app.state.speech_synthesiser is None:
256
+ app.state.speech_synthesiser = pipeline(
257
+ "text-to-speech", "microsoft/speecht5_tts"
258
+ )
259
+
260
+ if app.state.speech_speaker_embeddings_dataset is None:
261
+ app.state.speech_speaker_embeddings_dataset = load_dataset(
262
+ "Matthijs/cmu-arctic-xvectors", split="validation"
263
+ )
264
+
265
+
266
+ @app.post("/speech")
267
+ async def speech(request: Request, user=Depends(get_verified_user)):
268
+ body = await request.body()
269
+ name = hashlib.sha256(body).hexdigest()
270
+
271
+ file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
272
+ file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
273
+
274
+ # Check if the file already exists in the cache
275
+ if file_path.is_file():
276
+ return FileResponse(file_path)
277
+
278
+ if app.state.config.TTS_ENGINE == "openai":
279
+ headers = {}
280
+ headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
281
+ headers["Content-Type"] = "application/json"
282
+
283
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
284
+ headers["X-OpenWebUI-User-Name"] = user.name
285
+ headers["X-OpenWebUI-User-Id"] = user.id
286
+ headers["X-OpenWebUI-User-Email"] = user.email
287
+ headers["X-OpenWebUI-User-Role"] = user.role
288
+
289
+ try:
290
+ body = body.decode("utf-8")
291
+ body = json.loads(body)
292
+ body["model"] = app.state.config.TTS_MODEL
293
+ body = json.dumps(body).encode("utf-8")
294
+ except Exception:
295
+ pass
296
+
297
+ try:
298
+ async with aiohttp.ClientSession() as session:
299
+ async with session.post(
300
+ url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
301
+ data=body,
302
+ headers=headers,
303
+ ) as r:
304
+ r.raise_for_status()
305
+ async with aiofiles.open(file_path, "wb") as f:
306
+ await f.write(await r.read())
307
+
308
+ async with aiofiles.open(file_body_path, "w") as f:
309
+ await f.write(json.dumps(json.loads(body.decode("utf-8"))))
310
+
311
+ return FileResponse(file_path)
312
+
313
+ except Exception as e:
314
+ log.exception(e)
315
+ error_detail = "Open WebUI: Server Connection Error"
316
+ try:
317
+ if r.status != 200:
318
+ res = await r.json()
319
+ if "error" in res:
320
+ error_detail = f"External: {res['error']['message']}"
321
+ except Exception:
322
+ error_detail = f"External: {e}"
323
+
324
+ raise HTTPException(
325
+ status_code=getattr(r, "status", 500),
326
+ detail=error_detail,
327
+ )
328
+
329
+ elif app.state.config.TTS_ENGINE == "elevenlabs":
330
+ try:
331
+ payload = json.loads(body.decode("utf-8"))
332
+ except Exception as e:
333
+ log.exception(e)
334
+ raise HTTPException(status_code=400, detail="Invalid JSON payload")
335
+
336
+ voice_id = payload.get("voice", "")
337
+ if voice_id not in get_available_voices():
338
+ raise HTTPException(
339
+ status_code=400,
340
+ detail="Invalid voice id",
341
+ )
342
+
343
+ url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
344
+ headers = {
345
+ "Accept": "audio/mpeg",
346
+ "Content-Type": "application/json",
347
+ "xi-api-key": app.state.config.TTS_API_KEY,
348
+ }
349
+ data = {
350
+ "text": payload["input"],
351
+ "model_id": app.state.config.TTS_MODEL,
352
+ "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
353
+ }
354
+
355
+ try:
356
+ async with aiohttp.ClientSession() as session:
357
+ async with session.post(url, json=data, headers=headers) as r:
358
+ r.raise_for_status()
359
+ async with aiofiles.open(file_path, "wb") as f:
360
+ await f.write(await r.read())
361
+
362
+ async with aiofiles.open(file_body_path, "w") as f:
363
+ await f.write(json.dumps(json.loads(body.decode("utf-8"))))
364
+
365
+ return FileResponse(file_path)
366
+
367
+ except Exception as e:
368
+ log.exception(e)
369
+ error_detail = "Open WebUI: Server Connection Error"
370
+ try:
371
+ if r.status != 200:
372
+ res = await r.json()
373
+ if "error" in res:
374
+ error_detail = f"External: {res['error']['message']}"
375
+ except Exception:
376
+ error_detail = f"External: {e}"
377
+
378
+ raise HTTPException(
379
+ status_code=getattr(r, "status", 500),
380
+ detail=error_detail,
381
+ )
382
+
383
+ elif app.state.config.TTS_ENGINE == "azure":
384
+ try:
385
+ payload = json.loads(body.decode("utf-8"))
386
+ except Exception as e:
387
+ log.exception(e)
388
+ raise HTTPException(status_code=400, detail="Invalid JSON payload")
389
+
390
+ region = app.state.config.TTS_AZURE_SPEECH_REGION
391
+ language = app.state.config.TTS_VOICE
392
+ locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
393
+ output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
394
+ url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
395
+
396
+ headers = {
397
+ "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
398
+ "Content-Type": "application/ssml+xml",
399
+ "X-Microsoft-OutputFormat": output_format,
400
+ }
401
+
402
+ data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
403
+ <voice name="{language}">{payload["input"]}</voice>
404
+ </speak>"""
405
+
406
+ try:
407
+ async with aiohttp.ClientSession() as session:
408
+ async with session.post(url, headers=headers, data=data) as response:
409
+ if response.status == 200:
410
+ async with aiofiles.open(file_path, "wb") as f:
411
+ await f.write(await response.read())
412
+ return FileResponse(file_path)
413
+ else:
414
+ error_msg = f"Error synthesizing speech - {response.reason}"
415
+ log.error(error_msg)
416
+ raise HTTPException(status_code=500, detail=error_msg)
417
+ except Exception as e:
418
+ log.exception(e)
419
+ raise HTTPException(status_code=500, detail=str(e))
420
+ elif app.state.config.TTS_ENGINE == "transformers":
421
+ payload = None
422
+ try:
423
+ payload = json.loads(body.decode("utf-8"))
424
+ except Exception as e:
425
+ log.exception(e)
426
+ raise HTTPException(status_code=400, detail="Invalid JSON payload")
427
+
428
+ import torch
429
+ import soundfile as sf
430
+
431
+ load_speech_pipeline()
432
+
433
+ embeddings_dataset = app.state.speech_speaker_embeddings_dataset
434
+
435
+ speaker_index = 6799
436
+ try:
437
+ speaker_index = embeddings_dataset["filename"].index(
438
+ app.state.config.TTS_MODEL
439
+ )
440
+ except Exception:
441
+ pass
442
+
443
+ speaker_embedding = torch.tensor(
444
+ embeddings_dataset[speaker_index]["xvector"]
445
+ ).unsqueeze(0)
446
+
447
+ speech = app.state.speech_synthesiser(
448
+ payload["input"],
449
+ forward_params={"speaker_embeddings": speaker_embedding},
450
+ )
451
+
452
+ sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
453
+ with open(file_body_path, "w") as f:
454
+ json.dump(json.loads(body.decode("utf-8")), f)
455
+
456
+ return FileResponse(file_path)
457
+
458
+
459
+ def transcribe(file_path):
460
+ print("transcribe", file_path)
461
+ filename = os.path.basename(file_path)
462
+ file_dir = os.path.dirname(file_path)
463
+ id = filename.split(".")[0]
464
+
465
+ if app.state.config.STT_ENGINE == "":
466
+ if app.state.faster_whisper_model is None:
467
+ set_faster_whisper_model(app.state.config.WHISPER_MODEL)
468
+
469
+ model = app.state.faster_whisper_model
470
+ segments, info = model.transcribe(file_path, beam_size=5)
471
+ log.info(
472
+ "Detected language '%s' with probability %f"
473
+ % (info.language, info.language_probability)
474
+ )
475
+
476
+ transcript = "".join([segment.text for segment in list(segments)])
477
+ data = {"text": transcript.strip()}
478
+
479
+ # save the transcript to a json file
480
+ transcript_file = f"{file_dir}/{id}.json"
481
+ with open(transcript_file, "w") as f:
482
+ json.dump(data, f)
483
+
484
+ log.debug(data)
485
+ return data
486
+ elif app.state.config.STT_ENGINE == "openai":
487
+ if is_mp4_audio(file_path):
488
+ print("is_mp4_audio")
489
+ os.rename(file_path, file_path.replace(".wav", ".mp4"))
490
+ # Convert MP4 audio file to WAV format
491
+ convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
492
+
493
+ headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
494
+
495
+ files = {"file": (filename, open(file_path, "rb"))}
496
+ data = {"model": app.state.config.STT_MODEL}
497
+
498
+ log.debug(files, data)
499
+
500
+ r = None
501
+ try:
502
+ r = requests.post(
503
+ url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
504
+ headers=headers,
505
+ files=files,
506
+ data=data,
507
+ )
508
+
509
+ r.raise_for_status()
510
+
511
+ data = r.json()
512
+
513
+ # save the transcript to a json file
514
+ transcript_file = f"{file_dir}/{id}.json"
515
+ with open(transcript_file, "w") as f:
516
+ json.dump(data, f)
517
+
518
+ print(data)
519
+ return data
520
+ except Exception as e:
521
+ log.exception(e)
522
+ error_detail = "Open WebUI: Server Connection Error"
523
+ if r is not None:
524
+ try:
525
+ res = r.json()
526
+ if "error" in res:
527
+ error_detail = f"External: {res['error']['message']}"
528
+ except Exception:
529
+ error_detail = f"External: {e}"
530
+
531
+ raise Exception(error_detail)
532
+
533
+
534
+ @app.post("/transcriptions")
535
+ def transcription(
536
+ file: UploadFile = File(...),
537
+ user=Depends(get_verified_user),
538
+ ):
539
+ log.info(f"file.content_type: {file.content_type}")
540
+
541
+ if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
542
+ raise HTTPException(
543
+ status_code=status.HTTP_400_BAD_REQUEST,
544
+ detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
545
+ )
546
+
547
+ try:
548
+ ext = file.filename.split(".")[-1]
549
+ id = uuid.uuid4()
550
+
551
+ filename = f"{id}.{ext}"
552
+ contents = file.file.read()
553
+
554
+ file_dir = f"{CACHE_DIR}/audio/transcriptions"
555
+ os.makedirs(file_dir, exist_ok=True)
556
+ file_path = f"{file_dir}/{filename}"
557
+
558
+ with open(file_path, "wb") as f:
559
+ f.write(contents)
560
+
561
+ try:
562
+ if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB
563
+ log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
564
+ audio = AudioSegment.from_file(file_path)
565
+ audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
566
+ compressed_path = f"{file_dir}/{id}_compressed.opus"
567
+ audio.export(compressed_path, format="opus", bitrate="32k")
568
+ log.debug(f"Compressed audio to {compressed_path}")
569
+ file_path = compressed_path
570
+
571
+ if (
572
+ os.path.getsize(file_path) > MAX_FILE_SIZE
573
+ ): # Still larger than 25MB after compression
574
+ log.debug(
575
+ f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
576
+ )
577
+ raise HTTPException(
578
+ status_code=status.HTTP_400_BAD_REQUEST,
579
+ detail=ERROR_MESSAGES.FILE_TOO_LARGE(
580
+ size=f"{MAX_FILE_SIZE_MB}MB"
581
+ ),
582
+ )
583
+
584
+ data = transcribe(file_path)
585
+ else:
586
+ data = transcribe(file_path)
587
+
588
+ file_path = file_path.split("/")[-1]
589
+ return {**data, "filename": file_path}
590
+ except Exception as e:
591
+ log.exception(e)
592
+ raise HTTPException(
593
+ status_code=status.HTTP_400_BAD_REQUEST,
594
+ detail=ERROR_MESSAGES.DEFAULT(e),
595
+ )
596
+
597
+ except Exception as e:
598
+ log.exception(e)
599
+
600
+ raise HTTPException(
601
+ status_code=status.HTTP_400_BAD_REQUEST,
602
+ detail=ERROR_MESSAGES.DEFAULT(e),
603
+ )
604
+
605
+
606
+ def get_available_models() -> list[dict]:
607
+ if app.state.config.TTS_ENGINE == "openai":
608
+ return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
609
+ elif app.state.config.TTS_ENGINE == "elevenlabs":
610
+ headers = {
611
+ "xi-api-key": app.state.config.TTS_API_KEY,
612
+ "Content-Type": "application/json",
613
+ }
614
+
615
+ try:
616
+ response = requests.get(
617
+ "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
618
+ )
619
+ response.raise_for_status()
620
+ models = response.json()
621
+ return [
622
+ {"name": model["name"], "id": model["model_id"]} for model in models
623
+ ]
624
+ except requests.RequestException as e:
625
+ log.error(f"Error fetching voices: {str(e)}")
626
+ return []
627
+
628
+
629
+ @app.get("/models")
630
+ async def get_models(user=Depends(get_verified_user)):
631
+ return {"models": get_available_models()}
632
+
633
+
634
+ def get_available_voices() -> dict:
635
+ """Returns {voice_id: voice_name} dict"""
636
+ ret = {}
637
+ if app.state.config.TTS_ENGINE == "openai":
638
+ ret = {
639
+ "alloy": "alloy",
640
+ "echo": "echo",
641
+ "fable": "fable",
642
+ "onyx": "onyx",
643
+ "nova": "nova",
644
+ "shimmer": "shimmer",
645
+ }
646
+ elif app.state.config.TTS_ENGINE == "elevenlabs":
647
+ try:
648
+ ret = get_elevenlabs_voices()
649
+ except Exception:
650
+ # Avoided @lru_cache with exception
651
+ pass
652
+ elif app.state.config.TTS_ENGINE == "azure":
653
+ try:
654
+ region = app.state.config.TTS_AZURE_SPEECH_REGION
655
+ url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
656
+ headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
657
+
658
+ response = requests.get(url, headers=headers)
659
+ response.raise_for_status()
660
+ voices = response.json()
661
+ for voice in voices:
662
+ ret[voice["ShortName"]] = (
663
+ f"{voice['DisplayName']} ({voice['ShortName']})"
664
+ )
665
+ except requests.RequestException as e:
666
+ log.error(f"Error fetching voices: {str(e)}")
667
+
668
+ return ret
669
+
670
+
671
+ @lru_cache
672
+ def get_elevenlabs_voices() -> dict:
673
+ """
674
+ Note, set the following in your .env file to use Elevenlabs:
675
+ AUDIO_TTS_ENGINE=elevenlabs
676
+ AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
677
+ AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
678
+ AUDIO_TTS_MODEL=eleven_multilingual_v2
679
+ """
680
+ headers = {
681
+ "xi-api-key": app.state.config.TTS_API_KEY,
682
+ "Content-Type": "application/json",
683
+ }
684
+ try:
685
+ # TODO: Add retries
686
+ response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
687
+ response.raise_for_status()
688
+ voices_data = response.json()
689
+
690
+ voices = {}
691
+ for voice in voices_data.get("voices", []):
692
+ voices[voice["voice_id"]] = voice["name"]
693
+ except requests.RequestException as e:
694
+ # Avoid @lru_cache with exception
695
+ log.error(f"Error fetching voices: {str(e)}")
696
+ raise RuntimeError(f"Error fetching voices: {str(e)}")
697
+
698
+ return voices
699
+
700
+
701
+ @app.get("/voices")
702
+ async def get_voices(user=Depends(get_verified_user)):
703
+ return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}
backend/open_webui/apps/images/main.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import json
4
+ import logging
5
+ import mimetypes
6
+ import re
7
+ import uuid
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ import requests
12
+ from open_webui.apps.images.utils.comfyui import (
13
+ ComfyUIGenerateImageForm,
14
+ ComfyUIWorkflow,
15
+ comfyui_generate_image,
16
+ )
17
+ from open_webui.config import (
18
+ AUTOMATIC1111_API_AUTH,
19
+ AUTOMATIC1111_BASE_URL,
20
+ AUTOMATIC1111_CFG_SCALE,
21
+ AUTOMATIC1111_SAMPLER,
22
+ AUTOMATIC1111_SCHEDULER,
23
+ CACHE_DIR,
24
+ COMFYUI_BASE_URL,
25
+ COMFYUI_WORKFLOW,
26
+ COMFYUI_WORKFLOW_NODES,
27
+ CORS_ALLOW_ORIGIN,
28
+ ENABLE_IMAGE_GENERATION,
29
+ IMAGE_GENERATION_ENGINE,
30
+ IMAGE_GENERATION_MODEL,
31
+ IMAGE_SIZE,
32
+ IMAGE_STEPS,
33
+ IMAGES_OPENAI_API_BASE_URL,
34
+ IMAGES_OPENAI_API_KEY,
35
+ AppConfig,
36
+ )
37
+ from open_webui.constants import ERROR_MESSAGES
38
+ from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
39
+
40
+ from fastapi import Depends, FastAPI, HTTPException, Request
41
+ from fastapi.middleware.cors import CORSMiddleware
42
+ from pydantic import BaseModel
43
+ from open_webui.utils.utils import get_admin_user, get_verified_user
44
+
45
+ log = logging.getLogger(__name__)
46
+ log.setLevel(SRC_LOG_LEVELS["IMAGES"])
47
+
48
+ IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
49
+ IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
50
+
51
+ app = FastAPI(
52
+ docs_url="/docs" if ENV == "dev" else None,
53
+ openapi_url="/openapi.json" if ENV == "dev" else None,
54
+ redoc_url=None,
55
+ )
56
+
57
+ app.add_middleware(
58
+ CORSMiddleware,
59
+ allow_origins=CORS_ALLOW_ORIGIN,
60
+ allow_credentials=True,
61
+ allow_methods=["*"],
62
+ allow_headers=["*"],
63
+ )
64
+
65
+ app.state.config = AppConfig()
66
+
67
+ app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
68
+ app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
69
+
70
+ app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
71
+ app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
72
+
73
+ app.state.config.MODEL = IMAGE_GENERATION_MODEL
74
+
75
+ app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
76
+ app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
77
+ app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
78
+ app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
79
+ app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
80
+ app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
81
+ app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
82
+ app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
83
+
84
+ app.state.config.IMAGE_SIZE = IMAGE_SIZE
85
+ app.state.config.IMAGE_STEPS = IMAGE_STEPS
86
+
87
+
88
+ @app.get("/config")
89
+ async def get_config(request: Request, user=Depends(get_admin_user)):
90
+ return {
91
+ "enabled": app.state.config.ENABLED,
92
+ "engine": app.state.config.ENGINE,
93
+ "openai": {
94
+ "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
95
+ "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
96
+ },
97
+ "automatic1111": {
98
+ "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
99
+ "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
100
+ "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
101
+ "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
102
+ "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
103
+ },
104
+ "comfyui": {
105
+ "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
106
+ "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
107
+ "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
108
+ },
109
+ }
110
+
111
+
112
+ class OpenAIConfigForm(BaseModel):
113
+ OPENAI_API_BASE_URL: str
114
+ OPENAI_API_KEY: str
115
+
116
+
117
+ class Automatic1111ConfigForm(BaseModel):
118
+ AUTOMATIC1111_BASE_URL: str
119
+ AUTOMATIC1111_API_AUTH: str
120
+ AUTOMATIC1111_CFG_SCALE: Optional[str]
121
+ AUTOMATIC1111_SAMPLER: Optional[str]
122
+ AUTOMATIC1111_SCHEDULER: Optional[str]
123
+
124
+
125
+ class ComfyUIConfigForm(BaseModel):
126
+ COMFYUI_BASE_URL: str
127
+ COMFYUI_WORKFLOW: str
128
+ COMFYUI_WORKFLOW_NODES: list[dict]
129
+
130
+
131
+ class ConfigForm(BaseModel):
132
+ enabled: bool
133
+ engine: str
134
+ openai: OpenAIConfigForm
135
+ automatic1111: Automatic1111ConfigForm
136
+ comfyui: ComfyUIConfigForm
137
+
138
+
139
+ @app.post("/config/update")
140
+ async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
141
+ app.state.config.ENGINE = form_data.engine
142
+ app.state.config.ENABLED = form_data.enabled
143
+
144
+ app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
145
+ app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
146
+
147
+ app.state.config.AUTOMATIC1111_BASE_URL = (
148
+ form_data.automatic1111.AUTOMATIC1111_BASE_URL
149
+ )
150
+ app.state.config.AUTOMATIC1111_API_AUTH = (
151
+ form_data.automatic1111.AUTOMATIC1111_API_AUTH
152
+ )
153
+
154
+ app.state.config.AUTOMATIC1111_CFG_SCALE = (
155
+ float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
156
+ if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
157
+ else None
158
+ )
159
+ app.state.config.AUTOMATIC1111_SAMPLER = (
160
+ form_data.automatic1111.AUTOMATIC1111_SAMPLER
161
+ if form_data.automatic1111.AUTOMATIC1111_SAMPLER
162
+ else None
163
+ )
164
+ app.state.config.AUTOMATIC1111_SCHEDULER = (
165
+ form_data.automatic1111.AUTOMATIC1111_SCHEDULER
166
+ if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
167
+ else None
168
+ )
169
+
170
+ app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
171
+ app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
172
+ app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
173
+
174
+ return {
175
+ "enabled": app.state.config.ENABLED,
176
+ "engine": app.state.config.ENGINE,
177
+ "openai": {
178
+ "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
179
+ "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
180
+ },
181
+ "automatic1111": {
182
+ "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
183
+ "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
184
+ "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
185
+ "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
186
+ "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
187
+ },
188
+ "comfyui": {
189
+ "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
190
+ "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
191
+ "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
192
+ },
193
+ }
194
+
195
+
196
+ def get_automatic1111_api_auth():
197
+ if app.state.config.AUTOMATIC1111_API_AUTH is None:
198
+ return ""
199
+ else:
200
+ auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
201
+ auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
202
+ auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
203
+ return f"Basic {auth1111_base64_encoded_string}"
204
+
205
+
206
+ @app.get("/config/url/verify")
207
+ async def verify_url(user=Depends(get_admin_user)):
208
+ if app.state.config.ENGINE == "automatic1111":
209
+ try:
210
+ r = requests.get(
211
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
212
+ headers={"authorization": get_automatic1111_api_auth()},
213
+ )
214
+ r.raise_for_status()
215
+ return True
216
+ except Exception:
217
+ app.state.config.ENABLED = False
218
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
219
+ elif app.state.config.ENGINE == "comfyui":
220
+ try:
221
+ r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
222
+ r.raise_for_status()
223
+ return True
224
+ except Exception:
225
+ app.state.config.ENABLED = False
226
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
227
+ else:
228
+ return True
229
+
230
+
231
+ def set_image_model(model: str):
232
+ log.info(f"Setting image model to {model}")
233
+ app.state.config.MODEL = model
234
+ if app.state.config.ENGINE in ["", "automatic1111"]:
235
+ api_auth = get_automatic1111_api_auth()
236
+ r = requests.get(
237
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
238
+ headers={"authorization": api_auth},
239
+ )
240
+ options = r.json()
241
+ if model != options["sd_model_checkpoint"]:
242
+ options["sd_model_checkpoint"] = model
243
+ r = requests.post(
244
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
245
+ json=options,
246
+ headers={"authorization": api_auth},
247
+ )
248
+ return app.state.config.MODEL
249
+
250
+
251
+ def get_image_model():
252
+ if app.state.config.ENGINE == "openai":
253
+ return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
254
+ elif app.state.config.ENGINE == "comfyui":
255
+ return app.state.config.MODEL if app.state.config.MODEL else ""
256
+ elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
257
+ try:
258
+ r = requests.get(
259
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
260
+ headers={"authorization": get_automatic1111_api_auth()},
261
+ )
262
+ options = r.json()
263
+ return options["sd_model_checkpoint"]
264
+ except Exception as e:
265
+ app.state.config.ENABLED = False
266
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
267
+
268
+
269
+ class ImageConfigForm(BaseModel):
270
+ MODEL: str
271
+ IMAGE_SIZE: str
272
+ IMAGE_STEPS: int
273
+
274
+
275
+ @app.get("/image/config")
276
+ async def get_image_config(user=Depends(get_admin_user)):
277
+ return {
278
+ "MODEL": app.state.config.MODEL,
279
+ "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
280
+ "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
281
+ }
282
+
283
+
284
+ @app.post("/image/config/update")
285
+ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
286
+
287
+ set_image_model(form_data.MODEL)
288
+
289
+ pattern = r"^\d+x\d+$"
290
+ if re.match(pattern, form_data.IMAGE_SIZE):
291
+ app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
292
+ else:
293
+ raise HTTPException(
294
+ status_code=400,
295
+ detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
296
+ )
297
+
298
+ if form_data.IMAGE_STEPS >= 0:
299
+ app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
300
+ else:
301
+ raise HTTPException(
302
+ status_code=400,
303
+ detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
304
+ )
305
+
306
+ return {
307
+ "MODEL": app.state.config.MODEL,
308
+ "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
309
+ "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
310
+ }
311
+
312
+
313
+ @app.get("/models")
314
+ def get_models(user=Depends(get_verified_user)):
315
+ try:
316
+ if app.state.config.ENGINE == "openai":
317
+ return [
318
+ {"id": "dall-e-2", "name": "DALL·E 2"},
319
+ {"id": "dall-e-3", "name": "DALL·E 3"},
320
+ ]
321
+ elif app.state.config.ENGINE == "comfyui":
322
+ # TODO - get models from comfyui
323
+ r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
324
+ info = r.json()
325
+
326
+ workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
327
+ model_node_id = None
328
+
329
+ for node in app.state.config.COMFYUI_WORKFLOW_NODES:
330
+ if node["type"] == "model":
331
+ if node["node_ids"]:
332
+ model_node_id = node["node_ids"][0]
333
+ break
334
+
335
+ if model_node_id:
336
+ model_list_key = None
337
+
338
+ print(workflow[model_node_id]["class_type"])
339
+ for key in info[workflow[model_node_id]["class_type"]]["input"][
340
+ "required"
341
+ ]:
342
+ if "_name" in key:
343
+ model_list_key = key
344
+ break
345
+
346
+ if model_list_key:
347
+ return list(
348
+ map(
349
+ lambda model: {"id": model, "name": model},
350
+ info[workflow[model_node_id]["class_type"]]["input"][
351
+ "required"
352
+ ][model_list_key][0],
353
+ )
354
+ )
355
+ else:
356
+ return list(
357
+ map(
358
+ lambda model: {"id": model, "name": model},
359
+ info["CheckpointLoaderSimple"]["input"]["required"][
360
+ "ckpt_name"
361
+ ][0],
362
+ )
363
+ )
364
+ elif (
365
+ app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
366
+ ):
367
+ r = requests.get(
368
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
369
+ headers={"authorization": get_automatic1111_api_auth()},
370
+ )
371
+ models = r.json()
372
+ return list(
373
+ map(
374
+ lambda model: {"id": model["title"], "name": model["model_name"]},
375
+ models,
376
+ )
377
+ )
378
+ except Exception as e:
379
+ app.state.config.ENABLED = False
380
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
381
+
382
+
383
+ class GenerateImageForm(BaseModel):
384
+ model: Optional[str] = None
385
+ prompt: str
386
+ size: Optional[str] = None
387
+ n: int = 1
388
+ negative_prompt: Optional[str] = None
389
+
390
+
391
+ def save_b64_image(b64_str):
392
+ try:
393
+ image_id = str(uuid.uuid4())
394
+
395
+ if "," in b64_str:
396
+ header, encoded = b64_str.split(",", 1)
397
+ mime_type = header.split(";")[0]
398
+
399
+ img_data = base64.b64decode(encoded)
400
+ image_format = mimetypes.guess_extension(mime_type)
401
+
402
+ image_filename = f"{image_id}{image_format}"
403
+ file_path = IMAGE_CACHE_DIR / f"{image_filename}"
404
+ with open(file_path, "wb") as f:
405
+ f.write(img_data)
406
+ return image_filename
407
+ else:
408
+ image_filename = f"{image_id}.png"
409
+ file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
410
+
411
+ img_data = base64.b64decode(b64_str)
412
+
413
+ # Write the image data to a file
414
+ with open(file_path, "wb") as f:
415
+ f.write(img_data)
416
+ return image_filename
417
+
418
+ except Exception as e:
419
+ log.exception(f"Error saving image: {e}")
420
+ return None
421
+
422
+
423
+ def save_url_image(url):
424
+ image_id = str(uuid.uuid4())
425
+ try:
426
+ r = requests.get(url)
427
+ r.raise_for_status()
428
+ if r.headers["content-type"].split("/")[0] == "image":
429
+ mime_type = r.headers["content-type"]
430
+ image_format = mimetypes.guess_extension(mime_type)
431
+
432
+ if not image_format:
433
+ raise ValueError("Could not determine image type from MIME type")
434
+
435
+ image_filename = f"{image_id}{image_format}"
436
+
437
+ file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
438
+ with open(file_path, "wb") as image_file:
439
+ for chunk in r.iter_content(chunk_size=8192):
440
+ image_file.write(chunk)
441
+ return image_filename
442
+ else:
443
+ log.error("Url does not point to an image.")
444
+ return None
445
+
446
+ except Exception as e:
447
+ log.exception(f"Error saving image: {e}")
448
+ return None
449
+
450
+
451
+ @app.post("/generations")
452
+ async def image_generations(
453
+ form_data: GenerateImageForm,
454
+ user=Depends(get_verified_user),
455
+ ):
456
+ width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
457
+
458
+ r = None
459
+ try:
460
+ if app.state.config.ENGINE == "openai":
461
+ headers = {}
462
+ headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
463
+ headers["Content-Type"] = "application/json"
464
+
465
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
466
+ headers["X-OpenWebUI-User-Name"] = user.name
467
+ headers["X-OpenWebUI-User-Id"] = user.id
468
+ headers["X-OpenWebUI-User-Email"] = user.email
469
+ headers["X-OpenWebUI-User-Role"] = user.role
470
+
471
+ data = {
472
+ "model": (
473
+ app.state.config.MODEL
474
+ if app.state.config.MODEL != ""
475
+ else "dall-e-2"
476
+ ),
477
+ "prompt": form_data.prompt,
478
+ "n": form_data.n,
479
+ "size": (
480
+ form_data.size if form_data.size else app.state.config.IMAGE_SIZE
481
+ ),
482
+ "response_format": "b64_json",
483
+ }
484
+
485
+ # Use asyncio.to_thread for the requests.post call
486
+ r = await asyncio.to_thread(
487
+ requests.post,
488
+ url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
489
+ json=data,
490
+ headers=headers,
491
+ )
492
+
493
+ r.raise_for_status()
494
+ res = r.json()
495
+
496
+ images = []
497
+
498
+ for image in res["data"]:
499
+ image_filename = save_b64_image(image["b64_json"])
500
+ images.append({"url": f"/cache/image/generations/{image_filename}"})
501
+ file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
502
+
503
+ with open(file_body_path, "w") as f:
504
+ json.dump(data, f)
505
+
506
+ return images
507
+
508
+ elif app.state.config.ENGINE == "comfyui":
509
+ data = {
510
+ "prompt": form_data.prompt,
511
+ "width": width,
512
+ "height": height,
513
+ "n": form_data.n,
514
+ }
515
+
516
+ if app.state.config.IMAGE_STEPS is not None:
517
+ data["steps"] = app.state.config.IMAGE_STEPS
518
+
519
+ if form_data.negative_prompt is not None:
520
+ data["negative_prompt"] = form_data.negative_prompt
521
+
522
+ form_data = ComfyUIGenerateImageForm(
523
+ **{
524
+ "workflow": ComfyUIWorkflow(
525
+ **{
526
+ "workflow": app.state.config.COMFYUI_WORKFLOW,
527
+ "nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
528
+ }
529
+ ),
530
+ **data,
531
+ }
532
+ )
533
+ res = await comfyui_generate_image(
534
+ app.state.config.MODEL,
535
+ form_data,
536
+ user.id,
537
+ app.state.config.COMFYUI_BASE_URL,
538
+ )
539
+ log.debug(f"res: {res}")
540
+
541
+ images = []
542
+
543
+ for image in res["data"]:
544
+ image_filename = save_url_image(image["url"])
545
+ images.append({"url": f"/cache/image/generations/{image_filename}"})
546
+ file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
547
+
548
+ with open(file_body_path, "w") as f:
549
+ json.dump(form_data.model_dump(exclude_none=True), f)
550
+
551
+ log.debug(f"images: {images}")
552
+ return images
553
+ elif (
554
+ app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
555
+ ):
556
+ if form_data.model:
557
+ set_image_model(form_data.model)
558
+
559
+ data = {
560
+ "prompt": form_data.prompt,
561
+ "batch_size": form_data.n,
562
+ "width": width,
563
+ "height": height,
564
+ }
565
+
566
+ if app.state.config.IMAGE_STEPS is not None:
567
+ data["steps"] = app.state.config.IMAGE_STEPS
568
+
569
+ if form_data.negative_prompt is not None:
570
+ data["negative_prompt"] = form_data.negative_prompt
571
+
572
+ if app.state.config.AUTOMATIC1111_CFG_SCALE:
573
+ data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
574
+
575
+ if app.state.config.AUTOMATIC1111_SAMPLER:
576
+ data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
577
+
578
+ if app.state.config.AUTOMATIC1111_SCHEDULER:
579
+ data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
580
+
581
+ # Use asyncio.to_thread for the requests.post call
582
+ r = await asyncio.to_thread(
583
+ requests.post,
584
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
585
+ json=data,
586
+ headers={"authorization": get_automatic1111_api_auth()},
587
+ )
588
+
589
+ res = r.json()
590
+ log.debug(f"res: {res}")
591
+
592
+ images = []
593
+
594
+ for image in res["images"]:
595
+ image_filename = save_b64_image(image)
596
+ images.append({"url": f"/cache/image/generations/{image_filename}"})
597
+ file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
598
+
599
+ with open(file_body_path, "w") as f:
600
+ json.dump({**data, "info": res["info"]}, f)
601
+
602
+ return images
603
+ except Exception as e:
604
+ error = e
605
+ if r != None:
606
+ data = r.json()
607
+ if "error" in data:
608
+ error = data["error"]["message"]
609
+ raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
backend/open_webui/apps/images/utils/comfyui.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import random
5
+ import urllib.parse
6
+ import urllib.request
7
+ from typing import Optional
8
+
9
+ import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
10
+ from open_webui.env import SRC_LOG_LEVELS
11
+ from pydantic import BaseModel
12
+
13
+ log = logging.getLogger(__name__)
14
+ log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
15
+
16
+ default_headers = {"User-Agent": "Mozilla/5.0"}
17
+
18
+
19
+ def queue_prompt(prompt, client_id, base_url):
20
+ log.info("queue_prompt")
21
+ p = {"prompt": prompt, "client_id": client_id}
22
+ data = json.dumps(p).encode("utf-8")
23
+ log.debug(f"queue_prompt data: {data}")
24
+ try:
25
+ req = urllib.request.Request(
26
+ f"{base_url}/prompt", data=data, headers=default_headers
27
+ )
28
+ response = urllib.request.urlopen(req).read()
29
+ return json.loads(response)
30
+ except Exception as e:
31
+ log.exception(f"Error while queuing prompt: {e}")
32
+ raise e
33
+
34
+
35
+ def get_image(filename, subfolder, folder_type, base_url):
36
+ log.info("get_image")
37
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
38
+ url_values = urllib.parse.urlencode(data)
39
+ req = urllib.request.Request(
40
+ f"{base_url}/view?{url_values}", headers=default_headers
41
+ )
42
+ with urllib.request.urlopen(req) as response:
43
+ return response.read()
44
+
45
+
46
+ def get_image_url(filename, subfolder, folder_type, base_url):
47
+ log.info("get_image")
48
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
49
+ url_values = urllib.parse.urlencode(data)
50
+ return f"{base_url}/view?{url_values}"
51
+
52
+
53
+ def get_history(prompt_id, base_url):
54
+ log.info("get_history")
55
+
56
+ req = urllib.request.Request(
57
+ f"{base_url}/history/{prompt_id}", headers=default_headers
58
+ )
59
+ with urllib.request.urlopen(req) as response:
60
+ return json.loads(response.read())
61
+
62
+
63
+ def get_images(ws, prompt, client_id, base_url):
64
+ prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
65
+ output_images = []
66
+ while True:
67
+ out = ws.recv()
68
+ if isinstance(out, str):
69
+ message = json.loads(out)
70
+ if message["type"] == "executing":
71
+ data = message["data"]
72
+ if data["node"] is None and data["prompt_id"] == prompt_id:
73
+ break # Execution is done
74
+ else:
75
+ continue # previews are binary data
76
+
77
+ history = get_history(prompt_id, base_url)[prompt_id]
78
+ for o in history["outputs"]:
79
+ for node_id in history["outputs"]:
80
+ node_output = history["outputs"][node_id]
81
+ if "images" in node_output:
82
+ for image in node_output["images"]:
83
+ url = get_image_url(
84
+ image["filename"], image["subfolder"], image["type"], base_url
85
+ )
86
+ output_images.append({"url": url})
87
+ return {"data": output_images}
88
+
89
+
90
+ class ComfyUINodeInput(BaseModel):
91
+ type: Optional[str] = None
92
+ node_ids: list[str] = []
93
+ key: Optional[str] = "text"
94
+ value: Optional[str] = None
95
+
96
+
97
+ class ComfyUIWorkflow(BaseModel):
98
+ workflow: str
99
+ nodes: list[ComfyUINodeInput]
100
+
101
+
102
+ class ComfyUIGenerateImageForm(BaseModel):
103
+ workflow: ComfyUIWorkflow
104
+
105
+ prompt: str
106
+ negative_prompt: Optional[str] = None
107
+ width: int
108
+ height: int
109
+ n: int = 1
110
+
111
+ steps: Optional[int] = None
112
+ seed: Optional[int] = None
113
+
114
+
115
+ async def comfyui_generate_image(
116
+ model: str, payload: ComfyUIGenerateImageForm, client_id, base_url
117
+ ):
118
+ ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
119
+ workflow = json.loads(payload.workflow.workflow)
120
+
121
+ for node in payload.workflow.nodes:
122
+ if node.type:
123
+ if node.type == "model":
124
+ for node_id in node.node_ids:
125
+ workflow[node_id]["inputs"][node.key] = model
126
+ elif node.type == "prompt":
127
+ for node_id in node.node_ids:
128
+ workflow[node_id]["inputs"][
129
+ node.key if node.key else "text"
130
+ ] = payload.prompt
131
+ elif node.type == "negative_prompt":
132
+ for node_id in node.node_ids:
133
+ workflow[node_id]["inputs"][
134
+ node.key if node.key else "text"
135
+ ] = payload.negative_prompt
136
+ elif node.type == "width":
137
+ for node_id in node.node_ids:
138
+ workflow[node_id]["inputs"][
139
+ node.key if node.key else "width"
140
+ ] = payload.width
141
+ elif node.type == "height":
142
+ for node_id in node.node_ids:
143
+ workflow[node_id]["inputs"][
144
+ node.key if node.key else "height"
145
+ ] = payload.height
146
+ elif node.type == "n":
147
+ for node_id in node.node_ids:
148
+ workflow[node_id]["inputs"][
149
+ node.key if node.key else "batch_size"
150
+ ] = payload.n
151
+ elif node.type == "steps":
152
+ for node_id in node.node_ids:
153
+ workflow[node_id]["inputs"][
154
+ node.key if node.key else "steps"
155
+ ] = payload.steps
156
+ elif node.type == "seed":
157
+ seed = (
158
+ payload.seed
159
+ if payload.seed
160
+ else random.randint(0, 18446744073709551614)
161
+ )
162
+ for node_id in node.node_ids:
163
+ workflow[node_id]["inputs"][node.key] = seed
164
+ else:
165
+ for node_id in node.node_ids:
166
+ workflow[node_id]["inputs"][node.key] = node.value
167
+
168
+ try:
169
+ ws = websocket.WebSocket()
170
+ ws.connect(f"{ws_url}/ws?clientId={client_id}")
171
+ log.info("WebSocket connection established.")
172
+ except Exception as e:
173
+ log.exception(f"Failed to connect to WebSocket server: {e}")
174
+ return None
175
+
176
+ try:
177
+ log.info("Sending workflow to WebSocket server.")
178
+ log.info(f"Workflow: {workflow}")
179
+ images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url)
180
+ except Exception as e:
181
+ log.exception(f"Error while receiving images: {e}")
182
+ images = None
183
+
184
+ ws.close()
185
+
186
+ return images
backend/open_webui/apps/ollama/main.py ADDED
@@ -0,0 +1,1351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import os
5
+ import random
6
+ import re
7
+ import time
8
+ from typing import Optional, Union
9
+ from urllib.parse import urlparse
10
+
11
+ import aiohttp
12
+ from aiocache import cached
13
+
14
+ import requests
15
+ from open_webui.apps.webui.models.models import Models
16
+ from open_webui.config import (
17
+ CORS_ALLOW_ORIGIN,
18
+ ENABLE_OLLAMA_API,
19
+ OLLAMA_BASE_URLS,
20
+ OLLAMA_API_CONFIGS,
21
+ UPLOAD_DIR,
22
+ AppConfig,
23
+ )
24
+ from open_webui.env import (
25
+ AIOHTTP_CLIENT_TIMEOUT,
26
+ AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
27
+ BYPASS_MODEL_ACCESS_CONTROL,
28
+ )
29
+
30
+
31
+ from open_webui.constants import ERROR_MESSAGES
32
+ from open_webui.env import ENV, SRC_LOG_LEVELS
33
+ from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
34
+ from fastapi.middleware.cors import CORSMiddleware
35
+ from fastapi.responses import StreamingResponse
36
+ from pydantic import BaseModel, ConfigDict
37
+ from starlette.background import BackgroundTask
38
+
39
+
40
+ from open_webui.utils.misc import (
41
+ calculate_sha256,
42
+ )
43
+ from open_webui.utils.payload import (
44
+ apply_model_params_to_body_ollama,
45
+ apply_model_params_to_body_openai,
46
+ apply_model_system_prompt_to_body,
47
+ )
48
+ from open_webui.utils.utils import get_admin_user, get_verified_user
49
+ from open_webui.utils.access_control import has_access
50
+
51
+ log = logging.getLogger(__name__)
52
+ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
53
+
54
+
55
+ app = FastAPI(
56
+ docs_url="/docs" if ENV == "dev" else None,
57
+ openapi_url="/openapi.json" if ENV == "dev" else None,
58
+ redoc_url=None,
59
+ )
60
+
61
+ app.add_middleware(
62
+ CORSMiddleware,
63
+ allow_origins=CORS_ALLOW_ORIGIN,
64
+ allow_credentials=True,
65
+ allow_methods=["*"],
66
+ allow_headers=["*"],
67
+ )
68
+
69
+ app.state.config = AppConfig()
70
+
71
+ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
72
+ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
73
+ app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
74
+
75
+
76
+ # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
77
+ # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
78
+ # least connections, or least response time for better resource utilization and performance optimization.
79
+
80
+
81
+ @app.head("/")
82
+ @app.get("/")
83
+ async def get_status():
84
+ return {"status": True}
85
+
86
+
87
+ class ConnectionVerificationForm(BaseModel):
88
+ url: str
89
+ key: Optional[str] = None
90
+
91
+
92
+ @app.post("/verify")
93
+ async def verify_connection(
94
+ form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
95
+ ):
96
+ url = form_data.url
97
+ key = form_data.key
98
+
99
+ headers = {}
100
+ if key:
101
+ headers["Authorization"] = f"Bearer {key}"
102
+
103
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
104
+ async with aiohttp.ClientSession(timeout=timeout) as session:
105
+ try:
106
+ async with session.get(f"{url}/api/version", headers=headers) as r:
107
+ if r.status != 200:
108
+ # Extract response error details if available
109
+ error_detail = f"HTTP Error: {r.status}"
110
+ res = await r.json()
111
+ if "error" in res:
112
+ error_detail = f"External Error: {res['error']}"
113
+ raise Exception(error_detail)
114
+
115
+ response_data = await r.json()
116
+ return response_data
117
+
118
+ except aiohttp.ClientError as e:
119
+ # ClientError covers all aiohttp requests issues
120
+ log.exception(f"Client error: {str(e)}")
121
+ # Handle aiohttp-specific connection issues, timeout etc.
122
+ raise HTTPException(
123
+ status_code=500, detail="Open WebUI: Server Connection Error"
124
+ )
125
+ except Exception as e:
126
+ log.exception(f"Unexpected error: {e}")
127
+ # Generic error handler in case parsing JSON or other steps fail
128
+ error_detail = f"Unexpected error: {str(e)}"
129
+ raise HTTPException(status_code=500, detail=error_detail)
130
+
131
+
132
+ @app.get("/config")
133
+ async def get_config(user=Depends(get_admin_user)):
134
+ return {
135
+ "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
136
+ "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
137
+ "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
138
+ }
139
+
140
+
141
+ class OllamaConfigForm(BaseModel):
142
+ ENABLE_OLLAMA_API: Optional[bool] = None
143
+ OLLAMA_BASE_URLS: list[str]
144
+ OLLAMA_API_CONFIGS: dict
145
+
146
+
147
+ @app.post("/config/update")
148
+ async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
149
+ app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
150
+ app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
151
+
152
+ app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
153
+
154
+ # Remove any extra configs
155
+ config_urls = app.state.config.OLLAMA_API_CONFIGS.keys()
156
+ for url in list(app.state.config.OLLAMA_BASE_URLS):
157
+ if url not in config_urls:
158
+ app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
159
+
160
+ return {
161
+ "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
162
+ "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
163
+ "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
164
+ }
165
+
166
+
167
+ async def aiohttp_get(url, key=None):
168
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
169
+ try:
170
+ headers = {"Authorization": f"Bearer {key}"} if key else {}
171
+ async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
172
+ async with session.get(url, headers=headers) as response:
173
+ return await response.json()
174
+ except Exception as e:
175
+ # Handle connection error here
176
+ log.error(f"Connection error: {e}")
177
+ return None
178
+
179
+
180
+ async def cleanup_response(
181
+ response: Optional[aiohttp.ClientResponse],
182
+ session: Optional[aiohttp.ClientSession],
183
+ ):
184
+ if response:
185
+ response.close()
186
+ if session:
187
+ await session.close()
188
+
189
+
190
+ async def post_streaming_url(
191
+ url: str, payload: Union[str, bytes], stream: bool = True, content_type=None
192
+ ):
193
+ r = None
194
+ try:
195
+ session = aiohttp.ClientSession(
196
+ trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
197
+ )
198
+
199
+ parsed_url = urlparse(url)
200
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
201
+
202
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
203
+ key = api_config.get("key", None)
204
+
205
+ headers = {"Content-Type": "application/json"}
206
+ if key:
207
+ headers["Authorization"] = f"Bearer {key}"
208
+
209
+ r = await session.post(
210
+ url,
211
+ data=payload,
212
+ headers=headers,
213
+ )
214
+ r.raise_for_status()
215
+
216
+ if stream:
217
+ response_headers = dict(r.headers)
218
+ if content_type:
219
+ response_headers["Content-Type"] = content_type
220
+ return StreamingResponse(
221
+ r.content,
222
+ status_code=r.status,
223
+ headers=response_headers,
224
+ background=BackgroundTask(
225
+ cleanup_response, response=r, session=session
226
+ ),
227
+ )
228
+ else:
229
+ res = await r.json()
230
+ await cleanup_response(r, session)
231
+ return res
232
+
233
+ except Exception as e:
234
+ error_detail = "Open WebUI: Server Connection Error"
235
+ if r is not None:
236
+ try:
237
+ res = await r.json()
238
+ if "error" in res:
239
+ error_detail = f"Ollama: {res['error']}"
240
+ except Exception:
241
+ error_detail = f"Ollama: {e}"
242
+
243
+ raise HTTPException(
244
+ status_code=r.status if r else 500,
245
+ detail=error_detail,
246
+ )
247
+
248
+
249
+ def merge_models_lists(model_lists):
250
+ merged_models = {}
251
+
252
+ for idx, model_list in enumerate(model_lists):
253
+ if model_list is not None:
254
+ for model in model_list:
255
+ id = model["model"]
256
+ if id not in merged_models:
257
+ model["urls"] = [idx]
258
+ merged_models[id] = model
259
+ else:
260
+ merged_models[id]["urls"].append(idx)
261
+
262
+ return list(merged_models.values())
263
+
264
+
265
+ @cached(ttl=3)
266
+ async def get_all_models():
267
+ log.info("get_all_models()")
268
+ if app.state.config.ENABLE_OLLAMA_API:
269
+ tasks = []
270
+ for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS):
271
+ if url not in app.state.config.OLLAMA_API_CONFIGS:
272
+ tasks.append(aiohttp_get(f"{url}/api/tags"))
273
+ else:
274
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
275
+ enable = api_config.get("enable", True)
276
+ key = api_config.get("key", None)
277
+
278
+ if enable:
279
+ tasks.append(aiohttp_get(f"{url}/api/tags", key))
280
+ else:
281
+ tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
282
+
283
+ responses = await asyncio.gather(*tasks)
284
+
285
+ for idx, response in enumerate(responses):
286
+ if response:
287
+ url = app.state.config.OLLAMA_BASE_URLS[idx]
288
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
289
+
290
+ prefix_id = api_config.get("prefix_id", None)
291
+ model_ids = api_config.get("model_ids", [])
292
+
293
+ if len(model_ids) != 0 and "models" in response:
294
+ response["models"] = list(
295
+ filter(
296
+ lambda model: model["model"] in model_ids,
297
+ response["models"],
298
+ )
299
+ )
300
+
301
+ if prefix_id:
302
+ for model in response.get("models", []):
303
+ model["model"] = f"{prefix_id}.{model['model']}"
304
+
305
+ models = {
306
+ "models": merge_models_lists(
307
+ map(
308
+ lambda response: response.get("models", []) if response else None,
309
+ responses,
310
+ )
311
+ )
312
+ }
313
+
314
+ else:
315
+ models = {"models": []}
316
+
317
+ return models
318
+
319
+
320
+ @app.get("/api/tags")
321
+ @app.get("/api/tags/{url_idx}")
322
+ async def get_ollama_tags(
323
+ url_idx: Optional[int] = None, user=Depends(get_verified_user)
324
+ ):
325
+ models = []
326
+ if url_idx is None:
327
+ models = await get_all_models()
328
+ else:
329
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
330
+
331
+ parsed_url = urlparse(url)
332
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
333
+
334
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
335
+ key = api_config.get("key", None)
336
+
337
+ headers = {}
338
+ if key:
339
+ headers["Authorization"] = f"Bearer {key}"
340
+
341
+ r = None
342
+ try:
343
+ r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
344
+ r.raise_for_status()
345
+
346
+ models = r.json()
347
+ except Exception as e:
348
+ log.exception(e)
349
+ error_detail = "Open WebUI: Server Connection Error"
350
+ if r is not None:
351
+ try:
352
+ res = r.json()
353
+ if "error" in res:
354
+ error_detail = f"Ollama: {res['error']}"
355
+ except Exception:
356
+ error_detail = f"Ollama: {e}"
357
+
358
+ raise HTTPException(
359
+ status_code=r.status_code if r else 500,
360
+ detail=error_detail,
361
+ )
362
+
363
+ if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
364
+ # Filter models based on user access control
365
+ filtered_models = []
366
+ for model in models.get("models", []):
367
+ model_info = Models.get_model_by_id(model["model"])
368
+ if model_info:
369
+ if user.id == model_info.user_id or has_access(
370
+ user.id, type="read", access_control=model_info.access_control
371
+ ):
372
+ filtered_models.append(model)
373
+ models["models"] = filtered_models
374
+
375
+ return models
376
+
377
+
378
+ @app.get("/api/version")
379
+ @app.get("/api/version/{url_idx}")
380
+ async def get_ollama_versions(url_idx: Optional[int] = None):
381
+ if app.state.config.ENABLE_OLLAMA_API:
382
+ if url_idx is None:
383
+ # returns lowest version
384
+ tasks = [
385
+ aiohttp_get(
386
+ f"{url}/api/version",
387
+ app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
388
+ )
389
+ for url in app.state.config.OLLAMA_BASE_URLS
390
+ ]
391
+ responses = await asyncio.gather(*tasks)
392
+ responses = list(filter(lambda x: x is not None, responses))
393
+
394
+ if len(responses) > 0:
395
+ lowest_version = min(
396
+ responses,
397
+ key=lambda x: tuple(
398
+ map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
399
+ ),
400
+ )
401
+
402
+ return {"version": lowest_version["version"]}
403
+ else:
404
+ raise HTTPException(
405
+ status_code=500,
406
+ detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
407
+ )
408
+ else:
409
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
410
+
411
+ r = None
412
+ try:
413
+ r = requests.request(method="GET", url=f"{url}/api/version")
414
+ r.raise_for_status()
415
+
416
+ return r.json()
417
+ except Exception as e:
418
+ log.exception(e)
419
+ error_detail = "Open WebUI: Server Connection Error"
420
+ if r is not None:
421
+ try:
422
+ res = r.json()
423
+ if "error" in res:
424
+ error_detail = f"Ollama: {res['error']}"
425
+ except Exception:
426
+ error_detail = f"Ollama: {e}"
427
+
428
+ raise HTTPException(
429
+ status_code=r.status_code if r else 500,
430
+ detail=error_detail,
431
+ )
432
+ else:
433
+ return {"version": False}
434
+
435
+
436
+ class ModelNameForm(BaseModel):
437
+ name: str
438
+
439
+
440
+ @app.post("/api/pull")
441
+ @app.post("/api/pull/{url_idx}")
442
+ async def pull_model(
443
+ form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
444
+ ):
445
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
446
+ log.info(f"url: {url}")
447
+
448
+ # Admin should be able to pull models from any source
449
+ payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
450
+
451
+ return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
452
+
453
+
454
+ class PushModelForm(BaseModel):
455
+ name: str
456
+ insecure: Optional[bool] = None
457
+ stream: Optional[bool] = None
458
+
459
+
460
+ @app.delete("/api/push")
461
+ @app.delete("/api/push/{url_idx}")
462
+ async def push_model(
463
+ form_data: PushModelForm,
464
+ url_idx: Optional[int] = None,
465
+ user=Depends(get_admin_user),
466
+ ):
467
+ if url_idx is None:
468
+ model_list = await get_all_models()
469
+ models = {model["model"]: model for model in model_list["models"]}
470
+
471
+ if form_data.name in models:
472
+ url_idx = models[form_data.name]["urls"][0]
473
+ else:
474
+ raise HTTPException(
475
+ status_code=400,
476
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
477
+ )
478
+
479
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
480
+ log.debug(f"url: {url}")
481
+
482
+ return await post_streaming_url(
483
+ f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
484
+ )
485
+
486
+
487
+ class CreateModelForm(BaseModel):
488
+ name: str
489
+ modelfile: Optional[str] = None
490
+ stream: Optional[bool] = None
491
+ path: Optional[str] = None
492
+
493
+
494
+ @app.post("/api/create")
495
+ @app.post("/api/create/{url_idx}")
496
+ async def create_model(
497
+ form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
498
+ ):
499
+ log.debug(f"form_data: {form_data}")
500
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
501
+ log.info(f"url: {url}")
502
+
503
+ return await post_streaming_url(
504
+ f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
505
+ )
506
+
507
+
508
+ class CopyModelForm(BaseModel):
509
+ source: str
510
+ destination: str
511
+
512
+
513
+ @app.post("/api/copy")
514
+ @app.post("/api/copy/{url_idx}")
515
+ async def copy_model(
516
+ form_data: CopyModelForm,
517
+ url_idx: Optional[int] = None,
518
+ user=Depends(get_admin_user),
519
+ ):
520
+ if url_idx is None:
521
+ model_list = await get_all_models()
522
+ models = {model["model"]: model for model in model_list["models"]}
523
+
524
+ if form_data.source in models:
525
+ url_idx = models[form_data.source]["urls"][0]
526
+ else:
527
+ raise HTTPException(
528
+ status_code=400,
529
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
530
+ )
531
+
532
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
533
+ log.info(f"url: {url}")
534
+
535
+ parsed_url = urlparse(url)
536
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
537
+
538
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
539
+ key = api_config.get("key", None)
540
+
541
+ headers = {"Content-Type": "application/json"}
542
+ if key:
543
+ headers["Authorization"] = f"Bearer {key}"
544
+
545
+ r = requests.request(
546
+ method="POST",
547
+ url=f"{url}/api/copy",
548
+ headers=headers,
549
+ data=form_data.model_dump_json(exclude_none=True).encode(),
550
+ )
551
+
552
+ try:
553
+ r.raise_for_status()
554
+
555
+ log.debug(f"r.text: {r.text}")
556
+
557
+ return True
558
+ except Exception as e:
559
+ log.exception(e)
560
+ error_detail = "Open WebUI: Server Connection Error"
561
+ if r is not None:
562
+ try:
563
+ res = r.json()
564
+ if "error" in res:
565
+ error_detail = f"Ollama: {res['error']}"
566
+ except Exception:
567
+ error_detail = f"Ollama: {e}"
568
+
569
+ raise HTTPException(
570
+ status_code=r.status_code if r else 500,
571
+ detail=error_detail,
572
+ )
573
+
574
+
575
+ @app.delete("/api/delete")
576
+ @app.delete("/api/delete/{url_idx}")
577
+ async def delete_model(
578
+ form_data: ModelNameForm,
579
+ url_idx: Optional[int] = None,
580
+ user=Depends(get_admin_user),
581
+ ):
582
+ if url_idx is None:
583
+ model_list = await get_all_models()
584
+ models = {model["model"]: model for model in model_list["models"]}
585
+
586
+ if form_data.name in models:
587
+ url_idx = models[form_data.name]["urls"][0]
588
+ else:
589
+ raise HTTPException(
590
+ status_code=400,
591
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
592
+ )
593
+
594
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
595
+ log.info(f"url: {url}")
596
+
597
+ parsed_url = urlparse(url)
598
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
599
+
600
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
601
+ key = api_config.get("key", None)
602
+
603
+ headers = {"Content-Type": "application/json"}
604
+ if key:
605
+ headers["Authorization"] = f"Bearer {key}"
606
+
607
+ r = requests.request(
608
+ method="DELETE",
609
+ url=f"{url}/api/delete",
610
+ data=form_data.model_dump_json(exclude_none=True).encode(),
611
+ headers=headers,
612
+ )
613
+ try:
614
+ r.raise_for_status()
615
+
616
+ log.debug(f"r.text: {r.text}")
617
+
618
+ return True
619
+ except Exception as e:
620
+ log.exception(e)
621
+ error_detail = "Open WebUI: Server Connection Error"
622
+ if r is not None:
623
+ try:
624
+ res = r.json()
625
+ if "error" in res:
626
+ error_detail = f"Ollama: {res['error']}"
627
+ except Exception:
628
+ error_detail = f"Ollama: {e}"
629
+
630
+ raise HTTPException(
631
+ status_code=r.status_code if r else 500,
632
+ detail=error_detail,
633
+ )
634
+
635
+
636
+ @app.post("/api/show")
637
+ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
638
+ model_list = await get_all_models()
639
+ models = {model["model"]: model for model in model_list["models"]}
640
+
641
+ if form_data.name not in models:
642
+ raise HTTPException(
643
+ status_code=400,
644
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
645
+ )
646
+
647
+ url_idx = random.choice(models[form_data.name]["urls"])
648
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
649
+ log.info(f"url: {url}")
650
+
651
+ parsed_url = urlparse(url)
652
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
653
+
654
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
655
+ key = api_config.get("key", None)
656
+
657
+ headers = {"Content-Type": "application/json"}
658
+ if key:
659
+ headers["Authorization"] = f"Bearer {key}"
660
+
661
+ r = requests.request(
662
+ method="POST",
663
+ url=f"{url}/api/show",
664
+ headers=headers,
665
+ data=form_data.model_dump_json(exclude_none=True).encode(),
666
+ )
667
+ try:
668
+ r.raise_for_status()
669
+
670
+ return r.json()
671
+ except Exception as e:
672
+ log.exception(e)
673
+ error_detail = "Open WebUI: Server Connection Error"
674
+ if r is not None:
675
+ try:
676
+ res = r.json()
677
+ if "error" in res:
678
+ error_detail = f"Ollama: {res['error']}"
679
+ except Exception:
680
+ error_detail = f"Ollama: {e}"
681
+
682
+ raise HTTPException(
683
+ status_code=r.status_code if r else 500,
684
+ detail=error_detail,
685
+ )
686
+
687
+
688
+ class GenerateEmbeddingsForm(BaseModel):
689
+ model: str
690
+ prompt: str
691
+ options: Optional[dict] = None
692
+ keep_alive: Optional[Union[int, str]] = None
693
+
694
+
695
+ class GenerateEmbedForm(BaseModel):
696
+ model: str
697
+ input: list[str] | str
698
+ truncate: Optional[bool] = None
699
+ options: Optional[dict] = None
700
+ keep_alive: Optional[Union[int, str]] = None
701
+
702
+
703
+ @app.post("/api/embed")
704
+ @app.post("/api/embed/{url_idx}")
705
+ async def generate_embeddings(
706
+ form_data: GenerateEmbedForm,
707
+ url_idx: Optional[int] = None,
708
+ user=Depends(get_verified_user),
709
+ ):
710
+ return await generate_ollama_batch_embeddings(form_data, url_idx)
711
+
712
+
713
+ @app.post("/api/embeddings")
714
+ @app.post("/api/embeddings/{url_idx}")
715
+ async def generate_embeddings(
716
+ form_data: GenerateEmbeddingsForm,
717
+ url_idx: Optional[int] = None,
718
+ user=Depends(get_verified_user),
719
+ ):
720
+ return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
721
+
722
+
723
+ async def generate_ollama_embeddings(
724
+ form_data: GenerateEmbeddingsForm,
725
+ url_idx: Optional[int] = None,
726
+ ):
727
+ log.info(f"generate_ollama_embeddings {form_data}")
728
+
729
+ if url_idx is None:
730
+ model_list = await get_all_models()
731
+ models = {model["model"]: model for model in model_list["models"]}
732
+
733
+ model = form_data.model
734
+
735
+ if ":" not in model:
736
+ model = f"{model}:latest"
737
+
738
+ if model in models:
739
+ url_idx = random.choice(models[model]["urls"])
740
+ else:
741
+ raise HTTPException(
742
+ status_code=400,
743
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
744
+ )
745
+
746
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
747
+ log.info(f"url: {url}")
748
+
749
+ parsed_url = urlparse(url)
750
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
751
+
752
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
753
+ key = api_config.get("key", None)
754
+
755
+ headers = {"Content-Type": "application/json"}
756
+ if key:
757
+ headers["Authorization"] = f"Bearer {key}"
758
+
759
+ r = requests.request(
760
+ method="POST",
761
+ url=f"{url}/api/embeddings",
762
+ headers=headers,
763
+ data=form_data.model_dump_json(exclude_none=True).encode(),
764
+ )
765
+ try:
766
+ r.raise_for_status()
767
+
768
+ data = r.json()
769
+
770
+ log.info(f"generate_ollama_embeddings {data}")
771
+
772
+ if "embedding" in data:
773
+ return data
774
+ else:
775
+ raise Exception("Something went wrong :/")
776
+ except Exception as e:
777
+ log.exception(e)
778
+ error_detail = "Open WebUI: Server Connection Error"
779
+ if r is not None:
780
+ try:
781
+ res = r.json()
782
+ if "error" in res:
783
+ error_detail = f"Ollama: {res['error']}"
784
+ except Exception:
785
+ error_detail = f"Ollama: {e}"
786
+
787
+ raise HTTPException(
788
+ status_code=r.status_code if r else 500,
789
+ detail=error_detail,
790
+ )
791
+
792
+
793
+ async def generate_ollama_batch_embeddings(
794
+ form_data: GenerateEmbedForm,
795
+ url_idx: Optional[int] = None,
796
+ ):
797
+ log.info(f"generate_ollama_batch_embeddings {form_data}")
798
+
799
+ if url_idx is None:
800
+ model_list = await get_all_models()
801
+ models = {model["model"]: model for model in model_list["models"]}
802
+
803
+ model = form_data.model
804
+
805
+ if ":" not in model:
806
+ model = f"{model}:latest"
807
+
808
+ if model in models:
809
+ url_idx = random.choice(models[model]["urls"])
810
+ else:
811
+ raise HTTPException(
812
+ status_code=400,
813
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
814
+ )
815
+
816
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
817
+ log.info(f"url: {url}")
818
+
819
+ parsed_url = urlparse(url)
820
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
821
+
822
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
823
+ key = api_config.get("key", None)
824
+
825
+ headers = {"Content-Type": "application/json"}
826
+ if key:
827
+ headers["Authorization"] = f"Bearer {key}"
828
+
829
+ r = requests.request(
830
+ method="POST",
831
+ url=f"{url}/api/embed",
832
+ headers=headers,
833
+ data=form_data.model_dump_json(exclude_none=True).encode(),
834
+ )
835
+ try:
836
+ r.raise_for_status()
837
+
838
+ data = r.json()
839
+
840
+ log.info(f"generate_ollama_batch_embeddings {data}")
841
+
842
+ if "embeddings" in data:
843
+ return data
844
+ else:
845
+ raise Exception("Something went wrong :/")
846
+ except Exception as e:
847
+ log.exception(e)
848
+ error_detail = "Open WebUI: Server Connection Error"
849
+ if r is not None:
850
+ try:
851
+ res = r.json()
852
+ if "error" in res:
853
+ error_detail = f"Ollama: {res['error']}"
854
+ except Exception:
855
+ error_detail = f"Ollama: {e}"
856
+
857
+ raise Exception(error_detail)
858
+
859
+
860
+ class GenerateCompletionForm(BaseModel):
861
+ model: str
862
+ prompt: str
863
+ suffix: Optional[str] = None
864
+ images: Optional[list[str]] = None
865
+ format: Optional[str] = None
866
+ options: Optional[dict] = None
867
+ system: Optional[str] = None
868
+ template: Optional[str] = None
869
+ context: Optional[list[int]] = None
870
+ stream: Optional[bool] = True
871
+ raw: Optional[bool] = None
872
+ keep_alive: Optional[Union[int, str]] = None
873
+
874
+
875
+ @app.post("/api/generate")
876
+ @app.post("/api/generate/{url_idx}")
877
+ async def generate_completion(
878
+ form_data: GenerateCompletionForm,
879
+ url_idx: Optional[int] = None,
880
+ user=Depends(get_verified_user),
881
+ ):
882
+ if url_idx is None:
883
+ model_list = await get_all_models()
884
+ models = {model["model"]: model for model in model_list["models"]}
885
+
886
+ model = form_data.model
887
+
888
+ if ":" not in model:
889
+ model = f"{model}:latest"
890
+
891
+ if model in models:
892
+ url_idx = random.choice(models[model]["urls"])
893
+ else:
894
+ raise HTTPException(
895
+ status_code=400,
896
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
897
+ )
898
+
899
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
900
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
901
+ prefix_id = api_config.get("prefix_id", None)
902
+ if prefix_id:
903
+ form_data.model = form_data.model.replace(f"{prefix_id}.", "")
904
+ log.info(f"url: {url}")
905
+
906
+ return await post_streaming_url(
907
+ f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
908
+ )
909
+
910
+
911
+ class ChatMessage(BaseModel):
912
+ role: str
913
+ content: str
914
+ images: Optional[list[str]] = None
915
+
916
+
917
+ class GenerateChatCompletionForm(BaseModel):
918
+ model: str
919
+ messages: list[ChatMessage]
920
+ format: Optional[str] = None
921
+ options: Optional[dict] = None
922
+ template: Optional[str] = None
923
+ stream: Optional[bool] = True
924
+ keep_alive: Optional[Union[int, str]] = None
925
+
926
+
927
+ async def get_ollama_url(url_idx: Optional[int], model: str):
928
+ if url_idx is None:
929
+ model_list = await get_all_models()
930
+ models = {model["model"]: model for model in model_list["models"]}
931
+
932
+ if model not in models:
933
+ raise HTTPException(
934
+ status_code=400,
935
+ detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
936
+ )
937
+ url_idx = random.choice(models[model]["urls"])
938
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
939
+ return url
940
+
941
+
942
+ @app.post("/api/chat")
943
+ @app.post("/api/chat/{url_idx}")
944
+ async def generate_chat_completion(
945
+ form_data: GenerateChatCompletionForm,
946
+ url_idx: Optional[int] = None,
947
+ user=Depends(get_verified_user),
948
+ bypass_filter: Optional[bool] = False,
949
+ ):
950
+ payload = {**form_data.model_dump(exclude_none=True)}
951
+ log.debug(f"generate_chat_completion() - 1.payload = {payload}")
952
+ if "metadata" in payload:
953
+ del payload["metadata"]
954
+
955
+ model_id = payload["model"]
956
+ model_info = Models.get_model_by_id(model_id)
957
+
958
+ if model_info:
959
+ if model_info.base_model_id:
960
+ payload["model"] = model_info.base_model_id
961
+
962
+ params = model_info.params.model_dump()
963
+
964
+ if params:
965
+ if payload.get("options") is None:
966
+ payload["options"] = {}
967
+
968
+ payload["options"] = apply_model_params_to_body_ollama(
969
+ params, payload["options"]
970
+ )
971
+ payload = apply_model_system_prompt_to_body(params, payload, user)
972
+
973
+ # Check if user has access to the model
974
+ if not bypass_filter and user.role == "user":
975
+ if not (
976
+ user.id == model_info.user_id
977
+ or has_access(
978
+ user.id, type="read", access_control=model_info.access_control
979
+ )
980
+ ):
981
+ raise HTTPException(
982
+ status_code=403,
983
+ detail="Model not found",
984
+ )
985
+ elif not bypass_filter:
986
+ if user.role != "admin":
987
+ raise HTTPException(
988
+ status_code=403,
989
+ detail="Model not found",
990
+ )
991
+
992
+ if ":" not in payload["model"]:
993
+ payload["model"] = f"{payload['model']}:latest"
994
+
995
+ url = await get_ollama_url(url_idx, payload["model"])
996
+ log.info(f"url: {url}")
997
+ log.debug(f"generate_chat_completion() - 2.payload = {payload}")
998
+
999
+ parsed_url = urlparse(url)
1000
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
1001
+
1002
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
1003
+ prefix_id = api_config.get("prefix_id", None)
1004
+ if prefix_id:
1005
+ payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
1006
+
1007
+ return await post_streaming_url(
1008
+ f"{url}/api/chat",
1009
+ json.dumps(payload),
1010
+ stream=form_data.stream,
1011
+ content_type="application/x-ndjson",
1012
+ )
1013
+
1014
+
1015
+ # TODO: we should update this part once Ollama supports other types
1016
+ class OpenAIChatMessageContent(BaseModel):
1017
+ type: str
1018
+ model_config = ConfigDict(extra="allow")
1019
+
1020
+
1021
+ class OpenAIChatMessage(BaseModel):
1022
+ role: str
1023
+ content: Union[str, list[OpenAIChatMessageContent]]
1024
+
1025
+ model_config = ConfigDict(extra="allow")
1026
+
1027
+
1028
+ class OpenAIChatCompletionForm(BaseModel):
1029
+ model: str
1030
+ messages: list[OpenAIChatMessage]
1031
+
1032
+ model_config = ConfigDict(extra="allow")
1033
+
1034
+
1035
+ @app.post("/v1/chat/completions")
1036
+ @app.post("/v1/chat/completions/{url_idx}")
1037
+ async def generate_openai_chat_completion(
1038
+ form_data: dict,
1039
+ url_idx: Optional[int] = None,
1040
+ user=Depends(get_verified_user),
1041
+ ):
1042
+ try:
1043
+ completion_form = OpenAIChatCompletionForm(**form_data)
1044
+ except Exception as e:
1045
+ log.exception(e)
1046
+ raise HTTPException(
1047
+ status_code=400,
1048
+ detail=str(e),
1049
+ )
1050
+
1051
+ payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
1052
+ if "metadata" in payload:
1053
+ del payload["metadata"]
1054
+
1055
+ model_id = completion_form.model
1056
+ if ":" not in model_id:
1057
+ model_id = f"{model_id}:latest"
1058
+
1059
+ model_info = Models.get_model_by_id(model_id)
1060
+ if model_info:
1061
+ if model_info.base_model_id:
1062
+ payload["model"] = model_info.base_model_id
1063
+
1064
+ params = model_info.params.model_dump()
1065
+
1066
+ if params:
1067
+ payload = apply_model_params_to_body_openai(params, payload)
1068
+ payload = apply_model_system_prompt_to_body(params, payload, user)
1069
+
1070
+ # Check if user has access to the model
1071
+ if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
1072
+ if not (
1073
+ user.id == model_info.user_id
1074
+ or has_access(
1075
+ user.id, type="read", access_control=model_info.access_control
1076
+ )
1077
+ ):
1078
+ raise HTTPException(
1079
+ status_code=403,
1080
+ detail="Model not found",
1081
+ )
1082
+ else:
1083
+ if user.role != "admin":
1084
+ raise HTTPException(
1085
+ status_code=403,
1086
+ detail="Model not found",
1087
+ )
1088
+
1089
+ if ":" not in payload["model"]:
1090
+ payload["model"] = f"{payload['model']}:latest"
1091
+
1092
+ url = await get_ollama_url(url_idx, payload["model"])
1093
+ log.info(f"url: {url}")
1094
+
1095
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
1096
+ prefix_id = api_config.get("prefix_id", None)
1097
+ if prefix_id:
1098
+ payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
1099
+
1100
+ return await post_streaming_url(
1101
+ f"{url}/v1/chat/completions",
1102
+ json.dumps(payload),
1103
+ stream=payload.get("stream", False),
1104
+ )
1105
+
1106
+
1107
+ @app.get("/v1/models")
1108
+ @app.get("/v1/models/{url_idx}")
1109
+ async def get_openai_models(
1110
+ url_idx: Optional[int] = None,
1111
+ user=Depends(get_verified_user),
1112
+ ):
1113
+
1114
+ models = []
1115
+ if url_idx is None:
1116
+ model_list = await get_all_models()
1117
+ models = [
1118
+ {
1119
+ "id": model["model"],
1120
+ "object": "model",
1121
+ "created": int(time.time()),
1122
+ "owned_by": "openai",
1123
+ }
1124
+ for model in model_list["models"]
1125
+ ]
1126
+
1127
+ else:
1128
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
1129
+ try:
1130
+ r = requests.request(method="GET", url=f"{url}/api/tags")
1131
+ r.raise_for_status()
1132
+
1133
+ model_list = r.json()
1134
+
1135
+ models = [
1136
+ {
1137
+ "id": model["model"],
1138
+ "object": "model",
1139
+ "created": int(time.time()),
1140
+ "owned_by": "openai",
1141
+ }
1142
+ for model in models["models"]
1143
+ ]
1144
+ except Exception as e:
1145
+ log.exception(e)
1146
+ error_detail = "Open WebUI: Server Connection Error"
1147
+ if r is not None:
1148
+ try:
1149
+ res = r.json()
1150
+ if "error" in res:
1151
+ error_detail = f"Ollama: {res['error']}"
1152
+ except Exception:
1153
+ error_detail = f"Ollama: {e}"
1154
+
1155
+ raise HTTPException(
1156
+ status_code=r.status_code if r else 500,
1157
+ detail=error_detail,
1158
+ )
1159
+
1160
+ if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
1161
+ # Filter models based on user access control
1162
+ filtered_models = []
1163
+ for model in models:
1164
+ model_info = Models.get_model_by_id(model["id"])
1165
+ if model_info:
1166
+ if user.id == model_info.user_id or has_access(
1167
+ user.id, type="read", access_control=model_info.access_control
1168
+ ):
1169
+ filtered_models.append(model)
1170
+ models = filtered_models
1171
+
1172
+ return {
1173
+ "data": models,
1174
+ "object": "list",
1175
+ }
1176
+
1177
+
1178
+ class UrlForm(BaseModel):
1179
+ url: str
1180
+
1181
+
1182
+ class UploadBlobForm(BaseModel):
1183
+ filename: str
1184
+
1185
+
1186
+ def parse_huggingface_url(hf_url):
1187
+ try:
1188
+ # Parse the URL
1189
+ parsed_url = urlparse(hf_url)
1190
+
1191
+ # Get the path and split it into components
1192
+ path_components = parsed_url.path.split("/")
1193
+
1194
+ # Extract the desired output
1195
+ model_file = path_components[-1]
1196
+
1197
+ return model_file
1198
+ except ValueError:
1199
+ return None
1200
+
1201
+
1202
+ async def download_file_stream(
1203
+ ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
1204
+ ):
1205
+ done = False
1206
+
1207
+ if os.path.exists(file_path):
1208
+ current_size = os.path.getsize(file_path)
1209
+ else:
1210
+ current_size = 0
1211
+
1212
+ headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
1213
+
1214
+ timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
1215
+
1216
+ async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
1217
+ async with session.get(file_url, headers=headers) as response:
1218
+ total_size = int(response.headers.get("content-length", 0)) + current_size
1219
+
1220
+ with open(file_path, "ab+") as file:
1221
+ async for data in response.content.iter_chunked(chunk_size):
1222
+ current_size += len(data)
1223
+ file.write(data)
1224
+
1225
+ done = current_size == total_size
1226
+ progress = round((current_size / total_size) * 100, 2)
1227
+
1228
+ yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
1229
+
1230
+ if done:
1231
+ file.seek(0)
1232
+ hashed = calculate_sha256(file)
1233
+ file.seek(0)
1234
+
1235
+ url = f"{ollama_url}/api/blobs/sha256:{hashed}"
1236
+ response = requests.post(url, data=file)
1237
+
1238
+ if response.ok:
1239
+ res = {
1240
+ "done": done,
1241
+ "blob": f"sha256:{hashed}",
1242
+ "name": file_name,
1243
+ }
1244
+ os.remove(file_path)
1245
+
1246
+ yield f"data: {json.dumps(res)}\n\n"
1247
+ else:
1248
+ raise "Ollama: Could not create blob, Please try again."
1249
+
1250
+
1251
+ # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
1252
+ @app.post("/models/download")
1253
+ @app.post("/models/download/{url_idx}")
1254
+ async def download_model(
1255
+ form_data: UrlForm,
1256
+ url_idx: Optional[int] = None,
1257
+ user=Depends(get_admin_user),
1258
+ ):
1259
+ allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
1260
+
1261
+ if not any(form_data.url.startswith(host) for host in allowed_hosts):
1262
+ raise HTTPException(
1263
+ status_code=400,
1264
+ detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
1265
+ )
1266
+
1267
+ if url_idx is None:
1268
+ url_idx = 0
1269
+ url = app.state.config.OLLAMA_BASE_URLS[url_idx]
1270
+
1271
+ file_name = parse_huggingface_url(form_data.url)
1272
+
1273
+ if file_name:
1274
+ file_path = f"{UPLOAD_DIR}/{file_name}"
1275
+
1276
+ return StreamingResponse(
1277
+ download_file_stream(url, form_data.url, file_path, file_name),
1278
+ )
1279
+ else:
1280
+ return None
1281
+
1282
+
1283
+ @app.post("/models/upload")
1284
+ @app.post("/models/upload/{url_idx}")
1285
+ def upload_model(
1286
+ file: UploadFile = File(...),
1287
+ url_idx: Optional[int] = None,
1288
+ user=Depends(get_admin_user),
1289
+ ):
1290
+ if url_idx is None:
1291
+ url_idx = 0
1292
+ ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
1293
+
1294
+ file_path = f"{UPLOAD_DIR}/{file.filename}"
1295
+
1296
+ # Save file in chunks
1297
+ with open(file_path, "wb+") as f:
1298
+ for chunk in file.file:
1299
+ f.write(chunk)
1300
+
1301
+ def file_process_stream():
1302
+ nonlocal ollama_url
1303
+ total_size = os.path.getsize(file_path)
1304
+ chunk_size = 1024 * 1024
1305
+ try:
1306
+ with open(file_path, "rb") as f:
1307
+ total = 0
1308
+ done = False
1309
+
1310
+ while not done:
1311
+ chunk = f.read(chunk_size)
1312
+ if not chunk:
1313
+ done = True
1314
+ continue
1315
+
1316
+ total += len(chunk)
1317
+ progress = round((total / total_size) * 100, 2)
1318
+
1319
+ res = {
1320
+ "progress": progress,
1321
+ "total": total_size,
1322
+ "completed": total,
1323
+ }
1324
+ yield f"data: {json.dumps(res)}\n\n"
1325
+
1326
+ if done:
1327
+ f.seek(0)
1328
+ hashed = calculate_sha256(f)
1329
+ f.seek(0)
1330
+
1331
+ url = f"{ollama_url}/api/blobs/sha256:{hashed}"
1332
+ response = requests.post(url, data=f)
1333
+
1334
+ if response.ok:
1335
+ res = {
1336
+ "done": done,
1337
+ "blob": f"sha256:{hashed}",
1338
+ "name": file.filename,
1339
+ }
1340
+ os.remove(file_path)
1341
+ yield f"data: {json.dumps(res)}\n\n"
1342
+ else:
1343
+ raise Exception(
1344
+ "Ollama: Could not create blob, Please try again."
1345
+ )
1346
+
1347
+ except Exception as e:
1348
+ res = {"error": str(e)}
1349
+ yield f"data: {json.dumps(res)}\n\n"
1350
+
1351
+ return StreamingResponse(file_process_stream(), media_type="text/event-stream")
backend/open_webui/apps/openai/main.py ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import hashlib
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Literal, Optional, overload
7
+
8
+ import aiohttp
9
+ from aiocache import cached
10
+ import requests
11
+
12
+
13
+ from open_webui.apps.webui.models.models import Models
14
+ from open_webui.config import (
15
+ CACHE_DIR,
16
+ CORS_ALLOW_ORIGIN,
17
+ ENABLE_OPENAI_API,
18
+ OPENAI_API_BASE_URLS,
19
+ OPENAI_API_KEYS,
20
+ OPENAI_API_CONFIGS,
21
+ AppConfig,
22
+ )
23
+ from open_webui.env import (
24
+ AIOHTTP_CLIENT_TIMEOUT,
25
+ AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
26
+ ENABLE_FORWARD_USER_INFO_HEADERS,
27
+ BYPASS_MODEL_ACCESS_CONTROL,
28
+ )
29
+
30
+ from open_webui.constants import ERROR_MESSAGES
31
+ from open_webui.env import ENV, SRC_LOG_LEVELS
32
+ from fastapi import Depends, FastAPI, HTTPException, Request
33
+ from fastapi.middleware.cors import CORSMiddleware
34
+ from fastapi.responses import FileResponse, StreamingResponse
35
+ from pydantic import BaseModel
36
+ from starlette.background import BackgroundTask
37
+
38
+ from open_webui.utils.payload import (
39
+ apply_model_params_to_body_openai,
40
+ apply_model_system_prompt_to_body,
41
+ )
42
+
43
+ from open_webui.utils.utils import get_admin_user, get_verified_user
44
+ from open_webui.utils.access_control import has_access
45
+
46
+
47
+ log = logging.getLogger(__name__)
48
+ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
49
+
50
+
51
+ app = FastAPI(
52
+ docs_url="/docs" if ENV == "dev" else None,
53
+ openapi_url="/openapi.json" if ENV == "dev" else None,
54
+ redoc_url=None,
55
+ )
56
+
57
+
58
+ app.add_middleware(
59
+ CORSMiddleware,
60
+ allow_origins=CORS_ALLOW_ORIGIN,
61
+ allow_credentials=True,
62
+ allow_methods=["*"],
63
+ allow_headers=["*"],
64
+ )
65
+
66
+ app.state.config = AppConfig()
67
+
68
+ app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
69
+ app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
70
+ app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
71
+ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
72
+
73
+
74
+ @app.get("/config")
75
+ async def get_config(user=Depends(get_admin_user)):
76
+ return {
77
+ "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
78
+ "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
79
+ "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
80
+ "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
81
+ }
82
+
83
+
84
+ class OpenAIConfigForm(BaseModel):
85
+ ENABLE_OPENAI_API: Optional[bool] = None
86
+ OPENAI_API_BASE_URLS: list[str]
87
+ OPENAI_API_KEYS: list[str]
88
+ OPENAI_API_CONFIGS: dict
89
+
90
+
91
+ @app.post("/config/update")
92
+ async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
93
+ app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
94
+
95
+ app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
96
+ app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
97
+
98
+ # Check if API KEYS length is same than API URLS length
99
+ if len(app.state.config.OPENAI_API_KEYS) != len(
100
+ app.state.config.OPENAI_API_BASE_URLS
101
+ ):
102
+ if len(app.state.config.OPENAI_API_KEYS) > len(
103
+ app.state.config.OPENAI_API_BASE_URLS
104
+ ):
105
+ app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
106
+ : len(app.state.config.OPENAI_API_BASE_URLS)
107
+ ]
108
+ else:
109
+ app.state.config.OPENAI_API_KEYS += [""] * (
110
+ len(app.state.config.OPENAI_API_BASE_URLS)
111
+ - len(app.state.config.OPENAI_API_KEYS)
112
+ )
113
+
114
+ app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
115
+
116
+ # Remove any extra configs
117
+ config_urls = app.state.config.OPENAI_API_CONFIGS.keys()
118
+ for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
119
+ if url not in config_urls:
120
+ app.state.config.OPENAI_API_CONFIGS.pop(url, None)
121
+
122
+ return {
123
+ "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
124
+ "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
125
+ "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
126
+ "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
127
+ }
128
+
129
+
130
+ @app.post("/audio/speech")
131
+ async def speech(request: Request, user=Depends(get_verified_user)):
132
+ idx = None
133
+ try:
134
+ idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
135
+ body = await request.body()
136
+ name = hashlib.sha256(body).hexdigest()
137
+
138
+ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
139
+ SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
140
+ file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
141
+ file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
142
+
143
+ # Check if the file already exists in the cache
144
+ if file_path.is_file():
145
+ return FileResponse(file_path)
146
+
147
+ headers = {}
148
+ headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
149
+ headers["Content-Type"] = "application/json"
150
+ if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
151
+ headers["HTTP-Referer"] = "https://openwebui.com/"
152
+ headers["X-Title"] = "Open WebUI"
153
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
154
+ headers["X-OpenWebUI-User-Name"] = user.name
155
+ headers["X-OpenWebUI-User-Id"] = user.id
156
+ headers["X-OpenWebUI-User-Email"] = user.email
157
+ headers["X-OpenWebUI-User-Role"] = user.role
158
+ r = None
159
+ try:
160
+ r = requests.post(
161
+ url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
162
+ data=body,
163
+ headers=headers,
164
+ stream=True,
165
+ )
166
+
167
+ r.raise_for_status()
168
+
169
+ # Save the streaming content to a file
170
+ with open(file_path, "wb") as f:
171
+ for chunk in r.iter_content(chunk_size=8192):
172
+ f.write(chunk)
173
+
174
+ with open(file_body_path, "w") as f:
175
+ json.dump(json.loads(body.decode("utf-8")), f)
176
+
177
+ # Return the saved file
178
+ return FileResponse(file_path)
179
+
180
+ except Exception as e:
181
+ log.exception(e)
182
+ error_detail = "Open WebUI: Server Connection Error"
183
+ if r is not None:
184
+ try:
185
+ res = r.json()
186
+ if "error" in res:
187
+ error_detail = f"External: {res['error']}"
188
+ except Exception:
189
+ error_detail = f"External: {e}"
190
+
191
+ raise HTTPException(
192
+ status_code=r.status_code if r else 500, detail=error_detail
193
+ )
194
+
195
+ except ValueError:
196
+ raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
197
+
198
+
199
+ async def aiohttp_get(url, key=None):
200
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
201
+ try:
202
+ headers = {"Authorization": f"Bearer {key}"} if key else {}
203
+ async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
204
+ async with session.get(url, headers=headers) as response:
205
+ return await response.json()
206
+ except Exception as e:
207
+ # Handle connection error here
208
+ log.error(f"Connection error: {e}")
209
+ return None
210
+
211
+
212
+ async def cleanup_response(
213
+ response: Optional[aiohttp.ClientResponse],
214
+ session: Optional[aiohttp.ClientSession],
215
+ ):
216
+ if response:
217
+ response.close()
218
+ if session:
219
+ await session.close()
220
+
221
+
222
+ def merge_models_lists(model_lists):
223
+ log.debug(f"merge_models_lists {model_lists}")
224
+ merged_list = []
225
+
226
+ for idx, models in enumerate(model_lists):
227
+ if models is not None and "error" not in models:
228
+ merged_list.extend(
229
+ [
230
+ {
231
+ **model,
232
+ "name": model.get("name", model["id"]),
233
+ "owned_by": "openai",
234
+ "openai": model,
235
+ "urlIdx": idx,
236
+ }
237
+ for model in models
238
+ if "api.openai.com"
239
+ not in app.state.config.OPENAI_API_BASE_URLS[idx]
240
+ or not any(
241
+ name in model["id"]
242
+ for name in [
243
+ "babbage",
244
+ "dall-e",
245
+ "davinci",
246
+ "embedding",
247
+ "tts",
248
+ "whisper",
249
+ ]
250
+ )
251
+ ]
252
+ )
253
+
254
+ return merged_list
255
+
256
+
257
+ async def get_all_models_responses() -> list:
258
+ if not app.state.config.ENABLE_OPENAI_API:
259
+ return []
260
+
261
+ # Check if API KEYS length is same than API URLS length
262
+ num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
263
+ num_keys = len(app.state.config.OPENAI_API_KEYS)
264
+
265
+ if num_keys != num_urls:
266
+ # if there are more keys than urls, remove the extra keys
267
+ if num_keys > num_urls:
268
+ new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
269
+ app.state.config.OPENAI_API_KEYS = new_keys
270
+ # if there are more urls than keys, add empty keys
271
+ else:
272
+ app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
273
+
274
+ tasks = []
275
+ for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
276
+ if url not in app.state.config.OPENAI_API_CONFIGS:
277
+ tasks.append(
278
+ aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
279
+ )
280
+ else:
281
+ api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
282
+
283
+ enable = api_config.get("enable", True)
284
+ model_ids = api_config.get("model_ids", [])
285
+
286
+ if enable:
287
+ if len(model_ids) == 0:
288
+ tasks.append(
289
+ aiohttp_get(
290
+ f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]
291
+ )
292
+ )
293
+ else:
294
+ model_list = {
295
+ "object": "list",
296
+ "data": [
297
+ {
298
+ "id": model_id,
299
+ "name": model_id,
300
+ "owned_by": "openai",
301
+ "openai": {"id": model_id},
302
+ "urlIdx": idx,
303
+ }
304
+ for model_id in model_ids
305
+ ],
306
+ }
307
+
308
+ tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list)))
309
+ else:
310
+ tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
311
+
312
+ responses = await asyncio.gather(*tasks)
313
+
314
+ for idx, response in enumerate(responses):
315
+ if response:
316
+ url = app.state.config.OPENAI_API_BASE_URLS[idx]
317
+ api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
318
+
319
+ prefix_id = api_config.get("prefix_id", None)
320
+
321
+ if prefix_id:
322
+ for model in (
323
+ response if isinstance(response, list) else response.get("data", [])
324
+ ):
325
+ model["id"] = f"{prefix_id}.{model['id']}"
326
+
327
+ log.debug(f"get_all_models:responses() {responses}")
328
+
329
+ return responses
330
+
331
+
332
+ @cached(ttl=3)
333
+ async def get_all_models() -> dict[str, list]:
334
+ log.info("get_all_models()")
335
+
336
+ if not app.state.config.ENABLE_OPENAI_API:
337
+ return {"data": []}
338
+
339
+ responses = await get_all_models_responses()
340
+
341
+ def extract_data(response):
342
+ if response and "data" in response:
343
+ return response["data"]
344
+ if isinstance(response, list):
345
+ return response
346
+ return None
347
+
348
+ models = {"data": merge_models_lists(map(extract_data, responses))}
349
+ log.debug(f"models: {models}")
350
+
351
+ return models
352
+
353
+
354
+ @app.get("/models")
355
+ @app.get("/models/{url_idx}")
356
+ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
357
+ models = {
358
+ "data": [],
359
+ }
360
+
361
+ if url_idx is None:
362
+ models = await get_all_models()
363
+ else:
364
+ url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
365
+ key = app.state.config.OPENAI_API_KEYS[url_idx]
366
+
367
+ headers = {}
368
+ headers["Authorization"] = f"Bearer {key}"
369
+ headers["Content-Type"] = "application/json"
370
+
371
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
372
+ headers["X-OpenWebUI-User-Name"] = user.name
373
+ headers["X-OpenWebUI-User-Id"] = user.id
374
+ headers["X-OpenWebUI-User-Email"] = user.email
375
+ headers["X-OpenWebUI-User-Role"] = user.role
376
+
377
+ r = None
378
+
379
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
380
+ async with aiohttp.ClientSession(timeout=timeout) as session:
381
+ try:
382
+ async with session.get(f"{url}/models", headers=headers) as r:
383
+ if r.status != 200:
384
+ # Extract response error details if available
385
+ error_detail = f"HTTP Error: {r.status}"
386
+ res = await r.json()
387
+ if "error" in res:
388
+ error_detail = f"External Error: {res['error']}"
389
+ raise Exception(error_detail)
390
+
391
+ response_data = await r.json()
392
+
393
+ # Check if we're calling OpenAI API based on the URL
394
+ if "api.openai.com" in url:
395
+ # Filter models according to the specified conditions
396
+ response_data["data"] = [
397
+ model
398
+ for model in response_data.get("data", [])
399
+ if not any(
400
+ name in model["id"]
401
+ for name in [
402
+ "babbage",
403
+ "dall-e",
404
+ "davinci",
405
+ "embedding",
406
+ "tts",
407
+ "whisper",
408
+ ]
409
+ )
410
+ ]
411
+
412
+ models = response_data
413
+ except aiohttp.ClientError as e:
414
+ # ClientError covers all aiohttp requests issues
415
+ log.exception(f"Client error: {str(e)}")
416
+ # Handle aiohttp-specific connection issues, timeout etc.
417
+ raise HTTPException(
418
+ status_code=500, detail="Open WebUI: Server Connection Error"
419
+ )
420
+ except Exception as e:
421
+ log.exception(f"Unexpected error: {e}")
422
+ # Generic error handler in case parsing JSON or other steps fail
423
+ error_detail = f"Unexpected error: {str(e)}"
424
+ raise HTTPException(status_code=500, detail=error_detail)
425
+
426
+ if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
427
+ # Filter models based on user access control
428
+ filtered_models = []
429
+ for model in models.get("data", []):
430
+ model_info = Models.get_model_by_id(model["id"])
431
+ if model_info:
432
+ if user.id == model_info.user_id or has_access(
433
+ user.id, type="read", access_control=model_info.access_control
434
+ ):
435
+ filtered_models.append(model)
436
+ models["data"] = filtered_models
437
+
438
+ return models
439
+
440
+
441
+ class ConnectionVerificationForm(BaseModel):
442
+ url: str
443
+ key: str
444
+
445
+
446
+ @app.post("/verify")
447
+ async def verify_connection(
448
+ form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
449
+ ):
450
+ url = form_data.url
451
+ key = form_data.key
452
+
453
+ headers = {}
454
+ headers["Authorization"] = f"Bearer {key}"
455
+ headers["Content-Type"] = "application/json"
456
+
457
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
458
+ async with aiohttp.ClientSession(timeout=timeout) as session:
459
+ try:
460
+ async with session.get(f"{url}/models", headers=headers) as r:
461
+ if r.status != 200:
462
+ # Extract response error details if available
463
+ error_detail = f"HTTP Error: {r.status}"
464
+ res = await r.json()
465
+ if "error" in res:
466
+ error_detail = f"External Error: {res['error']}"
467
+ raise Exception(error_detail)
468
+
469
+ response_data = await r.json()
470
+ return response_data
471
+
472
+ except aiohttp.ClientError as e:
473
+ # ClientError covers all aiohttp requests issues
474
+ log.exception(f"Client error: {str(e)}")
475
+ # Handle aiohttp-specific connection issues, timeout etc.
476
+ raise HTTPException(
477
+ status_code=500, detail="Open WebUI: Server Connection Error"
478
+ )
479
+ except Exception as e:
480
+ log.exception(f"Unexpected error: {e}")
481
+ # Generic error handler in case parsing JSON or other steps fail
482
+ error_detail = f"Unexpected error: {str(e)}"
483
+ raise HTTPException(status_code=500, detail=error_detail)
484
+
485
+
486
+ @app.post("/chat/completions")
487
+ async def generate_chat_completion(
488
+ form_data: dict,
489
+ user=Depends(get_verified_user),
490
+ bypass_filter: Optional[bool] = False,
491
+ ):
492
+ idx = 0
493
+ payload = {**form_data}
494
+
495
+ if "metadata" in payload:
496
+ del payload["metadata"]
497
+
498
+ model_id = form_data.get("model")
499
+ model_info = Models.get_model_by_id(model_id)
500
+
501
+ # Check model info and override the payload
502
+ if model_info:
503
+ if model_info.base_model_id:
504
+ payload["model"] = model_info.base_model_id
505
+
506
+ params = model_info.params.model_dump()
507
+ payload = apply_model_params_to_body_openai(params, payload)
508
+ payload = apply_model_system_prompt_to_body(params, payload, user)
509
+
510
+ # Check if user has access to the model
511
+ if not bypass_filter and user.role == "user":
512
+ if not (
513
+ user.id == model_info.user_id
514
+ or has_access(
515
+ user.id, type="read", access_control=model_info.access_control
516
+ )
517
+ ):
518
+ raise HTTPException(
519
+ status_code=403,
520
+ detail="Model not found",
521
+ )
522
+ elif not bypass_filter:
523
+ if user.role != "admin":
524
+ raise HTTPException(
525
+ status_code=403,
526
+ detail="Model not found",
527
+ )
528
+
529
+ # Attemp to get urlIdx from the model
530
+ models = await get_all_models()
531
+
532
+ # Find the model from the list
533
+ model = next(
534
+ (model for model in models["data"] if model["id"] == payload.get("model")),
535
+ None,
536
+ )
537
+
538
+ if model:
539
+ idx = model["urlIdx"]
540
+ else:
541
+ raise HTTPException(
542
+ status_code=404,
543
+ detail="Model not found",
544
+ )
545
+
546
+ # Get the API config for the model
547
+ api_config = app.state.config.OPENAI_API_CONFIGS.get(
548
+ app.state.config.OPENAI_API_BASE_URLS[idx], {}
549
+ )
550
+ prefix_id = api_config.get("prefix_id", None)
551
+
552
+ if prefix_id:
553
+ payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
554
+
555
+ # Add user info to the payload if the model is a pipeline
556
+ if "pipeline" in model and model.get("pipeline"):
557
+ payload["user"] = {
558
+ "name": user.name,
559
+ "id": user.id,
560
+ "email": user.email,
561
+ "role": user.role,
562
+ }
563
+
564
+ url = app.state.config.OPENAI_API_BASE_URLS[idx]
565
+ key = app.state.config.OPENAI_API_KEYS[idx]
566
+
567
+ # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
568
+ is_o1 = payload["model"].lower().startswith("o1-")
569
+ # Change max_completion_tokens to max_tokens (Backward compatible)
570
+ if "api.openai.com" not in url and not is_o1:
571
+ if "max_completion_tokens" in payload:
572
+ # Remove "max_completion_tokens" from the payload
573
+ payload["max_tokens"] = payload["max_completion_tokens"]
574
+ del payload["max_completion_tokens"]
575
+ else:
576
+ if is_o1 and "max_tokens" in payload:
577
+ payload["max_completion_tokens"] = payload["max_tokens"]
578
+ del payload["max_tokens"]
579
+ if "max_tokens" in payload and "max_completion_tokens" in payload:
580
+ del payload["max_tokens"]
581
+
582
+ # Fix: O1 does not support the "system" parameter, Modify "system" to "user"
583
+ if is_o1 and payload["messages"][0]["role"] == "system":
584
+ payload["messages"][0]["role"] = "user"
585
+
586
+ # Convert the modified body back to JSON
587
+ payload = json.dumps(payload)
588
+
589
+ headers = {}
590
+ headers["Authorization"] = f"Bearer {key}"
591
+ headers["Content-Type"] = "application/json"
592
+ if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
593
+ headers["HTTP-Referer"] = "https://openwebui.com/"
594
+ headers["X-Title"] = "Open WebUI"
595
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
596
+ headers["X-OpenWebUI-User-Name"] = user.name
597
+ headers["X-OpenWebUI-User-Id"] = user.id
598
+ headers["X-OpenWebUI-User-Email"] = user.email
599
+ headers["X-OpenWebUI-User-Role"] = user.role
600
+
601
+ r = None
602
+ session = None
603
+ streaming = False
604
+ response = None
605
+
606
+ try:
607
+ session = aiohttp.ClientSession(
608
+ trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
609
+ )
610
+ r = await session.request(
611
+ method="POST",
612
+ url=f"{url}/chat/completions",
613
+ data=payload,
614
+ headers=headers,
615
+ )
616
+
617
+ # Check if response is SSE
618
+ if "text/event-stream" in r.headers.get("Content-Type", ""):
619
+ streaming = True
620
+ return StreamingResponse(
621
+ r.content,
622
+ status_code=r.status,
623
+ headers=dict(r.headers),
624
+ background=BackgroundTask(
625
+ cleanup_response, response=r, session=session
626
+ ),
627
+ )
628
+ else:
629
+ try:
630
+ response = await r.json()
631
+ except Exception as e:
632
+ log.error(e)
633
+ response = await r.text()
634
+
635
+ r.raise_for_status()
636
+ return response
637
+ except Exception as e:
638
+ log.exception(e)
639
+ error_detail = "Open WebUI: Server Connection Error"
640
+ if isinstance(response, dict):
641
+ if "error" in response:
642
+ error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
643
+ elif isinstance(response, str):
644
+ error_detail = response
645
+
646
+ raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
647
+ finally:
648
+ if not streaming and session:
649
+ if r:
650
+ r.close()
651
+ await session.close()
652
+
653
+
654
+ @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
655
+ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
656
+ idx = 0
657
+
658
+ body = await request.body()
659
+
660
+ url = app.state.config.OPENAI_API_BASE_URLS[idx]
661
+ key = app.state.config.OPENAI_API_KEYS[idx]
662
+
663
+ target_url = f"{url}/{path}"
664
+
665
+ headers = {}
666
+ headers["Authorization"] = f"Bearer {key}"
667
+ headers["Content-Type"] = "application/json"
668
+ if ENABLE_FORWARD_USER_INFO_HEADERS:
669
+ headers["X-OpenWebUI-User-Name"] = user.name
670
+ headers["X-OpenWebUI-User-Id"] = user.id
671
+ headers["X-OpenWebUI-User-Email"] = user.email
672
+ headers["X-OpenWebUI-User-Role"] = user.role
673
+
674
+ r = None
675
+ session = None
676
+ streaming = False
677
+
678
+ try:
679
+ session = aiohttp.ClientSession(trust_env=True)
680
+ r = await session.request(
681
+ method=request.method,
682
+ url=target_url,
683
+ data=body,
684
+ headers=headers,
685
+ )
686
+
687
+ r.raise_for_status()
688
+
689
+ # Check if response is SSE
690
+ if "text/event-stream" in r.headers.get("Content-Type", ""):
691
+ streaming = True
692
+ return StreamingResponse(
693
+ r.content,
694
+ status_code=r.status,
695
+ headers=dict(r.headers),
696
+ background=BackgroundTask(
697
+ cleanup_response, response=r, session=session
698
+ ),
699
+ )
700
+ else:
701
+ response_data = await r.json()
702
+ return response_data
703
+ except Exception as e:
704
+ log.exception(e)
705
+ error_detail = "Open WebUI: Server Connection Error"
706
+ if r is not None:
707
+ try:
708
+ res = await r.json()
709
+ print(res)
710
+ if "error" in res:
711
+ error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
712
+ except Exception:
713
+ error_detail = f"External: {e}"
714
+ raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
715
+ finally:
716
+ if not streaming and session:
717
+ if r:
718
+ r.close()
719
+ await session.close()
backend/open_webui/apps/retrieval/loaders/main.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import logging
3
+ import ftfy
4
+
5
+ from langchain_community.document_loaders import (
6
+ BSHTMLLoader,
7
+ CSVLoader,
8
+ Docx2txtLoader,
9
+ OutlookMessageLoader,
10
+ PyPDFLoader,
11
+ TextLoader,
12
+ UnstructuredEPubLoader,
13
+ UnstructuredExcelLoader,
14
+ UnstructuredMarkdownLoader,
15
+ UnstructuredPowerPointLoader,
16
+ UnstructuredRSTLoader,
17
+ UnstructuredXMLLoader,
18
+ YoutubeLoader,
19
+ )
20
+ from langchain_core.documents import Document
21
+ from open_webui.env import SRC_LOG_LEVELS
22
+
23
+ log = logging.getLogger(__name__)
24
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
25
+
26
+ known_source_ext = [
27
+ "go",
28
+ "py",
29
+ "java",
30
+ "sh",
31
+ "bat",
32
+ "ps1",
33
+ "cmd",
34
+ "js",
35
+ "ts",
36
+ "css",
37
+ "cpp",
38
+ "hpp",
39
+ "h",
40
+ "c",
41
+ "cs",
42
+ "sql",
43
+ "log",
44
+ "ini",
45
+ "pl",
46
+ "pm",
47
+ "r",
48
+ "dart",
49
+ "dockerfile",
50
+ "env",
51
+ "php",
52
+ "hs",
53
+ "hsc",
54
+ "lua",
55
+ "nginxconf",
56
+ "conf",
57
+ "m",
58
+ "mm",
59
+ "plsql",
60
+ "perl",
61
+ "rb",
62
+ "rs",
63
+ "db2",
64
+ "scala",
65
+ "bash",
66
+ "swift",
67
+ "vue",
68
+ "svelte",
69
+ "msg",
70
+ "ex",
71
+ "exs",
72
+ "erl",
73
+ "tsx",
74
+ "jsx",
75
+ "hs",
76
+ "lhs",
77
+ ]
78
+
79
+
80
+ class TikaLoader:
81
+ def __init__(self, url, file_path, mime_type=None):
82
+ self.url = url
83
+ self.file_path = file_path
84
+ self.mime_type = mime_type
85
+
86
+ def load(self) -> list[Document]:
87
+ with open(self.file_path, "rb") as f:
88
+ data = f.read()
89
+
90
+ if self.mime_type is not None:
91
+ headers = {"Content-Type": self.mime_type}
92
+ else:
93
+ headers = {}
94
+
95
+ endpoint = self.url
96
+ if not endpoint.endswith("/"):
97
+ endpoint += "/"
98
+ endpoint += "tika/text"
99
+
100
+ r = requests.put(endpoint, data=data, headers=headers)
101
+
102
+ if r.ok:
103
+ raw_metadata = r.json()
104
+ text = raw_metadata.get("X-TIKA:content", "<No text content found>")
105
+
106
+ if "Content-Type" in raw_metadata:
107
+ headers["Content-Type"] = raw_metadata["Content-Type"]
108
+
109
+ log.info("Tika extracted text: %s", text)
110
+
111
+ return [Document(page_content=text, metadata=headers)]
112
+ else:
113
+ raise Exception(f"Error calling Tika: {r.reason}")
114
+
115
+
116
+ class Loader:
117
+ def __init__(self, engine: str = "", **kwargs):
118
+ self.engine = engine
119
+ self.kwargs = kwargs
120
+
121
+ def load(
122
+ self, filename: str, file_content_type: str, file_path: str
123
+ ) -> list[Document]:
124
+ loader = self._get_loader(filename, file_content_type, file_path)
125
+ docs = loader.load()
126
+
127
+ return [
128
+ Document(
129
+ page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
130
+ )
131
+ for doc in docs
132
+ ]
133
+
134
+ def _get_loader(self, filename: str, file_content_type: str, file_path: str):
135
+ file_ext = filename.split(".")[-1].lower()
136
+
137
+ if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
138
+ if file_ext in known_source_ext or (
139
+ file_content_type and file_content_type.find("text/") >= 0
140
+ ):
141
+ loader = TextLoader(file_path, autodetect_encoding=True)
142
+ else:
143
+ loader = TikaLoader(
144
+ url=self.kwargs.get("TIKA_SERVER_URL"),
145
+ file_path=file_path,
146
+ mime_type=file_content_type,
147
+ )
148
+ else:
149
+ if file_ext == "pdf":
150
+ loader = PyPDFLoader(
151
+ file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES")
152
+ )
153
+ elif file_ext == "csv":
154
+ loader = CSVLoader(file_path)
155
+ elif file_ext == "rst":
156
+ loader = UnstructuredRSTLoader(file_path, mode="elements")
157
+ elif file_ext == "xml":
158
+ loader = UnstructuredXMLLoader(file_path)
159
+ elif file_ext in ["htm", "html"]:
160
+ loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
161
+ elif file_ext == "md":
162
+ loader = TextLoader(file_path, autodetect_encoding=True)
163
+ elif file_content_type == "application/epub+zip":
164
+ loader = UnstructuredEPubLoader(file_path)
165
+ elif (
166
+ file_content_type
167
+ == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
168
+ or file_ext == "docx"
169
+ ):
170
+ loader = Docx2txtLoader(file_path)
171
+ elif file_content_type in [
172
+ "application/vnd.ms-excel",
173
+ "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
174
+ ] or file_ext in ["xls", "xlsx"]:
175
+ loader = UnstructuredExcelLoader(file_path)
176
+ elif file_content_type in [
177
+ "application/vnd.ms-powerpoint",
178
+ "application/vnd.openxmlformats-officedocument.presentationml.presentation",
179
+ ] or file_ext in ["ppt", "pptx"]:
180
+ loader = UnstructuredPowerPointLoader(file_path)
181
+ elif file_ext == "msg":
182
+ loader = OutlookMessageLoader(file_path)
183
+ elif file_ext in known_source_ext or (
184
+ file_content_type and file_content_type.find("text/") >= 0
185
+ ):
186
+ loader = TextLoader(file_path, autodetect_encoding=True)
187
+ else:
188
+ loader = TextLoader(file_path, autodetect_encoding=True)
189
+
190
+ return loader
backend/open_webui/apps/retrieval/loaders/youtube.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from typing import Any, Dict, Generator, List, Optional, Sequence, Union
4
+ from urllib.parse import parse_qs, urlparse
5
+ from langchain_core.documents import Document
6
+ from open_webui.env import SRC_LOG_LEVELS
7
+
8
+ log = logging.getLogger(__name__)
9
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
10
+
11
+ ALLOWED_SCHEMES = {"http", "https"}
12
+ ALLOWED_NETLOCS = {
13
+ "youtu.be",
14
+ "m.youtube.com",
15
+ "youtube.com",
16
+ "www.youtube.com",
17
+ "www.youtube-nocookie.com",
18
+ "vid.plus",
19
+ }
20
+
21
+
22
+ def _parse_video_id(url: str) -> Optional[str]:
23
+ """Parse a YouTube URL and return the video ID if valid, otherwise None."""
24
+ parsed_url = urlparse(url)
25
+
26
+ if parsed_url.scheme not in ALLOWED_SCHEMES:
27
+ return None
28
+
29
+ if parsed_url.netloc not in ALLOWED_NETLOCS:
30
+ return None
31
+
32
+ path = parsed_url.path
33
+
34
+ if path.endswith("/watch"):
35
+ query = parsed_url.query
36
+ parsed_query = parse_qs(query)
37
+ if "v" in parsed_query:
38
+ ids = parsed_query["v"]
39
+ video_id = ids if isinstance(ids, str) else ids[0]
40
+ else:
41
+ return None
42
+ else:
43
+ path = parsed_url.path.lstrip("/")
44
+ video_id = path.split("/")[-1]
45
+
46
+ if len(video_id) != 11: # Video IDs are 11 characters long
47
+ return None
48
+
49
+ return video_id
50
+
51
+
52
+ class YoutubeLoader:
53
+ """Load `YouTube` video transcripts."""
54
+
55
+ def __init__(
56
+ self,
57
+ video_id: str,
58
+ language: Union[str, Sequence[str]] = "en",
59
+ proxy_url: Optional[str] = None,
60
+ ):
61
+ """Initialize with YouTube video ID."""
62
+ _video_id = _parse_video_id(video_id)
63
+ self.video_id = _video_id if _video_id is not None else video_id
64
+ self._metadata = {"source": video_id}
65
+ self.language = language
66
+ self.proxy_url = proxy_url
67
+ if isinstance(language, str):
68
+ self.language = [language]
69
+ else:
70
+ self.language = language
71
+
72
+ def load(self) -> List[Document]:
73
+ """Load YouTube transcripts into `Document` objects."""
74
+ try:
75
+ from youtube_transcript_api import (
76
+ NoTranscriptFound,
77
+ TranscriptsDisabled,
78
+ YouTubeTranscriptApi,
79
+ )
80
+ except ImportError:
81
+ raise ImportError(
82
+ 'Could not import "youtube_transcript_api" Python package. '
83
+ "Please install it with `pip install youtube-transcript-api`."
84
+ )
85
+
86
+ if self.proxy_url:
87
+ youtube_proxies = {
88
+ "http": self.proxy_url,
89
+ "https": self.proxy_url,
90
+ }
91
+ # Don't log complete URL because it might contain secrets
92
+ log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
93
+ else:
94
+ youtube_proxies = None
95
+
96
+ try:
97
+ transcript_list = YouTubeTranscriptApi.list_transcripts(
98
+ self.video_id, proxies=youtube_proxies
99
+ )
100
+ except Exception as e:
101
+ log.exception("Loading YouTube transcript failed")
102
+ return []
103
+
104
+ try:
105
+ transcript = transcript_list.find_transcript(self.language)
106
+ except NoTranscriptFound:
107
+ transcript = transcript_list.find_transcript(["en"])
108
+
109
+ transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
110
+
111
+ transcript = " ".join(
112
+ map(
113
+ lambda transcript_piece: transcript_piece["text"].strip(" "),
114
+ transcript_pieces,
115
+ )
116
+ )
117
+ return [Document(page_content=transcript, metadata=self._metadata)]
backend/open_webui/apps/retrieval/main.py ADDED
@@ -0,0 +1,1494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: Merge this with the webui_app and make it a single app
2
+
3
+ import json
4
+ import logging
5
+ import mimetypes
6
+ import os
7
+ import shutil
8
+
9
+ import uuid
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Iterator, Optional, Sequence, Union
13
+
14
+ from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from pydantic import BaseModel
17
+ import tiktoken
18
+
19
+
20
+ from open_webui.storage.provider import Storage
21
+ from open_webui.apps.webui.models.knowledge import Knowledges
22
+ from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
23
+
24
+ # Document loaders
25
+ from open_webui.apps.retrieval.loaders.main import Loader
26
+ from open_webui.apps.retrieval.loaders.youtube import YoutubeLoader
27
+
28
+ # Web search engines
29
+ from open_webui.apps.retrieval.web.main import SearchResult
30
+ from open_webui.apps.retrieval.web.utils import get_web_loader
31
+ from open_webui.apps.retrieval.web.brave import search_brave
32
+ from open_webui.apps.retrieval.web.mojeek import search_mojeek
33
+ from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo
34
+ from open_webui.apps.retrieval.web.google_pse import search_google_pse
35
+ from open_webui.apps.retrieval.web.jina_search import search_jina
36
+ from open_webui.apps.retrieval.web.searchapi import search_searchapi
37
+ from open_webui.apps.retrieval.web.searxng import search_searxng
38
+ from open_webui.apps.retrieval.web.serper import search_serper
39
+ from open_webui.apps.retrieval.web.serply import search_serply
40
+ from open_webui.apps.retrieval.web.serpstack import search_serpstack
41
+ from open_webui.apps.retrieval.web.tavily import search_tavily
42
+ from open_webui.apps.retrieval.web.bing import search_bing
43
+
44
+
45
+ from open_webui.apps.retrieval.utils import (
46
+ get_embedding_function,
47
+ get_model_path,
48
+ query_collection,
49
+ query_collection_with_hybrid_search,
50
+ query_doc,
51
+ query_doc_with_hybrid_search,
52
+ )
53
+
54
+ from open_webui.apps.webui.models.files import Files
55
+ from open_webui.config import (
56
+ BRAVE_SEARCH_API_KEY,
57
+ MOJEEK_SEARCH_API_KEY,
58
+ TIKTOKEN_ENCODING_NAME,
59
+ RAG_TEXT_SPLITTER,
60
+ CHUNK_OVERLAP,
61
+ CHUNK_SIZE,
62
+ CONTENT_EXTRACTION_ENGINE,
63
+ CORS_ALLOW_ORIGIN,
64
+ ENABLE_RAG_HYBRID_SEARCH,
65
+ ENABLE_RAG_LOCAL_WEB_FETCH,
66
+ ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
67
+ ENABLE_RAG_WEB_SEARCH,
68
+ ENV,
69
+ GOOGLE_PSE_API_KEY,
70
+ GOOGLE_PSE_ENGINE_ID,
71
+ PDF_EXTRACT_IMAGES,
72
+ RAG_EMBEDDING_ENGINE,
73
+ RAG_EMBEDDING_MODEL,
74
+ RAG_EMBEDDING_MODEL_AUTO_UPDATE,
75
+ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
76
+ RAG_EMBEDDING_BATCH_SIZE,
77
+ RAG_FILE_MAX_COUNT,
78
+ RAG_FILE_MAX_SIZE,
79
+ RAG_OPENAI_API_BASE_URL,
80
+ RAG_OPENAI_API_KEY,
81
+ RAG_OLLAMA_BASE_URL,
82
+ RAG_OLLAMA_API_KEY,
83
+ RAG_RELEVANCE_THRESHOLD,
84
+ RAG_RERANKING_MODEL,
85
+ RAG_RERANKING_MODEL_AUTO_UPDATE,
86
+ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
87
+ DEFAULT_RAG_TEMPLATE,
88
+ RAG_TEMPLATE,
89
+ RAG_TOP_K,
90
+ RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
91
+ RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
92
+ RAG_WEB_SEARCH_ENGINE,
93
+ RAG_WEB_SEARCH_RESULT_COUNT,
94
+ JINA_API_KEY,
95
+ SEARCHAPI_API_KEY,
96
+ SEARCHAPI_ENGINE,
97
+ SEARXNG_QUERY_URL,
98
+ SERPER_API_KEY,
99
+ SERPLY_API_KEY,
100
+ SERPSTACK_API_KEY,
101
+ SERPSTACK_HTTPS,
102
+ TAVILY_API_KEY,
103
+ BING_SEARCH_V7_ENDPOINT,
104
+ BING_SEARCH_V7_SUBSCRIPTION_KEY,
105
+ TIKA_SERVER_URL,
106
+ UPLOAD_DIR,
107
+ YOUTUBE_LOADER_LANGUAGE,
108
+ YOUTUBE_LOADER_PROXY_URL,
109
+ DEFAULT_LOCALE,
110
+ AppConfig,
111
+ )
112
+ from open_webui.constants import ERROR_MESSAGES
113
+ from open_webui.env import (
114
+ SRC_LOG_LEVELS,
115
+ DEVICE_TYPE,
116
+ DOCKER,
117
+ )
118
+ from open_webui.utils.misc import (
119
+ calculate_sha256,
120
+ calculate_sha256_string,
121
+ extract_folders_after_data_docs,
122
+ sanitize_filename,
123
+ )
124
+ from open_webui.utils.utils import get_admin_user, get_verified_user
125
+
126
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
127
+ from langchain_core.documents import Document
128
+
129
+
130
+ log = logging.getLogger(__name__)
131
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
132
+
133
+ app = FastAPI(
134
+ docs_url="/docs" if ENV == "dev" else None,
135
+ openapi_url="/openapi.json" if ENV == "dev" else None,
136
+ redoc_url=None,
137
+ )
138
+
139
+ app.state.config = AppConfig()
140
+
141
+ app.state.config.TOP_K = RAG_TOP_K
142
+ app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
143
+ app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
144
+ app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
145
+
146
+ app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
147
+ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
148
+ ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
149
+ )
150
+
151
+ app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
152
+ app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
153
+
154
+ app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
155
+ app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
156
+
157
+ app.state.config.CHUNK_SIZE = CHUNK_SIZE
158
+ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
159
+
160
+ app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
161
+ app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
162
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
163
+ app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
164
+ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
165
+
166
+ app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
167
+ app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
168
+
169
+ app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
170
+ app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
171
+
172
+ app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
173
+
174
+ app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
175
+ app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
176
+ app.state.YOUTUBE_LOADER_TRANSLATION = None
177
+
178
+
179
+ app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
180
+ app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
181
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
182
+
183
+ app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
184
+ app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
185
+ app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
186
+ app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
187
+ app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
188
+ app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
189
+ app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
190
+ app.state.config.SERPER_API_KEY = SERPER_API_KEY
191
+ app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
192
+ app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
193
+ app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
194
+ app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
195
+ app.state.config.JINA_API_KEY = JINA_API_KEY
196
+ app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
197
+ app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
198
+
199
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
200
+ app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
201
+
202
+
203
+ def update_embedding_model(
204
+ embedding_model: str,
205
+ auto_update: bool = False,
206
+ ):
207
+ if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
208
+ from sentence_transformers import SentenceTransformer
209
+
210
+ try:
211
+ app.state.sentence_transformer_ef = SentenceTransformer(
212
+ get_model_path(embedding_model, auto_update),
213
+ device=DEVICE_TYPE,
214
+ trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
215
+ )
216
+ except Exception as e:
217
+ log.debug(f"Error loading SentenceTransformer: {e}")
218
+ app.state.sentence_transformer_ef = None
219
+ else:
220
+ app.state.sentence_transformer_ef = None
221
+
222
+
223
+ def update_reranking_model(
224
+ reranking_model: str,
225
+ auto_update: bool = False,
226
+ ):
227
+ if reranking_model:
228
+ if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
229
+ try:
230
+ from open_webui.apps.retrieval.models.colbert import ColBERT
231
+
232
+ app.state.sentence_transformer_rf = ColBERT(
233
+ get_model_path(reranking_model, auto_update),
234
+ env="docker" if DOCKER else None,
235
+ )
236
+ except Exception as e:
237
+ log.error(f"ColBERT: {e}")
238
+ app.state.sentence_transformer_rf = None
239
+ app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
240
+ else:
241
+ import sentence_transformers
242
+
243
+ try:
244
+ app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
245
+ get_model_path(reranking_model, auto_update),
246
+ device=DEVICE_TYPE,
247
+ trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
248
+ )
249
+ except:
250
+ log.error("CrossEncoder error")
251
+ app.state.sentence_transformer_rf = None
252
+ app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
253
+ else:
254
+ app.state.sentence_transformer_rf = None
255
+
256
+
257
+ update_embedding_model(
258
+ app.state.config.RAG_EMBEDDING_MODEL,
259
+ RAG_EMBEDDING_MODEL_AUTO_UPDATE,
260
+ )
261
+
262
+ update_reranking_model(
263
+ app.state.config.RAG_RERANKING_MODEL,
264
+ RAG_RERANKING_MODEL_AUTO_UPDATE,
265
+ )
266
+
267
+
268
+ app.state.EMBEDDING_FUNCTION = get_embedding_function(
269
+ app.state.config.RAG_EMBEDDING_ENGINE,
270
+ app.state.config.RAG_EMBEDDING_MODEL,
271
+ app.state.sentence_transformer_ef,
272
+ (
273
+ app.state.config.OPENAI_API_BASE_URL
274
+ if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
275
+ else app.state.config.OLLAMA_BASE_URL
276
+ ),
277
+ (
278
+ app.state.config.OPENAI_API_KEY
279
+ if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
280
+ else app.state.config.OLLAMA_API_KEY
281
+ ),
282
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE,
283
+ )
284
+
285
+ app.add_middleware(
286
+ CORSMiddleware,
287
+ allow_origins=CORS_ALLOW_ORIGIN,
288
+ allow_credentials=True,
289
+ allow_methods=["*"],
290
+ allow_headers=["*"],
291
+ )
292
+
293
+
294
+ class CollectionNameForm(BaseModel):
295
+ collection_name: Optional[str] = None
296
+
297
+
298
+ class ProcessUrlForm(CollectionNameForm):
299
+ url: str
300
+
301
+
302
+ class SearchForm(CollectionNameForm):
303
+ query: str
304
+
305
+
306
+ @app.get("/")
307
+ async def get_status():
308
+ return {
309
+ "status": True,
310
+ "chunk_size": app.state.config.CHUNK_SIZE,
311
+ "chunk_overlap": app.state.config.CHUNK_OVERLAP,
312
+ "template": app.state.config.RAG_TEMPLATE,
313
+ "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
314
+ "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
315
+ "reranking_model": app.state.config.RAG_RERANKING_MODEL,
316
+ "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
317
+ }
318
+
319
+
320
+ @app.get("/embedding")
321
+ async def get_embedding_config(user=Depends(get_admin_user)):
322
+ return {
323
+ "status": True,
324
+ "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
325
+ "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
326
+ "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
327
+ "openai_config": {
328
+ "url": app.state.config.OPENAI_API_BASE_URL,
329
+ "key": app.state.config.OPENAI_API_KEY,
330
+ },
331
+ "ollama_config": {
332
+ "url": app.state.config.OLLAMA_BASE_URL,
333
+ "key": app.state.config.OLLAMA_API_KEY,
334
+ },
335
+ }
336
+
337
+
338
+ @app.get("/reranking")
339
+ async def get_reraanking_config(user=Depends(get_admin_user)):
340
+ return {
341
+ "status": True,
342
+ "reranking_model": app.state.config.RAG_RERANKING_MODEL,
343
+ }
344
+
345
+
346
+ class OpenAIConfigForm(BaseModel):
347
+ url: str
348
+ key: str
349
+
350
+
351
+ class OllamaConfigForm(BaseModel):
352
+ url: str
353
+ key: str
354
+
355
+
356
+ class EmbeddingModelUpdateForm(BaseModel):
357
+ openai_config: Optional[OpenAIConfigForm] = None
358
+ ollama_config: Optional[OllamaConfigForm] = None
359
+ embedding_engine: str
360
+ embedding_model: str
361
+ embedding_batch_size: Optional[int] = 1
362
+
363
+
364
+ @app.post("/embedding/update")
365
+ async def update_embedding_config(
366
+ form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
367
+ ):
368
+ log.info(
369
+ f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
370
+ )
371
+ try:
372
+ app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
373
+ app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
374
+
375
+ if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
376
+ if form_data.openai_config is not None:
377
+ app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
378
+ app.state.config.OPENAI_API_KEY = form_data.openai_config.key
379
+
380
+ if form_data.ollama_config is not None:
381
+ app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
382
+ app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
383
+
384
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
385
+
386
+ update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
387
+
388
+ app.state.EMBEDDING_FUNCTION = get_embedding_function(
389
+ app.state.config.RAG_EMBEDDING_ENGINE,
390
+ app.state.config.RAG_EMBEDDING_MODEL,
391
+ app.state.sentence_transformer_ef,
392
+ (
393
+ app.state.config.OPENAI_API_BASE_URL
394
+ if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
395
+ else app.state.config.OLLAMA_BASE_URL
396
+ ),
397
+ (
398
+ app.state.config.OPENAI_API_KEY
399
+ if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
400
+ else app.state.config.OLLAMA_API_KEY
401
+ ),
402
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE,
403
+ )
404
+
405
+ return {
406
+ "status": True,
407
+ "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
408
+ "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
409
+ "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
410
+ "openai_config": {
411
+ "url": app.state.config.OPENAI_API_BASE_URL,
412
+ "key": app.state.config.OPENAI_API_KEY,
413
+ },
414
+ "ollama_config": {
415
+ "url": app.state.config.OLLAMA_BASE_URL,
416
+ "key": app.state.config.OLLAMA_API_KEY,
417
+ },
418
+ }
419
+ except Exception as e:
420
+ log.exception(f"Problem updating embedding model: {e}")
421
+ raise HTTPException(
422
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
423
+ detail=ERROR_MESSAGES.DEFAULT(e),
424
+ )
425
+
426
+
427
+ class RerankingModelUpdateForm(BaseModel):
428
+ reranking_model: str
429
+
430
+
431
+ @app.post("/reranking/update")
432
+ async def update_reranking_config(
433
+ form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
434
+ ):
435
+ log.info(
436
+ f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
437
+ )
438
+ try:
439
+ app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
440
+
441
+ update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True)
442
+
443
+ return {
444
+ "status": True,
445
+ "reranking_model": app.state.config.RAG_RERANKING_MODEL,
446
+ }
447
+ except Exception as e:
448
+ log.exception(f"Problem updating reranking model: {e}")
449
+ raise HTTPException(
450
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
451
+ detail=ERROR_MESSAGES.DEFAULT(e),
452
+ )
453
+
454
+
455
+ @app.get("/config")
456
+ async def get_rag_config(user=Depends(get_admin_user)):
457
+ return {
458
+ "status": True,
459
+ "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
460
+ "content_extraction": {
461
+ "engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
462
+ "tika_server_url": app.state.config.TIKA_SERVER_URL,
463
+ },
464
+ "chunk": {
465
+ "text_splitter": app.state.config.TEXT_SPLITTER,
466
+ "chunk_size": app.state.config.CHUNK_SIZE,
467
+ "chunk_overlap": app.state.config.CHUNK_OVERLAP,
468
+ },
469
+ "file": {
470
+ "max_size": app.state.config.FILE_MAX_SIZE,
471
+ "max_count": app.state.config.FILE_MAX_COUNT,
472
+ },
473
+ "youtube": {
474
+ "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
475
+ "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
476
+ "proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL,
477
+ },
478
+ "web": {
479
+ "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
480
+ "search": {
481
+ "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
482
+ "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
483
+ "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
484
+ "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
485
+ "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
486
+ "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
487
+ "mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY,
488
+ "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
489
+ "serpstack_https": app.state.config.SERPSTACK_HTTPS,
490
+ "serper_api_key": app.state.config.SERPER_API_KEY,
491
+ "serply_api_key": app.state.config.SERPLY_API_KEY,
492
+ "tavily_api_key": app.state.config.TAVILY_API_KEY,
493
+ "searchapi_api_key": app.state.config.SEARCHAPI_API_KEY,
494
+ "seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE,
495
+ "jina_api_key": app.state.config.JINA_API_KEY,
496
+ "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT,
497
+ "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
498
+ "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
499
+ "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
500
+ },
501
+ },
502
+ }
503
+
504
+
505
+ class FileConfig(BaseModel):
506
+ max_size: Optional[int] = None
507
+ max_count: Optional[int] = None
508
+
509
+
510
+ class ContentExtractionConfig(BaseModel):
511
+ engine: str = ""
512
+ tika_server_url: Optional[str] = None
513
+
514
+
515
+ class ChunkParamUpdateForm(BaseModel):
516
+ text_splitter: Optional[str] = None
517
+ chunk_size: int
518
+ chunk_overlap: int
519
+
520
+
521
+ class YoutubeLoaderConfig(BaseModel):
522
+ language: list[str]
523
+ translation: Optional[str] = None
524
+ proxy_url: str = ""
525
+
526
+
527
+ class WebSearchConfig(BaseModel):
528
+ enabled: bool
529
+ engine: Optional[str] = None
530
+ searxng_query_url: Optional[str] = None
531
+ google_pse_api_key: Optional[str] = None
532
+ google_pse_engine_id: Optional[str] = None
533
+ brave_search_api_key: Optional[str] = None
534
+ mojeek_search_api_key: Optional[str] = None
535
+ serpstack_api_key: Optional[str] = None
536
+ serpstack_https: Optional[bool] = None
537
+ serper_api_key: Optional[str] = None
538
+ serply_api_key: Optional[str] = None
539
+ tavily_api_key: Optional[str] = None
540
+ searchapi_api_key: Optional[str] = None
541
+ searchapi_engine: Optional[str] = None
542
+ jina_api_key: Optional[str] = None
543
+ bing_search_v7_endpoint: Optional[str] = None
544
+ bing_search_v7_subscription_key: Optional[str] = None
545
+ result_count: Optional[int] = None
546
+ concurrent_requests: Optional[int] = None
547
+
548
+
549
+ class WebConfig(BaseModel):
550
+ search: WebSearchConfig
551
+ web_loader_ssl_verification: Optional[bool] = None
552
+
553
+
554
+ class ConfigUpdateForm(BaseModel):
555
+ pdf_extract_images: Optional[bool] = None
556
+ file: Optional[FileConfig] = None
557
+ content_extraction: Optional[ContentExtractionConfig] = None
558
+ chunk: Optional[ChunkParamUpdateForm] = None
559
+ youtube: Optional[YoutubeLoaderConfig] = None
560
+ web: Optional[WebConfig] = None
561
+
562
+
563
+ @app.post("/config/update")
564
+ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
565
+ app.state.config.PDF_EXTRACT_IMAGES = (
566
+ form_data.pdf_extract_images
567
+ if form_data.pdf_extract_images is not None
568
+ else app.state.config.PDF_EXTRACT_IMAGES
569
+ )
570
+
571
+ if form_data.file is not None:
572
+ app.state.config.FILE_MAX_SIZE = form_data.file.max_size
573
+ app.state.config.FILE_MAX_COUNT = form_data.file.max_count
574
+
575
+ if form_data.content_extraction is not None:
576
+ log.info(f"Updating text settings: {form_data.content_extraction}")
577
+ app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine
578
+ app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
579
+
580
+ if form_data.chunk is not None:
581
+ app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
582
+ app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
583
+ app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
584
+
585
+ if form_data.youtube is not None:
586
+ app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
587
+ app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url
588
+ app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
589
+
590
+ if form_data.web is not None:
591
+ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
592
+ # Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
593
+ form_data.web.web_loader_ssl_verification
594
+ )
595
+
596
+ app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
597
+ app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
598
+ app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
599
+ app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
600
+ app.state.config.GOOGLE_PSE_ENGINE_ID = (
601
+ form_data.web.search.google_pse_engine_id
602
+ )
603
+ app.state.config.BRAVE_SEARCH_API_KEY = (
604
+ form_data.web.search.brave_search_api_key
605
+ )
606
+ app.state.config.MOJEEK_SEARCH_API_KEY = (
607
+ form_data.web.search.mojeek_search_api_key
608
+ )
609
+ app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
610
+ app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
611
+ app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
612
+ app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
613
+ app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
614
+ app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
615
+ app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
616
+
617
+ app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
618
+ app.state.config.BING_SEARCH_V7_ENDPOINT = (
619
+ form_data.web.search.bing_search_v7_endpoint
620
+ )
621
+ app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = (
622
+ form_data.web.search.bing_search_v7_subscription_key
623
+ )
624
+
625
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
626
+ app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
627
+ form_data.web.search.concurrent_requests
628
+ )
629
+
630
+ return {
631
+ "status": True,
632
+ "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
633
+ "file": {
634
+ "max_size": app.state.config.FILE_MAX_SIZE,
635
+ "max_count": app.state.config.FILE_MAX_COUNT,
636
+ },
637
+ "content_extraction": {
638
+ "engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
639
+ "tika_server_url": app.state.config.TIKA_SERVER_URL,
640
+ },
641
+ "chunk": {
642
+ "text_splitter": app.state.config.TEXT_SPLITTER,
643
+ "chunk_size": app.state.config.CHUNK_SIZE,
644
+ "chunk_overlap": app.state.config.CHUNK_OVERLAP,
645
+ },
646
+ "youtube": {
647
+ "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
648
+ "proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL,
649
+ "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
650
+ },
651
+ "web": {
652
+ "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
653
+ "search": {
654
+ "enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
655
+ "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
656
+ "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
657
+ "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
658
+ "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
659
+ "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
660
+ "mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY,
661
+ "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
662
+ "serpstack_https": app.state.config.SERPSTACK_HTTPS,
663
+ "serper_api_key": app.state.config.SERPER_API_KEY,
664
+ "serply_api_key": app.state.config.SERPLY_API_KEY,
665
+ "serachapi_api_key": app.state.config.SEARCHAPI_API_KEY,
666
+ "searchapi_engine": app.state.config.SEARCHAPI_ENGINE,
667
+ "tavily_api_key": app.state.config.TAVILY_API_KEY,
668
+ "jina_api_key": app.state.config.JINA_API_KEY,
669
+ "bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT,
670
+ "bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
671
+ "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
672
+ "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
673
+ },
674
+ },
675
+ }
676
+
677
+
678
+ @app.get("/template")
679
+ async def get_rag_template(user=Depends(get_verified_user)):
680
+ return {
681
+ "status": True,
682
+ "template": app.state.config.RAG_TEMPLATE,
683
+ }
684
+
685
+
686
+ @app.get("/query/settings")
687
+ async def get_query_settings(user=Depends(get_admin_user)):
688
+ return {
689
+ "status": True,
690
+ "template": app.state.config.RAG_TEMPLATE,
691
+ "k": app.state.config.TOP_K,
692
+ "r": app.state.config.RELEVANCE_THRESHOLD,
693
+ "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
694
+ }
695
+
696
+
697
+ class QuerySettingsForm(BaseModel):
698
+ k: Optional[int] = None
699
+ r: Optional[float] = None
700
+ template: Optional[str] = None
701
+ hybrid: Optional[bool] = None
702
+
703
+
704
+ @app.post("/query/settings/update")
705
+ async def update_query_settings(
706
+ form_data: QuerySettingsForm, user=Depends(get_admin_user)
707
+ ):
708
+ app.state.config.RAG_TEMPLATE = form_data.template
709
+ app.state.config.TOP_K = form_data.k if form_data.k else 4
710
+ app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
711
+
712
+ app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
713
+ form_data.hybrid if form_data.hybrid else False
714
+ )
715
+
716
+ return {
717
+ "status": True,
718
+ "template": app.state.config.RAG_TEMPLATE,
719
+ "k": app.state.config.TOP_K,
720
+ "r": app.state.config.RELEVANCE_THRESHOLD,
721
+ "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
722
+ }
723
+
724
+
725
+ ####################################
726
+ #
727
+ # Document process and retrieval
728
+ #
729
+ ####################################
730
+
731
+
732
+ def _get_docs_info(docs: list[Document]) -> str:
733
+ docs_info = set()
734
+
735
+ # Trying to select relevant metadata identifying the document.
736
+ for doc in docs:
737
+ metadata = getattr(doc, "metadata", {})
738
+ doc_name = metadata.get("name", "")
739
+ if not doc_name:
740
+ doc_name = metadata.get("title", "")
741
+ if not doc_name:
742
+ doc_name = metadata.get("source", "")
743
+ if doc_name:
744
+ docs_info.add(doc_name)
745
+
746
+ return ", ".join(docs_info)
747
+
748
+
749
+ def save_docs_to_vector_db(
750
+ docs,
751
+ collection_name,
752
+ metadata: Optional[dict] = None,
753
+ overwrite: bool = False,
754
+ split: bool = True,
755
+ add: bool = False,
756
+ ) -> bool:
757
+ log.info(
758
+ f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
759
+ )
760
+
761
+ # Check if entries with the same hash (metadata.hash) already exist
762
+ if metadata and "hash" in metadata:
763
+ result = VECTOR_DB_CLIENT.query(
764
+ collection_name=collection_name,
765
+ filter={"hash": metadata["hash"]},
766
+ )
767
+
768
+ if result is not None:
769
+ existing_doc_ids = result.ids[0]
770
+ if existing_doc_ids:
771
+ log.info(f"Document with hash {metadata['hash']} already exists")
772
+ raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
773
+
774
+ if split:
775
+ if app.state.config.TEXT_SPLITTER in ["", "character"]:
776
+ text_splitter = RecursiveCharacterTextSplitter(
777
+ chunk_size=app.state.config.CHUNK_SIZE,
778
+ chunk_overlap=app.state.config.CHUNK_OVERLAP,
779
+ add_start_index=True,
780
+ )
781
+ elif app.state.config.TEXT_SPLITTER == "token":
782
+ log.info(
783
+ f"Using token text splitter: {app.state.config.TIKTOKEN_ENCODING_NAME}"
784
+ )
785
+
786
+ tiktoken.get_encoding(str(app.state.config.TIKTOKEN_ENCODING_NAME))
787
+ text_splitter = TokenTextSplitter(
788
+ encoding_name=str(app.state.config.TIKTOKEN_ENCODING_NAME),
789
+ chunk_size=app.state.config.CHUNK_SIZE,
790
+ chunk_overlap=app.state.config.CHUNK_OVERLAP,
791
+ add_start_index=True,
792
+ )
793
+ else:
794
+ raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
795
+
796
+ docs = text_splitter.split_documents(docs)
797
+
798
+ if len(docs) == 0:
799
+ raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
800
+
801
+ texts = [doc.page_content for doc in docs]
802
+ metadatas = [
803
+ {
804
+ **doc.metadata,
805
+ **(metadata if metadata else {}),
806
+ "embedding_config": json.dumps(
807
+ {
808
+ "engine": app.state.config.RAG_EMBEDDING_ENGINE,
809
+ "model": app.state.config.RAG_EMBEDDING_MODEL,
810
+ }
811
+ ),
812
+ }
813
+ for doc in docs
814
+ ]
815
+
816
+ # ChromaDB does not like datetime formats
817
+ # for meta-data so convert them to string.
818
+ for metadata in metadatas:
819
+ for key, value in metadata.items():
820
+ if isinstance(value, datetime):
821
+ metadata[key] = str(value)
822
+
823
+ try:
824
+ if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
825
+ log.info(f"collection {collection_name} already exists")
826
+
827
+ if overwrite:
828
+ VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
829
+ log.info(f"deleting existing collection {collection_name}")
830
+ elif add is False:
831
+ log.info(
832
+ f"collection {collection_name} already exists, overwrite is False and add is False"
833
+ )
834
+ return True
835
+
836
+ log.info(f"adding to collection {collection_name}")
837
+ embedding_function = get_embedding_function(
838
+ app.state.config.RAG_EMBEDDING_ENGINE,
839
+ app.state.config.RAG_EMBEDDING_MODEL,
840
+ app.state.sentence_transformer_ef,
841
+ (
842
+ app.state.config.OPENAI_API_BASE_URL
843
+ if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
844
+ else app.state.config.OLLAMA_BASE_URL
845
+ ),
846
+ (
847
+ app.state.config.OPENAI_API_KEY
848
+ if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
849
+ else app.state.config.OLLAMA_API_KEY
850
+ ),
851
+ app.state.config.RAG_EMBEDDING_BATCH_SIZE,
852
+ )
853
+
854
+ embeddings = embedding_function(
855
+ list(map(lambda x: x.replace("\n", " "), texts))
856
+ )
857
+
858
+ items = [
859
+ {
860
+ "id": str(uuid.uuid4()),
861
+ "text": text,
862
+ "vector": embeddings[idx],
863
+ "metadata": metadatas[idx],
864
+ }
865
+ for idx, text in enumerate(texts)
866
+ ]
867
+
868
+ VECTOR_DB_CLIENT.insert(
869
+ collection_name=collection_name,
870
+ items=items,
871
+ )
872
+
873
+ return True
874
+ except Exception as e:
875
+ log.exception(e)
876
+ raise e
877
+
878
+
879
+ class ProcessFileForm(BaseModel):
880
+ file_id: str
881
+ content: Optional[str] = None
882
+ collection_name: Optional[str] = None
883
+
884
+
885
+ @app.post("/process/file")
886
+ def process_file(
887
+ form_data: ProcessFileForm,
888
+ user=Depends(get_verified_user),
889
+ ):
890
+ try:
891
+ file = Files.get_file_by_id(form_data.file_id)
892
+
893
+ collection_name = form_data.collection_name
894
+
895
+ if collection_name is None:
896
+ collection_name = f"file-{file.id}"
897
+
898
+ if form_data.content:
899
+ # Update the content in the file
900
+ # Usage: /files/{file_id}/data/content/update
901
+
902
+ VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
903
+
904
+ docs = [
905
+ Document(
906
+ page_content=form_data.content.replace("<br/>", "\n"),
907
+ metadata={
908
+ **file.meta,
909
+ "name": file.filename,
910
+ "created_by": file.user_id,
911
+ "file_id": file.id,
912
+ "source": file.filename,
913
+ },
914
+ )
915
+ ]
916
+
917
+ text_content = form_data.content
918
+ elif form_data.collection_name:
919
+ # Check if the file has already been processed and save the content
920
+ # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
921
+
922
+ result = VECTOR_DB_CLIENT.query(
923
+ collection_name=f"file-{file.id}", filter={"file_id": file.id}
924
+ )
925
+
926
+ if result is not None and len(result.ids[0]) > 0:
927
+ docs = [
928
+ Document(
929
+ page_content=result.documents[0][idx],
930
+ metadata=result.metadatas[0][idx],
931
+ )
932
+ for idx, id in enumerate(result.ids[0])
933
+ ]
934
+ else:
935
+ docs = [
936
+ Document(
937
+ page_content=file.data.get("content", ""),
938
+ metadata={
939
+ **file.meta,
940
+ "name": file.filename,
941
+ "created_by": file.user_id,
942
+ "file_id": file.id,
943
+ "source": file.filename,
944
+ },
945
+ )
946
+ ]
947
+
948
+ text_content = file.data.get("content", "")
949
+ else:
950
+ # Process the file and save the content
951
+ # Usage: /files/
952
+ file_path = file.path
953
+ if file_path:
954
+ file_path = Storage.get_file(file_path)
955
+ loader = Loader(
956
+ engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
957
+ TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
958
+ PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
959
+ )
960
+ docs = loader.load(
961
+ file.filename, file.meta.get("content_type"), file_path
962
+ )
963
+
964
+ docs = [
965
+ Document(
966
+ page_content=doc.page_content,
967
+ metadata={
968
+ **doc.metadata,
969
+ "name": file.filename,
970
+ "created_by": file.user_id,
971
+ "file_id": file.id,
972
+ "source": file.filename,
973
+ },
974
+ )
975
+ for doc in docs
976
+ ]
977
+ else:
978
+ docs = [
979
+ Document(
980
+ page_content=file.data.get("content", ""),
981
+ metadata={
982
+ **file.meta,
983
+ "name": file.filename,
984
+ "created_by": file.user_id,
985
+ "file_id": file.id,
986
+ "source": file.filename,
987
+ },
988
+ )
989
+ ]
990
+ text_content = " ".join([doc.page_content for doc in docs])
991
+
992
+ log.debug(f"text_content: {text_content}")
993
+ Files.update_file_data_by_id(
994
+ file.id,
995
+ {"content": text_content},
996
+ )
997
+
998
+ hash = calculate_sha256_string(text_content)
999
+ Files.update_file_hash_by_id(file.id, hash)
1000
+
1001
+ try:
1002
+ result = save_docs_to_vector_db(
1003
+ docs=docs,
1004
+ collection_name=collection_name,
1005
+ metadata={
1006
+ "file_id": file.id,
1007
+ "name": file.filename,
1008
+ "hash": hash,
1009
+ },
1010
+ add=(True if form_data.collection_name else False),
1011
+ )
1012
+
1013
+ if result:
1014
+ Files.update_file_metadata_by_id(
1015
+ file.id,
1016
+ {
1017
+ "collection_name": collection_name,
1018
+ },
1019
+ )
1020
+
1021
+ return {
1022
+ "status": True,
1023
+ "collection_name": collection_name,
1024
+ "filename": file.filename,
1025
+ "content": text_content,
1026
+ }
1027
+ except Exception as e:
1028
+ raise e
1029
+ except Exception as e:
1030
+ log.exception(e)
1031
+ if "No pandoc was found" in str(e):
1032
+ raise HTTPException(
1033
+ status_code=status.HTTP_400_BAD_REQUEST,
1034
+ detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
1035
+ )
1036
+ else:
1037
+ raise HTTPException(
1038
+ status_code=status.HTTP_400_BAD_REQUEST,
1039
+ detail=str(e),
1040
+ )
1041
+
1042
+
1043
+ class ProcessTextForm(BaseModel):
1044
+ name: str
1045
+ content: str
1046
+ collection_name: Optional[str] = None
1047
+
1048
+
1049
+ @app.post("/process/text")
1050
+ def process_text(
1051
+ form_data: ProcessTextForm,
1052
+ user=Depends(get_verified_user),
1053
+ ):
1054
+ collection_name = form_data.collection_name
1055
+ if collection_name is None:
1056
+ collection_name = calculate_sha256_string(form_data.content)
1057
+
1058
+ docs = [
1059
+ Document(
1060
+ page_content=form_data.content,
1061
+ metadata={"name": form_data.name, "created_by": user.id},
1062
+ )
1063
+ ]
1064
+ text_content = form_data.content
1065
+ log.debug(f"text_content: {text_content}")
1066
+
1067
+ result = save_docs_to_vector_db(docs, collection_name)
1068
+
1069
+ if result:
1070
+ return {
1071
+ "status": True,
1072
+ "collection_name": collection_name,
1073
+ "content": text_content,
1074
+ }
1075
+ else:
1076
+ raise HTTPException(
1077
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1078
+ detail=ERROR_MESSAGES.DEFAULT(),
1079
+ )
1080
+
1081
+
1082
+ @app.post("/process/youtube")
1083
+ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
1084
+ try:
1085
+ collection_name = form_data.collection_name
1086
+ if not collection_name:
1087
+ collection_name = calculate_sha256_string(form_data.url)[:63]
1088
+
1089
+ loader = YoutubeLoader(
1090
+ form_data.url,
1091
+ language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
1092
+ proxy_url=app.state.config.YOUTUBE_LOADER_PROXY_URL,
1093
+ )
1094
+
1095
+ docs = loader.load()
1096
+ content = " ".join([doc.page_content for doc in docs])
1097
+ log.debug(f"text_content: {content}")
1098
+ save_docs_to_vector_db(docs, collection_name, overwrite=True)
1099
+
1100
+ return {
1101
+ "status": True,
1102
+ "collection_name": collection_name,
1103
+ "filename": form_data.url,
1104
+ "file": {
1105
+ "data": {
1106
+ "content": content,
1107
+ },
1108
+ "meta": {
1109
+ "name": form_data.url,
1110
+ },
1111
+ },
1112
+ }
1113
+ except Exception as e:
1114
+ log.exception(e)
1115
+ raise HTTPException(
1116
+ status_code=status.HTTP_400_BAD_REQUEST,
1117
+ detail=ERROR_MESSAGES.DEFAULT(e),
1118
+ )
1119
+
1120
+
1121
+ @app.post("/process/web")
1122
+ def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
1123
+ try:
1124
+ collection_name = form_data.collection_name
1125
+ if not collection_name:
1126
+ collection_name = calculate_sha256_string(form_data.url)[:63]
1127
+
1128
+ loader = get_web_loader(
1129
+ form_data.url,
1130
+ verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
1131
+ requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
1132
+ )
1133
+ docs = loader.load()
1134
+ content = " ".join([doc.page_content for doc in docs])
1135
+ log.debug(f"text_content: {content}")
1136
+ save_docs_to_vector_db(docs, collection_name, overwrite=True)
1137
+
1138
+ return {
1139
+ "status": True,
1140
+ "collection_name": collection_name,
1141
+ "filename": form_data.url,
1142
+ "file": {
1143
+ "data": {
1144
+ "content": content,
1145
+ },
1146
+ "meta": {
1147
+ "name": form_data.url,
1148
+ },
1149
+ },
1150
+ }
1151
+ except Exception as e:
1152
+ log.exception(e)
1153
+ raise HTTPException(
1154
+ status_code=status.HTTP_400_BAD_REQUEST,
1155
+ detail=ERROR_MESSAGES.DEFAULT(e),
1156
+ )
1157
+
1158
+
1159
+ def search_web(engine: str, query: str) -> list[SearchResult]:
1160
+ """Search the web using a search engine and return the results as a list of SearchResult objects.
1161
+ Will look for a search engine API key in environment variables in the following order:
1162
+ - SEARXNG_QUERY_URL
1163
+ - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
1164
+ - BRAVE_SEARCH_API_KEY
1165
+ - MOJEEK_SEARCH_API_KEY
1166
+ - SERPSTACK_API_KEY
1167
+ - SERPER_API_KEY
1168
+ - SERPLY_API_KEY
1169
+ - TAVILY_API_KEY
1170
+ - SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
1171
+ Args:
1172
+ query (str): The query to search for
1173
+ """
1174
+
1175
+ # TODO: add playwright to search the web
1176
+ if engine == "searxng":
1177
+ if app.state.config.SEARXNG_QUERY_URL:
1178
+ return search_searxng(
1179
+ app.state.config.SEARXNG_QUERY_URL,
1180
+ query,
1181
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1182
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1183
+ )
1184
+ else:
1185
+ raise Exception("No SEARXNG_QUERY_URL found in environment variables")
1186
+ elif engine == "google_pse":
1187
+ if (
1188
+ app.state.config.GOOGLE_PSE_API_KEY
1189
+ and app.state.config.GOOGLE_PSE_ENGINE_ID
1190
+ ):
1191
+ return search_google_pse(
1192
+ app.state.config.GOOGLE_PSE_API_KEY,
1193
+ app.state.config.GOOGLE_PSE_ENGINE_ID,
1194
+ query,
1195
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1196
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1197
+ )
1198
+ else:
1199
+ raise Exception(
1200
+ "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
1201
+ )
1202
+ elif engine == "brave":
1203
+ if app.state.config.BRAVE_SEARCH_API_KEY:
1204
+ return search_brave(
1205
+ app.state.config.BRAVE_SEARCH_API_KEY,
1206
+ query,
1207
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1208
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1209
+ )
1210
+ else:
1211
+ raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
1212
+ elif engine == "mojeek":
1213
+ if app.state.config.MOJEEK_SEARCH_API_KEY:
1214
+ return search_mojeek(
1215
+ app.state.config.MOJEEK_SEARCH_API_KEY,
1216
+ query,
1217
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1218
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1219
+ )
1220
+ else:
1221
+ raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
1222
+ elif engine == "serpstack":
1223
+ if app.state.config.SERPSTACK_API_KEY:
1224
+ return search_serpstack(
1225
+ app.state.config.SERPSTACK_API_KEY,
1226
+ query,
1227
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1228
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1229
+ https_enabled=app.state.config.SERPSTACK_HTTPS,
1230
+ )
1231
+ else:
1232
+ raise Exception("No SERPSTACK_API_KEY found in environment variables")
1233
+ elif engine == "serper":
1234
+ if app.state.config.SERPER_API_KEY:
1235
+ return search_serper(
1236
+ app.state.config.SERPER_API_KEY,
1237
+ query,
1238
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1239
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1240
+ )
1241
+ else:
1242
+ raise Exception("No SERPER_API_KEY found in environment variables")
1243
+ elif engine == "serply":
1244
+ if app.state.config.SERPLY_API_KEY:
1245
+ return search_serply(
1246
+ app.state.config.SERPLY_API_KEY,
1247
+ query,
1248
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1249
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1250
+ )
1251
+ else:
1252
+ raise Exception("No SERPLY_API_KEY found in environment variables")
1253
+ elif engine == "duckduckgo":
1254
+ return search_duckduckgo(
1255
+ query,
1256
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1257
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1258
+ )
1259
+ elif engine == "tavily":
1260
+ if app.state.config.TAVILY_API_KEY:
1261
+ return search_tavily(
1262
+ app.state.config.TAVILY_API_KEY,
1263
+ query,
1264
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1265
+ )
1266
+ else:
1267
+ raise Exception("No TAVILY_API_KEY found in environment variables")
1268
+ elif engine == "searchapi":
1269
+ if app.state.config.SEARCHAPI_API_KEY:
1270
+ return search_searchapi(
1271
+ app.state.config.SEARCHAPI_API_KEY,
1272
+ app.state.config.SEARCHAPI_ENGINE,
1273
+ query,
1274
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1275
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1276
+ )
1277
+ else:
1278
+ raise Exception("No SEARCHAPI_API_KEY found in environment variables")
1279
+ elif engine == "jina":
1280
+ return search_jina(
1281
+ app.state.config.JINA_API_KEY,
1282
+ query,
1283
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1284
+ )
1285
+ elif engine == "bing":
1286
+ return search_bing(
1287
+ app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
1288
+ app.state.config.BING_SEARCH_V7_ENDPOINT,
1289
+ str(DEFAULT_LOCALE),
1290
+ query,
1291
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
1292
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
1293
+ )
1294
+ else:
1295
+ raise Exception("No search engine API key found in environment variables")
1296
+
1297
+
1298
+ @app.post("/process/web/search")
1299
+ def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
1300
+ try:
1301
+ logging.info(
1302
+ f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
1303
+ )
1304
+ web_results = search_web(
1305
+ app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
1306
+ )
1307
+ except Exception as e:
1308
+ log.exception(e)
1309
+
1310
+ print(e)
1311
+ raise HTTPException(
1312
+ status_code=status.HTTP_400_BAD_REQUEST,
1313
+ detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
1314
+ )
1315
+
1316
+ try:
1317
+ collection_name = form_data.collection_name
1318
+ if collection_name == "":
1319
+ collection_name = calculate_sha256_string(form_data.query)[:63]
1320
+
1321
+ urls = [result.link for result in web_results]
1322
+
1323
+ loader = get_web_loader(
1324
+ urls,
1325
+ verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
1326
+ requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
1327
+ )
1328
+ docs = loader.aload()
1329
+
1330
+ save_docs_to_vector_db(docs, collection_name, overwrite=True)
1331
+
1332
+ return {
1333
+ "status": True,
1334
+ "collection_name": collection_name,
1335
+ "filenames": urls,
1336
+ }
1337
+ except Exception as e:
1338
+ log.exception(e)
1339
+ raise HTTPException(
1340
+ status_code=status.HTTP_400_BAD_REQUEST,
1341
+ detail=ERROR_MESSAGES.DEFAULT(e),
1342
+ )
1343
+
1344
+
1345
+ class QueryDocForm(BaseModel):
1346
+ collection_name: str
1347
+ query: str
1348
+ k: Optional[int] = None
1349
+ r: Optional[float] = None
1350
+ hybrid: Optional[bool] = None
1351
+
1352
+
1353
+ @app.post("/query/doc")
1354
+ def query_doc_handler(
1355
+ form_data: QueryDocForm,
1356
+ user=Depends(get_verified_user),
1357
+ ):
1358
+ try:
1359
+ if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
1360
+ return query_doc_with_hybrid_search(
1361
+ collection_name=form_data.collection_name,
1362
+ query=form_data.query,
1363
+ embedding_function=app.state.EMBEDDING_FUNCTION,
1364
+ k=form_data.k if form_data.k else app.state.config.TOP_K,
1365
+ reranking_function=app.state.sentence_transformer_rf,
1366
+ r=(
1367
+ form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
1368
+ ),
1369
+ )
1370
+ else:
1371
+ return query_doc(
1372
+ collection_name=form_data.collection_name,
1373
+ query=form_data.query,
1374
+ embedding_function=app.state.EMBEDDING_FUNCTION,
1375
+ k=form_data.k if form_data.k else app.state.config.TOP_K,
1376
+ )
1377
+ except Exception as e:
1378
+ log.exception(e)
1379
+ raise HTTPException(
1380
+ status_code=status.HTTP_400_BAD_REQUEST,
1381
+ detail=ERROR_MESSAGES.DEFAULT(e),
1382
+ )
1383
+
1384
+
1385
+ class QueryCollectionsForm(BaseModel):
1386
+ collection_names: list[str]
1387
+ query: str
1388
+ k: Optional[int] = None
1389
+ r: Optional[float] = None
1390
+ hybrid: Optional[bool] = None
1391
+
1392
+
1393
+ @app.post("/query/collection")
1394
+ def query_collection_handler(
1395
+ form_data: QueryCollectionsForm,
1396
+ user=Depends(get_verified_user),
1397
+ ):
1398
+ try:
1399
+ if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
1400
+ return query_collection_with_hybrid_search(
1401
+ collection_names=form_data.collection_names,
1402
+ queries=[form_data.query],
1403
+ embedding_function=app.state.EMBEDDING_FUNCTION,
1404
+ k=form_data.k if form_data.k else app.state.config.TOP_K,
1405
+ reranking_function=app.state.sentence_transformer_rf,
1406
+ r=(
1407
+ form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
1408
+ ),
1409
+ )
1410
+ else:
1411
+ return query_collection(
1412
+ collection_names=form_data.collection_names,
1413
+ queries=[form_data.query],
1414
+ embedding_function=app.state.EMBEDDING_FUNCTION,
1415
+ k=form_data.k if form_data.k else app.state.config.TOP_K,
1416
+ )
1417
+
1418
+ except Exception as e:
1419
+ log.exception(e)
1420
+ raise HTTPException(
1421
+ status_code=status.HTTP_400_BAD_REQUEST,
1422
+ detail=ERROR_MESSAGES.DEFAULT(e),
1423
+ )
1424
+
1425
+
1426
+ ####################################
1427
+ #
1428
+ # Vector DB operations
1429
+ #
1430
+ ####################################
1431
+
1432
+
1433
+ class DeleteForm(BaseModel):
1434
+ collection_name: str
1435
+ file_id: str
1436
+
1437
+
1438
+ @app.post("/delete")
1439
+ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
1440
+ try:
1441
+ if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
1442
+ file = Files.get_file_by_id(form_data.file_id)
1443
+ hash = file.hash
1444
+
1445
+ VECTOR_DB_CLIENT.delete(
1446
+ collection_name=form_data.collection_name,
1447
+ metadata={"hash": hash},
1448
+ )
1449
+ return {"status": True}
1450
+ else:
1451
+ return {"status": False}
1452
+ except Exception as e:
1453
+ log.exception(e)
1454
+ return {"status": False}
1455
+
1456
+
1457
+ @app.post("/reset/db")
1458
+ def reset_vector_db(user=Depends(get_admin_user)):
1459
+ VECTOR_DB_CLIENT.reset()
1460
+ Knowledges.delete_all_knowledge()
1461
+
1462
+
1463
+ @app.post("/reset/uploads")
1464
+ def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
1465
+ folder = f"{UPLOAD_DIR}"
1466
+ try:
1467
+ # Check if the directory exists
1468
+ if os.path.exists(folder):
1469
+ # Iterate over all the files and directories in the specified directory
1470
+ for filename in os.listdir(folder):
1471
+ file_path = os.path.join(folder, filename)
1472
+ try:
1473
+ if os.path.isfile(file_path) or os.path.islink(file_path):
1474
+ os.unlink(file_path) # Remove the file or link
1475
+ elif os.path.isdir(file_path):
1476
+ shutil.rmtree(file_path) # Remove the directory
1477
+ except Exception as e:
1478
+ print(f"Failed to delete {file_path}. Reason: {e}")
1479
+ else:
1480
+ print(f"The directory {folder} does not exist")
1481
+ except Exception as e:
1482
+ print(f"Failed to process the directory {folder}. Reason: {e}")
1483
+ return True
1484
+
1485
+
1486
+ if ENV == "dev":
1487
+
1488
+ @app.get("/ef")
1489
+ async def get_embeddings():
1490
+ return {"result": app.state.EMBEDDING_FUNCTION("hello world")}
1491
+
1492
+ @app.get("/ef/{text}")
1493
+ async def get_embeddings_text(text: str):
1494
+ return {"result": app.state.EMBEDDING_FUNCTION(text)}
backend/open_webui/apps/retrieval/models/colbert.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from colbert.infra import ColBERTConfig
5
+ from colbert.modeling.checkpoint import Checkpoint
6
+
7
+
8
+ class ColBERT:
9
+ def __init__(self, name, **kwargs) -> None:
10
+ print("ColBERT: Loading model", name)
11
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ DOCKER = kwargs.get("env") == "docker"
14
+ if DOCKER:
15
+ # This is a workaround for the issue with the docker container
16
+ # where the torch extension is not loaded properly
17
+ # and the following error is thrown:
18
+ # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
19
+
20
+ lock_file = (
21
+ "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
22
+ )
23
+ if os.path.exists(lock_file):
24
+ os.remove(lock_file)
25
+
26
+ self.ckpt = Checkpoint(
27
+ name,
28
+ colbert_config=ColBERTConfig(model_name=name),
29
+ ).to(self.device)
30
+ pass
31
+
32
+ def calculate_similarity_scores(self, query_embeddings, document_embeddings):
33
+
34
+ query_embeddings = query_embeddings.to(self.device)
35
+ document_embeddings = document_embeddings.to(self.device)
36
+
37
+ # Validate dimensions to ensure compatibility
38
+ if query_embeddings.dim() != 3:
39
+ raise ValueError(
40
+ f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
41
+ )
42
+ if document_embeddings.dim() != 3:
43
+ raise ValueError(
44
+ f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
45
+ )
46
+ if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
47
+ raise ValueError(
48
+ "There should be either one query or queries equal to the number of documents."
49
+ )
50
+
51
+ # Transpose the query embeddings to align for matrix multiplication
52
+ transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
53
+ # Compute similarity scores using batch matrix multiplication
54
+ computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings)
55
+ # Apply max pooling to extract the highest semantic similarity across each document's sequence
56
+ maximum_scores = torch.max(computed_scores, dim=1).values
57
+
58
+ # Sum up the maximum scores across features to get the overall document relevance scores
59
+ final_scores = maximum_scores.sum(dim=1)
60
+
61
+ normalized_scores = torch.softmax(final_scores, dim=0)
62
+
63
+ return normalized_scores.detach().cpu().numpy().astype(np.float32)
64
+
65
+ def predict(self, sentences):
66
+
67
+ query = sentences[0][0]
68
+ docs = [i[1] for i in sentences]
69
+
70
+ # Embedding the documents
71
+ embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
72
+ # Embedding the queries
73
+ embedded_queries = self.ckpt.queryFromText([query], bsize=32)
74
+ embedded_query = embedded_queries[0]
75
+
76
+ # Calculate retrieval scores for the query against all documents
77
+ scores = self.calculate_similarity_scores(
78
+ embedded_query.unsqueeze(0), embedded_docs
79
+ )
80
+
81
+ return scores
backend/open_webui/apps/retrieval/utils.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import uuid
4
+ from typing import Optional, Union
5
+
6
+ import asyncio
7
+ import requests
8
+
9
+ from huggingface_hub import snapshot_download
10
+ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
11
+ from langchain_community.retrievers import BM25Retriever
12
+ from langchain_core.documents import Document
13
+
14
+ from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
15
+ from open_webui.utils.misc import get_last_user_message
16
+
17
+ from open_webui.env import SRC_LOG_LEVELS
18
+
19
+ log = logging.getLogger(__name__)
20
+ log.setLevel(SRC_LOG_LEVELS["RAG"])
21
+
22
+
23
+ from typing import Any
24
+
25
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
26
+ from langchain_core.retrievers import BaseRetriever
27
+
28
+
29
+ class VectorSearchRetriever(BaseRetriever):
30
+ collection_name: Any
31
+ embedding_function: Any
32
+ top_k: int
33
+
34
+ def _get_relevant_documents(
35
+ self,
36
+ query: str,
37
+ *,
38
+ run_manager: CallbackManagerForRetrieverRun,
39
+ ) -> list[Document]:
40
+ result = VECTOR_DB_CLIENT.search(
41
+ collection_name=self.collection_name,
42
+ vectors=[self.embedding_function(query)],
43
+ limit=self.top_k,
44
+ )
45
+
46
+ ids = result.ids[0]
47
+ metadatas = result.metadatas[0]
48
+ documents = result.documents[0]
49
+
50
+ results = []
51
+ for idx in range(len(ids)):
52
+ results.append(
53
+ Document(
54
+ metadata=metadatas[idx],
55
+ page_content=documents[idx],
56
+ )
57
+ )
58
+ return results
59
+
60
+
61
+ def query_doc(
62
+ collection_name: str,
63
+ query_embedding: list[float],
64
+ k: int,
65
+ ):
66
+ try:
67
+ result = VECTOR_DB_CLIENT.search(
68
+ collection_name=collection_name,
69
+ vectors=[query_embedding],
70
+ limit=k,
71
+ )
72
+
73
+ log.info(f"query_doc:result {result.ids} {result.metadatas}")
74
+ return result
75
+ except Exception as e:
76
+ print(e)
77
+ raise e
78
+
79
+
80
+ def query_doc_with_hybrid_search(
81
+ collection_name: str,
82
+ query: str,
83
+ embedding_function,
84
+ k: int,
85
+ reranking_function,
86
+ r: float,
87
+ ) -> dict:
88
+ try:
89
+ result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
90
+
91
+ bm25_retriever = BM25Retriever.from_texts(
92
+ texts=result.documents[0],
93
+ metadatas=result.metadatas[0],
94
+ )
95
+ bm25_retriever.k = k
96
+
97
+ vector_search_retriever = VectorSearchRetriever(
98
+ collection_name=collection_name,
99
+ embedding_function=embedding_function,
100
+ top_k=k,
101
+ )
102
+
103
+ ensemble_retriever = EnsembleRetriever(
104
+ retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
105
+ )
106
+ compressor = RerankCompressor(
107
+ embedding_function=embedding_function,
108
+ top_n=k,
109
+ reranking_function=reranking_function,
110
+ r_score=r,
111
+ )
112
+
113
+ compression_retriever = ContextualCompressionRetriever(
114
+ base_compressor=compressor, base_retriever=ensemble_retriever
115
+ )
116
+
117
+ result = compression_retriever.invoke(query)
118
+ result = {
119
+ "distances": [[d.metadata.get("score") for d in result]],
120
+ "documents": [[d.page_content for d in result]],
121
+ "metadatas": [[d.metadata for d in result]],
122
+ }
123
+
124
+ log.info(
125
+ "query_doc_with_hybrid_search:result "
126
+ + f'{result["metadatas"]} {result["distances"]}'
127
+ )
128
+ return result
129
+ except Exception as e:
130
+ raise e
131
+
132
+
133
+ def merge_and_sort_query_results(
134
+ query_results: list[dict], k: int, reverse: bool = False
135
+ ) -> list[dict]:
136
+ # Initialize lists to store combined data
137
+ combined_distances = []
138
+ combined_documents = []
139
+ combined_metadatas = []
140
+
141
+ for data in query_results:
142
+ combined_distances.extend(data["distances"][0])
143
+ combined_documents.extend(data["documents"][0])
144
+ combined_metadatas.extend(data["metadatas"][0])
145
+
146
+ # Create a list of tuples (distance, document, metadata)
147
+ combined = list(zip(combined_distances, combined_documents, combined_metadatas))
148
+
149
+ # Sort the list based on distances
150
+ combined.sort(key=lambda x: x[0], reverse=reverse)
151
+
152
+ # We don't have anything :-(
153
+ if not combined:
154
+ sorted_distances = []
155
+ sorted_documents = []
156
+ sorted_metadatas = []
157
+ else:
158
+ # Unzip the sorted list
159
+ sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
160
+
161
+ # Slicing the lists to include only k elements
162
+ sorted_distances = list(sorted_distances)[:k]
163
+ sorted_documents = list(sorted_documents)[:k]
164
+ sorted_metadatas = list(sorted_metadatas)[:k]
165
+
166
+ # Create the output dictionary
167
+ result = {
168
+ "distances": [sorted_distances],
169
+ "documents": [sorted_documents],
170
+ "metadatas": [sorted_metadatas],
171
+ }
172
+
173
+ return result
174
+
175
+
176
+ def query_collection(
177
+ collection_names: list[str],
178
+ queries: list[str],
179
+ embedding_function,
180
+ k: int,
181
+ ) -> dict:
182
+ results = []
183
+ for query in queries:
184
+ query_embedding = embedding_function(query)
185
+ for collection_name in collection_names:
186
+ if collection_name:
187
+ try:
188
+ result = query_doc(
189
+ collection_name=collection_name,
190
+ k=k,
191
+ query_embedding=query_embedding,
192
+ )
193
+ if result is not None:
194
+ results.append(result.model_dump())
195
+ except Exception as e:
196
+ log.exception(f"Error when querying the collection: {e}")
197
+ else:
198
+ pass
199
+
200
+ return merge_and_sort_query_results(results, k=k)
201
+
202
+
203
+ def query_collection_with_hybrid_search(
204
+ collection_names: list[str],
205
+ queries: list[str],
206
+ embedding_function,
207
+ k: int,
208
+ reranking_function,
209
+ r: float,
210
+ ) -> dict:
211
+ results = []
212
+ error = False
213
+ for collection_name in collection_names:
214
+ try:
215
+ for query in queries:
216
+ result = query_doc_with_hybrid_search(
217
+ collection_name=collection_name,
218
+ query=query,
219
+ embedding_function=embedding_function,
220
+ k=k,
221
+ reranking_function=reranking_function,
222
+ r=r,
223
+ )
224
+ results.append(result)
225
+ except Exception as e:
226
+ log.exception(
227
+ "Error when querying the collection with " f"hybrid_search: {e}"
228
+ )
229
+ error = True
230
+
231
+ if error:
232
+ raise Exception(
233
+ "Hybrid search failed for all collections. Using Non hybrid search as fallback."
234
+ )
235
+
236
+ return merge_and_sort_query_results(results, k=k, reverse=True)
237
+
238
+
239
+ def get_embedding_function(
240
+ embedding_engine,
241
+ embedding_model,
242
+ embedding_function,
243
+ url,
244
+ key,
245
+ embedding_batch_size,
246
+ ):
247
+ if embedding_engine == "":
248
+ return lambda query: embedding_function.encode(query).tolist()
249
+ elif embedding_engine in ["ollama", "openai"]:
250
+ func = lambda query: generate_embeddings(
251
+ engine=embedding_engine,
252
+ model=embedding_model,
253
+ text=query,
254
+ url=url,
255
+ key=key,
256
+ )
257
+
258
+ def generate_multiple(query, func):
259
+ if isinstance(query, list):
260
+ embeddings = []
261
+ for i in range(0, len(query), embedding_batch_size):
262
+ embeddings.extend(func(query[i : i + embedding_batch_size]))
263
+ return embeddings
264
+ else:
265
+ return func(query)
266
+
267
+ return lambda query: generate_multiple(query, func)
268
+
269
+
270
+ def get_sources_from_files(
271
+ files,
272
+ queries,
273
+ embedding_function,
274
+ k,
275
+ reranking_function,
276
+ r,
277
+ hybrid_search,
278
+ ):
279
+ log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
280
+
281
+ extracted_collections = []
282
+ relevant_contexts = []
283
+
284
+ for file in files:
285
+ if file.get("context") == "full":
286
+ context = {
287
+ "documents": [[file.get("file").get("data", {}).get("content")]],
288
+ "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
289
+ }
290
+ else:
291
+ context = None
292
+
293
+ collection_names = []
294
+ if file.get("type") == "collection":
295
+ if file.get("legacy"):
296
+ collection_names = file.get("collection_names", [])
297
+ else:
298
+ collection_names.append(file["id"])
299
+ elif file.get("collection_name"):
300
+ collection_names.append(file["collection_name"])
301
+ elif file.get("id"):
302
+ if file.get("legacy"):
303
+ collection_names.append(f"{file['id']}")
304
+ else:
305
+ collection_names.append(f"file-{file['id']}")
306
+
307
+ collection_names = set(collection_names).difference(extracted_collections)
308
+ if not collection_names:
309
+ log.debug(f"skipping {file} as it has already been extracted")
310
+ continue
311
+
312
+ try:
313
+ context = None
314
+ if file.get("type") == "text":
315
+ context = file["content"]
316
+ else:
317
+ if hybrid_search:
318
+ try:
319
+ context = query_collection_with_hybrid_search(
320
+ collection_names=collection_names,
321
+ queries=queries,
322
+ embedding_function=embedding_function,
323
+ k=k,
324
+ reranking_function=reranking_function,
325
+ r=r,
326
+ )
327
+ except Exception as e:
328
+ log.debug(
329
+ "Error when using hybrid search, using"
330
+ " non hybrid search as fallback."
331
+ )
332
+
333
+ if (not hybrid_search) or (context is None):
334
+ context = query_collection(
335
+ collection_names=collection_names,
336
+ queries=queries,
337
+ embedding_function=embedding_function,
338
+ k=k,
339
+ )
340
+ except Exception as e:
341
+ log.exception(e)
342
+
343
+ extracted_collections.extend(collection_names)
344
+
345
+ if context:
346
+ if "data" in file:
347
+ del file["data"]
348
+ relevant_contexts.append({**context, "file": file})
349
+
350
+ sources = []
351
+ for context in relevant_contexts:
352
+ try:
353
+ if "documents" in context:
354
+ if "metadatas" in context:
355
+ source = {
356
+ "source": context["file"],
357
+ "document": context["documents"][0],
358
+ "metadata": context["metadatas"][0],
359
+ }
360
+ if "distances" in context and context["distances"]:
361
+ source["distances"] = context["distances"][0]
362
+
363
+ sources.append(source)
364
+ except Exception as e:
365
+ log.exception(e)
366
+
367
+ return sources
368
+
369
+
370
+ def get_model_path(model: str, update_model: bool = False):
371
+ # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
372
+ cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
373
+
374
+ local_files_only = not update_model
375
+
376
+ snapshot_kwargs = {
377
+ "cache_dir": cache_dir,
378
+ "local_files_only": local_files_only,
379
+ }
380
+
381
+ log.debug(f"model: {model}")
382
+ log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
383
+
384
+ # Inspiration from upstream sentence_transformers
385
+ if (
386
+ os.path.exists(model)
387
+ or ("\\" in model or model.count("/") > 1)
388
+ and local_files_only
389
+ ):
390
+ # If fully qualified path exists, return input, else set repo_id
391
+ return model
392
+ elif "/" not in model:
393
+ # Set valid repo_id for model short-name
394
+ model = "sentence-transformers" + "/" + model
395
+
396
+ snapshot_kwargs["repo_id"] = model
397
+
398
+ # Attempt to query the huggingface_hub library to determine the local path and/or to update
399
+ try:
400
+ model_repo_path = snapshot_download(**snapshot_kwargs)
401
+ log.debug(f"model_repo_path: {model_repo_path}")
402
+ return model_repo_path
403
+ except Exception as e:
404
+ log.exception(f"Cannot determine model snapshot path: {e}")
405
+ return model
406
+
407
+
408
+ def generate_openai_batch_embeddings(
409
+ model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
410
+ ) -> Optional[list[list[float]]]:
411
+ try:
412
+ r = requests.post(
413
+ f"{url}/embeddings",
414
+ headers={
415
+ "Content-Type": "application/json",
416
+ "Authorization": f"Bearer {key}",
417
+ },
418
+ json={"input": texts, "model": model},
419
+ )
420
+ r.raise_for_status()
421
+ data = r.json()
422
+ if "data" in data:
423
+ return [elem["embedding"] for elem in data["data"]]
424
+ else:
425
+ raise "Something went wrong :/"
426
+ except Exception as e:
427
+ print(e)
428
+ return None
429
+
430
+
431
+ def generate_ollama_batch_embeddings(
432
+ model: str, texts: list[str], url: str, key: str = ""
433
+ ) -> Optional[list[list[float]]]:
434
+ try:
435
+ r = requests.post(
436
+ f"{url}/api/embed",
437
+ headers={
438
+ "Content-Type": "application/json",
439
+ "Authorization": f"Bearer {key}",
440
+ },
441
+ json={"input": texts, "model": model},
442
+ )
443
+ r.raise_for_status()
444
+ data = r.json()
445
+
446
+ if "embeddings" in data:
447
+ return data["embeddings"]
448
+ else:
449
+ raise "Something went wrong :/"
450
+ except Exception as e:
451
+ print(e)
452
+ return None
453
+
454
+
455
+ def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
456
+ url = kwargs.get("url", "")
457
+ key = kwargs.get("key", "")
458
+
459
+ if engine == "ollama":
460
+ if isinstance(text, list):
461
+ embeddings = generate_ollama_batch_embeddings(
462
+ **{"model": model, "texts": text, "url": url, "key": key}
463
+ )
464
+ else:
465
+ embeddings = generate_ollama_batch_embeddings(
466
+ **{"model": model, "texts": [text], "url": url, "key": key}
467
+ )
468
+ return embeddings[0] if isinstance(text, str) else embeddings
469
+ elif engine == "openai":
470
+ if isinstance(text, list):
471
+ embeddings = generate_openai_batch_embeddings(model, text, url, key)
472
+ else:
473
+ embeddings = generate_openai_batch_embeddings(model, [text], url, key)
474
+
475
+ return embeddings[0] if isinstance(text, str) else embeddings
476
+
477
+
478
+ import operator
479
+ from typing import Optional, Sequence
480
+
481
+ from langchain_core.callbacks import Callbacks
482
+ from langchain_core.documents import BaseDocumentCompressor, Document
483
+
484
+
485
+ class RerankCompressor(BaseDocumentCompressor):
486
+ embedding_function: Any
487
+ top_n: int
488
+ reranking_function: Any
489
+ r_score: float
490
+
491
+ class Config:
492
+ extra = "forbid"
493
+ arbitrary_types_allowed = True
494
+
495
+ def compress_documents(
496
+ self,
497
+ documents: Sequence[Document],
498
+ query: str,
499
+ callbacks: Optional[Callbacks] = None,
500
+ ) -> Sequence[Document]:
501
+ reranking = self.reranking_function is not None
502
+
503
+ if reranking:
504
+ scores = self.reranking_function.predict(
505
+ [(query, doc.page_content) for doc in documents]
506
+ )
507
+ else:
508
+ from sentence_transformers import util
509
+
510
+ query_embedding = self.embedding_function(query)
511
+ document_embedding = self.embedding_function(
512
+ [doc.page_content for doc in documents]
513
+ )
514
+ scores = util.cos_sim(query_embedding, document_embedding)[0]
515
+
516
+ docs_with_scores = list(zip(documents, scores.tolist()))
517
+ if self.r_score:
518
+ docs_with_scores = [
519
+ (d, s) for d, s in docs_with_scores if s >= self.r_score
520
+ ]
521
+
522
+ result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
523
+ final_results = []
524
+ for doc, doc_score in result[: self.top_n]:
525
+ metadata = doc.metadata
526
+ metadata["score"] = doc_score
527
+ doc = Document(
528
+ page_content=doc.page_content,
529
+ metadata=metadata,
530
+ )
531
+ final_results.append(doc)
532
+ return final_results
backend/open_webui/apps/retrieval/vector/connector.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from open_webui.config import VECTOR_DB
2
+
3
+ if VECTOR_DB == "milvus":
4
+ from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
5
+
6
+ VECTOR_DB_CLIENT = MilvusClient()
7
+ elif VECTOR_DB == "qdrant":
8
+ from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
9
+
10
+ VECTOR_DB_CLIENT = QdrantClient()
11
+ elif VECTOR_DB == "opensearch":
12
+ from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient
13
+
14
+ VECTOR_DB_CLIENT = OpenSearchClient()
15
+ elif VECTOR_DB == "pgvector":
16
+ from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient
17
+
18
+ VECTOR_DB_CLIENT = PgvectorClient()
19
+ else:
20
+ from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
21
+
22
+ VECTOR_DB_CLIENT = ChromaClient()
backend/open_webui/apps/retrieval/vector/dbs/chroma.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from chromadb import Settings
3
+ from chromadb.utils.batch_utils import create_batches
4
+
5
+ from typing import Optional
6
+
7
+ from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
8
+ from open_webui.config import (
9
+ CHROMA_DATA_PATH,
10
+ CHROMA_HTTP_HOST,
11
+ CHROMA_HTTP_PORT,
12
+ CHROMA_HTTP_HEADERS,
13
+ CHROMA_HTTP_SSL,
14
+ CHROMA_TENANT,
15
+ CHROMA_DATABASE,
16
+ CHROMA_CLIENT_AUTH_PROVIDER,
17
+ CHROMA_CLIENT_AUTH_CREDENTIALS,
18
+ )
19
+
20
+
21
+ class ChromaClient:
22
+ def __init__(self):
23
+ settings_dict = {
24
+ "allow_reset": True,
25
+ "anonymized_telemetry": False,
26
+ }
27
+ if CHROMA_CLIENT_AUTH_PROVIDER is not None:
28
+ settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
29
+ if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
30
+ settings_dict["chroma_client_auth_credentials"] = (
31
+ CHROMA_CLIENT_AUTH_CREDENTIALS
32
+ )
33
+
34
+ if CHROMA_HTTP_HOST != "":
35
+ self.client = chromadb.HttpClient(
36
+ host=CHROMA_HTTP_HOST,
37
+ port=CHROMA_HTTP_PORT,
38
+ headers=CHROMA_HTTP_HEADERS,
39
+ ssl=CHROMA_HTTP_SSL,
40
+ tenant=CHROMA_TENANT,
41
+ database=CHROMA_DATABASE,
42
+ settings=Settings(**settings_dict),
43
+ )
44
+ else:
45
+ self.client = chromadb.PersistentClient(
46
+ path=CHROMA_DATA_PATH,
47
+ settings=Settings(**settings_dict),
48
+ tenant=CHROMA_TENANT,
49
+ database=CHROMA_DATABASE,
50
+ )
51
+
52
+ def has_collection(self, collection_name: str) -> bool:
53
+ # Check if the collection exists based on the collection name.
54
+ collections = self.client.list_collections()
55
+ return collection_name in [collection.name for collection in collections]
56
+
57
+ def delete_collection(self, collection_name: str):
58
+ # Delete the collection based on the collection name.
59
+ return self.client.delete_collection(name=collection_name)
60
+
61
+ def search(
62
+ self, collection_name: str, vectors: list[list[float | int]], limit: int
63
+ ) -> Optional[SearchResult]:
64
+ # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
65
+ try:
66
+ collection = self.client.get_collection(name=collection_name)
67
+ if collection:
68
+ result = collection.query(
69
+ query_embeddings=vectors,
70
+ n_results=limit,
71
+ )
72
+
73
+ return SearchResult(
74
+ **{
75
+ "ids": result["ids"],
76
+ "distances": result["distances"],
77
+ "documents": result["documents"],
78
+ "metadatas": result["metadatas"],
79
+ }
80
+ )
81
+ return None
82
+ except Exception as e:
83
+ return None
84
+
85
+ def query(
86
+ self, collection_name: str, filter: dict, limit: Optional[int] = None
87
+ ) -> Optional[GetResult]:
88
+ # Query the items from the collection based on the filter.
89
+ try:
90
+ collection = self.client.get_collection(name=collection_name)
91
+ if collection:
92
+ result = collection.get(
93
+ where=filter,
94
+ limit=limit,
95
+ )
96
+
97
+ return GetResult(
98
+ **{
99
+ "ids": [result["ids"]],
100
+ "documents": [result["documents"]],
101
+ "metadatas": [result["metadatas"]],
102
+ }
103
+ )
104
+ return None
105
+ except Exception as e:
106
+ print(e)
107
+ return None
108
+
109
+ def get(self, collection_name: str) -> Optional[GetResult]:
110
+ # Get all the items in the collection.
111
+ collection = self.client.get_collection(name=collection_name)
112
+ if collection:
113
+ result = collection.get()
114
+ return GetResult(
115
+ **{
116
+ "ids": [result["ids"]],
117
+ "documents": [result["documents"]],
118
+ "metadatas": [result["metadatas"]],
119
+ }
120
+ )
121
+ return None
122
+
123
+ def insert(self, collection_name: str, items: list[VectorItem]):
124
+ # Insert the items into the collection, if the collection does not exist, it will be created.
125
+ collection = self.client.get_or_create_collection(
126
+ name=collection_name, metadata={"hnsw:space": "cosine"}
127
+ )
128
+
129
+ ids = [item["id"] for item in items]
130
+ documents = [item["text"] for item in items]
131
+ embeddings = [item["vector"] for item in items]
132
+ metadatas = [item["metadata"] for item in items]
133
+
134
+ for batch in create_batches(
135
+ api=self.client,
136
+ documents=documents,
137
+ embeddings=embeddings,
138
+ ids=ids,
139
+ metadatas=metadatas,
140
+ ):
141
+ collection.add(*batch)
142
+
143
+ def upsert(self, collection_name: str, items: list[VectorItem]):
144
+ # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
145
+ collection = self.client.get_or_create_collection(
146
+ name=collection_name, metadata={"hnsw:space": "cosine"}
147
+ )
148
+
149
+ ids = [item["id"] for item in items]
150
+ documents = [item["text"] for item in items]
151
+ embeddings = [item["vector"] for item in items]
152
+ metadatas = [item["metadata"] for item in items]
153
+
154
+ collection.upsert(
155
+ ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
156
+ )
157
+
158
+ def delete(
159
+ self,
160
+ collection_name: str,
161
+ ids: Optional[list[str]] = None,
162
+ filter: Optional[dict] = None,
163
+ ):
164
+ # Delete the items from the collection based on the ids.
165
+ collection = self.client.get_collection(name=collection_name)
166
+ if collection:
167
+ if ids:
168
+ collection.delete(ids=ids)
169
+ elif filter:
170
+ collection.delete(where=filter)
171
+
172
+ def reset(self):
173
+ # Resets the database. This will delete all collections and item entries.
174
+ return self.client.reset()
backend/open_webui/apps/retrieval/vector/dbs/milvus.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymilvus import MilvusClient as Client
2
+ from pymilvus import FieldSchema, DataType
3
+ import json
4
+
5
+ from typing import Optional
6
+
7
+ from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
8
+ from open_webui.config import (
9
+ MILVUS_URI,
10
+ )
11
+
12
+
13
+ class MilvusClient:
14
+ def __init__(self):
15
+ self.collection_prefix = "open_webui"
16
+ self.client = Client(uri=MILVUS_URI)
17
+
18
+ def _result_to_get_result(self, result) -> GetResult:
19
+ ids = []
20
+ documents = []
21
+ metadatas = []
22
+
23
+ for match in result:
24
+ _ids = []
25
+ _documents = []
26
+ _metadatas = []
27
+ for item in match:
28
+ _ids.append(item.get("id"))
29
+ _documents.append(item.get("data", {}).get("text"))
30
+ _metadatas.append(item.get("metadata"))
31
+
32
+ ids.append(_ids)
33
+ documents.append(_documents)
34
+ metadatas.append(_metadatas)
35
+
36
+ return GetResult(
37
+ **{
38
+ "ids": ids,
39
+ "documents": documents,
40
+ "metadatas": metadatas,
41
+ }
42
+ )
43
+
44
+ def _result_to_search_result(self, result) -> SearchResult:
45
+ ids = []
46
+ distances = []
47
+ documents = []
48
+ metadatas = []
49
+
50
+ for match in result:
51
+ _ids = []
52
+ _distances = []
53
+ _documents = []
54
+ _metadatas = []
55
+
56
+ for item in match:
57
+ _ids.append(item.get("id"))
58
+ _distances.append(item.get("distance"))
59
+ _documents.append(item.get("entity", {}).get("data", {}).get("text"))
60
+ _metadatas.append(item.get("entity", {}).get("metadata"))
61
+
62
+ ids.append(_ids)
63
+ distances.append(_distances)
64
+ documents.append(_documents)
65
+ metadatas.append(_metadatas)
66
+
67
+ return SearchResult(
68
+ **{
69
+ "ids": ids,
70
+ "distances": distances,
71
+ "documents": documents,
72
+ "metadatas": metadatas,
73
+ }
74
+ )
75
+
76
+ def _create_collection(self, collection_name: str, dimension: int):
77
+ schema = self.client.create_schema(
78
+ auto_id=False,
79
+ enable_dynamic_field=True,
80
+ )
81
+ schema.add_field(
82
+ field_name="id",
83
+ datatype=DataType.VARCHAR,
84
+ is_primary=True,
85
+ max_length=65535,
86
+ )
87
+ schema.add_field(
88
+ field_name="vector",
89
+ datatype=DataType.FLOAT_VECTOR,
90
+ dim=dimension,
91
+ description="vector",
92
+ )
93
+ schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
94
+ schema.add_field(
95
+ field_name="metadata", datatype=DataType.JSON, description="metadata"
96
+ )
97
+
98
+ index_params = self.client.prepare_index_params()
99
+ index_params.add_index(
100
+ field_name="vector",
101
+ index_type="HNSW",
102
+ metric_type="COSINE",
103
+ params={"M": 16, "efConstruction": 100},
104
+ )
105
+
106
+ self.client.create_collection(
107
+ collection_name=f"{self.collection_prefix}_{collection_name}",
108
+ schema=schema,
109
+ index_params=index_params,
110
+ )
111
+
112
+ def has_collection(self, collection_name: str) -> bool:
113
+ # Check if the collection exists based on the collection name.
114
+ collection_name = collection_name.replace("-", "_")
115
+ return self.client.has_collection(
116
+ collection_name=f"{self.collection_prefix}_{collection_name}"
117
+ )
118
+
119
+ def delete_collection(self, collection_name: str):
120
+ # Delete the collection based on the collection name.
121
+ collection_name = collection_name.replace("-", "_")
122
+ return self.client.drop_collection(
123
+ collection_name=f"{self.collection_prefix}_{collection_name}"
124
+ )
125
+
126
+ def search(
127
+ self, collection_name: str, vectors: list[list[float | int]], limit: int
128
+ ) -> Optional[SearchResult]:
129
+ # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
130
+ collection_name = collection_name.replace("-", "_")
131
+ result = self.client.search(
132
+ collection_name=f"{self.collection_prefix}_{collection_name}",
133
+ data=vectors,
134
+ limit=limit,
135
+ output_fields=["data", "metadata"],
136
+ )
137
+
138
+ return self._result_to_search_result(result)
139
+
140
+ def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
141
+ # Construct the filter string for querying
142
+ collection_name = collection_name.replace("-", "_")
143
+ if not self.has_collection(collection_name):
144
+ return None
145
+
146
+ filter_string = " && ".join(
147
+ [
148
+ f'metadata["{key}"] == {json.dumps(value)}'
149
+ for key, value in filter.items()
150
+ ]
151
+ )
152
+
153
+ max_limit = 16383 # The maximum number of records per request
154
+ all_results = []
155
+
156
+ if limit is None:
157
+ limit = float("inf") # Use infinity as a placeholder for no limit
158
+
159
+ # Initialize offset and remaining to handle pagination
160
+ offset = 0
161
+ remaining = limit
162
+
163
+ try:
164
+ # Loop until there are no more items to fetch or the desired limit is reached
165
+ while remaining > 0:
166
+ print("remaining", remaining)
167
+ current_fetch = min(
168
+ max_limit, remaining
169
+ ) # Determine how many items to fetch in this iteration
170
+
171
+ results = self.client.query(
172
+ collection_name=f"{self.collection_prefix}_{collection_name}",
173
+ filter=filter_string,
174
+ output_fields=["*"],
175
+ limit=current_fetch,
176
+ offset=offset,
177
+ )
178
+
179
+ if not results:
180
+ break
181
+
182
+ all_results.extend(results)
183
+ results_count = len(results)
184
+ remaining -= (
185
+ results_count # Decrease remaining by the number of items fetched
186
+ )
187
+ offset += results_count
188
+
189
+ # Break the loop if the results returned are less than the requested fetch count
190
+ if results_count < current_fetch:
191
+ break
192
+
193
+ print(all_results)
194
+ return self._result_to_get_result([all_results])
195
+ except Exception as e:
196
+ print(e)
197
+ return None
198
+
199
+ def get(self, collection_name: str) -> Optional[GetResult]:
200
+ # Get all the items in the collection.
201
+ collection_name = collection_name.replace("-", "_")
202
+ result = self.client.query(
203
+ collection_name=f"{self.collection_prefix}_{collection_name}",
204
+ filter='id != ""',
205
+ )
206
+ return self._result_to_get_result([result])
207
+
208
+ def insert(self, collection_name: str, items: list[VectorItem]):
209
+ # Insert the items into the collection, if the collection does not exist, it will be created.
210
+ collection_name = collection_name.replace("-", "_")
211
+ if not self.client.has_collection(
212
+ collection_name=f"{self.collection_prefix}_{collection_name}"
213
+ ):
214
+ self._create_collection(
215
+ collection_name=collection_name, dimension=len(items[0]["vector"])
216
+ )
217
+
218
+ return self.client.insert(
219
+ collection_name=f"{self.collection_prefix}_{collection_name}",
220
+ data=[
221
+ {
222
+ "id": item["id"],
223
+ "vector": item["vector"],
224
+ "data": {"text": item["text"]},
225
+ "metadata": item["metadata"],
226
+ }
227
+ for item in items
228
+ ],
229
+ )
230
+
231
+ def upsert(self, collection_name: str, items: list[VectorItem]):
232
+ # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
233
+ collection_name = collection_name.replace("-", "_")
234
+ if not self.client.has_collection(
235
+ collection_name=f"{self.collection_prefix}_{collection_name}"
236
+ ):
237
+ self._create_collection(
238
+ collection_name=collection_name, dimension=len(items[0]["vector"])
239
+ )
240
+
241
+ return self.client.upsert(
242
+ collection_name=f"{self.collection_prefix}_{collection_name}",
243
+ data=[
244
+ {
245
+ "id": item["id"],
246
+ "vector": item["vector"],
247
+ "data": {"text": item["text"]},
248
+ "metadata": item["metadata"],
249
+ }
250
+ for item in items
251
+ ],
252
+ )
253
+
254
+ def delete(
255
+ self,
256
+ collection_name: str,
257
+ ids: Optional[list[str]] = None,
258
+ filter: Optional[dict] = None,
259
+ ):
260
+ # Delete the items from the collection based on the ids.
261
+ collection_name = collection_name.replace("-", "_")
262
+ if ids:
263
+ return self.client.delete(
264
+ collection_name=f"{self.collection_prefix}_{collection_name}",
265
+ ids=ids,
266
+ )
267
+ elif filter:
268
+ # Convert the filter dictionary to a string using JSON_CONTAINS.
269
+ filter_string = " && ".join(
270
+ [
271
+ f'metadata["{key}"] == {json.dumps(value)}'
272
+ for key, value in filter.items()
273
+ ]
274
+ )
275
+
276
+ return self.client.delete(
277
+ collection_name=f"{self.collection_prefix}_{collection_name}",
278
+ filter=filter_string,
279
+ )
280
+
281
+ def reset(self):
282
+ # Resets the database. This will delete all collections and item entries.
283
+ collection_names = self.client.list_collections()
284
+ for collection_name in collection_names:
285
+ if collection_name.startswith(self.collection_prefix):
286
+ self.client.drop_collection(collection_name=collection_name)