Spaces:
Running
Running
github-actions[bot]
commited on
Commit
•
a259699
0
Parent(s):
GitHub deploy: 4dd6ef04617035c07b6ca33ed42cbbbf95640fb8
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +20 -0
- .env.example +13 -0
- .eslintignore +13 -0
- .eslintrc.cjs +31 -0
- .gitattributes +3 -0
- .github/FUNDING.yml +1 -0
- .github/ISSUE_TEMPLATE/bug_report.md +80 -0
- .github/ISSUE_TEMPLATE/feature_request.md +35 -0
- .github/dependabot.yml +12 -0
- .github/pull_request_template.md +72 -0
- .github/workflows/build-release.yml +72 -0
- .github/workflows/deploy-to-hf-spaces.yml +69 -0
- .github/workflows/docker-build.yaml +477 -0
- .github/workflows/format-backend.yaml +39 -0
- .github/workflows/format-build-frontend.yaml +57 -0
- .github/workflows/integration-test.yml +253 -0
- .github/workflows/lint-backend.disabled +27 -0
- .github/workflows/lint-frontend.disabled +21 -0
- .github/workflows/release-pypi.yml +32 -0
- .github/workflows/sync-hf-spaces-with-dev.yml +30 -0
- .gitignore +309 -0
- .npmrc +1 -0
- .prettierignore +316 -0
- .prettierrc +9 -0
- CHANGELOG.md +0 -0
- CODE_OF_CONDUCT.md +99 -0
- Caddyfile.localhost +64 -0
- Dockerfile +176 -0
- INSTALLATION.md +35 -0
- LICENSE +21 -0
- Makefile +33 -0
- README.md +233 -0
- TROUBLESHOOTING.md +36 -0
- backend/.dockerignore +14 -0
- backend/.gitignore +12 -0
- backend/dev.sh +2 -0
- backend/open_webui/__init__.py +77 -0
- backend/open_webui/alembic.ini +114 -0
- backend/open_webui/apps/audio/main.py +703 -0
- backend/open_webui/apps/images/main.py +609 -0
- backend/open_webui/apps/images/utils/comfyui.py +186 -0
- backend/open_webui/apps/ollama/main.py +1351 -0
- backend/open_webui/apps/openai/main.py +719 -0
- backend/open_webui/apps/retrieval/loaders/main.py +190 -0
- backend/open_webui/apps/retrieval/loaders/youtube.py +117 -0
- backend/open_webui/apps/retrieval/main.py +1494 -0
- backend/open_webui/apps/retrieval/models/colbert.py +81 -0
- backend/open_webui/apps/retrieval/utils.py +532 -0
- backend/open_webui/apps/retrieval/vector/connector.py +22 -0
- backend/open_webui/apps/retrieval/vector/dbs/chroma.py +174 -0
.dockerignore
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.github
|
2 |
+
.DS_Store
|
3 |
+
docs
|
4 |
+
kubernetes
|
5 |
+
node_modules
|
6 |
+
/.svelte-kit
|
7 |
+
/package
|
8 |
+
.env
|
9 |
+
.env.*
|
10 |
+
vite.config.js.timestamp-*
|
11 |
+
vite.config.ts.timestamp-*
|
12 |
+
__pycache__
|
13 |
+
.idea
|
14 |
+
venv
|
15 |
+
_old
|
16 |
+
uploads
|
17 |
+
.ipynb_checkpoints
|
18 |
+
**/*.db
|
19 |
+
_test
|
20 |
+
backend/data/*
|
.env.example
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ollama URL for the backend to connect
|
2 |
+
# The path '/ollama' will be redirected to the specified backend URL
|
3 |
+
OLLAMA_BASE_URL='http://localhost:11434'
|
4 |
+
|
5 |
+
OPENAI_API_BASE_URL=''
|
6 |
+
OPENAI_API_KEY=''
|
7 |
+
|
8 |
+
# AUTOMATIC1111_BASE_URL="http://localhost:7860"
|
9 |
+
|
10 |
+
# DO NOT TRACK
|
11 |
+
SCARF_NO_ANALYTICS=true
|
12 |
+
DO_NOT_TRACK=true
|
13 |
+
ANONYMIZED_TELEMETRY=false
|
.eslintignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
node_modules
|
3 |
+
/build
|
4 |
+
/.svelte-kit
|
5 |
+
/package
|
6 |
+
.env
|
7 |
+
.env.*
|
8 |
+
!.env.example
|
9 |
+
|
10 |
+
# Ignore files for PNPM, NPM and YARN
|
11 |
+
pnpm-lock.yaml
|
12 |
+
package-lock.json
|
13 |
+
yarn.lock
|
.eslintrc.cjs
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
module.exports = {
|
2 |
+
root: true,
|
3 |
+
extends: [
|
4 |
+
'eslint:recommended',
|
5 |
+
'plugin:@typescript-eslint/recommended',
|
6 |
+
'plugin:svelte/recommended',
|
7 |
+
'plugin:cypress/recommended',
|
8 |
+
'prettier'
|
9 |
+
],
|
10 |
+
parser: '@typescript-eslint/parser',
|
11 |
+
plugins: ['@typescript-eslint'],
|
12 |
+
parserOptions: {
|
13 |
+
sourceType: 'module',
|
14 |
+
ecmaVersion: 2020,
|
15 |
+
extraFileExtensions: ['.svelte']
|
16 |
+
},
|
17 |
+
env: {
|
18 |
+
browser: true,
|
19 |
+
es2017: true,
|
20 |
+
node: true
|
21 |
+
},
|
22 |
+
overrides: [
|
23 |
+
{
|
24 |
+
files: ['*.svelte'],
|
25 |
+
parser: 'svelte-eslint-parser',
|
26 |
+
parserOptions: {
|
27 |
+
parser: '@typescript-eslint/parser'
|
28 |
+
}
|
29 |
+
}
|
30 |
+
]
|
31 |
+
};
|
.gitattributes
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
*.sh text eol=lf
|
2 |
+
*.ttf filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.github/FUNDING.yml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
github: tjbck
|
.github/ISSUE_TEMPLATE/bug_report.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Bug report
|
3 |
+
about: Create a report to help us improve
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
---
|
8 |
+
|
9 |
+
# Bug Report
|
10 |
+
|
11 |
+
## Important Notes
|
12 |
+
|
13 |
+
- **Before submitting a bug report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If you’re unsure, start a discussion post first. This will help us efficiently focus on improving the project.
|
14 |
+
|
15 |
+
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We’re here to help if you’re open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
|
16 |
+
|
17 |
+
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
18 |
+
|
19 |
+
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
|
20 |
+
|
21 |
+
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
|
22 |
+
|
23 |
+
---
|
24 |
+
|
25 |
+
## Installation Method
|
26 |
+
|
27 |
+
[Describe the method you used to install the project, e.g., git clone, Docker, pip, etc.]
|
28 |
+
|
29 |
+
## Environment
|
30 |
+
|
31 |
+
- **Open WebUI Version:** [e.g., v0.3.11]
|
32 |
+
- **Ollama (if applicable):** [e.g., v0.2.0, v0.1.32-rc1]
|
33 |
+
|
34 |
+
- **Operating System:** [e.g., Windows 10, macOS Big Sur, Ubuntu 20.04]
|
35 |
+
- **Browser (if applicable):** [e.g., Chrome 100.0, Firefox 98.0]
|
36 |
+
|
37 |
+
**Confirmation:**
|
38 |
+
|
39 |
+
- [ ] I have read and followed all the instructions provided in the README.md.
|
40 |
+
- [ ] I am on the latest version of both Open WebUI and Ollama.
|
41 |
+
- [ ] I have included the browser console logs.
|
42 |
+
- [ ] I have included the Docker container logs.
|
43 |
+
- [ ] I have provided the exact steps to reproduce the bug in the "Steps to Reproduce" section below.
|
44 |
+
|
45 |
+
## Expected Behavior:
|
46 |
+
|
47 |
+
[Describe what you expected to happen.]
|
48 |
+
|
49 |
+
## Actual Behavior:
|
50 |
+
|
51 |
+
[Describe what actually happened.]
|
52 |
+
|
53 |
+
## Description
|
54 |
+
|
55 |
+
**Bug Summary:**
|
56 |
+
[Provide a brief but clear summary of the bug]
|
57 |
+
|
58 |
+
## Reproduction Details
|
59 |
+
|
60 |
+
**Steps to Reproduce:**
|
61 |
+
[Outline the steps to reproduce the bug. Be as detailed as possible.]
|
62 |
+
|
63 |
+
## Logs and Screenshots
|
64 |
+
|
65 |
+
**Browser Console Logs:**
|
66 |
+
[Include relevant browser console logs, if applicable]
|
67 |
+
|
68 |
+
**Docker Container Logs:**
|
69 |
+
[Include relevant Docker container logs, if applicable]
|
70 |
+
|
71 |
+
**Screenshots/Screen Recordings (if applicable):**
|
72 |
+
[Attach any relevant screenshots to help illustrate the issue]
|
73 |
+
|
74 |
+
## Additional Information
|
75 |
+
|
76 |
+
[Include any additional details that may help in understanding and reproducing the issue. This could include specific configurations, error messages, or anything else relevant to the bug.]
|
77 |
+
|
78 |
+
## Note
|
79 |
+
|
80 |
+
If the bug report is incomplete or does not follow the provided instructions, it may not be addressed. Please ensure that you have followed the steps outlined in the README.md and troubleshooting.md documents, and provide all necessary information for us to reproduce and address the issue. Thank you!
|
.github/ISSUE_TEMPLATE/feature_request.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Feature request
|
3 |
+
about: Suggest an idea for this project
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
---
|
8 |
+
|
9 |
+
# Feature Request
|
10 |
+
|
11 |
+
## Important Notes
|
12 |
+
|
13 |
+
- **Before submitting a report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If you’re unsure, start a discussion post first. This will help us efficiently focus on improving the project.
|
14 |
+
|
15 |
+
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We’re here to help if you’re open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
|
16 |
+
|
17 |
+
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
18 |
+
|
19 |
+
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, it’s not that the issue doesn’t exist; we need your help!
|
20 |
+
|
21 |
+
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
|
22 |
+
|
23 |
+
---
|
24 |
+
|
25 |
+
**Is your feature request related to a problem? Please describe.**
|
26 |
+
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
27 |
+
|
28 |
+
**Describe the solution you'd like**
|
29 |
+
A clear and concise description of what you want to happen.
|
30 |
+
|
31 |
+
**Describe alternatives you've considered**
|
32 |
+
A clear and concise description of any alternative solutions or features you've considered.
|
33 |
+
|
34 |
+
**Additional context**
|
35 |
+
Add any other context or screenshots about the feature request here.
|
.github/dependabot.yml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: 2
|
2 |
+
updates:
|
3 |
+
- package-ecosystem: pip
|
4 |
+
directory: '/backend'
|
5 |
+
schedule:
|
6 |
+
interval: monthly
|
7 |
+
target-branch: 'dev'
|
8 |
+
- package-ecosystem: 'github-actions'
|
9 |
+
directory: '/'
|
10 |
+
schedule:
|
11 |
+
# Check for updates to GitHub Actions every week
|
12 |
+
interval: monthly
|
.github/pull_request_template.md
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pull Request Checklist
|
2 |
+
|
3 |
+
### Note to first-time contributors: Please open a discussion post in [Discussions](https://github.com/open-webui/open-webui/discussions) and describe your changes before submitting a pull request.
|
4 |
+
|
5 |
+
**Before submitting, make sure you've checked the following:**
|
6 |
+
|
7 |
+
- [ ] **Target branch:** Please verify that the pull request targets the `dev` branch.
|
8 |
+
- [ ] **Description:** Provide a concise description of the changes made in this pull request.
|
9 |
+
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
|
10 |
+
- [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
|
11 |
+
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
|
12 |
+
- [ ] **Testing:** Have you written and run sufficient tests for validating the changes?
|
13 |
+
- [ ] **Code review:** Have you performed a self-review of your code, addressing any coding standard issues and ensuring adherence to the project's coding standards?
|
14 |
+
- [ ] **Prefix:** To cleary categorize this pull request, prefix the pull request title, using one of the following:
|
15 |
+
- **BREAKING CHANGE**: Significant changes that may affect compatibility
|
16 |
+
- **build**: Changes that affect the build system or external dependencies
|
17 |
+
- **ci**: Changes to our continuous integration processes or workflows
|
18 |
+
- **chore**: Refactor, cleanup, or other non-functional code changes
|
19 |
+
- **docs**: Documentation update or addition
|
20 |
+
- **feat**: Introduces a new feature or enhancement to the codebase
|
21 |
+
- **fix**: Bug fix or error correction
|
22 |
+
- **i18n**: Internationalization or localization changes
|
23 |
+
- **perf**: Performance improvement
|
24 |
+
- **refactor**: Code restructuring for better maintainability, readability, or scalability
|
25 |
+
- **style**: Changes that do not affect the meaning of the code (white-space, formatting, missing semi-colons, etc.)
|
26 |
+
- **test**: Adding missing tests or correcting existing tests
|
27 |
+
- **WIP**: Work in progress, a temporary label for incomplete or ongoing work
|
28 |
+
|
29 |
+
# Changelog Entry
|
30 |
+
|
31 |
+
### Description
|
32 |
+
|
33 |
+
- [Concisely describe the changes made in this pull request, including any relevant motivation and impact (e.g., fixing a bug, adding a feature, or improving performance)]
|
34 |
+
|
35 |
+
### Added
|
36 |
+
|
37 |
+
- [List any new features, functionalities, or additions]
|
38 |
+
|
39 |
+
### Changed
|
40 |
+
|
41 |
+
- [List any changes, updates, refactorings, or optimizations]
|
42 |
+
|
43 |
+
### Deprecated
|
44 |
+
|
45 |
+
- [List any deprecated functionality or features that have been removed]
|
46 |
+
|
47 |
+
### Removed
|
48 |
+
|
49 |
+
- [List any removed features, files, or functionalities]
|
50 |
+
|
51 |
+
### Fixed
|
52 |
+
|
53 |
+
- [List any fixes, corrections, or bug fixes]
|
54 |
+
|
55 |
+
### Security
|
56 |
+
|
57 |
+
- [List any new or updated security-related changes, including vulnerability fixes]
|
58 |
+
|
59 |
+
### Breaking Changes
|
60 |
+
|
61 |
+
- **BREAKING CHANGE**: [List any breaking changes affecting compatibility or functionality]
|
62 |
+
|
63 |
+
---
|
64 |
+
|
65 |
+
### Additional Information
|
66 |
+
|
67 |
+
- [Insert any additional context, notes, or explanations for the changes]
|
68 |
+
- [Reference any related issues, commits, or other relevant information]
|
69 |
+
|
70 |
+
### Screenshots or Videos
|
71 |
+
|
72 |
+
- [Attach any relevant screenshots or videos demonstrating the changes]
|
.github/workflows/build-release.yml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Release
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main # or whatever branch you want to use
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
release:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
|
12 |
+
steps:
|
13 |
+
- name: Checkout repository
|
14 |
+
uses: actions/checkout@v4
|
15 |
+
|
16 |
+
- name: Check for changes in package.json
|
17 |
+
run: |
|
18 |
+
git diff --cached --diff-filter=d package.json || {
|
19 |
+
echo "No changes to package.json"
|
20 |
+
exit 1
|
21 |
+
}
|
22 |
+
|
23 |
+
- name: Get version number from package.json
|
24 |
+
id: get_version
|
25 |
+
run: |
|
26 |
+
VERSION=$(jq -r '.version' package.json)
|
27 |
+
echo "::set-output name=version::$VERSION"
|
28 |
+
|
29 |
+
- name: Extract latest CHANGELOG entry
|
30 |
+
id: changelog
|
31 |
+
run: |
|
32 |
+
CHANGELOG_CONTENT=$(awk 'BEGIN {print_section=0;} /^## \[/ {if (print_section == 0) {print_section=1;} else {exit;}} print_section {print;}' CHANGELOG.md)
|
33 |
+
CHANGELOG_ESCAPED=$(echo "$CHANGELOG_CONTENT" | sed ':a;N;$!ba;s/\n/%0A/g')
|
34 |
+
echo "Extracted latest release notes from CHANGELOG.md:"
|
35 |
+
echo -e "$CHANGELOG_CONTENT"
|
36 |
+
echo "::set-output name=content::$CHANGELOG_ESCAPED"
|
37 |
+
|
38 |
+
- name: Create GitHub release
|
39 |
+
uses: actions/github-script@v7
|
40 |
+
with:
|
41 |
+
github-token: ${{ secrets.GITHUB_TOKEN }}
|
42 |
+
script: |
|
43 |
+
const changelog = `${{ steps.changelog.outputs.content }}`;
|
44 |
+
const release = await github.rest.repos.createRelease({
|
45 |
+
owner: context.repo.owner,
|
46 |
+
repo: context.repo.repo,
|
47 |
+
tag_name: `v${{ steps.get_version.outputs.version }}`,
|
48 |
+
name: `v${{ steps.get_version.outputs.version }}`,
|
49 |
+
body: changelog,
|
50 |
+
})
|
51 |
+
console.log(`Created release ${release.data.html_url}`)
|
52 |
+
|
53 |
+
- name: Upload package to GitHub release
|
54 |
+
uses: actions/upload-artifact@v4
|
55 |
+
with:
|
56 |
+
name: package
|
57 |
+
path: |
|
58 |
+
.
|
59 |
+
!.git
|
60 |
+
env:
|
61 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
62 |
+
|
63 |
+
- name: Trigger Docker build workflow
|
64 |
+
uses: actions/github-script@v7
|
65 |
+
with:
|
66 |
+
script: |
|
67 |
+
github.rest.actions.createWorkflowDispatch({
|
68 |
+
owner: context.repo.owner,
|
69 |
+
repo: context.repo.repo,
|
70 |
+
workflow_id: 'docker-build.yaml',
|
71 |
+
ref: 'v${{ steps.get_version.outputs.version }}',
|
72 |
+
})
|
.github/workflows/deploy-to-hf-spaces.yml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Deploy to HuggingFace Spaces
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- hf-space
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
check-secret:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
outputs:
|
13 |
+
token-set: ${{ steps.check-key.outputs.defined }}
|
14 |
+
steps:
|
15 |
+
- id: check-key
|
16 |
+
env:
|
17 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
18 |
+
if: "${{ env.HF_TOKEN != '' }}"
|
19 |
+
run: echo "defined=true" >> $GITHUB_OUTPUT
|
20 |
+
|
21 |
+
deploy:
|
22 |
+
runs-on: ubuntu-latest
|
23 |
+
needs: [check-secret]
|
24 |
+
if: needs.check-secret.outputs.token-set == 'true'
|
25 |
+
env:
|
26 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
27 |
+
HF_USERNAME: ${{ secrets.HF_USERNAME }}
|
28 |
+
HF_SPACE_NAME: ${{ secrets.HF_SPACE_NAME }}
|
29 |
+
steps:
|
30 |
+
- name: Checkout repository
|
31 |
+
uses: actions/checkout@v4
|
32 |
+
with:
|
33 |
+
lfs: true
|
34 |
+
|
35 |
+
- name: Remove git history
|
36 |
+
run: rm -rf .git
|
37 |
+
|
38 |
+
- name: Prepend YAML front matter to README.md
|
39 |
+
run: |
|
40 |
+
cat <<EOF >README.md
|
41 |
+
---
|
42 |
+
title: Open WebUI
|
43 |
+
emoji: 🐳
|
44 |
+
colorFrom: purple
|
45 |
+
colorTo: gray
|
46 |
+
sdk: docker
|
47 |
+
app_port: 8080
|
48 |
+
hf_oauth: true
|
49 |
+
hf_oauth_scopes:
|
50 |
+
- email
|
51 |
+
---
|
52 |
+
$(cat README.md)
|
53 |
+
EOF
|
54 |
+
|
55 |
+
- name: Configure git
|
56 |
+
run: |
|
57 |
+
git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
58 |
+
git config --global user.name "github-actions[bot]"
|
59 |
+
|
60 |
+
- name: Set up Git and push to Space
|
61 |
+
run: |
|
62 |
+
git init --initial-branch=main
|
63 |
+
git lfs install
|
64 |
+
git lfs track "*.ttf"
|
65 |
+
git lfs track "*.jpg"
|
66 |
+
rm demo.gif
|
67 |
+
git add .
|
68 |
+
git commit -m "GitHub deploy: ${{ github.sha }}"
|
69 |
+
git push --force https://${HF_USERNAME}:${HF_TOKEN}@huggingface.co/spaces/${HF_USERNAME}/${HF_SPACE_NAME} main
|
.github/workflows/docker-build.yaml
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Create and publish Docker images with specific build args
|
2 |
+
|
3 |
+
on:
|
4 |
+
workflow_dispatch:
|
5 |
+
push:
|
6 |
+
branches:
|
7 |
+
- main
|
8 |
+
- dev
|
9 |
+
tags:
|
10 |
+
- v*
|
11 |
+
|
12 |
+
env:
|
13 |
+
REGISTRY: ghcr.io
|
14 |
+
|
15 |
+
jobs:
|
16 |
+
build-main-image:
|
17 |
+
runs-on: ubuntu-latest
|
18 |
+
permissions:
|
19 |
+
contents: read
|
20 |
+
packages: write
|
21 |
+
strategy:
|
22 |
+
fail-fast: false
|
23 |
+
matrix:
|
24 |
+
platform:
|
25 |
+
- linux/amd64
|
26 |
+
- linux/arm64
|
27 |
+
|
28 |
+
steps:
|
29 |
+
# GitHub Packages requires the entire repository name to be in lowercase
|
30 |
+
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
|
31 |
+
- name: Set repository and image name to lowercase
|
32 |
+
run: |
|
33 |
+
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
|
34 |
+
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
|
35 |
+
env:
|
36 |
+
IMAGE_NAME: '${{ github.repository }}'
|
37 |
+
|
38 |
+
- name: Prepare
|
39 |
+
run: |
|
40 |
+
platform=${{ matrix.platform }}
|
41 |
+
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
42 |
+
|
43 |
+
- name: Checkout repository
|
44 |
+
uses: actions/checkout@v4
|
45 |
+
|
46 |
+
- name: Set up QEMU
|
47 |
+
uses: docker/setup-qemu-action@v3
|
48 |
+
|
49 |
+
- name: Set up Docker Buildx
|
50 |
+
uses: docker/setup-buildx-action@v3
|
51 |
+
|
52 |
+
- name: Log in to the Container registry
|
53 |
+
uses: docker/login-action@v3
|
54 |
+
with:
|
55 |
+
registry: ${{ env.REGISTRY }}
|
56 |
+
username: ${{ github.actor }}
|
57 |
+
password: ${{ secrets.GITHUB_TOKEN }}
|
58 |
+
|
59 |
+
- name: Extract metadata for Docker images (default latest tag)
|
60 |
+
id: meta
|
61 |
+
uses: docker/metadata-action@v5
|
62 |
+
with:
|
63 |
+
images: ${{ env.FULL_IMAGE_NAME }}
|
64 |
+
tags: |
|
65 |
+
type=ref,event=branch
|
66 |
+
type=ref,event=tag
|
67 |
+
type=sha,prefix=git-
|
68 |
+
type=semver,pattern={{version}}
|
69 |
+
type=semver,pattern={{major}}.{{minor}}
|
70 |
+
flavor: |
|
71 |
+
latest=${{ github.ref == 'refs/heads/main' }}
|
72 |
+
|
73 |
+
- name: Extract metadata for Docker cache
|
74 |
+
id: cache-meta
|
75 |
+
uses: docker/metadata-action@v5
|
76 |
+
with:
|
77 |
+
images: ${{ env.FULL_IMAGE_NAME }}
|
78 |
+
tags: |
|
79 |
+
type=ref,event=branch
|
80 |
+
${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
|
81 |
+
flavor: |
|
82 |
+
prefix=cache-${{ matrix.platform }}-
|
83 |
+
latest=false
|
84 |
+
|
85 |
+
- name: Build Docker image (latest)
|
86 |
+
uses: docker/build-push-action@v5
|
87 |
+
id: build
|
88 |
+
with:
|
89 |
+
context: .
|
90 |
+
push: true
|
91 |
+
platforms: ${{ matrix.platform }}
|
92 |
+
labels: ${{ steps.meta.outputs.labels }}
|
93 |
+
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
|
94 |
+
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
|
95 |
+
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
|
96 |
+
build-args: |
|
97 |
+
BUILD_HASH=${{ github.sha }}
|
98 |
+
|
99 |
+
- name: Export digest
|
100 |
+
run: |
|
101 |
+
mkdir -p /tmp/digests
|
102 |
+
digest="${{ steps.build.outputs.digest }}"
|
103 |
+
touch "/tmp/digests/${digest#sha256:}"
|
104 |
+
|
105 |
+
- name: Upload digest
|
106 |
+
uses: actions/upload-artifact@v4
|
107 |
+
with:
|
108 |
+
name: digests-main-${{ env.PLATFORM_PAIR }}
|
109 |
+
path: /tmp/digests/*
|
110 |
+
if-no-files-found: error
|
111 |
+
retention-days: 1
|
112 |
+
|
113 |
+
build-cuda-image:
|
114 |
+
runs-on: ubuntu-latest
|
115 |
+
permissions:
|
116 |
+
contents: read
|
117 |
+
packages: write
|
118 |
+
strategy:
|
119 |
+
fail-fast: false
|
120 |
+
matrix:
|
121 |
+
platform:
|
122 |
+
- linux/amd64
|
123 |
+
- linux/arm64
|
124 |
+
|
125 |
+
steps:
|
126 |
+
# GitHub Packages requires the entire repository name to be in lowercase
|
127 |
+
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
|
128 |
+
- name: Set repository and image name to lowercase
|
129 |
+
run: |
|
130 |
+
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
|
131 |
+
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
|
132 |
+
env:
|
133 |
+
IMAGE_NAME: '${{ github.repository }}'
|
134 |
+
|
135 |
+
- name: Prepare
|
136 |
+
run: |
|
137 |
+
platform=${{ matrix.platform }}
|
138 |
+
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
139 |
+
|
140 |
+
- name: Checkout repository
|
141 |
+
uses: actions/checkout@v4
|
142 |
+
|
143 |
+
- name: Set up QEMU
|
144 |
+
uses: docker/setup-qemu-action@v3
|
145 |
+
|
146 |
+
- name: Set up Docker Buildx
|
147 |
+
uses: docker/setup-buildx-action@v3
|
148 |
+
|
149 |
+
- name: Log in to the Container registry
|
150 |
+
uses: docker/login-action@v3
|
151 |
+
with:
|
152 |
+
registry: ${{ env.REGISTRY }}
|
153 |
+
username: ${{ github.actor }}
|
154 |
+
password: ${{ secrets.GITHUB_TOKEN }}
|
155 |
+
|
156 |
+
- name: Extract metadata for Docker images (cuda tag)
|
157 |
+
id: meta
|
158 |
+
uses: docker/metadata-action@v5
|
159 |
+
with:
|
160 |
+
images: ${{ env.FULL_IMAGE_NAME }}
|
161 |
+
tags: |
|
162 |
+
type=ref,event=branch
|
163 |
+
type=ref,event=tag
|
164 |
+
type=sha,prefix=git-
|
165 |
+
type=semver,pattern={{version}}
|
166 |
+
type=semver,pattern={{major}}.{{minor}}
|
167 |
+
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda
|
168 |
+
flavor: |
|
169 |
+
latest=${{ github.ref == 'refs/heads/main' }}
|
170 |
+
suffix=-cuda,onlatest=true
|
171 |
+
|
172 |
+
- name: Extract metadata for Docker cache
|
173 |
+
id: cache-meta
|
174 |
+
uses: docker/metadata-action@v5
|
175 |
+
with:
|
176 |
+
images: ${{ env.FULL_IMAGE_NAME }}
|
177 |
+
tags: |
|
178 |
+
type=ref,event=branch
|
179 |
+
${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
|
180 |
+
flavor: |
|
181 |
+
prefix=cache-cuda-${{ matrix.platform }}-
|
182 |
+
latest=false
|
183 |
+
|
184 |
+
- name: Build Docker image (cuda)
|
185 |
+
uses: docker/build-push-action@v5
|
186 |
+
id: build
|
187 |
+
with:
|
188 |
+
context: .
|
189 |
+
push: true
|
190 |
+
platforms: ${{ matrix.platform }}
|
191 |
+
labels: ${{ steps.meta.outputs.labels }}
|
192 |
+
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
|
193 |
+
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
|
194 |
+
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
|
195 |
+
build-args: |
|
196 |
+
BUILD_HASH=${{ github.sha }}
|
197 |
+
USE_CUDA=true
|
198 |
+
|
199 |
+
- name: Export digest
|
200 |
+
run: |
|
201 |
+
mkdir -p /tmp/digests
|
202 |
+
digest="${{ steps.build.outputs.digest }}"
|
203 |
+
touch "/tmp/digests/${digest#sha256:}"
|
204 |
+
|
205 |
+
- name: Upload digest
|
206 |
+
uses: actions/upload-artifact@v4
|
207 |
+
with:
|
208 |
+
name: digests-cuda-${{ env.PLATFORM_PAIR }}
|
209 |
+
path: /tmp/digests/*
|
210 |
+
if-no-files-found: error
|
211 |
+
retention-days: 1
|
212 |
+
|
213 |
+
build-ollama-image:
|
214 |
+
runs-on: ubuntu-latest
|
215 |
+
permissions:
|
216 |
+
contents: read
|
217 |
+
packages: write
|
218 |
+
strategy:
|
219 |
+
fail-fast: false
|
220 |
+
matrix:
|
221 |
+
platform:
|
222 |
+
- linux/amd64
|
223 |
+
- linux/arm64
|
224 |
+
|
225 |
+
steps:
|
226 |
+
# GitHub Packages requires the entire repository name to be in lowercase
|
227 |
+
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
|
228 |
+
- name: Set repository and image name to lowercase
|
229 |
+
run: |
|
230 |
+
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
|
231 |
+
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
|
232 |
+
env:
|
233 |
+
IMAGE_NAME: '${{ github.repository }}'
|
234 |
+
|
235 |
+
- name: Prepare
|
236 |
+
run: |
|
237 |
+
platform=${{ matrix.platform }}
|
238 |
+
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
239 |
+
|
240 |
+
- name: Checkout repository
|
241 |
+
uses: actions/checkout@v4
|
242 |
+
|
243 |
+
- name: Set up QEMU
|
244 |
+
uses: docker/setup-qemu-action@v3
|
245 |
+
|
246 |
+
- name: Set up Docker Buildx
|
247 |
+
uses: docker/setup-buildx-action@v3
|
248 |
+
|
249 |
+
- name: Log in to the Container registry
|
250 |
+
uses: docker/login-action@v3
|
251 |
+
with:
|
252 |
+
registry: ${{ env.REGISTRY }}
|
253 |
+
username: ${{ github.actor }}
|
254 |
+
password: ${{ secrets.GITHUB_TOKEN }}
|
255 |
+
|
256 |
+
- name: Extract metadata for Docker images (ollama tag)
|
257 |
+
id: meta
|
258 |
+
uses: docker/metadata-action@v5
|
259 |
+
with:
|
260 |
+
images: ${{ env.FULL_IMAGE_NAME }}
|
261 |
+
tags: |
|
262 |
+
type=ref,event=branch
|
263 |
+
type=ref,event=tag
|
264 |
+
type=sha,prefix=git-
|
265 |
+
type=semver,pattern={{version}}
|
266 |
+
type=semver,pattern={{major}}.{{minor}}
|
267 |
+
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=ollama
|
268 |
+
flavor: |
|
269 |
+
latest=${{ github.ref == 'refs/heads/main' }}
|
270 |
+
suffix=-ollama,onlatest=true
|
271 |
+
|
272 |
+
- name: Extract metadata for Docker cache
|
273 |
+
id: cache-meta
|
274 |
+
uses: docker/metadata-action@v5
|
275 |
+
with:
|
276 |
+
images: ${{ env.FULL_IMAGE_NAME }}
|
277 |
+
tags: |
|
278 |
+
type=ref,event=branch
|
279 |
+
${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }}
|
280 |
+
flavor: |
|
281 |
+
prefix=cache-ollama-${{ matrix.platform }}-
|
282 |
+
latest=false
|
283 |
+
|
284 |
+
- name: Build Docker image (ollama)
|
285 |
+
uses: docker/build-push-action@v5
|
286 |
+
id: build
|
287 |
+
with:
|
288 |
+
context: .
|
289 |
+
push: true
|
290 |
+
platforms: ${{ matrix.platform }}
|
291 |
+
labels: ${{ steps.meta.outputs.labels }}
|
292 |
+
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
|
293 |
+
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
|
294 |
+
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
|
295 |
+
build-args: |
|
296 |
+
BUILD_HASH=${{ github.sha }}
|
297 |
+
USE_OLLAMA=true
|
298 |
+
|
299 |
+
- name: Export digest
|
300 |
+
run: |
|
301 |
+
mkdir -p /tmp/digests
|
302 |
+
digest="${{ steps.build.outputs.digest }}"
|
303 |
+
touch "/tmp/digests/${digest#sha256:}"
|
304 |
+
|
305 |
+
- name: Upload digest
|
306 |
+
uses: actions/upload-artifact@v4
|
307 |
+
with:
|
308 |
+
name: digests-ollama-${{ env.PLATFORM_PAIR }}
|
309 |
+
path: /tmp/digests/*
|
310 |
+
if-no-files-found: error
|
311 |
+
retention-days: 1
|
312 |
+
|
313 |
+
merge-main-images:
|
314 |
+
runs-on: ubuntu-latest
|
315 |
+
needs: [build-main-image]
|
316 |
+
steps:
|
317 |
+
# GitHub Packages requires the entire repository name to be in lowercase
|
318 |
+
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
|
319 |
+
- name: Set repository and image name to lowercase
|
320 |
+
run: |
|
321 |
+
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
|
322 |
+
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
|
323 |
+
env:
|
324 |
+
IMAGE_NAME: '${{ github.repository }}'
|
325 |
+
|
326 |
+
- name: Download digests
|
327 |
+
uses: actions/download-artifact@v4
|
328 |
+
with:
|
329 |
+
pattern: digests-main-*
|
330 |
+
path: /tmp/digests
|
331 |
+
merge-multiple: true
|
332 |
+
|
333 |
+
- name: Set up Docker Buildx
|
334 |
+
uses: docker/setup-buildx-action@v3
|
335 |
+
|
336 |
+
- name: Log in to the Container registry
|
337 |
+
uses: docker/login-action@v3
|
338 |
+
with:
|
339 |
+
registry: ${{ env.REGISTRY }}
|
340 |
+
username: ${{ github.actor }}
|
341 |
+
password: ${{ secrets.GITHUB_TOKEN }}
|
342 |
+
|
343 |
+
- name: Extract metadata for Docker images (default latest tag)
|
344 |
+
id: meta
|
345 |
+
uses: docker/metadata-action@v5
|
346 |
+
with:
|
347 |
+
images: ${{ env.FULL_IMAGE_NAME }}
|
348 |
+
tags: |
|
349 |
+
type=ref,event=branch
|
350 |
+
type=ref,event=tag
|
351 |
+
type=sha,prefix=git-
|
352 |
+
type=semver,pattern={{version}}
|
353 |
+
type=semver,pattern={{major}}.{{minor}}
|
354 |
+
flavor: |
|
355 |
+
latest=${{ github.ref == 'refs/heads/main' }}
|
356 |
+
|
357 |
+
- name: Create manifest list and push
|
358 |
+
working-directory: /tmp/digests
|
359 |
+
run: |
|
360 |
+
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
361 |
+
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
|
362 |
+
|
363 |
+
- name: Inspect image
|
364 |
+
run: |
|
365 |
+
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
|
366 |
+
|
367 |
+
merge-cuda-images:
|
368 |
+
runs-on: ubuntu-latest
|
369 |
+
needs: [build-cuda-image]
|
370 |
+
steps:
|
371 |
+
# GitHub Packages requires the entire repository name to be in lowercase
|
372 |
+
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
|
373 |
+
- name: Set repository and image name to lowercase
|
374 |
+
run: |
|
375 |
+
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
|
376 |
+
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
|
377 |
+
env:
|
378 |
+
IMAGE_NAME: '${{ github.repository }}'
|
379 |
+
|
380 |
+
- name: Download digests
|
381 |
+
uses: actions/download-artifact@v4
|
382 |
+
with:
|
383 |
+
pattern: digests-cuda-*
|
384 |
+
path: /tmp/digests
|
385 |
+
merge-multiple: true
|
386 |
+
|
387 |
+
- name: Set up Docker Buildx
|
388 |
+
uses: docker/setup-buildx-action@v3
|
389 |
+
|
390 |
+
- name: Log in to the Container registry
|
391 |
+
uses: docker/login-action@v3
|
392 |
+
with:
|
393 |
+
registry: ${{ env.REGISTRY }}
|
394 |
+
username: ${{ github.actor }}
|
395 |
+
password: ${{ secrets.GITHUB_TOKEN }}
|
396 |
+
|
397 |
+
- name: Extract metadata for Docker images (default latest tag)
|
398 |
+
id: meta
|
399 |
+
uses: docker/metadata-action@v5
|
400 |
+
with:
|
401 |
+
images: ${{ env.FULL_IMAGE_NAME }}
|
402 |
+
tags: |
|
403 |
+
type=ref,event=branch
|
404 |
+
type=ref,event=tag
|
405 |
+
type=sha,prefix=git-
|
406 |
+
type=semver,pattern={{version}}
|
407 |
+
type=semver,pattern={{major}}.{{minor}}
|
408 |
+
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda
|
409 |
+
flavor: |
|
410 |
+
latest=${{ github.ref == 'refs/heads/main' }}
|
411 |
+
suffix=-cuda,onlatest=true
|
412 |
+
|
413 |
+
- name: Create manifest list and push
|
414 |
+
working-directory: /tmp/digests
|
415 |
+
run: |
|
416 |
+
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
417 |
+
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
|
418 |
+
|
419 |
+
- name: Inspect image
|
420 |
+
run: |
|
421 |
+
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
|
422 |
+
|
423 |
+
merge-ollama-images:
|
424 |
+
runs-on: ubuntu-latest
|
425 |
+
needs: [build-ollama-image]
|
426 |
+
steps:
|
427 |
+
# GitHub Packages requires the entire repository name to be in lowercase
|
428 |
+
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
|
429 |
+
- name: Set repository and image name to lowercase
|
430 |
+
run: |
|
431 |
+
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
|
432 |
+
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
|
433 |
+
env:
|
434 |
+
IMAGE_NAME: '${{ github.repository }}'
|
435 |
+
|
436 |
+
- name: Download digests
|
437 |
+
uses: actions/download-artifact@v4
|
438 |
+
with:
|
439 |
+
pattern: digests-ollama-*
|
440 |
+
path: /tmp/digests
|
441 |
+
merge-multiple: true
|
442 |
+
|
443 |
+
- name: Set up Docker Buildx
|
444 |
+
uses: docker/setup-buildx-action@v3
|
445 |
+
|
446 |
+
- name: Log in to the Container registry
|
447 |
+
uses: docker/login-action@v3
|
448 |
+
with:
|
449 |
+
registry: ${{ env.REGISTRY }}
|
450 |
+
username: ${{ github.actor }}
|
451 |
+
password: ${{ secrets.GITHUB_TOKEN }}
|
452 |
+
|
453 |
+
- name: Extract metadata for Docker images (default ollama tag)
|
454 |
+
id: meta
|
455 |
+
uses: docker/metadata-action@v5
|
456 |
+
with:
|
457 |
+
images: ${{ env.FULL_IMAGE_NAME }}
|
458 |
+
tags: |
|
459 |
+
type=ref,event=branch
|
460 |
+
type=ref,event=tag
|
461 |
+
type=sha,prefix=git-
|
462 |
+
type=semver,pattern={{version}}
|
463 |
+
type=semver,pattern={{major}}.{{minor}}
|
464 |
+
type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=ollama
|
465 |
+
flavor: |
|
466 |
+
latest=${{ github.ref == 'refs/heads/main' }}
|
467 |
+
suffix=-ollama,onlatest=true
|
468 |
+
|
469 |
+
- name: Create manifest list and push
|
470 |
+
working-directory: /tmp/digests
|
471 |
+
run: |
|
472 |
+
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
473 |
+
$(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *)
|
474 |
+
|
475 |
+
- name: Inspect image
|
476 |
+
run: |
|
477 |
+
docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }}
|
.github/workflows/format-backend.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Python CI
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
- dev
|
8 |
+
pull_request:
|
9 |
+
branches:
|
10 |
+
- main
|
11 |
+
- dev
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
build:
|
15 |
+
name: 'Format Backend'
|
16 |
+
runs-on: ubuntu-latest
|
17 |
+
|
18 |
+
strategy:
|
19 |
+
matrix:
|
20 |
+
python-version: [3.11]
|
21 |
+
|
22 |
+
steps:
|
23 |
+
- uses: actions/checkout@v4
|
24 |
+
|
25 |
+
- name: Set up Python
|
26 |
+
uses: actions/setup-python@v5
|
27 |
+
with:
|
28 |
+
python-version: ${{ matrix.python-version }}
|
29 |
+
|
30 |
+
- name: Install dependencies
|
31 |
+
run: |
|
32 |
+
python -m pip install --upgrade pip
|
33 |
+
pip install black
|
34 |
+
|
35 |
+
- name: Format backend
|
36 |
+
run: npm run format:backend
|
37 |
+
|
38 |
+
- name: Check for changes after format
|
39 |
+
run: git diff --exit-code
|
.github/workflows/format-build-frontend.yaml
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Frontend Build
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
- dev
|
8 |
+
pull_request:
|
9 |
+
branches:
|
10 |
+
- main
|
11 |
+
- dev
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
build:
|
15 |
+
name: 'Format & Build Frontend'
|
16 |
+
runs-on: ubuntu-latest
|
17 |
+
steps:
|
18 |
+
- name: Checkout Repository
|
19 |
+
uses: actions/checkout@v4
|
20 |
+
|
21 |
+
- name: Setup Node.js
|
22 |
+
uses: actions/setup-node@v4
|
23 |
+
with:
|
24 |
+
node-version: '22' # Or specify any other version you want to use
|
25 |
+
|
26 |
+
- name: Install Dependencies
|
27 |
+
run: npm install
|
28 |
+
|
29 |
+
- name: Format Frontend
|
30 |
+
run: npm run format
|
31 |
+
|
32 |
+
- name: Run i18next
|
33 |
+
run: npm run i18n:parse
|
34 |
+
|
35 |
+
- name: Check for Changes After Format
|
36 |
+
run: git diff --exit-code
|
37 |
+
|
38 |
+
- name: Build Frontend
|
39 |
+
run: npm run build
|
40 |
+
|
41 |
+
test-frontend:
|
42 |
+
name: 'Frontend Unit Tests'
|
43 |
+
runs-on: ubuntu-latest
|
44 |
+
steps:
|
45 |
+
- name: Checkout Repository
|
46 |
+
uses: actions/checkout@v4
|
47 |
+
|
48 |
+
- name: Setup Node.js
|
49 |
+
uses: actions/setup-node@v4
|
50 |
+
with:
|
51 |
+
node-version: '22'
|
52 |
+
|
53 |
+
- name: Install Dependencies
|
54 |
+
run: npm ci
|
55 |
+
|
56 |
+
- name: Run vitest
|
57 |
+
run: npm run test:frontend
|
.github/workflows/integration-test.yml
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Integration Test
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
- dev
|
8 |
+
pull_request:
|
9 |
+
branches:
|
10 |
+
- main
|
11 |
+
- dev
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
cypress-run:
|
15 |
+
name: Run Cypress Integration Tests
|
16 |
+
runs-on: ubuntu-latest
|
17 |
+
steps:
|
18 |
+
- name: Maximize build space
|
19 |
+
uses: AdityaGarg8/remove-unwanted-software@v4.1
|
20 |
+
with:
|
21 |
+
remove-android: 'true'
|
22 |
+
remove-haskell: 'true'
|
23 |
+
remove-codeql: 'true'
|
24 |
+
|
25 |
+
- name: Checkout Repository
|
26 |
+
uses: actions/checkout@v4
|
27 |
+
|
28 |
+
- name: Build and run Compose Stack
|
29 |
+
run: |
|
30 |
+
docker compose \
|
31 |
+
--file docker-compose.yaml \
|
32 |
+
--file docker-compose.api.yaml \
|
33 |
+
--file docker-compose.a1111-test.yaml \
|
34 |
+
up --detach --build
|
35 |
+
|
36 |
+
- name: Delete Docker build cache
|
37 |
+
run: |
|
38 |
+
docker builder prune --all --force
|
39 |
+
|
40 |
+
- name: Wait for Ollama to be up
|
41 |
+
timeout-minutes: 5
|
42 |
+
run: |
|
43 |
+
until curl --output /dev/null --silent --fail http://localhost:11434; do
|
44 |
+
printf '.'
|
45 |
+
sleep 1
|
46 |
+
done
|
47 |
+
echo "Service is up!"
|
48 |
+
|
49 |
+
- name: Preload Ollama model
|
50 |
+
run: |
|
51 |
+
docker exec ollama ollama pull qwen:0.5b-chat-v1.5-q2_K
|
52 |
+
|
53 |
+
- name: Cypress run
|
54 |
+
uses: cypress-io/github-action@v6
|
55 |
+
with:
|
56 |
+
browser: chrome
|
57 |
+
wait-on: 'http://localhost:3000'
|
58 |
+
config: baseUrl=http://localhost:3000
|
59 |
+
|
60 |
+
- uses: actions/upload-artifact@v4
|
61 |
+
if: always()
|
62 |
+
name: Upload Cypress videos
|
63 |
+
with:
|
64 |
+
name: cypress-videos
|
65 |
+
path: cypress/videos
|
66 |
+
if-no-files-found: ignore
|
67 |
+
|
68 |
+
- name: Extract Compose logs
|
69 |
+
if: always()
|
70 |
+
run: |
|
71 |
+
docker compose logs > compose-logs.txt
|
72 |
+
|
73 |
+
- uses: actions/upload-artifact@v4
|
74 |
+
if: always()
|
75 |
+
name: Upload Compose logs
|
76 |
+
with:
|
77 |
+
name: compose-logs
|
78 |
+
path: compose-logs.txt
|
79 |
+
if-no-files-found: ignore
|
80 |
+
|
81 |
+
# pytest:
|
82 |
+
# name: Run Backend Tests
|
83 |
+
# runs-on: ubuntu-latest
|
84 |
+
# steps:
|
85 |
+
# - uses: actions/checkout@v4
|
86 |
+
|
87 |
+
# - name: Set up Python
|
88 |
+
# uses: actions/setup-python@v5
|
89 |
+
# with:
|
90 |
+
# python-version: ${{ matrix.python-version }}
|
91 |
+
|
92 |
+
# - name: Install dependencies
|
93 |
+
# run: |
|
94 |
+
# python -m pip install --upgrade pip
|
95 |
+
# pip install -r backend/requirements.txt
|
96 |
+
|
97 |
+
# - name: pytest run
|
98 |
+
# run: |
|
99 |
+
# ls -al
|
100 |
+
# cd backend
|
101 |
+
# PYTHONPATH=. pytest . -o log_cli=true -o log_cli_level=INFO
|
102 |
+
|
103 |
+
migration_test:
|
104 |
+
name: Run Migration Tests
|
105 |
+
runs-on: ubuntu-latest
|
106 |
+
services:
|
107 |
+
postgres:
|
108 |
+
image: postgres
|
109 |
+
env:
|
110 |
+
POSTGRES_PASSWORD: postgres
|
111 |
+
options: >-
|
112 |
+
--health-cmd pg_isready
|
113 |
+
--health-interval 10s
|
114 |
+
--health-timeout 5s
|
115 |
+
--health-retries 5
|
116 |
+
ports:
|
117 |
+
- 5432:5432
|
118 |
+
# mysql:
|
119 |
+
# image: mysql
|
120 |
+
# env:
|
121 |
+
# MYSQL_ROOT_PASSWORD: mysql
|
122 |
+
# MYSQL_DATABASE: mysql
|
123 |
+
# options: >-
|
124 |
+
# --health-cmd "mysqladmin ping -h localhost"
|
125 |
+
# --health-interval 10s
|
126 |
+
# --health-timeout 5s
|
127 |
+
# --health-retries 5
|
128 |
+
# ports:
|
129 |
+
# - 3306:3306
|
130 |
+
steps:
|
131 |
+
- name: Checkout Repository
|
132 |
+
uses: actions/checkout@v4
|
133 |
+
|
134 |
+
- name: Set up Python
|
135 |
+
uses: actions/setup-python@v5
|
136 |
+
with:
|
137 |
+
python-version: ${{ matrix.python-version }}
|
138 |
+
|
139 |
+
- name: Set up uv
|
140 |
+
uses: yezz123/setup-uv@v4
|
141 |
+
with:
|
142 |
+
uv-venv: venv
|
143 |
+
|
144 |
+
- name: Activate virtualenv
|
145 |
+
run: |
|
146 |
+
. venv/bin/activate
|
147 |
+
echo PATH=$PATH >> $GITHUB_ENV
|
148 |
+
|
149 |
+
- name: Install dependencies
|
150 |
+
run: |
|
151 |
+
uv pip install -r backend/requirements.txt
|
152 |
+
|
153 |
+
- name: Test backend with SQLite
|
154 |
+
id: sqlite
|
155 |
+
env:
|
156 |
+
WEBUI_SECRET_KEY: secret-key
|
157 |
+
GLOBAL_LOG_LEVEL: debug
|
158 |
+
run: |
|
159 |
+
cd backend
|
160 |
+
uvicorn open_webui.main:app --port "8080" --forwarded-allow-ips '*' &
|
161 |
+
UVICORN_PID=$!
|
162 |
+
# Wait up to 40 seconds for the server to start
|
163 |
+
for i in {1..40}; do
|
164 |
+
curl -s http://localhost:8080/api/config > /dev/null && break
|
165 |
+
sleep 1
|
166 |
+
if [ $i -eq 40 ]; then
|
167 |
+
echo "Server failed to start"
|
168 |
+
kill -9 $UVICORN_PID
|
169 |
+
exit 1
|
170 |
+
fi
|
171 |
+
done
|
172 |
+
# Check that the server is still running after 5 seconds
|
173 |
+
sleep 5
|
174 |
+
if ! kill -0 $UVICORN_PID; then
|
175 |
+
echo "Server has stopped"
|
176 |
+
exit 1
|
177 |
+
fi
|
178 |
+
|
179 |
+
- name: Test backend with Postgres
|
180 |
+
if: success() || steps.sqlite.conclusion == 'failure'
|
181 |
+
env:
|
182 |
+
WEBUI_SECRET_KEY: secret-key
|
183 |
+
GLOBAL_LOG_LEVEL: debug
|
184 |
+
DATABASE_URL: postgresql://postgres:postgres@localhost:5432/postgres
|
185 |
+
DATABASE_POOL_SIZE: 10
|
186 |
+
DATABASE_POOL_MAX_OVERFLOW: 10
|
187 |
+
DATABASE_POOL_TIMEOUT: 30
|
188 |
+
run: |
|
189 |
+
cd backend
|
190 |
+
uvicorn open_webui.main:app --port "8081" --forwarded-allow-ips '*' &
|
191 |
+
UVICORN_PID=$!
|
192 |
+
# Wait up to 20 seconds for the server to start
|
193 |
+
for i in {1..20}; do
|
194 |
+
curl -s http://localhost:8081/api/config > /dev/null && break
|
195 |
+
sleep 1
|
196 |
+
if [ $i -eq 20 ]; then
|
197 |
+
echo "Server failed to start"
|
198 |
+
kill -9 $UVICORN_PID
|
199 |
+
exit 1
|
200 |
+
fi
|
201 |
+
done
|
202 |
+
# Check that the server is still running after 5 seconds
|
203 |
+
sleep 5
|
204 |
+
if ! kill -0 $UVICORN_PID; then
|
205 |
+
echo "Server has stopped"
|
206 |
+
exit 1
|
207 |
+
fi
|
208 |
+
|
209 |
+
# Check that service will reconnect to postgres when connection will be closed
|
210 |
+
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
|
211 |
+
if [[ "$status_code" -ne 200 ]] ; then
|
212 |
+
echo "Server has failed before postgres reconnect check"
|
213 |
+
exit 1
|
214 |
+
fi
|
215 |
+
|
216 |
+
echo "Terminating all connections to postgres..."
|
217 |
+
python -c "import os, psycopg2 as pg2; \
|
218 |
+
conn = pg2.connect(dsn=os.environ['DATABASE_URL'].replace('+pool', '')); \
|
219 |
+
cur = conn.cursor(); \
|
220 |
+
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
|
221 |
+
|
222 |
+
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
|
223 |
+
if [[ "$status_code" -ne 200 ]] ; then
|
224 |
+
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
|
225 |
+
exit 1
|
226 |
+
fi
|
227 |
+
|
228 |
+
# - name: Test backend with MySQL
|
229 |
+
# if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'
|
230 |
+
# env:
|
231 |
+
# WEBUI_SECRET_KEY: secret-key
|
232 |
+
# GLOBAL_LOG_LEVEL: debug
|
233 |
+
# DATABASE_URL: mysql://root:mysql@localhost:3306/mysql
|
234 |
+
# run: |
|
235 |
+
# cd backend
|
236 |
+
# uvicorn open_webui.main:app --port "8083" --forwarded-allow-ips '*' &
|
237 |
+
# UVICORN_PID=$!
|
238 |
+
# # Wait up to 20 seconds for the server to start
|
239 |
+
# for i in {1..20}; do
|
240 |
+
# curl -s http://localhost:8083/api/config > /dev/null && break
|
241 |
+
# sleep 1
|
242 |
+
# if [ $i -eq 20 ]; then
|
243 |
+
# echo "Server failed to start"
|
244 |
+
# kill -9 $UVICORN_PID
|
245 |
+
# exit 1
|
246 |
+
# fi
|
247 |
+
# done
|
248 |
+
# # Check that the server is still running after 5 seconds
|
249 |
+
# sleep 5
|
250 |
+
# if ! kill -0 $UVICORN_PID; then
|
251 |
+
# echo "Server has stopped"
|
252 |
+
# exit 1
|
253 |
+
# fi
|
.github/workflows/lint-backend.disabled
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Python CI
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: ['main']
|
5 |
+
pull_request:
|
6 |
+
jobs:
|
7 |
+
build:
|
8 |
+
name: 'Lint Backend'
|
9 |
+
env:
|
10 |
+
PUBLIC_API_BASE_URL: ''
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
strategy:
|
13 |
+
matrix:
|
14 |
+
node-version:
|
15 |
+
- latest
|
16 |
+
steps:
|
17 |
+
- uses: actions/checkout@v4
|
18 |
+
- name: Use Python
|
19 |
+
uses: actions/setup-python@v5
|
20 |
+
- name: Use Bun
|
21 |
+
uses: oven-sh/setup-bun@v1
|
22 |
+
- name: Install dependencies
|
23 |
+
run: |
|
24 |
+
python -m pip install --upgrade pip
|
25 |
+
pip install pylint
|
26 |
+
- name: Lint backend
|
27 |
+
run: bun run lint:backend
|
.github/workflows/lint-frontend.disabled
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Bun CI
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: ['main']
|
5 |
+
pull_request:
|
6 |
+
jobs:
|
7 |
+
build:
|
8 |
+
name: 'Lint Frontend'
|
9 |
+
env:
|
10 |
+
PUBLIC_API_BASE_URL: ''
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v4
|
14 |
+
- name: Use Bun
|
15 |
+
uses: oven-sh/setup-bun@v1
|
16 |
+
- run: bun --version
|
17 |
+
- name: Install frontend dependencies
|
18 |
+
run: bun install --frozen-lockfile
|
19 |
+
- run: bun run lint:frontend
|
20 |
+
- run: bun run lint:types
|
21 |
+
if: success() || failure()
|
.github/workflows/release-pypi.yml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Release to PyPI
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main # or whatever branch you want to use
|
7 |
+
- pypi-release
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
release:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
environment:
|
13 |
+
name: pypi
|
14 |
+
url: https://pypi.org/p/open-webui
|
15 |
+
permissions:
|
16 |
+
id-token: write
|
17 |
+
steps:
|
18 |
+
- name: Checkout repository
|
19 |
+
uses: actions/checkout@v4
|
20 |
+
- uses: actions/setup-node@v4
|
21 |
+
with:
|
22 |
+
node-version: 18
|
23 |
+
- uses: actions/setup-python@v5
|
24 |
+
with:
|
25 |
+
python-version: 3.11
|
26 |
+
- name: Build
|
27 |
+
run: |
|
28 |
+
python -m pip install --upgrade pip
|
29 |
+
pip install build
|
30 |
+
python -m build .
|
31 |
+
- name: Publish package distributions to PyPI
|
32 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
.github/workflows/sync-hf-spaces-with-dev.yml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync hf-spaces with dev
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- dev
|
7 |
+
- hf-spaces
|
8 |
+
schedule:
|
9 |
+
- cron: '0 0 * * *'
|
10 |
+
workflow_dispatch:
|
11 |
+
|
12 |
+
jobs:
|
13 |
+
sync:
|
14 |
+
runs-on: ubuntu-latest
|
15 |
+
permissions:
|
16 |
+
contents: write
|
17 |
+
steps:
|
18 |
+
- name: Checkout repository
|
19 |
+
uses: actions/checkout@v4
|
20 |
+
with:
|
21 |
+
fetch-depth: 0
|
22 |
+
|
23 |
+
- name: Sync with dev
|
24 |
+
run: |
|
25 |
+
git checkout dev
|
26 |
+
git fetch origin
|
27 |
+
git checkout hf-space
|
28 |
+
git pull
|
29 |
+
git merge origin/dev
|
30 |
+
git push origin hf-space
|
.gitignore
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
node_modules
|
3 |
+
/build
|
4 |
+
/.svelte-kit
|
5 |
+
/package
|
6 |
+
.env
|
7 |
+
.env.*
|
8 |
+
!.env.example
|
9 |
+
vite.config.js.timestamp-*
|
10 |
+
vite.config.ts.timestamp-*
|
11 |
+
# Byte-compiled / optimized / DLL files
|
12 |
+
__pycache__/
|
13 |
+
*.py[cod]
|
14 |
+
*$py.class
|
15 |
+
|
16 |
+
# C extensions
|
17 |
+
*.so
|
18 |
+
|
19 |
+
# Pyodide distribution
|
20 |
+
static/pyodide/*
|
21 |
+
!static/pyodide/pyodide-lock.json
|
22 |
+
|
23 |
+
# Distribution / packaging
|
24 |
+
.Python
|
25 |
+
build/
|
26 |
+
develop-eggs/
|
27 |
+
dist/
|
28 |
+
downloads/
|
29 |
+
eggs/
|
30 |
+
.eggs/
|
31 |
+
lib64/
|
32 |
+
parts/
|
33 |
+
sdist/
|
34 |
+
var/
|
35 |
+
wheels/
|
36 |
+
share/python-wheels/
|
37 |
+
*.egg-info/
|
38 |
+
.installed.cfg
|
39 |
+
*.egg
|
40 |
+
MANIFEST
|
41 |
+
|
42 |
+
# PyInstaller
|
43 |
+
# Usually these files are written by a python script from a template
|
44 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
45 |
+
*.manifest
|
46 |
+
*.spec
|
47 |
+
|
48 |
+
# Installer logs
|
49 |
+
pip-log.txt
|
50 |
+
pip-delete-this-directory.txt
|
51 |
+
|
52 |
+
# Unit test / coverage reports
|
53 |
+
htmlcov/
|
54 |
+
.tox/
|
55 |
+
.nox/
|
56 |
+
.coverage
|
57 |
+
.coverage.*
|
58 |
+
.cache
|
59 |
+
nosetests.xml
|
60 |
+
coverage.xml
|
61 |
+
*.cover
|
62 |
+
*.py,cover
|
63 |
+
.hypothesis/
|
64 |
+
.pytest_cache/
|
65 |
+
cover/
|
66 |
+
|
67 |
+
# Translations
|
68 |
+
*.mo
|
69 |
+
*.pot
|
70 |
+
|
71 |
+
# Django stuff:
|
72 |
+
*.log
|
73 |
+
local_settings.py
|
74 |
+
db.sqlite3
|
75 |
+
db.sqlite3-journal
|
76 |
+
|
77 |
+
# Flask stuff:
|
78 |
+
instance/
|
79 |
+
.webassets-cache
|
80 |
+
|
81 |
+
# Scrapy stuff:
|
82 |
+
.scrapy
|
83 |
+
|
84 |
+
# Sphinx documentation
|
85 |
+
docs/_build/
|
86 |
+
|
87 |
+
# PyBuilder
|
88 |
+
.pybuilder/
|
89 |
+
target/
|
90 |
+
|
91 |
+
# Jupyter Notebook
|
92 |
+
.ipynb_checkpoints
|
93 |
+
|
94 |
+
# IPython
|
95 |
+
profile_default/
|
96 |
+
ipython_config.py
|
97 |
+
|
98 |
+
# pyenv
|
99 |
+
# For a library or package, you might want to ignore these files since the code is
|
100 |
+
# intended to run in multiple environments; otherwise, check them in:
|
101 |
+
# .python-version
|
102 |
+
|
103 |
+
# pipenv
|
104 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
105 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
106 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
107 |
+
# install all needed dependencies.
|
108 |
+
#Pipfile.lock
|
109 |
+
|
110 |
+
# poetry
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
112 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
113 |
+
# commonly ignored for libraries.
|
114 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
115 |
+
#poetry.lock
|
116 |
+
|
117 |
+
# pdm
|
118 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
119 |
+
#pdm.lock
|
120 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
121 |
+
# in version control.
|
122 |
+
# https://pdm.fming.dev/#use-with-ide
|
123 |
+
.pdm.toml
|
124 |
+
|
125 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
126 |
+
__pypackages__/
|
127 |
+
|
128 |
+
# Celery stuff
|
129 |
+
celerybeat-schedule
|
130 |
+
celerybeat.pid
|
131 |
+
|
132 |
+
# SageMath parsed files
|
133 |
+
*.sage.py
|
134 |
+
|
135 |
+
# Environments
|
136 |
+
.env
|
137 |
+
.venv
|
138 |
+
env/
|
139 |
+
venv/
|
140 |
+
ENV/
|
141 |
+
env.bak/
|
142 |
+
venv.bak/
|
143 |
+
|
144 |
+
# Spyder project settings
|
145 |
+
.spyderproject
|
146 |
+
.spyproject
|
147 |
+
|
148 |
+
# Rope project settings
|
149 |
+
.ropeproject
|
150 |
+
|
151 |
+
# mkdocs documentation
|
152 |
+
/site
|
153 |
+
|
154 |
+
# mypy
|
155 |
+
.mypy_cache/
|
156 |
+
.dmypy.json
|
157 |
+
dmypy.json
|
158 |
+
|
159 |
+
# Pyre type checker
|
160 |
+
.pyre/
|
161 |
+
|
162 |
+
# pytype static type analyzer
|
163 |
+
.pytype/
|
164 |
+
|
165 |
+
# Cython debug symbols
|
166 |
+
cython_debug/
|
167 |
+
|
168 |
+
# PyCharm
|
169 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
170 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
171 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
172 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
173 |
+
.idea/
|
174 |
+
|
175 |
+
# Logs
|
176 |
+
logs
|
177 |
+
*.log
|
178 |
+
npm-debug.log*
|
179 |
+
yarn-debug.log*
|
180 |
+
yarn-error.log*
|
181 |
+
lerna-debug.log*
|
182 |
+
.pnpm-debug.log*
|
183 |
+
|
184 |
+
# Diagnostic reports (https://nodejs.org/api/report.html)
|
185 |
+
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
|
186 |
+
|
187 |
+
# Runtime data
|
188 |
+
pids
|
189 |
+
*.pid
|
190 |
+
*.seed
|
191 |
+
*.pid.lock
|
192 |
+
|
193 |
+
# Directory for instrumented libs generated by jscoverage/JSCover
|
194 |
+
lib-cov
|
195 |
+
|
196 |
+
# Coverage directory used by tools like istanbul
|
197 |
+
coverage
|
198 |
+
*.lcov
|
199 |
+
|
200 |
+
# nyc test coverage
|
201 |
+
.nyc_output
|
202 |
+
|
203 |
+
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
|
204 |
+
.grunt
|
205 |
+
|
206 |
+
# Bower dependency directory (https://bower.io/)
|
207 |
+
bower_components
|
208 |
+
|
209 |
+
# node-waf configuration
|
210 |
+
.lock-wscript
|
211 |
+
|
212 |
+
# Compiled binary addons (https://nodejs.org/api/addons.html)
|
213 |
+
build/Release
|
214 |
+
|
215 |
+
# Dependency directories
|
216 |
+
node_modules/
|
217 |
+
jspm_packages/
|
218 |
+
|
219 |
+
# Snowpack dependency directory (https://snowpack.dev/)
|
220 |
+
web_modules/
|
221 |
+
|
222 |
+
# TypeScript cache
|
223 |
+
*.tsbuildinfo
|
224 |
+
|
225 |
+
# Optional npm cache directory
|
226 |
+
.npm
|
227 |
+
|
228 |
+
# Optional eslint cache
|
229 |
+
.eslintcache
|
230 |
+
|
231 |
+
# Optional stylelint cache
|
232 |
+
.stylelintcache
|
233 |
+
|
234 |
+
# Microbundle cache
|
235 |
+
.rpt2_cache/
|
236 |
+
.rts2_cache_cjs/
|
237 |
+
.rts2_cache_es/
|
238 |
+
.rts2_cache_umd/
|
239 |
+
|
240 |
+
# Optional REPL history
|
241 |
+
.node_repl_history
|
242 |
+
|
243 |
+
# Output of 'npm pack'
|
244 |
+
*.tgz
|
245 |
+
|
246 |
+
# Yarn Integrity file
|
247 |
+
.yarn-integrity
|
248 |
+
|
249 |
+
# dotenv environment variable files
|
250 |
+
.env
|
251 |
+
.env.development.local
|
252 |
+
.env.test.local
|
253 |
+
.env.production.local
|
254 |
+
.env.local
|
255 |
+
|
256 |
+
# parcel-bundler cache (https://parceljs.org/)
|
257 |
+
.cache
|
258 |
+
.parcel-cache
|
259 |
+
|
260 |
+
# Next.js build output
|
261 |
+
.next
|
262 |
+
out
|
263 |
+
|
264 |
+
# Nuxt.js build / generate output
|
265 |
+
.nuxt
|
266 |
+
dist
|
267 |
+
|
268 |
+
# Gatsby files
|
269 |
+
.cache/
|
270 |
+
# Comment in the public line in if your project uses Gatsby and not Next.js
|
271 |
+
# https://nextjs.org/blog/next-9-1#public-directory-support
|
272 |
+
# public
|
273 |
+
|
274 |
+
# vuepress build output
|
275 |
+
.vuepress/dist
|
276 |
+
|
277 |
+
# vuepress v2.x temp and cache directory
|
278 |
+
.temp
|
279 |
+
.cache
|
280 |
+
|
281 |
+
# Docusaurus cache and generated files
|
282 |
+
.docusaurus
|
283 |
+
|
284 |
+
# Serverless directories
|
285 |
+
.serverless/
|
286 |
+
|
287 |
+
# FuseBox cache
|
288 |
+
.fusebox/
|
289 |
+
|
290 |
+
# DynamoDB Local files
|
291 |
+
.dynamodb/
|
292 |
+
|
293 |
+
# TernJS port file
|
294 |
+
.tern-port
|
295 |
+
|
296 |
+
# Stores VSCode versions used for testing VSCode extensions
|
297 |
+
.vscode-test
|
298 |
+
|
299 |
+
# yarn v2
|
300 |
+
.yarn/cache
|
301 |
+
.yarn/unplugged
|
302 |
+
.yarn/build-state.yml
|
303 |
+
.yarn/install-state.gz
|
304 |
+
.pnp.*
|
305 |
+
|
306 |
+
# cypress artifacts
|
307 |
+
cypress/videos
|
308 |
+
cypress/screenshots
|
309 |
+
.vscode/settings.json
|
.npmrc
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
engine-strict=true
|
.prettierignore
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore files for PNPM, NPM and YARN
|
2 |
+
pnpm-lock.yaml
|
3 |
+
package-lock.json
|
4 |
+
yarn.lock
|
5 |
+
|
6 |
+
kubernetes/
|
7 |
+
|
8 |
+
# Copy of .gitignore
|
9 |
+
.DS_Store
|
10 |
+
node_modules
|
11 |
+
/build
|
12 |
+
/.svelte-kit
|
13 |
+
/package
|
14 |
+
.env
|
15 |
+
.env.*
|
16 |
+
!.env.example
|
17 |
+
vite.config.js.timestamp-*
|
18 |
+
vite.config.ts.timestamp-*
|
19 |
+
# Byte-compiled / optimized / DLL files
|
20 |
+
__pycache__/
|
21 |
+
*.py[cod]
|
22 |
+
*$py.class
|
23 |
+
|
24 |
+
# C extensions
|
25 |
+
*.so
|
26 |
+
|
27 |
+
# Distribution / packaging
|
28 |
+
.Python
|
29 |
+
build/
|
30 |
+
develop-eggs/
|
31 |
+
dist/
|
32 |
+
downloads/
|
33 |
+
eggs/
|
34 |
+
.eggs/
|
35 |
+
lib64/
|
36 |
+
parts/
|
37 |
+
sdist/
|
38 |
+
var/
|
39 |
+
wheels/
|
40 |
+
share/python-wheels/
|
41 |
+
*.egg-info/
|
42 |
+
.installed.cfg
|
43 |
+
*.egg
|
44 |
+
MANIFEST
|
45 |
+
|
46 |
+
# PyInstaller
|
47 |
+
# Usually these files are written by a python script from a template
|
48 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
49 |
+
*.manifest
|
50 |
+
*.spec
|
51 |
+
|
52 |
+
# Installer logs
|
53 |
+
pip-log.txt
|
54 |
+
pip-delete-this-directory.txt
|
55 |
+
|
56 |
+
# Unit test / coverage reports
|
57 |
+
htmlcov/
|
58 |
+
.tox/
|
59 |
+
.nox/
|
60 |
+
.coverage
|
61 |
+
.coverage.*
|
62 |
+
.cache
|
63 |
+
nosetests.xml
|
64 |
+
coverage.xml
|
65 |
+
*.cover
|
66 |
+
*.py,cover
|
67 |
+
.hypothesis/
|
68 |
+
.pytest_cache/
|
69 |
+
cover/
|
70 |
+
|
71 |
+
# Translations
|
72 |
+
*.mo
|
73 |
+
*.pot
|
74 |
+
|
75 |
+
# Django stuff:
|
76 |
+
*.log
|
77 |
+
local_settings.py
|
78 |
+
db.sqlite3
|
79 |
+
db.sqlite3-journal
|
80 |
+
|
81 |
+
# Flask stuff:
|
82 |
+
instance/
|
83 |
+
.webassets-cache
|
84 |
+
|
85 |
+
# Scrapy stuff:
|
86 |
+
.scrapy
|
87 |
+
|
88 |
+
# Sphinx documentation
|
89 |
+
docs/_build/
|
90 |
+
|
91 |
+
# PyBuilder
|
92 |
+
.pybuilder/
|
93 |
+
target/
|
94 |
+
|
95 |
+
# Jupyter Notebook
|
96 |
+
.ipynb_checkpoints
|
97 |
+
|
98 |
+
# IPython
|
99 |
+
profile_default/
|
100 |
+
ipython_config.py
|
101 |
+
|
102 |
+
# pyenv
|
103 |
+
# For a library or package, you might want to ignore these files since the code is
|
104 |
+
# intended to run in multiple environments; otherwise, check them in:
|
105 |
+
# .python-version
|
106 |
+
|
107 |
+
# pipenv
|
108 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
109 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
110 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
111 |
+
# install all needed dependencies.
|
112 |
+
#Pipfile.lock
|
113 |
+
|
114 |
+
# poetry
|
115 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
116 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
117 |
+
# commonly ignored for libraries.
|
118 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
119 |
+
#poetry.lock
|
120 |
+
|
121 |
+
# pdm
|
122 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
123 |
+
#pdm.lock
|
124 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
125 |
+
# in version control.
|
126 |
+
# https://pdm.fming.dev/#use-with-ide
|
127 |
+
.pdm.toml
|
128 |
+
|
129 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
130 |
+
__pypackages__/
|
131 |
+
|
132 |
+
# Celery stuff
|
133 |
+
celerybeat-schedule
|
134 |
+
celerybeat.pid
|
135 |
+
|
136 |
+
# SageMath parsed files
|
137 |
+
*.sage.py
|
138 |
+
|
139 |
+
# Environments
|
140 |
+
.env
|
141 |
+
.venv
|
142 |
+
env/
|
143 |
+
venv/
|
144 |
+
ENV/
|
145 |
+
env.bak/
|
146 |
+
venv.bak/
|
147 |
+
|
148 |
+
# Spyder project settings
|
149 |
+
.spyderproject
|
150 |
+
.spyproject
|
151 |
+
|
152 |
+
# Rope project settings
|
153 |
+
.ropeproject
|
154 |
+
|
155 |
+
# mkdocs documentation
|
156 |
+
/site
|
157 |
+
|
158 |
+
# mypy
|
159 |
+
.mypy_cache/
|
160 |
+
.dmypy.json
|
161 |
+
dmypy.json
|
162 |
+
|
163 |
+
# Pyre type checker
|
164 |
+
.pyre/
|
165 |
+
|
166 |
+
# pytype static type analyzer
|
167 |
+
.pytype/
|
168 |
+
|
169 |
+
# Cython debug symbols
|
170 |
+
cython_debug/
|
171 |
+
|
172 |
+
# PyCharm
|
173 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
174 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
175 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
176 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
177 |
+
.idea/
|
178 |
+
|
179 |
+
# Logs
|
180 |
+
logs
|
181 |
+
*.log
|
182 |
+
npm-debug.log*
|
183 |
+
yarn-debug.log*
|
184 |
+
yarn-error.log*
|
185 |
+
lerna-debug.log*
|
186 |
+
.pnpm-debug.log*
|
187 |
+
|
188 |
+
# Diagnostic reports (https://nodejs.org/api/report.html)
|
189 |
+
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
|
190 |
+
|
191 |
+
# Runtime data
|
192 |
+
pids
|
193 |
+
*.pid
|
194 |
+
*.seed
|
195 |
+
*.pid.lock
|
196 |
+
|
197 |
+
# Directory for instrumented libs generated by jscoverage/JSCover
|
198 |
+
lib-cov
|
199 |
+
|
200 |
+
# Coverage directory used by tools like istanbul
|
201 |
+
coverage
|
202 |
+
*.lcov
|
203 |
+
|
204 |
+
# nyc test coverage
|
205 |
+
.nyc_output
|
206 |
+
|
207 |
+
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
|
208 |
+
.grunt
|
209 |
+
|
210 |
+
# Bower dependency directory (https://bower.io/)
|
211 |
+
bower_components
|
212 |
+
|
213 |
+
# node-waf configuration
|
214 |
+
.lock-wscript
|
215 |
+
|
216 |
+
# Compiled binary addons (https://nodejs.org/api/addons.html)
|
217 |
+
build/Release
|
218 |
+
|
219 |
+
# Dependency directories
|
220 |
+
node_modules/
|
221 |
+
jspm_packages/
|
222 |
+
|
223 |
+
# Snowpack dependency directory (https://snowpack.dev/)
|
224 |
+
web_modules/
|
225 |
+
|
226 |
+
# TypeScript cache
|
227 |
+
*.tsbuildinfo
|
228 |
+
|
229 |
+
# Optional npm cache directory
|
230 |
+
.npm
|
231 |
+
|
232 |
+
# Optional eslint cache
|
233 |
+
.eslintcache
|
234 |
+
|
235 |
+
# Optional stylelint cache
|
236 |
+
.stylelintcache
|
237 |
+
|
238 |
+
# Microbundle cache
|
239 |
+
.rpt2_cache/
|
240 |
+
.rts2_cache_cjs/
|
241 |
+
.rts2_cache_es/
|
242 |
+
.rts2_cache_umd/
|
243 |
+
|
244 |
+
# Optional REPL history
|
245 |
+
.node_repl_history
|
246 |
+
|
247 |
+
# Output of 'npm pack'
|
248 |
+
*.tgz
|
249 |
+
|
250 |
+
# Yarn Integrity file
|
251 |
+
.yarn-integrity
|
252 |
+
|
253 |
+
# dotenv environment variable files
|
254 |
+
.env
|
255 |
+
.env.development.local
|
256 |
+
.env.test.local
|
257 |
+
.env.production.local
|
258 |
+
.env.local
|
259 |
+
|
260 |
+
# parcel-bundler cache (https://parceljs.org/)
|
261 |
+
.cache
|
262 |
+
.parcel-cache
|
263 |
+
|
264 |
+
# Next.js build output
|
265 |
+
.next
|
266 |
+
out
|
267 |
+
|
268 |
+
# Nuxt.js build / generate output
|
269 |
+
.nuxt
|
270 |
+
dist
|
271 |
+
|
272 |
+
# Gatsby files
|
273 |
+
.cache/
|
274 |
+
# Comment in the public line in if your project uses Gatsby and not Next.js
|
275 |
+
# https://nextjs.org/blog/next-9-1#public-directory-support
|
276 |
+
# public
|
277 |
+
|
278 |
+
# vuepress build output
|
279 |
+
.vuepress/dist
|
280 |
+
|
281 |
+
# vuepress v2.x temp and cache directory
|
282 |
+
.temp
|
283 |
+
.cache
|
284 |
+
|
285 |
+
# Docusaurus cache and generated files
|
286 |
+
.docusaurus
|
287 |
+
|
288 |
+
# Serverless directories
|
289 |
+
.serverless/
|
290 |
+
|
291 |
+
# FuseBox cache
|
292 |
+
.fusebox/
|
293 |
+
|
294 |
+
# DynamoDB Local files
|
295 |
+
.dynamodb/
|
296 |
+
|
297 |
+
# TernJS port file
|
298 |
+
.tern-port
|
299 |
+
|
300 |
+
# Stores VSCode versions used for testing VSCode extensions
|
301 |
+
.vscode-test
|
302 |
+
|
303 |
+
# yarn v2
|
304 |
+
.yarn/cache
|
305 |
+
.yarn/unplugged
|
306 |
+
.yarn/build-state.yml
|
307 |
+
.yarn/install-state.gz
|
308 |
+
.pnp.*
|
309 |
+
|
310 |
+
# cypress artifacts
|
311 |
+
cypress/videos
|
312 |
+
cypress/screenshots
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
/static/*
|
.prettierrc
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"useTabs": true,
|
3 |
+
"singleQuote": true,
|
4 |
+
"trailingComma": "none",
|
5 |
+
"printWidth": 100,
|
6 |
+
"plugins": ["prettier-plugin-svelte"],
|
7 |
+
"pluginSearchDirs": ["."],
|
8 |
+
"overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }]
|
9 |
+
}
|
CHANGELOG.md
ADDED
The diff for this file is too large to render.
See raw diff
|
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributor Covenant Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
As members, contributors, and leaders of this community, we pledge to make participation in our open-source project a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
|
6 |
+
|
7 |
+
We are committed to creating and maintaining an open, respectful, and professional environment where positive contributions and meaningful discussions can flourish. By participating in this project, you agree to uphold these values and align your behavior to the standards outlined in this Code of Conduct.
|
8 |
+
|
9 |
+
## Why These Standards Are Important
|
10 |
+
|
11 |
+
Open-source projects rely on a community of volunteers dedicating their time, expertise, and effort toward a shared goal. These projects are inherently collaborative but also fragile, as the success of the project depends on the goodwill, energy, and productivity of those involved.
|
12 |
+
|
13 |
+
Maintaining a positive and respectful environment is essential to safeguarding the integrity of this project and protecting contributors' efforts. Behavior that disrupts this atmosphere—whether through hostility, entitlement, or unprofessional conduct—can severely harm the morale and productivity of the community. **Strict enforcement of these standards ensures a safe and supportive space for meaningful collaboration.**
|
14 |
+
|
15 |
+
This is a community where **respect and professionalism are mandatory.** Violations of these standards will result in **zero tolerance** and immediate enforcement to prevent disruption and ensure the well-being of all participants.
|
16 |
+
|
17 |
+
## Our Standards
|
18 |
+
|
19 |
+
Examples of behavior that contribute to a positive and professional community include:
|
20 |
+
|
21 |
+
- **Respecting others.** Be considerate, listen actively, and engage with empathy toward others' viewpoints and experiences.
|
22 |
+
- **Constructive feedback.** Provide actionable, thoughtful, and respectful feedback that helps improve the project and encourages collaboration. Avoid unproductive negativity or hypercriticism.
|
23 |
+
- **Recognizing volunteer contributions.** Appreciate that contributors dedicate their free time and resources selflessly. Approach them with gratitude and patience.
|
24 |
+
- **Focusing on shared goals.** Collaborate in ways that prioritize the health, success, and sustainability of the community over individual agendas.
|
25 |
+
|
26 |
+
Examples of unacceptable behavior include:
|
27 |
+
|
28 |
+
- The use of discriminatory, demeaning, or sexualized language or behavior.
|
29 |
+
- Personal attacks, derogatory comments, trolling, or inflammatory political or ideological arguments.
|
30 |
+
- Harassment, intimidation, or any behavior intended to create a hostile, uncomfortable, or unsafe environment.
|
31 |
+
- Publishing others' private information (e.g., physical or email addresses) without explicit permission.
|
32 |
+
- **Entitlement, demand, or aggression toward contributors.** Volunteers are under no obligation to provide immediate or personalized support. Rude or dismissive behavior will not be tolerated.
|
33 |
+
- **Unproductive or destructive behavior.** This includes venting frustration as hostility ("tantrums"), hypercriticism, attention-seeking negativity, or anything that distracts from the project's goals.
|
34 |
+
- **Spamming and promotional exploitation.** Sharing irrelevant product promotions or self-promotion in the community is not allowed unless it directly contributes value to the discussion.
|
35 |
+
|
36 |
+
### Feedback and Community Engagement
|
37 |
+
|
38 |
+
- **Constructive feedback is encouraged, but hostile or entitled behavior will result in immediate action.** If you disagree with elements of the project, we encourage you to offer meaningful improvements or fork the project if necessary. Healthy discussions and technical disagreements are welcome only when handled with professionalism.
|
39 |
+
- **Respect contributors' time and efforts.** No one is entitled to personalized or on-demand assistance. This is a community built on collaboration and shared effort; demanding or demeaning behavior undermines that trust and will not be allowed.
|
40 |
+
|
41 |
+
### Zero Tolerance: No Warnings, Immediate Action
|
42 |
+
|
43 |
+
This community operates under a **zero-tolerance policy.** Any behavior deemed unacceptable under this Code of Conduct will result in **immediate enforcement, without prior warning.**
|
44 |
+
|
45 |
+
We employ this approach to ensure that unproductive or disruptive behavior does not escalate further or cause unnecessary harm to other contributors. The standards are clear, and violations of any kind—whether mild or severe—will be addressed decisively to protect the community.
|
46 |
+
|
47 |
+
## Enforcement Responsibilities
|
48 |
+
|
49 |
+
Community leaders are responsible for upholding and enforcing these standards. They are empowered to take **immediate and appropriate action** to address any behaviors they deem unacceptable under this Code of Conduct. These actions are taken with the goal of protecting the community and preserving its safe, positive, and productive environment.
|
50 |
+
|
51 |
+
## Scope
|
52 |
+
|
53 |
+
This Code of Conduct applies to all community spaces, including forums, repositories, social media accounts, and in-person events. It also applies when an individual represents the community in public settings, such as conferences or official communications.
|
54 |
+
|
55 |
+
Additionally, any behavior outside of these defined spaces that negatively impacts the community or its members may fall within the scope of this Code of Conduct.
|
56 |
+
|
57 |
+
## Reporting Violations
|
58 |
+
|
59 |
+
Instances of unacceptable behavior can be reported to the leadership team at **hello@openwebui.com**. Reports will be handled promptly, confidentially, and with consideration for the safety and well-being of the reporter.
|
60 |
+
|
61 |
+
All community leaders are required to uphold confidentiality and impartiality when addressing reports of violations.
|
62 |
+
|
63 |
+
## Enforcement Guidelines
|
64 |
+
|
65 |
+
### Ban
|
66 |
+
|
67 |
+
**Community Impact**: Community leaders will issue a ban to any participant whose behavior is deemed unacceptable according to this Code of Conduct. Bans are enforced immediately and without prior notice.
|
68 |
+
|
69 |
+
A ban may be temporary or permanent, depending on the severity of the violation. This includes—but is not limited to—behavior such as:
|
70 |
+
|
71 |
+
- Harassment or abusive behavior toward contributors.
|
72 |
+
- Persistent negativity or hostility that disrupts the collaborative environment.
|
73 |
+
- Disrespectful, demanding, or aggressive interactions with others.
|
74 |
+
- Attempts to cause harm or sabotage the community.
|
75 |
+
|
76 |
+
**Consequence**: A banned individual is immediately removed from access to all community spaces, communication channels, and events. Community leaders reserve the right to enforce either a time-limited suspension or a permanent ban based on the specific circumstances of the violation.
|
77 |
+
|
78 |
+
This approach ensures that disruptive behaviors are addressed swiftly and decisively in order to maintain the integrity and productivity of the community.
|
79 |
+
|
80 |
+
## Why Zero Tolerance Is Necessary
|
81 |
+
|
82 |
+
Open-source projects thrive on collaboration, goodwill, and mutual respect. Toxic behaviors—such as entitlement, hostility, or persistent negativity—threaten not just individual contributors but the health of the project as a whole. Allowing such behaviors to persist robs contributors of their time, energy, and enthusiasm for the work they do.
|
83 |
+
|
84 |
+
By enforcing a zero-tolerance policy, we ensure that the community remains a safe, welcoming space for all participants. These measures are not about harshness—they are about protecting contributors and fostering a productive environment where innovation can thrive.
|
85 |
+
|
86 |
+
Our expectations are clear, and our enforcement reflects our commitment to this project's long-term success.
|
87 |
+
|
88 |
+
## Attribution
|
89 |
+
|
90 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at
|
91 |
+
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
92 |
+
|
93 |
+
Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
|
94 |
+
|
95 |
+
[homepage]: https://www.contributor-covenant.org
|
96 |
+
|
97 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
98 |
+
https://www.contributor-covenant.org/faq. Translations are available at
|
99 |
+
https://www.contributor-covenant.org/translations.
|
Caddyfile.localhost
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Run with
|
2 |
+
# caddy run --envfile ./example.env --config ./Caddyfile.localhost
|
3 |
+
#
|
4 |
+
# This is configured for
|
5 |
+
# - Automatic HTTPS (even for localhost)
|
6 |
+
# - Reverse Proxying to Ollama API Base URL (http://localhost:11434/api)
|
7 |
+
# - CORS
|
8 |
+
# - HTTP Basic Auth API Tokens (uncomment basicauth section)
|
9 |
+
|
10 |
+
|
11 |
+
# CORS Preflight (OPTIONS) + Request (GET, POST, PATCH, PUT, DELETE)
|
12 |
+
(cors-api) {
|
13 |
+
@match-cors-api-preflight method OPTIONS
|
14 |
+
handle @match-cors-api-preflight {
|
15 |
+
header {
|
16 |
+
Access-Control-Allow-Origin "{http.request.header.origin}"
|
17 |
+
Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
18 |
+
Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
|
19 |
+
Access-Control-Allow-Credentials "true"
|
20 |
+
Access-Control-Max-Age "3600"
|
21 |
+
defer
|
22 |
+
}
|
23 |
+
respond "" 204
|
24 |
+
}
|
25 |
+
|
26 |
+
@match-cors-api-request {
|
27 |
+
not {
|
28 |
+
header Origin "{http.request.scheme}://{http.request.host}"
|
29 |
+
}
|
30 |
+
header Origin "{http.request.header.origin}"
|
31 |
+
}
|
32 |
+
handle @match-cors-api-request {
|
33 |
+
header {
|
34 |
+
Access-Control-Allow-Origin "{http.request.header.origin}"
|
35 |
+
Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
36 |
+
Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
|
37 |
+
Access-Control-Allow-Credentials "true"
|
38 |
+
Access-Control-Max-Age "3600"
|
39 |
+
defer
|
40 |
+
}
|
41 |
+
}
|
42 |
+
}
|
43 |
+
|
44 |
+
# replace localhost with example.com or whatever
|
45 |
+
localhost {
|
46 |
+
## HTTP Basic Auth
|
47 |
+
## (uncomment to enable)
|
48 |
+
# basicauth {
|
49 |
+
# # see .example.env for how to generate tokens
|
50 |
+
# {env.OLLAMA_API_ID} {env.OLLAMA_API_TOKEN_DIGEST}
|
51 |
+
# }
|
52 |
+
|
53 |
+
handle /api/* {
|
54 |
+
# Comment to disable CORS
|
55 |
+
import cors-api
|
56 |
+
|
57 |
+
reverse_proxy localhost:11434
|
58 |
+
}
|
59 |
+
|
60 |
+
# Same-Origin Static Web Server
|
61 |
+
file_server {
|
62 |
+
root ./build/
|
63 |
+
}
|
64 |
+
}
|
Dockerfile
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# syntax=docker/dockerfile:1
|
2 |
+
# Initialize device type args
|
3 |
+
# use build args in the docker build command with --build-arg="BUILDARG=true"
|
4 |
+
ARG USE_CUDA=false
|
5 |
+
ARG USE_OLLAMA=false
|
6 |
+
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
|
7 |
+
ARG USE_CUDA_VER=cu121
|
8 |
+
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
|
9 |
+
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
|
10 |
+
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
|
11 |
+
# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
|
12 |
+
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
13 |
+
ARG USE_RERANKING_MODEL=""
|
14 |
+
|
15 |
+
# Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
|
16 |
+
ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
|
17 |
+
|
18 |
+
ARG BUILD_HASH=dev-build
|
19 |
+
# Override at your own risk - non-root configurations are untested
|
20 |
+
ARG UID=0
|
21 |
+
ARG GID=0
|
22 |
+
|
23 |
+
######## WebUI frontend ########
|
24 |
+
FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build
|
25 |
+
ARG BUILD_HASH
|
26 |
+
|
27 |
+
WORKDIR /app
|
28 |
+
|
29 |
+
COPY package.json package-lock.json ./
|
30 |
+
RUN npm ci
|
31 |
+
|
32 |
+
COPY . .
|
33 |
+
ENV APP_BUILD_HASH=${BUILD_HASH}
|
34 |
+
RUN npm run build
|
35 |
+
|
36 |
+
######## WebUI backend ########
|
37 |
+
FROM python:3.11-slim-bookworm AS base
|
38 |
+
|
39 |
+
# Use args
|
40 |
+
ARG USE_CUDA
|
41 |
+
ARG USE_OLLAMA
|
42 |
+
ARG USE_CUDA_VER
|
43 |
+
ARG USE_EMBEDDING_MODEL
|
44 |
+
ARG USE_RERANKING_MODEL
|
45 |
+
ARG UID
|
46 |
+
ARG GID
|
47 |
+
|
48 |
+
## Basis ##
|
49 |
+
ENV ENV=prod \
|
50 |
+
PORT=8080 \
|
51 |
+
# pass build args to the build
|
52 |
+
USE_OLLAMA_DOCKER=${USE_OLLAMA} \
|
53 |
+
USE_CUDA_DOCKER=${USE_CUDA} \
|
54 |
+
USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
|
55 |
+
USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
|
56 |
+
USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
|
57 |
+
|
58 |
+
## Basis URL Config ##
|
59 |
+
ENV OLLAMA_BASE_URL="/ollama" \
|
60 |
+
OPENAI_API_BASE_URL=""
|
61 |
+
|
62 |
+
## API Key and Security Config ##
|
63 |
+
ENV OPENAI_API_KEY="" \
|
64 |
+
WEBUI_SECRET_KEY="" \
|
65 |
+
SCARF_NO_ANALYTICS=true \
|
66 |
+
DO_NOT_TRACK=true \
|
67 |
+
ANONYMIZED_TELEMETRY=false
|
68 |
+
|
69 |
+
#### Other models #########################################################
|
70 |
+
## whisper TTS model settings ##
|
71 |
+
ENV WHISPER_MODEL="base" \
|
72 |
+
WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
|
73 |
+
|
74 |
+
## RAG Embedding model settings ##
|
75 |
+
ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
|
76 |
+
RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
|
77 |
+
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
|
78 |
+
|
79 |
+
## Tiktoken model settings ##
|
80 |
+
ENV TIKTOKEN_ENCODING_NAME="cl100k_base" \
|
81 |
+
TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
|
82 |
+
|
83 |
+
## Hugging Face download cache ##
|
84 |
+
ENV HF_HOME="/app/backend/data/cache/embedding/models"
|
85 |
+
|
86 |
+
## Torch Extensions ##
|
87 |
+
# ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions"
|
88 |
+
|
89 |
+
#### Other models ##########################################################
|
90 |
+
|
91 |
+
WORKDIR /app/backend
|
92 |
+
|
93 |
+
ENV HOME=/root
|
94 |
+
# Create user and group if not root
|
95 |
+
RUN if [ $UID -ne 0 ]; then \
|
96 |
+
if [ $GID -ne 0 ]; then \
|
97 |
+
addgroup --gid $GID app; \
|
98 |
+
fi; \
|
99 |
+
adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
|
100 |
+
fi
|
101 |
+
|
102 |
+
RUN mkdir -p $HOME/.cache/chroma
|
103 |
+
RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
|
104 |
+
|
105 |
+
# Make sure the user has access to the app and root directory
|
106 |
+
RUN chown -R $UID:$GID /app $HOME
|
107 |
+
|
108 |
+
RUN if [ "$USE_OLLAMA" = "true" ]; then \
|
109 |
+
apt-get update && \
|
110 |
+
# Install pandoc and netcat
|
111 |
+
apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
|
112 |
+
apt-get install -y --no-install-recommends gcc python3-dev && \
|
113 |
+
# for RAG OCR
|
114 |
+
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
|
115 |
+
# install helper tools
|
116 |
+
apt-get install -y --no-install-recommends curl jq && \
|
117 |
+
# install ollama
|
118 |
+
curl -fsSL https://ollama.com/install.sh | sh && \
|
119 |
+
# cleanup
|
120 |
+
rm -rf /var/lib/apt/lists/*; \
|
121 |
+
else \
|
122 |
+
apt-get update && \
|
123 |
+
# Install pandoc, netcat and gcc
|
124 |
+
apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
|
125 |
+
apt-get install -y --no-install-recommends gcc python3-dev && \
|
126 |
+
# for RAG OCR
|
127 |
+
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
|
128 |
+
# cleanup
|
129 |
+
rm -rf /var/lib/apt/lists/*; \
|
130 |
+
fi
|
131 |
+
|
132 |
+
# install python dependencies
|
133 |
+
COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
|
134 |
+
|
135 |
+
RUN pip3 install uv && \
|
136 |
+
if [ "$USE_CUDA" = "true" ]; then \
|
137 |
+
# If you use CUDA the whisper and embedding model will be downloaded on first use
|
138 |
+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
|
139 |
+
uv pip install --system -r requirements.txt --no-cache-dir && \
|
140 |
+
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
|
141 |
+
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
|
142 |
+
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
|
143 |
+
else \
|
144 |
+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
|
145 |
+
uv pip install --system -r requirements.txt --no-cache-dir && \
|
146 |
+
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
|
147 |
+
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
|
148 |
+
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
|
149 |
+
fi; \
|
150 |
+
chown -R $UID:$GID /app/backend/data/
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
# copy embedding weight from build
|
155 |
+
# RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
|
156 |
+
# COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
|
157 |
+
|
158 |
+
# copy built frontend files
|
159 |
+
COPY --chown=$UID:$GID --from=build /app/build /app/build
|
160 |
+
COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md
|
161 |
+
COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json
|
162 |
+
|
163 |
+
# copy backend files
|
164 |
+
COPY --chown=$UID:$GID ./backend .
|
165 |
+
|
166 |
+
EXPOSE 8080
|
167 |
+
|
168 |
+
HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
|
169 |
+
|
170 |
+
USER $UID:$GID
|
171 |
+
|
172 |
+
ARG BUILD_HASH
|
173 |
+
ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
|
174 |
+
ENV DOCKER=true
|
175 |
+
|
176 |
+
CMD [ "bash", "start.sh"]
|
INSTALLATION.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Installing Both Ollama and Open WebUI Using Kustomize
|
2 |
+
|
3 |
+
For cpu-only pod
|
4 |
+
|
5 |
+
```bash
|
6 |
+
kubectl apply -f ./kubernetes/manifest/base
|
7 |
+
```
|
8 |
+
|
9 |
+
For gpu-enabled pod
|
10 |
+
|
11 |
+
```bash
|
12 |
+
kubectl apply -k ./kubernetes/manifest
|
13 |
+
```
|
14 |
+
|
15 |
+
### Installing Both Ollama and Open WebUI Using Helm
|
16 |
+
|
17 |
+
Package Helm file first
|
18 |
+
|
19 |
+
```bash
|
20 |
+
helm package ./kubernetes/helm/
|
21 |
+
```
|
22 |
+
|
23 |
+
For cpu-only pod
|
24 |
+
|
25 |
+
```bash
|
26 |
+
helm install ollama-webui ./ollama-webui-*.tgz
|
27 |
+
```
|
28 |
+
|
29 |
+
For gpu-enabled pod
|
30 |
+
|
31 |
+
```bash
|
32 |
+
helm install ollama-webui ./ollama-webui-*.tgz --set ollama.resources.limits.nvidia.com/gpu="1"
|
33 |
+
```
|
34 |
+
|
35 |
+
Check the `kubernetes/helm/values.yaml` file to know which parameters are available for customization
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Timothy Jaeryang Baek
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Makefile
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
ifneq ($(shell which docker-compose 2>/dev/null),)
|
3 |
+
DOCKER_COMPOSE := docker-compose
|
4 |
+
else
|
5 |
+
DOCKER_COMPOSE := docker compose
|
6 |
+
endif
|
7 |
+
|
8 |
+
install:
|
9 |
+
$(DOCKER_COMPOSE) up -d
|
10 |
+
|
11 |
+
remove:
|
12 |
+
@chmod +x confirm_remove.sh
|
13 |
+
@./confirm_remove.sh
|
14 |
+
|
15 |
+
start:
|
16 |
+
$(DOCKER_COMPOSE) start
|
17 |
+
startAndBuild:
|
18 |
+
$(DOCKER_COMPOSE) up -d --build
|
19 |
+
|
20 |
+
stop:
|
21 |
+
$(DOCKER_COMPOSE) stop
|
22 |
+
|
23 |
+
update:
|
24 |
+
# Calls the LLM update script
|
25 |
+
chmod +x update_ollama_models.sh
|
26 |
+
@./update_ollama_models.sh
|
27 |
+
@git pull
|
28 |
+
$(DOCKER_COMPOSE) down
|
29 |
+
# Make sure the ollama-webui container is stopped before rebuilding
|
30 |
+
@docker stop open-webui || true
|
31 |
+
$(DOCKER_COMPOSE) up --build -d
|
32 |
+
$(DOCKER_COMPOSE) start
|
33 |
+
|
README.md
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Open WebUI
|
3 |
+
emoji: 🐳
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: gray
|
6 |
+
sdk: docker
|
7 |
+
app_port: 8080
|
8 |
+
hf_oauth: true
|
9 |
+
hf_oauth_scopes:
|
10 |
+
- email
|
11 |
+
---
|
12 |
+
---
|
13 |
+
title: Open WebUI
|
14 |
+
emoji: 🐳
|
15 |
+
colorFrom: purple
|
16 |
+
colorTo: gray
|
17 |
+
sdk: docker
|
18 |
+
app_port: 8080
|
19 |
+
---
|
20 |
+
|
21 |
+
# Open WebUI 👋
|
22 |
+
|
23 |
+
![GitHub stars](https://img.shields.io/github/stars/open-webui/open-webui?style=social)
|
24 |
+
![GitHub forks](https://img.shields.io/github/forks/open-webui/open-webui?style=social)
|
25 |
+
![GitHub watchers](https://img.shields.io/github/watchers/open-webui/open-webui?style=social)
|
26 |
+
![GitHub repo size](https://img.shields.io/github/repo-size/open-webui/open-webui)
|
27 |
+
![GitHub language count](https://img.shields.io/github/languages/count/open-webui/open-webui)
|
28 |
+
![GitHub top language](https://img.shields.io/github/languages/top/open-webui/open-webui)
|
29 |
+
![GitHub last commit](https://img.shields.io/github/last-commit/open-webui/open-webui?color=red)
|
30 |
+
![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Follama-webui%2Follama-wbui&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)
|
31 |
+
[![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s)
|
32 |
+
[![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck)
|
33 |
+
|
34 |
+
Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-rich, and user-friendly self-hosted WebUI designed to operate entirely offline. It supports various LLM runners, including Ollama and OpenAI-compatible APIs. For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
|
35 |
+
|
36 |
+
![Open WebUI Demo](./demo.gif)
|
37 |
+
|
38 |
+
## Key Features of Open WebUI ⭐
|
39 |
+
|
40 |
+
- 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images.
|
41 |
+
|
42 |
+
- 🤝 **Ollama/OpenAI API Integration**: Effortlessly integrate OpenAI-compatible APIs for versatile conversations alongside Ollama models. Customize the OpenAI API URL to link with **LMStudio, GroqCloud, Mistral, OpenRouter, and more**.
|
43 |
+
|
44 |
+
- 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users.
|
45 |
+
|
46 |
+
- 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
|
47 |
+
|
48 |
+
- 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
|
49 |
+
|
50 |
+
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
|
51 |
+
|
52 |
+
- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment.
|
53 |
+
|
54 |
+
- 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
|
55 |
+
|
56 |
+
- 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
|
57 |
+
|
58 |
+
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
|
59 |
+
|
60 |
+
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience.
|
61 |
+
|
62 |
+
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
|
63 |
+
|
64 |
+
- 🎨 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API or ComfyUI (local), and OpenAI's DALL-E (external), enriching your chat experience with dynamic visual content.
|
65 |
+
|
66 |
+
- ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel.
|
67 |
+
|
68 |
+
- 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
|
69 |
+
|
70 |
+
- 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
|
71 |
+
|
72 |
+
- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
|
73 |
+
|
74 |
+
- 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates, fixes, and new features.
|
75 |
+
|
76 |
+
Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!
|
77 |
+
|
78 |
+
## 🔗 Also Check Out Open WebUI Community!
|
79 |
+
|
80 |
+
Don't forget to explore our sibling project, [Open WebUI Community](https://openwebui.com/), where you can discover, download, and explore customized Modelfiles. Open WebUI Community offers a wide range of exciting possibilities for enhancing your chat interactions with Open WebUI! 🚀
|
81 |
+
|
82 |
+
## How to Install 🚀
|
83 |
+
|
84 |
+
### Installation via Python pip 🐍
|
85 |
+
|
86 |
+
Open WebUI can be installed using pip, the Python package installer. Before proceeding, ensure you're using **Python 3.11** to avoid compatibility issues.
|
87 |
+
|
88 |
+
1. **Install Open WebUI**:
|
89 |
+
Open your terminal and run the following command to install Open WebUI:
|
90 |
+
|
91 |
+
```bash
|
92 |
+
pip install open-webui
|
93 |
+
```
|
94 |
+
|
95 |
+
2. **Running Open WebUI**:
|
96 |
+
After installation, you can start Open WebUI by executing:
|
97 |
+
|
98 |
+
```bash
|
99 |
+
open-webui serve
|
100 |
+
```
|
101 |
+
|
102 |
+
This will start the Open WebUI server, which you can access at [http://localhost:8080](http://localhost:8080)
|
103 |
+
|
104 |
+
### Quick Start with Docker 🐳
|
105 |
+
|
106 |
+
> [!NOTE]
|
107 |
+
> Please note that for certain Docker environments, additional configurations might be needed. If you encounter any connection issues, our detailed guide on [Open WebUI Documentation](https://docs.openwebui.com/) is ready to assist you.
|
108 |
+
|
109 |
+
> [!WARNING]
|
110 |
+
> When using Docker to install Open WebUI, make sure to include the `-v open-webui:/app/backend/data` in your Docker command. This step is crucial as it ensures your database is properly mounted and prevents any loss of data.
|
111 |
+
|
112 |
+
> [!TIP]
|
113 |
+
> If you wish to utilize Open WebUI with Ollama included or CUDA acceleration, we recommend utilizing our official images tagged with either `:cuda` or `:ollama`. To enable CUDA, you must install the [Nvidia CUDA container toolkit](https://docs.nvidia.com/dgx/nvidia-container-runtime-upgrade/) on your Linux/WSL system.
|
114 |
+
|
115 |
+
### Installation with Default Configuration
|
116 |
+
|
117 |
+
- **If Ollama is on your computer**, use this command:
|
118 |
+
|
119 |
+
```bash
|
120 |
+
docker run -d -p 3000:8080 --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
|
121 |
+
```
|
122 |
+
|
123 |
+
- **If Ollama is on a Different Server**, use this command:
|
124 |
+
|
125 |
+
To connect to Ollama on another server, change the `OLLAMA_BASE_URL` to the server's URL:
|
126 |
+
|
127 |
+
```bash
|
128 |
+
docker run -d -p 3000:8080 -e OLLAMA_BASE_URL=https://example.com -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
|
129 |
+
```
|
130 |
+
|
131 |
+
- **To run Open WebUI with Nvidia GPU support**, use this command:
|
132 |
+
|
133 |
+
```bash
|
134 |
+
docker run -d -p 3000:8080 --gpus all --add-host=host.docker.internal:host-gateway -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:cuda
|
135 |
+
```
|
136 |
+
|
137 |
+
### Installation for OpenAI API Usage Only
|
138 |
+
|
139 |
+
- **If you're only using OpenAI API**, use this command:
|
140 |
+
|
141 |
+
```bash
|
142 |
+
docker run -d -p 3000:8080 -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
|
143 |
+
```
|
144 |
+
|
145 |
+
### Installing Open WebUI with Bundled Ollama Support
|
146 |
+
|
147 |
+
This installation method uses a single container image that bundles Open WebUI with Ollama, allowing for a streamlined setup via a single command. Choose the appropriate command based on your hardware setup:
|
148 |
+
|
149 |
+
- **With GPU Support**:
|
150 |
+
Utilize GPU resources by running the following command:
|
151 |
+
|
152 |
+
```bash
|
153 |
+
docker run -d -p 3000:8080 --gpus=all -v ollama:/root/.ollama -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:ollama
|
154 |
+
```
|
155 |
+
|
156 |
+
- **For CPU Only**:
|
157 |
+
If you're not using a GPU, use this command instead:
|
158 |
+
|
159 |
+
```bash
|
160 |
+
docker run -d -p 3000:8080 -v ollama:/root/.ollama -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:ollama
|
161 |
+
```
|
162 |
+
|
163 |
+
Both commands facilitate a built-in, hassle-free installation of both Open WebUI and Ollama, ensuring that you can get everything up and running swiftly.
|
164 |
+
|
165 |
+
After installation, you can access Open WebUI at [http://localhost:3000](http://localhost:3000). Enjoy! 😄
|
166 |
+
|
167 |
+
### Other Installation Methods
|
168 |
+
|
169 |
+
We offer various installation alternatives, including non-Docker native installation methods, Docker Compose, Kustomize, and Helm. Visit our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/) or join our [Discord community](https://discord.gg/5rJgQTnV4s) for comprehensive guidance.
|
170 |
+
|
171 |
+
### Troubleshooting
|
172 |
+
|
173 |
+
Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s).
|
174 |
+
|
175 |
+
#### Open WebUI: Server Connection Error
|
176 |
+
|
177 |
+
If you're experiencing connection issues, it’s often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`.
|
178 |
+
|
179 |
+
**Example Docker Command**:
|
180 |
+
|
181 |
+
```bash
|
182 |
+
docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
|
183 |
+
```
|
184 |
+
|
185 |
+
### Keeping Your Docker Installation Up-to-Date
|
186 |
+
|
187 |
+
In case you want to update your local Docker installation to the latest version, you can do it with [Watchtower](https://containrrr.dev/watchtower/):
|
188 |
+
|
189 |
+
```bash
|
190 |
+
docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower --run-once open-webui
|
191 |
+
```
|
192 |
+
|
193 |
+
In the last part of the command, replace `open-webui` with your container name if it is different.
|
194 |
+
|
195 |
+
Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/tutorials/migration/).
|
196 |
+
|
197 |
+
### Using the Dev Branch 🌙
|
198 |
+
|
199 |
+
> [!WARNING]
|
200 |
+
> The `:dev` branch contains the latest unstable features and changes. Use it at your own risk as it may have bugs or incomplete features.
|
201 |
+
|
202 |
+
If you want to try out the latest bleeding-edge features and are okay with occasional instability, you can use the `:dev` tag like this:
|
203 |
+
|
204 |
+
```bash
|
205 |
+
docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --add-host=host.docker.internal:host-gateway --restart always ghcr.io/open-webui/open-webui:dev
|
206 |
+
```
|
207 |
+
|
208 |
+
## What's Next? 🌟
|
209 |
+
|
210 |
+
Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/).
|
211 |
+
|
212 |
+
## License 📜
|
213 |
+
|
214 |
+
This project is licensed under the [MIT License](LICENSE) - see the [LICENSE](LICENSE) file for details. 📄
|
215 |
+
|
216 |
+
## Support 💬
|
217 |
+
|
218 |
+
If you have any questions, suggestions, or need assistance, please open an issue or join our
|
219 |
+
[Open WebUI Discord community](https://discord.gg/5rJgQTnV4s) to connect with us! 🤝
|
220 |
+
|
221 |
+
## Star History
|
222 |
+
|
223 |
+
<a href="https://star-history.com/#open-webui/open-webui&Date">
|
224 |
+
<picture>
|
225 |
+
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date&theme=dark" />
|
226 |
+
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date" />
|
227 |
+
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=open-webui/open-webui&type=Date" />
|
228 |
+
</picture>
|
229 |
+
</a>
|
230 |
+
|
231 |
+
---
|
232 |
+
|
233 |
+
Created by [Timothy Jaeryang Baek](https://github.com/tjbck) - Let's make Open WebUI even more amazing together! 💪
|
TROUBLESHOOTING.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Open WebUI Troubleshooting Guide
|
2 |
+
|
3 |
+
## Understanding the Open WebUI Architecture
|
4 |
+
|
5 |
+
The Open WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues.
|
6 |
+
|
7 |
+
- **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_BASE_URL` environment variable. Therefore, a request made to `/ollama` in the WebUI is effectively the same as making a request to `OLLAMA_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_BASE_URL/api/tags` in the backend.
|
8 |
+
|
9 |
+
- **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer.
|
10 |
+
|
11 |
+
## Open WebUI: Server Connection Error
|
12 |
+
|
13 |
+
If you're experiencing connection issues, it’s often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`.
|
14 |
+
|
15 |
+
**Example Docker Command**:
|
16 |
+
|
17 |
+
```bash
|
18 |
+
docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
|
19 |
+
```
|
20 |
+
|
21 |
+
### Error on Slow Responses for Ollama
|
22 |
+
|
23 |
+
Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds.
|
24 |
+
|
25 |
+
### General Connection Errors
|
26 |
+
|
27 |
+
**Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.com/) for the latest updates.
|
28 |
+
|
29 |
+
**Troubleshooting Steps**:
|
30 |
+
|
31 |
+
1. **Verify Ollama URL Format**:
|
32 |
+
- When running the Web UI container, ensure the `OLLAMA_BASE_URL` is correctly set. (e.g., `http://192.168.1.1:11434` for different host setups).
|
33 |
+
- In the Open WebUI, navigate to "Settings" > "General".
|
34 |
+
- Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]` (e.g., `http://localhost:11434`).
|
35 |
+
|
36 |
+
By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord.
|
backend/.dockerignore
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.env
|
3 |
+
_old
|
4 |
+
uploads
|
5 |
+
.ipynb_checkpoints
|
6 |
+
*.db
|
7 |
+
_test
|
8 |
+
!/data
|
9 |
+
/data/*
|
10 |
+
!/data/litellm
|
11 |
+
/data/litellm/*
|
12 |
+
!data/litellm/config.yaml
|
13 |
+
|
14 |
+
!data/config.json
|
backend/.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.env
|
3 |
+
_old
|
4 |
+
uploads
|
5 |
+
.ipynb_checkpoints
|
6 |
+
*.db
|
7 |
+
_test
|
8 |
+
Pipfile
|
9 |
+
!/data
|
10 |
+
/data/*
|
11 |
+
/open_webui/data/*
|
12 |
+
.webui_secret_key
|
backend/dev.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
PORT="${PORT:-8080}"
|
2 |
+
uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload
|
backend/open_webui/__init__.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import typer
|
7 |
+
import uvicorn
|
8 |
+
|
9 |
+
app = typer.Typer()
|
10 |
+
|
11 |
+
KEY_FILE = Path.cwd() / ".webui_secret_key"
|
12 |
+
|
13 |
+
|
14 |
+
@app.command()
|
15 |
+
def serve(
|
16 |
+
host: str = "0.0.0.0",
|
17 |
+
port: int = 8080,
|
18 |
+
):
|
19 |
+
os.environ["FROM_INIT_PY"] = "true"
|
20 |
+
if os.getenv("WEBUI_SECRET_KEY") is None:
|
21 |
+
typer.echo(
|
22 |
+
"Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
|
23 |
+
)
|
24 |
+
if not KEY_FILE.exists():
|
25 |
+
typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}")
|
26 |
+
KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12)))
|
27 |
+
typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}")
|
28 |
+
os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text()
|
29 |
+
|
30 |
+
if os.getenv("USE_CUDA_DOCKER", "false") == "true":
|
31 |
+
typer.echo(
|
32 |
+
"CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
|
33 |
+
)
|
34 |
+
LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":")
|
35 |
+
os.environ["LD_LIBRARY_PATH"] = ":".join(
|
36 |
+
LD_LIBRARY_PATH
|
37 |
+
+ [
|
38 |
+
"/usr/local/lib/python3.11/site-packages/torch/lib",
|
39 |
+
"/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
|
40 |
+
]
|
41 |
+
)
|
42 |
+
try:
|
43 |
+
import torch
|
44 |
+
|
45 |
+
assert torch.cuda.is_available(), "CUDA not available"
|
46 |
+
typer.echo("CUDA seems to be working")
|
47 |
+
except Exception as e:
|
48 |
+
typer.echo(
|
49 |
+
"Error when testing CUDA but USE_CUDA_DOCKER is true. "
|
50 |
+
"Resetting USE_CUDA_DOCKER to false and removing "
|
51 |
+
f"LD_LIBRARY_PATH modifications: {e}"
|
52 |
+
)
|
53 |
+
os.environ["USE_CUDA_DOCKER"] = "false"
|
54 |
+
os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
|
55 |
+
|
56 |
+
import open_webui.main # we need set environment variables before importing main
|
57 |
+
|
58 |
+
uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*")
|
59 |
+
|
60 |
+
|
61 |
+
@app.command()
|
62 |
+
def dev(
|
63 |
+
host: str = "0.0.0.0",
|
64 |
+
port: int = 8080,
|
65 |
+
reload: bool = True,
|
66 |
+
):
|
67 |
+
uvicorn.run(
|
68 |
+
"open_webui.main:app",
|
69 |
+
host=host,
|
70 |
+
port=port,
|
71 |
+
reload=reload,
|
72 |
+
forwarded_allow_ips="*",
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
app()
|
backend/open_webui/alembic.ini
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A generic, single database configuration.
|
2 |
+
|
3 |
+
[alembic]
|
4 |
+
# path to migration scripts
|
5 |
+
script_location = migrations
|
6 |
+
|
7 |
+
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
8 |
+
# Uncomment the line below if you want the files to be prepended with date and time
|
9 |
+
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
10 |
+
|
11 |
+
# sys.path path, will be prepended to sys.path if present.
|
12 |
+
# defaults to the current working directory.
|
13 |
+
prepend_sys_path = .
|
14 |
+
|
15 |
+
# timezone to use when rendering the date within the migration file
|
16 |
+
# as well as the filename.
|
17 |
+
# If specified, requires the python>=3.9 or backports.zoneinfo library.
|
18 |
+
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
19 |
+
# string value is passed to ZoneInfo()
|
20 |
+
# leave blank for localtime
|
21 |
+
# timezone =
|
22 |
+
|
23 |
+
# max length of characters to apply to the
|
24 |
+
# "slug" field
|
25 |
+
# truncate_slug_length = 40
|
26 |
+
|
27 |
+
# set to 'true' to run the environment during
|
28 |
+
# the 'revision' command, regardless of autogenerate
|
29 |
+
# revision_environment = false
|
30 |
+
|
31 |
+
# set to 'true' to allow .pyc and .pyo files without
|
32 |
+
# a source .py file to be detected as revisions in the
|
33 |
+
# versions/ directory
|
34 |
+
# sourceless = false
|
35 |
+
|
36 |
+
# version location specification; This defaults
|
37 |
+
# to migrations/versions. When using multiple version
|
38 |
+
# directories, initial revisions must be specified with --version-path.
|
39 |
+
# The path separator used here should be the separator specified by "version_path_separator" below.
|
40 |
+
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
|
41 |
+
|
42 |
+
# version path separator; As mentioned above, this is the character used to split
|
43 |
+
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
44 |
+
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
45 |
+
# Valid values for version_path_separator are:
|
46 |
+
#
|
47 |
+
# version_path_separator = :
|
48 |
+
# version_path_separator = ;
|
49 |
+
# version_path_separator = space
|
50 |
+
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
51 |
+
|
52 |
+
# set to 'true' to search source files recursively
|
53 |
+
# in each "version_locations" directory
|
54 |
+
# new in Alembic version 1.10
|
55 |
+
# recursive_version_locations = false
|
56 |
+
|
57 |
+
# the output encoding used when revision files
|
58 |
+
# are written from script.py.mako
|
59 |
+
# output_encoding = utf-8
|
60 |
+
|
61 |
+
# sqlalchemy.url = REPLACE_WITH_DATABASE_URL
|
62 |
+
|
63 |
+
|
64 |
+
[post_write_hooks]
|
65 |
+
# post_write_hooks defines scripts or Python functions that are run
|
66 |
+
# on newly generated revision scripts. See the documentation for further
|
67 |
+
# detail and examples
|
68 |
+
|
69 |
+
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
70 |
+
# hooks = black
|
71 |
+
# black.type = console_scripts
|
72 |
+
# black.entrypoint = black
|
73 |
+
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
74 |
+
|
75 |
+
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
76 |
+
# hooks = ruff
|
77 |
+
# ruff.type = exec
|
78 |
+
# ruff.executable = %(here)s/.venv/bin/ruff
|
79 |
+
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
80 |
+
|
81 |
+
# Logging configuration
|
82 |
+
[loggers]
|
83 |
+
keys = root,sqlalchemy,alembic
|
84 |
+
|
85 |
+
[handlers]
|
86 |
+
keys = console
|
87 |
+
|
88 |
+
[formatters]
|
89 |
+
keys = generic
|
90 |
+
|
91 |
+
[logger_root]
|
92 |
+
level = WARN
|
93 |
+
handlers = console
|
94 |
+
qualname =
|
95 |
+
|
96 |
+
[logger_sqlalchemy]
|
97 |
+
level = WARN
|
98 |
+
handlers =
|
99 |
+
qualname = sqlalchemy.engine
|
100 |
+
|
101 |
+
[logger_alembic]
|
102 |
+
level = INFO
|
103 |
+
handlers =
|
104 |
+
qualname = alembic
|
105 |
+
|
106 |
+
[handler_console]
|
107 |
+
class = StreamHandler
|
108 |
+
args = (sys.stderr,)
|
109 |
+
level = NOTSET
|
110 |
+
formatter = generic
|
111 |
+
|
112 |
+
[formatter_generic]
|
113 |
+
format = %(levelname)-5.5s [%(name)s] %(message)s
|
114 |
+
datefmt = %H:%M:%S
|
backend/open_webui/apps/audio/main.py
ADDED
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import uuid
|
6 |
+
from functools import lru_cache
|
7 |
+
from pathlib import Path
|
8 |
+
from pydub import AudioSegment
|
9 |
+
from pydub.silence import split_on_silence
|
10 |
+
|
11 |
+
import aiohttp
|
12 |
+
import aiofiles
|
13 |
+
import requests
|
14 |
+
from open_webui.config import (
|
15 |
+
AUDIO_STT_ENGINE,
|
16 |
+
AUDIO_STT_MODEL,
|
17 |
+
AUDIO_STT_OPENAI_API_BASE_URL,
|
18 |
+
AUDIO_STT_OPENAI_API_KEY,
|
19 |
+
AUDIO_TTS_API_KEY,
|
20 |
+
AUDIO_TTS_ENGINE,
|
21 |
+
AUDIO_TTS_MODEL,
|
22 |
+
AUDIO_TTS_OPENAI_API_BASE_URL,
|
23 |
+
AUDIO_TTS_OPENAI_API_KEY,
|
24 |
+
AUDIO_TTS_SPLIT_ON,
|
25 |
+
AUDIO_TTS_VOICE,
|
26 |
+
AUDIO_TTS_AZURE_SPEECH_REGION,
|
27 |
+
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
28 |
+
CACHE_DIR,
|
29 |
+
CORS_ALLOW_ORIGIN,
|
30 |
+
WHISPER_MODEL,
|
31 |
+
WHISPER_MODEL_AUTO_UPDATE,
|
32 |
+
WHISPER_MODEL_DIR,
|
33 |
+
AppConfig,
|
34 |
+
)
|
35 |
+
|
36 |
+
from open_webui.constants import ERROR_MESSAGES
|
37 |
+
from open_webui.env import (
|
38 |
+
ENV,
|
39 |
+
SRC_LOG_LEVELS,
|
40 |
+
DEVICE_TYPE,
|
41 |
+
ENABLE_FORWARD_USER_INFO_HEADERS,
|
42 |
+
)
|
43 |
+
|
44 |
+
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
|
45 |
+
from fastapi.middleware.cors import CORSMiddleware
|
46 |
+
from fastapi.responses import FileResponse
|
47 |
+
from pydantic import BaseModel
|
48 |
+
from open_webui.utils.utils import get_admin_user, get_verified_user
|
49 |
+
|
50 |
+
# Constants
|
51 |
+
MAX_FILE_SIZE_MB = 25
|
52 |
+
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
53 |
+
|
54 |
+
|
55 |
+
log = logging.getLogger(__name__)
|
56 |
+
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
57 |
+
|
58 |
+
app = FastAPI(
|
59 |
+
docs_url="/docs" if ENV == "dev" else None,
|
60 |
+
openapi_url="/openapi.json" if ENV == "dev" else None,
|
61 |
+
redoc_url=None,
|
62 |
+
)
|
63 |
+
|
64 |
+
app.add_middleware(
|
65 |
+
CORSMiddleware,
|
66 |
+
allow_origins=CORS_ALLOW_ORIGIN,
|
67 |
+
allow_credentials=True,
|
68 |
+
allow_methods=["*"],
|
69 |
+
allow_headers=["*"],
|
70 |
+
)
|
71 |
+
|
72 |
+
app.state.config = AppConfig()
|
73 |
+
|
74 |
+
app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
|
75 |
+
app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
|
76 |
+
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
|
77 |
+
app.state.config.STT_MODEL = AUDIO_STT_MODEL
|
78 |
+
|
79 |
+
app.state.config.WHISPER_MODEL = WHISPER_MODEL
|
80 |
+
app.state.faster_whisper_model = None
|
81 |
+
|
82 |
+
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
83 |
+
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
84 |
+
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
|
85 |
+
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
|
86 |
+
app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
|
87 |
+
app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
|
88 |
+
app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
|
89 |
+
|
90 |
+
|
91 |
+
app.state.speech_synthesiser = None
|
92 |
+
app.state.speech_speaker_embeddings_dataset = None
|
93 |
+
|
94 |
+
app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
|
95 |
+
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
|
96 |
+
|
97 |
+
# setting device type for whisper model
|
98 |
+
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
|
99 |
+
log.info(f"whisper_device_type: {whisper_device_type}")
|
100 |
+
|
101 |
+
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
102 |
+
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
103 |
+
|
104 |
+
|
105 |
+
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
106 |
+
if model and app.state.config.STT_ENGINE == "":
|
107 |
+
from faster_whisper import WhisperModel
|
108 |
+
|
109 |
+
faster_whisper_kwargs = {
|
110 |
+
"model_size_or_path": model,
|
111 |
+
"device": whisper_device_type,
|
112 |
+
"compute_type": "int8",
|
113 |
+
"download_root": WHISPER_MODEL_DIR,
|
114 |
+
"local_files_only": not auto_update,
|
115 |
+
}
|
116 |
+
|
117 |
+
try:
|
118 |
+
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
|
119 |
+
except Exception:
|
120 |
+
log.warning(
|
121 |
+
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
122 |
+
)
|
123 |
+
faster_whisper_kwargs["local_files_only"] = False
|
124 |
+
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
|
125 |
+
|
126 |
+
else:
|
127 |
+
app.state.faster_whisper_model = None
|
128 |
+
|
129 |
+
|
130 |
+
class TTSConfigForm(BaseModel):
|
131 |
+
OPENAI_API_BASE_URL: str
|
132 |
+
OPENAI_API_KEY: str
|
133 |
+
API_KEY: str
|
134 |
+
ENGINE: str
|
135 |
+
MODEL: str
|
136 |
+
VOICE: str
|
137 |
+
SPLIT_ON: str
|
138 |
+
AZURE_SPEECH_REGION: str
|
139 |
+
AZURE_SPEECH_OUTPUT_FORMAT: str
|
140 |
+
|
141 |
+
|
142 |
+
class STTConfigForm(BaseModel):
|
143 |
+
OPENAI_API_BASE_URL: str
|
144 |
+
OPENAI_API_KEY: str
|
145 |
+
ENGINE: str
|
146 |
+
MODEL: str
|
147 |
+
WHISPER_MODEL: str
|
148 |
+
|
149 |
+
|
150 |
+
class AudioConfigUpdateForm(BaseModel):
|
151 |
+
tts: TTSConfigForm
|
152 |
+
stt: STTConfigForm
|
153 |
+
|
154 |
+
|
155 |
+
from pydub import AudioSegment
|
156 |
+
from pydub.utils import mediainfo
|
157 |
+
|
158 |
+
|
159 |
+
def is_mp4_audio(file_path):
|
160 |
+
"""Check if the given file is an MP4 audio file."""
|
161 |
+
if not os.path.isfile(file_path):
|
162 |
+
print(f"File not found: {file_path}")
|
163 |
+
return False
|
164 |
+
|
165 |
+
info = mediainfo(file_path)
|
166 |
+
if (
|
167 |
+
info.get("codec_name") == "aac"
|
168 |
+
and info.get("codec_type") == "audio"
|
169 |
+
and info.get("codec_tag_string") == "mp4a"
|
170 |
+
):
|
171 |
+
return True
|
172 |
+
return False
|
173 |
+
|
174 |
+
|
175 |
+
def convert_mp4_to_wav(file_path, output_path):
|
176 |
+
"""Convert MP4 audio file to WAV format."""
|
177 |
+
audio = AudioSegment.from_file(file_path, format="mp4")
|
178 |
+
audio.export(output_path, format="wav")
|
179 |
+
print(f"Converted {file_path} to {output_path}")
|
180 |
+
|
181 |
+
|
182 |
+
@app.get("/config")
|
183 |
+
async def get_audio_config(user=Depends(get_admin_user)):
|
184 |
+
return {
|
185 |
+
"tts": {
|
186 |
+
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
|
187 |
+
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
|
188 |
+
"API_KEY": app.state.config.TTS_API_KEY,
|
189 |
+
"ENGINE": app.state.config.TTS_ENGINE,
|
190 |
+
"MODEL": app.state.config.TTS_MODEL,
|
191 |
+
"VOICE": app.state.config.TTS_VOICE,
|
192 |
+
"SPLIT_ON": app.state.config.TTS_SPLIT_ON,
|
193 |
+
"AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
|
194 |
+
"AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
195 |
+
},
|
196 |
+
"stt": {
|
197 |
+
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
|
198 |
+
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
199 |
+
"ENGINE": app.state.config.STT_ENGINE,
|
200 |
+
"MODEL": app.state.config.STT_MODEL,
|
201 |
+
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
|
202 |
+
},
|
203 |
+
}
|
204 |
+
|
205 |
+
|
206 |
+
@app.post("/config/update")
|
207 |
+
async def update_audio_config(
|
208 |
+
form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
|
209 |
+
):
|
210 |
+
app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
|
211 |
+
app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
|
212 |
+
app.state.config.TTS_API_KEY = form_data.tts.API_KEY
|
213 |
+
app.state.config.TTS_ENGINE = form_data.tts.ENGINE
|
214 |
+
app.state.config.TTS_MODEL = form_data.tts.MODEL
|
215 |
+
app.state.config.TTS_VOICE = form_data.tts.VOICE
|
216 |
+
app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
|
217 |
+
app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
|
218 |
+
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
|
219 |
+
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
|
220 |
+
)
|
221 |
+
|
222 |
+
app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
|
223 |
+
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
|
224 |
+
app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
225 |
+
app.state.config.STT_MODEL = form_data.stt.MODEL
|
226 |
+
app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
227 |
+
set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
|
228 |
+
|
229 |
+
return {
|
230 |
+
"tts": {
|
231 |
+
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
|
232 |
+
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
|
233 |
+
"API_KEY": app.state.config.TTS_API_KEY,
|
234 |
+
"ENGINE": app.state.config.TTS_ENGINE,
|
235 |
+
"MODEL": app.state.config.TTS_MODEL,
|
236 |
+
"VOICE": app.state.config.TTS_VOICE,
|
237 |
+
"SPLIT_ON": app.state.config.TTS_SPLIT_ON,
|
238 |
+
"AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
|
239 |
+
"AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
240 |
+
},
|
241 |
+
"stt": {
|
242 |
+
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
|
243 |
+
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
244 |
+
"ENGINE": app.state.config.STT_ENGINE,
|
245 |
+
"MODEL": app.state.config.STT_MODEL,
|
246 |
+
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
|
247 |
+
},
|
248 |
+
}
|
249 |
+
|
250 |
+
|
251 |
+
def load_speech_pipeline():
|
252 |
+
from transformers import pipeline
|
253 |
+
from datasets import load_dataset
|
254 |
+
|
255 |
+
if app.state.speech_synthesiser is None:
|
256 |
+
app.state.speech_synthesiser = pipeline(
|
257 |
+
"text-to-speech", "microsoft/speecht5_tts"
|
258 |
+
)
|
259 |
+
|
260 |
+
if app.state.speech_speaker_embeddings_dataset is None:
|
261 |
+
app.state.speech_speaker_embeddings_dataset = load_dataset(
|
262 |
+
"Matthijs/cmu-arctic-xvectors", split="validation"
|
263 |
+
)
|
264 |
+
|
265 |
+
|
266 |
+
@app.post("/speech")
|
267 |
+
async def speech(request: Request, user=Depends(get_verified_user)):
|
268 |
+
body = await request.body()
|
269 |
+
name = hashlib.sha256(body).hexdigest()
|
270 |
+
|
271 |
+
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
272 |
+
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
273 |
+
|
274 |
+
# Check if the file already exists in the cache
|
275 |
+
if file_path.is_file():
|
276 |
+
return FileResponse(file_path)
|
277 |
+
|
278 |
+
if app.state.config.TTS_ENGINE == "openai":
|
279 |
+
headers = {}
|
280 |
+
headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
|
281 |
+
headers["Content-Type"] = "application/json"
|
282 |
+
|
283 |
+
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
284 |
+
headers["X-OpenWebUI-User-Name"] = user.name
|
285 |
+
headers["X-OpenWebUI-User-Id"] = user.id
|
286 |
+
headers["X-OpenWebUI-User-Email"] = user.email
|
287 |
+
headers["X-OpenWebUI-User-Role"] = user.role
|
288 |
+
|
289 |
+
try:
|
290 |
+
body = body.decode("utf-8")
|
291 |
+
body = json.loads(body)
|
292 |
+
body["model"] = app.state.config.TTS_MODEL
|
293 |
+
body = json.dumps(body).encode("utf-8")
|
294 |
+
except Exception:
|
295 |
+
pass
|
296 |
+
|
297 |
+
try:
|
298 |
+
async with aiohttp.ClientSession() as session:
|
299 |
+
async with session.post(
|
300 |
+
url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
301 |
+
data=body,
|
302 |
+
headers=headers,
|
303 |
+
) as r:
|
304 |
+
r.raise_for_status()
|
305 |
+
async with aiofiles.open(file_path, "wb") as f:
|
306 |
+
await f.write(await r.read())
|
307 |
+
|
308 |
+
async with aiofiles.open(file_body_path, "w") as f:
|
309 |
+
await f.write(json.dumps(json.loads(body.decode("utf-8"))))
|
310 |
+
|
311 |
+
return FileResponse(file_path)
|
312 |
+
|
313 |
+
except Exception as e:
|
314 |
+
log.exception(e)
|
315 |
+
error_detail = "Open WebUI: Server Connection Error"
|
316 |
+
try:
|
317 |
+
if r.status != 200:
|
318 |
+
res = await r.json()
|
319 |
+
if "error" in res:
|
320 |
+
error_detail = f"External: {res['error']['message']}"
|
321 |
+
except Exception:
|
322 |
+
error_detail = f"External: {e}"
|
323 |
+
|
324 |
+
raise HTTPException(
|
325 |
+
status_code=getattr(r, "status", 500),
|
326 |
+
detail=error_detail,
|
327 |
+
)
|
328 |
+
|
329 |
+
elif app.state.config.TTS_ENGINE == "elevenlabs":
|
330 |
+
try:
|
331 |
+
payload = json.loads(body.decode("utf-8"))
|
332 |
+
except Exception as e:
|
333 |
+
log.exception(e)
|
334 |
+
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
335 |
+
|
336 |
+
voice_id = payload.get("voice", "")
|
337 |
+
if voice_id not in get_available_voices():
|
338 |
+
raise HTTPException(
|
339 |
+
status_code=400,
|
340 |
+
detail="Invalid voice id",
|
341 |
+
)
|
342 |
+
|
343 |
+
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
|
344 |
+
headers = {
|
345 |
+
"Accept": "audio/mpeg",
|
346 |
+
"Content-Type": "application/json",
|
347 |
+
"xi-api-key": app.state.config.TTS_API_KEY,
|
348 |
+
}
|
349 |
+
data = {
|
350 |
+
"text": payload["input"],
|
351 |
+
"model_id": app.state.config.TTS_MODEL,
|
352 |
+
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
|
353 |
+
}
|
354 |
+
|
355 |
+
try:
|
356 |
+
async with aiohttp.ClientSession() as session:
|
357 |
+
async with session.post(url, json=data, headers=headers) as r:
|
358 |
+
r.raise_for_status()
|
359 |
+
async with aiofiles.open(file_path, "wb") as f:
|
360 |
+
await f.write(await r.read())
|
361 |
+
|
362 |
+
async with aiofiles.open(file_body_path, "w") as f:
|
363 |
+
await f.write(json.dumps(json.loads(body.decode("utf-8"))))
|
364 |
+
|
365 |
+
return FileResponse(file_path)
|
366 |
+
|
367 |
+
except Exception as e:
|
368 |
+
log.exception(e)
|
369 |
+
error_detail = "Open WebUI: Server Connection Error"
|
370 |
+
try:
|
371 |
+
if r.status != 200:
|
372 |
+
res = await r.json()
|
373 |
+
if "error" in res:
|
374 |
+
error_detail = f"External: {res['error']['message']}"
|
375 |
+
except Exception:
|
376 |
+
error_detail = f"External: {e}"
|
377 |
+
|
378 |
+
raise HTTPException(
|
379 |
+
status_code=getattr(r, "status", 500),
|
380 |
+
detail=error_detail,
|
381 |
+
)
|
382 |
+
|
383 |
+
elif app.state.config.TTS_ENGINE == "azure":
|
384 |
+
try:
|
385 |
+
payload = json.loads(body.decode("utf-8"))
|
386 |
+
except Exception as e:
|
387 |
+
log.exception(e)
|
388 |
+
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
389 |
+
|
390 |
+
region = app.state.config.TTS_AZURE_SPEECH_REGION
|
391 |
+
language = app.state.config.TTS_VOICE
|
392 |
+
locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
|
393 |
+
output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
|
394 |
+
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
395 |
+
|
396 |
+
headers = {
|
397 |
+
"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
|
398 |
+
"Content-Type": "application/ssml+xml",
|
399 |
+
"X-Microsoft-OutputFormat": output_format,
|
400 |
+
}
|
401 |
+
|
402 |
+
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
403 |
+
<voice name="{language}">{payload["input"]}</voice>
|
404 |
+
</speak>"""
|
405 |
+
|
406 |
+
try:
|
407 |
+
async with aiohttp.ClientSession() as session:
|
408 |
+
async with session.post(url, headers=headers, data=data) as response:
|
409 |
+
if response.status == 200:
|
410 |
+
async with aiofiles.open(file_path, "wb") as f:
|
411 |
+
await f.write(await response.read())
|
412 |
+
return FileResponse(file_path)
|
413 |
+
else:
|
414 |
+
error_msg = f"Error synthesizing speech - {response.reason}"
|
415 |
+
log.error(error_msg)
|
416 |
+
raise HTTPException(status_code=500, detail=error_msg)
|
417 |
+
except Exception as e:
|
418 |
+
log.exception(e)
|
419 |
+
raise HTTPException(status_code=500, detail=str(e))
|
420 |
+
elif app.state.config.TTS_ENGINE == "transformers":
|
421 |
+
payload = None
|
422 |
+
try:
|
423 |
+
payload = json.loads(body.decode("utf-8"))
|
424 |
+
except Exception as e:
|
425 |
+
log.exception(e)
|
426 |
+
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
427 |
+
|
428 |
+
import torch
|
429 |
+
import soundfile as sf
|
430 |
+
|
431 |
+
load_speech_pipeline()
|
432 |
+
|
433 |
+
embeddings_dataset = app.state.speech_speaker_embeddings_dataset
|
434 |
+
|
435 |
+
speaker_index = 6799
|
436 |
+
try:
|
437 |
+
speaker_index = embeddings_dataset["filename"].index(
|
438 |
+
app.state.config.TTS_MODEL
|
439 |
+
)
|
440 |
+
except Exception:
|
441 |
+
pass
|
442 |
+
|
443 |
+
speaker_embedding = torch.tensor(
|
444 |
+
embeddings_dataset[speaker_index]["xvector"]
|
445 |
+
).unsqueeze(0)
|
446 |
+
|
447 |
+
speech = app.state.speech_synthesiser(
|
448 |
+
payload["input"],
|
449 |
+
forward_params={"speaker_embeddings": speaker_embedding},
|
450 |
+
)
|
451 |
+
|
452 |
+
sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
|
453 |
+
with open(file_body_path, "w") as f:
|
454 |
+
json.dump(json.loads(body.decode("utf-8")), f)
|
455 |
+
|
456 |
+
return FileResponse(file_path)
|
457 |
+
|
458 |
+
|
459 |
+
def transcribe(file_path):
|
460 |
+
print("transcribe", file_path)
|
461 |
+
filename = os.path.basename(file_path)
|
462 |
+
file_dir = os.path.dirname(file_path)
|
463 |
+
id = filename.split(".")[0]
|
464 |
+
|
465 |
+
if app.state.config.STT_ENGINE == "":
|
466 |
+
if app.state.faster_whisper_model is None:
|
467 |
+
set_faster_whisper_model(app.state.config.WHISPER_MODEL)
|
468 |
+
|
469 |
+
model = app.state.faster_whisper_model
|
470 |
+
segments, info = model.transcribe(file_path, beam_size=5)
|
471 |
+
log.info(
|
472 |
+
"Detected language '%s' with probability %f"
|
473 |
+
% (info.language, info.language_probability)
|
474 |
+
)
|
475 |
+
|
476 |
+
transcript = "".join([segment.text for segment in list(segments)])
|
477 |
+
data = {"text": transcript.strip()}
|
478 |
+
|
479 |
+
# save the transcript to a json file
|
480 |
+
transcript_file = f"{file_dir}/{id}.json"
|
481 |
+
with open(transcript_file, "w") as f:
|
482 |
+
json.dump(data, f)
|
483 |
+
|
484 |
+
log.debug(data)
|
485 |
+
return data
|
486 |
+
elif app.state.config.STT_ENGINE == "openai":
|
487 |
+
if is_mp4_audio(file_path):
|
488 |
+
print("is_mp4_audio")
|
489 |
+
os.rename(file_path, file_path.replace(".wav", ".mp4"))
|
490 |
+
# Convert MP4 audio file to WAV format
|
491 |
+
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
|
492 |
+
|
493 |
+
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
|
494 |
+
|
495 |
+
files = {"file": (filename, open(file_path, "rb"))}
|
496 |
+
data = {"model": app.state.config.STT_MODEL}
|
497 |
+
|
498 |
+
log.debug(files, data)
|
499 |
+
|
500 |
+
r = None
|
501 |
+
try:
|
502 |
+
r = requests.post(
|
503 |
+
url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
|
504 |
+
headers=headers,
|
505 |
+
files=files,
|
506 |
+
data=data,
|
507 |
+
)
|
508 |
+
|
509 |
+
r.raise_for_status()
|
510 |
+
|
511 |
+
data = r.json()
|
512 |
+
|
513 |
+
# save the transcript to a json file
|
514 |
+
transcript_file = f"{file_dir}/{id}.json"
|
515 |
+
with open(transcript_file, "w") as f:
|
516 |
+
json.dump(data, f)
|
517 |
+
|
518 |
+
print(data)
|
519 |
+
return data
|
520 |
+
except Exception as e:
|
521 |
+
log.exception(e)
|
522 |
+
error_detail = "Open WebUI: Server Connection Error"
|
523 |
+
if r is not None:
|
524 |
+
try:
|
525 |
+
res = r.json()
|
526 |
+
if "error" in res:
|
527 |
+
error_detail = f"External: {res['error']['message']}"
|
528 |
+
except Exception:
|
529 |
+
error_detail = f"External: {e}"
|
530 |
+
|
531 |
+
raise Exception(error_detail)
|
532 |
+
|
533 |
+
|
534 |
+
@app.post("/transcriptions")
|
535 |
+
def transcription(
|
536 |
+
file: UploadFile = File(...),
|
537 |
+
user=Depends(get_verified_user),
|
538 |
+
):
|
539 |
+
log.info(f"file.content_type: {file.content_type}")
|
540 |
+
|
541 |
+
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
|
542 |
+
raise HTTPException(
|
543 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
544 |
+
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
545 |
+
)
|
546 |
+
|
547 |
+
try:
|
548 |
+
ext = file.filename.split(".")[-1]
|
549 |
+
id = uuid.uuid4()
|
550 |
+
|
551 |
+
filename = f"{id}.{ext}"
|
552 |
+
contents = file.file.read()
|
553 |
+
|
554 |
+
file_dir = f"{CACHE_DIR}/audio/transcriptions"
|
555 |
+
os.makedirs(file_dir, exist_ok=True)
|
556 |
+
file_path = f"{file_dir}/{filename}"
|
557 |
+
|
558 |
+
with open(file_path, "wb") as f:
|
559 |
+
f.write(contents)
|
560 |
+
|
561 |
+
try:
|
562 |
+
if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB
|
563 |
+
log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
|
564 |
+
audio = AudioSegment.from_file(file_path)
|
565 |
+
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
|
566 |
+
compressed_path = f"{file_dir}/{id}_compressed.opus"
|
567 |
+
audio.export(compressed_path, format="opus", bitrate="32k")
|
568 |
+
log.debug(f"Compressed audio to {compressed_path}")
|
569 |
+
file_path = compressed_path
|
570 |
+
|
571 |
+
if (
|
572 |
+
os.path.getsize(file_path) > MAX_FILE_SIZE
|
573 |
+
): # Still larger than 25MB after compression
|
574 |
+
log.debug(
|
575 |
+
f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
|
576 |
+
)
|
577 |
+
raise HTTPException(
|
578 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
579 |
+
detail=ERROR_MESSAGES.FILE_TOO_LARGE(
|
580 |
+
size=f"{MAX_FILE_SIZE_MB}MB"
|
581 |
+
),
|
582 |
+
)
|
583 |
+
|
584 |
+
data = transcribe(file_path)
|
585 |
+
else:
|
586 |
+
data = transcribe(file_path)
|
587 |
+
|
588 |
+
file_path = file_path.split("/")[-1]
|
589 |
+
return {**data, "filename": file_path}
|
590 |
+
except Exception as e:
|
591 |
+
log.exception(e)
|
592 |
+
raise HTTPException(
|
593 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
594 |
+
detail=ERROR_MESSAGES.DEFAULT(e),
|
595 |
+
)
|
596 |
+
|
597 |
+
except Exception as e:
|
598 |
+
log.exception(e)
|
599 |
+
|
600 |
+
raise HTTPException(
|
601 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
602 |
+
detail=ERROR_MESSAGES.DEFAULT(e),
|
603 |
+
)
|
604 |
+
|
605 |
+
|
606 |
+
def get_available_models() -> list[dict]:
|
607 |
+
if app.state.config.TTS_ENGINE == "openai":
|
608 |
+
return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
609 |
+
elif app.state.config.TTS_ENGINE == "elevenlabs":
|
610 |
+
headers = {
|
611 |
+
"xi-api-key": app.state.config.TTS_API_KEY,
|
612 |
+
"Content-Type": "application/json",
|
613 |
+
}
|
614 |
+
|
615 |
+
try:
|
616 |
+
response = requests.get(
|
617 |
+
"https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
|
618 |
+
)
|
619 |
+
response.raise_for_status()
|
620 |
+
models = response.json()
|
621 |
+
return [
|
622 |
+
{"name": model["name"], "id": model["model_id"]} for model in models
|
623 |
+
]
|
624 |
+
except requests.RequestException as e:
|
625 |
+
log.error(f"Error fetching voices: {str(e)}")
|
626 |
+
return []
|
627 |
+
|
628 |
+
|
629 |
+
@app.get("/models")
|
630 |
+
async def get_models(user=Depends(get_verified_user)):
|
631 |
+
return {"models": get_available_models()}
|
632 |
+
|
633 |
+
|
634 |
+
def get_available_voices() -> dict:
|
635 |
+
"""Returns {voice_id: voice_name} dict"""
|
636 |
+
ret = {}
|
637 |
+
if app.state.config.TTS_ENGINE == "openai":
|
638 |
+
ret = {
|
639 |
+
"alloy": "alloy",
|
640 |
+
"echo": "echo",
|
641 |
+
"fable": "fable",
|
642 |
+
"onyx": "onyx",
|
643 |
+
"nova": "nova",
|
644 |
+
"shimmer": "shimmer",
|
645 |
+
}
|
646 |
+
elif app.state.config.TTS_ENGINE == "elevenlabs":
|
647 |
+
try:
|
648 |
+
ret = get_elevenlabs_voices()
|
649 |
+
except Exception:
|
650 |
+
# Avoided @lru_cache with exception
|
651 |
+
pass
|
652 |
+
elif app.state.config.TTS_ENGINE == "azure":
|
653 |
+
try:
|
654 |
+
region = app.state.config.TTS_AZURE_SPEECH_REGION
|
655 |
+
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
|
656 |
+
headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
|
657 |
+
|
658 |
+
response = requests.get(url, headers=headers)
|
659 |
+
response.raise_for_status()
|
660 |
+
voices = response.json()
|
661 |
+
for voice in voices:
|
662 |
+
ret[voice["ShortName"]] = (
|
663 |
+
f"{voice['DisplayName']} ({voice['ShortName']})"
|
664 |
+
)
|
665 |
+
except requests.RequestException as e:
|
666 |
+
log.error(f"Error fetching voices: {str(e)}")
|
667 |
+
|
668 |
+
return ret
|
669 |
+
|
670 |
+
|
671 |
+
@lru_cache
|
672 |
+
def get_elevenlabs_voices() -> dict:
|
673 |
+
"""
|
674 |
+
Note, set the following in your .env file to use Elevenlabs:
|
675 |
+
AUDIO_TTS_ENGINE=elevenlabs
|
676 |
+
AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
|
677 |
+
AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
|
678 |
+
AUDIO_TTS_MODEL=eleven_multilingual_v2
|
679 |
+
"""
|
680 |
+
headers = {
|
681 |
+
"xi-api-key": app.state.config.TTS_API_KEY,
|
682 |
+
"Content-Type": "application/json",
|
683 |
+
}
|
684 |
+
try:
|
685 |
+
# TODO: Add retries
|
686 |
+
response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
|
687 |
+
response.raise_for_status()
|
688 |
+
voices_data = response.json()
|
689 |
+
|
690 |
+
voices = {}
|
691 |
+
for voice in voices_data.get("voices", []):
|
692 |
+
voices[voice["voice_id"]] = voice["name"]
|
693 |
+
except requests.RequestException as e:
|
694 |
+
# Avoid @lru_cache with exception
|
695 |
+
log.error(f"Error fetching voices: {str(e)}")
|
696 |
+
raise RuntimeError(f"Error fetching voices: {str(e)}")
|
697 |
+
|
698 |
+
return voices
|
699 |
+
|
700 |
+
|
701 |
+
@app.get("/voices")
|
702 |
+
async def get_voices(user=Depends(get_verified_user)):
|
703 |
+
return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}
|
backend/open_webui/apps/images/main.py
ADDED
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import base64
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import mimetypes
|
6 |
+
import re
|
7 |
+
import uuid
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import requests
|
12 |
+
from open_webui.apps.images.utils.comfyui import (
|
13 |
+
ComfyUIGenerateImageForm,
|
14 |
+
ComfyUIWorkflow,
|
15 |
+
comfyui_generate_image,
|
16 |
+
)
|
17 |
+
from open_webui.config import (
|
18 |
+
AUTOMATIC1111_API_AUTH,
|
19 |
+
AUTOMATIC1111_BASE_URL,
|
20 |
+
AUTOMATIC1111_CFG_SCALE,
|
21 |
+
AUTOMATIC1111_SAMPLER,
|
22 |
+
AUTOMATIC1111_SCHEDULER,
|
23 |
+
CACHE_DIR,
|
24 |
+
COMFYUI_BASE_URL,
|
25 |
+
COMFYUI_WORKFLOW,
|
26 |
+
COMFYUI_WORKFLOW_NODES,
|
27 |
+
CORS_ALLOW_ORIGIN,
|
28 |
+
ENABLE_IMAGE_GENERATION,
|
29 |
+
IMAGE_GENERATION_ENGINE,
|
30 |
+
IMAGE_GENERATION_MODEL,
|
31 |
+
IMAGE_SIZE,
|
32 |
+
IMAGE_STEPS,
|
33 |
+
IMAGES_OPENAI_API_BASE_URL,
|
34 |
+
IMAGES_OPENAI_API_KEY,
|
35 |
+
AppConfig,
|
36 |
+
)
|
37 |
+
from open_webui.constants import ERROR_MESSAGES
|
38 |
+
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
|
39 |
+
|
40 |
+
from fastapi import Depends, FastAPI, HTTPException, Request
|
41 |
+
from fastapi.middleware.cors import CORSMiddleware
|
42 |
+
from pydantic import BaseModel
|
43 |
+
from open_webui.utils.utils import get_admin_user, get_verified_user
|
44 |
+
|
45 |
+
log = logging.getLogger(__name__)
|
46 |
+
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
47 |
+
|
48 |
+
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
|
49 |
+
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
50 |
+
|
51 |
+
app = FastAPI(
|
52 |
+
docs_url="/docs" if ENV == "dev" else None,
|
53 |
+
openapi_url="/openapi.json" if ENV == "dev" else None,
|
54 |
+
redoc_url=None,
|
55 |
+
)
|
56 |
+
|
57 |
+
app.add_middleware(
|
58 |
+
CORSMiddleware,
|
59 |
+
allow_origins=CORS_ALLOW_ORIGIN,
|
60 |
+
allow_credentials=True,
|
61 |
+
allow_methods=["*"],
|
62 |
+
allow_headers=["*"],
|
63 |
+
)
|
64 |
+
|
65 |
+
app.state.config = AppConfig()
|
66 |
+
|
67 |
+
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
|
68 |
+
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
|
69 |
+
|
70 |
+
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
71 |
+
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
72 |
+
|
73 |
+
app.state.config.MODEL = IMAGE_GENERATION_MODEL
|
74 |
+
|
75 |
+
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
76 |
+
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
|
77 |
+
app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
|
78 |
+
app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
|
79 |
+
app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
|
80 |
+
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
81 |
+
app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
|
82 |
+
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
|
83 |
+
|
84 |
+
app.state.config.IMAGE_SIZE = IMAGE_SIZE
|
85 |
+
app.state.config.IMAGE_STEPS = IMAGE_STEPS
|
86 |
+
|
87 |
+
|
88 |
+
@app.get("/config")
|
89 |
+
async def get_config(request: Request, user=Depends(get_admin_user)):
|
90 |
+
return {
|
91 |
+
"enabled": app.state.config.ENABLED,
|
92 |
+
"engine": app.state.config.ENGINE,
|
93 |
+
"openai": {
|
94 |
+
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
|
95 |
+
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
|
96 |
+
},
|
97 |
+
"automatic1111": {
|
98 |
+
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
|
99 |
+
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
|
100 |
+
"AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
|
101 |
+
"AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
|
102 |
+
"AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
|
103 |
+
},
|
104 |
+
"comfyui": {
|
105 |
+
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
|
106 |
+
"COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
|
107 |
+
"COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
|
108 |
+
},
|
109 |
+
}
|
110 |
+
|
111 |
+
|
112 |
+
class OpenAIConfigForm(BaseModel):
|
113 |
+
OPENAI_API_BASE_URL: str
|
114 |
+
OPENAI_API_KEY: str
|
115 |
+
|
116 |
+
|
117 |
+
class Automatic1111ConfigForm(BaseModel):
|
118 |
+
AUTOMATIC1111_BASE_URL: str
|
119 |
+
AUTOMATIC1111_API_AUTH: str
|
120 |
+
AUTOMATIC1111_CFG_SCALE: Optional[str]
|
121 |
+
AUTOMATIC1111_SAMPLER: Optional[str]
|
122 |
+
AUTOMATIC1111_SCHEDULER: Optional[str]
|
123 |
+
|
124 |
+
|
125 |
+
class ComfyUIConfigForm(BaseModel):
|
126 |
+
COMFYUI_BASE_URL: str
|
127 |
+
COMFYUI_WORKFLOW: str
|
128 |
+
COMFYUI_WORKFLOW_NODES: list[dict]
|
129 |
+
|
130 |
+
|
131 |
+
class ConfigForm(BaseModel):
|
132 |
+
enabled: bool
|
133 |
+
engine: str
|
134 |
+
openai: OpenAIConfigForm
|
135 |
+
automatic1111: Automatic1111ConfigForm
|
136 |
+
comfyui: ComfyUIConfigForm
|
137 |
+
|
138 |
+
|
139 |
+
@app.post("/config/update")
|
140 |
+
async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
|
141 |
+
app.state.config.ENGINE = form_data.engine
|
142 |
+
app.state.config.ENABLED = form_data.enabled
|
143 |
+
|
144 |
+
app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
|
145 |
+
app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
146 |
+
|
147 |
+
app.state.config.AUTOMATIC1111_BASE_URL = (
|
148 |
+
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
149 |
+
)
|
150 |
+
app.state.config.AUTOMATIC1111_API_AUTH = (
|
151 |
+
form_data.automatic1111.AUTOMATIC1111_API_AUTH
|
152 |
+
)
|
153 |
+
|
154 |
+
app.state.config.AUTOMATIC1111_CFG_SCALE = (
|
155 |
+
float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
|
156 |
+
if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
|
157 |
+
else None
|
158 |
+
)
|
159 |
+
app.state.config.AUTOMATIC1111_SAMPLER = (
|
160 |
+
form_data.automatic1111.AUTOMATIC1111_SAMPLER
|
161 |
+
if form_data.automatic1111.AUTOMATIC1111_SAMPLER
|
162 |
+
else None
|
163 |
+
)
|
164 |
+
app.state.config.AUTOMATIC1111_SCHEDULER = (
|
165 |
+
form_data.automatic1111.AUTOMATIC1111_SCHEDULER
|
166 |
+
if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
|
167 |
+
else None
|
168 |
+
)
|
169 |
+
|
170 |
+
app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
171 |
+
app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
172 |
+
app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
|
173 |
+
|
174 |
+
return {
|
175 |
+
"enabled": app.state.config.ENABLED,
|
176 |
+
"engine": app.state.config.ENGINE,
|
177 |
+
"openai": {
|
178 |
+
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
|
179 |
+
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
|
180 |
+
},
|
181 |
+
"automatic1111": {
|
182 |
+
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
|
183 |
+
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
|
184 |
+
"AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
|
185 |
+
"AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
|
186 |
+
"AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
|
187 |
+
},
|
188 |
+
"comfyui": {
|
189 |
+
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
|
190 |
+
"COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
|
191 |
+
"COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
|
192 |
+
},
|
193 |
+
}
|
194 |
+
|
195 |
+
|
196 |
+
def get_automatic1111_api_auth():
|
197 |
+
if app.state.config.AUTOMATIC1111_API_AUTH is None:
|
198 |
+
return ""
|
199 |
+
else:
|
200 |
+
auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
|
201 |
+
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
|
202 |
+
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
|
203 |
+
return f"Basic {auth1111_base64_encoded_string}"
|
204 |
+
|
205 |
+
|
206 |
+
@app.get("/config/url/verify")
|
207 |
+
async def verify_url(user=Depends(get_admin_user)):
|
208 |
+
if app.state.config.ENGINE == "automatic1111":
|
209 |
+
try:
|
210 |
+
r = requests.get(
|
211 |
+
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
212 |
+
headers={"authorization": get_automatic1111_api_auth()},
|
213 |
+
)
|
214 |
+
r.raise_for_status()
|
215 |
+
return True
|
216 |
+
except Exception:
|
217 |
+
app.state.config.ENABLED = False
|
218 |
+
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
219 |
+
elif app.state.config.ENGINE == "comfyui":
|
220 |
+
try:
|
221 |
+
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
|
222 |
+
r.raise_for_status()
|
223 |
+
return True
|
224 |
+
except Exception:
|
225 |
+
app.state.config.ENABLED = False
|
226 |
+
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
227 |
+
else:
|
228 |
+
return True
|
229 |
+
|
230 |
+
|
231 |
+
def set_image_model(model: str):
|
232 |
+
log.info(f"Setting image model to {model}")
|
233 |
+
app.state.config.MODEL = model
|
234 |
+
if app.state.config.ENGINE in ["", "automatic1111"]:
|
235 |
+
api_auth = get_automatic1111_api_auth()
|
236 |
+
r = requests.get(
|
237 |
+
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
238 |
+
headers={"authorization": api_auth},
|
239 |
+
)
|
240 |
+
options = r.json()
|
241 |
+
if model != options["sd_model_checkpoint"]:
|
242 |
+
options["sd_model_checkpoint"] = model
|
243 |
+
r = requests.post(
|
244 |
+
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
245 |
+
json=options,
|
246 |
+
headers={"authorization": api_auth},
|
247 |
+
)
|
248 |
+
return app.state.config.MODEL
|
249 |
+
|
250 |
+
|
251 |
+
def get_image_model():
|
252 |
+
if app.state.config.ENGINE == "openai":
|
253 |
+
return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
|
254 |
+
elif app.state.config.ENGINE == "comfyui":
|
255 |
+
return app.state.config.MODEL if app.state.config.MODEL else ""
|
256 |
+
elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
|
257 |
+
try:
|
258 |
+
r = requests.get(
|
259 |
+
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
260 |
+
headers={"authorization": get_automatic1111_api_auth()},
|
261 |
+
)
|
262 |
+
options = r.json()
|
263 |
+
return options["sd_model_checkpoint"]
|
264 |
+
except Exception as e:
|
265 |
+
app.state.config.ENABLED = False
|
266 |
+
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
267 |
+
|
268 |
+
|
269 |
+
class ImageConfigForm(BaseModel):
|
270 |
+
MODEL: str
|
271 |
+
IMAGE_SIZE: str
|
272 |
+
IMAGE_STEPS: int
|
273 |
+
|
274 |
+
|
275 |
+
@app.get("/image/config")
|
276 |
+
async def get_image_config(user=Depends(get_admin_user)):
|
277 |
+
return {
|
278 |
+
"MODEL": app.state.config.MODEL,
|
279 |
+
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
|
280 |
+
"IMAGE_STEPS": app.state.config.IMAGE_STEPS,
|
281 |
+
}
|
282 |
+
|
283 |
+
|
284 |
+
@app.post("/image/config/update")
|
285 |
+
async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
|
286 |
+
|
287 |
+
set_image_model(form_data.MODEL)
|
288 |
+
|
289 |
+
pattern = r"^\d+x\d+$"
|
290 |
+
if re.match(pattern, form_data.IMAGE_SIZE):
|
291 |
+
app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
|
292 |
+
else:
|
293 |
+
raise HTTPException(
|
294 |
+
status_code=400,
|
295 |
+
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
|
296 |
+
)
|
297 |
+
|
298 |
+
if form_data.IMAGE_STEPS >= 0:
|
299 |
+
app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
|
300 |
+
else:
|
301 |
+
raise HTTPException(
|
302 |
+
status_code=400,
|
303 |
+
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
|
304 |
+
)
|
305 |
+
|
306 |
+
return {
|
307 |
+
"MODEL": app.state.config.MODEL,
|
308 |
+
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
|
309 |
+
"IMAGE_STEPS": app.state.config.IMAGE_STEPS,
|
310 |
+
}
|
311 |
+
|
312 |
+
|
313 |
+
@app.get("/models")
|
314 |
+
def get_models(user=Depends(get_verified_user)):
|
315 |
+
try:
|
316 |
+
if app.state.config.ENGINE == "openai":
|
317 |
+
return [
|
318 |
+
{"id": "dall-e-2", "name": "DALL·E 2"},
|
319 |
+
{"id": "dall-e-3", "name": "DALL·E 3"},
|
320 |
+
]
|
321 |
+
elif app.state.config.ENGINE == "comfyui":
|
322 |
+
# TODO - get models from comfyui
|
323 |
+
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
|
324 |
+
info = r.json()
|
325 |
+
|
326 |
+
workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
|
327 |
+
model_node_id = None
|
328 |
+
|
329 |
+
for node in app.state.config.COMFYUI_WORKFLOW_NODES:
|
330 |
+
if node["type"] == "model":
|
331 |
+
if node["node_ids"]:
|
332 |
+
model_node_id = node["node_ids"][0]
|
333 |
+
break
|
334 |
+
|
335 |
+
if model_node_id:
|
336 |
+
model_list_key = None
|
337 |
+
|
338 |
+
print(workflow[model_node_id]["class_type"])
|
339 |
+
for key in info[workflow[model_node_id]["class_type"]]["input"][
|
340 |
+
"required"
|
341 |
+
]:
|
342 |
+
if "_name" in key:
|
343 |
+
model_list_key = key
|
344 |
+
break
|
345 |
+
|
346 |
+
if model_list_key:
|
347 |
+
return list(
|
348 |
+
map(
|
349 |
+
lambda model: {"id": model, "name": model},
|
350 |
+
info[workflow[model_node_id]["class_type"]]["input"][
|
351 |
+
"required"
|
352 |
+
][model_list_key][0],
|
353 |
+
)
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
return list(
|
357 |
+
map(
|
358 |
+
lambda model: {"id": model, "name": model},
|
359 |
+
info["CheckpointLoaderSimple"]["input"]["required"][
|
360 |
+
"ckpt_name"
|
361 |
+
][0],
|
362 |
+
)
|
363 |
+
)
|
364 |
+
elif (
|
365 |
+
app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
|
366 |
+
):
|
367 |
+
r = requests.get(
|
368 |
+
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
|
369 |
+
headers={"authorization": get_automatic1111_api_auth()},
|
370 |
+
)
|
371 |
+
models = r.json()
|
372 |
+
return list(
|
373 |
+
map(
|
374 |
+
lambda model: {"id": model["title"], "name": model["model_name"]},
|
375 |
+
models,
|
376 |
+
)
|
377 |
+
)
|
378 |
+
except Exception as e:
|
379 |
+
app.state.config.ENABLED = False
|
380 |
+
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
381 |
+
|
382 |
+
|
383 |
+
class GenerateImageForm(BaseModel):
|
384 |
+
model: Optional[str] = None
|
385 |
+
prompt: str
|
386 |
+
size: Optional[str] = None
|
387 |
+
n: int = 1
|
388 |
+
negative_prompt: Optional[str] = None
|
389 |
+
|
390 |
+
|
391 |
+
def save_b64_image(b64_str):
|
392 |
+
try:
|
393 |
+
image_id = str(uuid.uuid4())
|
394 |
+
|
395 |
+
if "," in b64_str:
|
396 |
+
header, encoded = b64_str.split(",", 1)
|
397 |
+
mime_type = header.split(";")[0]
|
398 |
+
|
399 |
+
img_data = base64.b64decode(encoded)
|
400 |
+
image_format = mimetypes.guess_extension(mime_type)
|
401 |
+
|
402 |
+
image_filename = f"{image_id}{image_format}"
|
403 |
+
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
|
404 |
+
with open(file_path, "wb") as f:
|
405 |
+
f.write(img_data)
|
406 |
+
return image_filename
|
407 |
+
else:
|
408 |
+
image_filename = f"{image_id}.png"
|
409 |
+
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
|
410 |
+
|
411 |
+
img_data = base64.b64decode(b64_str)
|
412 |
+
|
413 |
+
# Write the image data to a file
|
414 |
+
with open(file_path, "wb") as f:
|
415 |
+
f.write(img_data)
|
416 |
+
return image_filename
|
417 |
+
|
418 |
+
except Exception as e:
|
419 |
+
log.exception(f"Error saving image: {e}")
|
420 |
+
return None
|
421 |
+
|
422 |
+
|
423 |
+
def save_url_image(url):
|
424 |
+
image_id = str(uuid.uuid4())
|
425 |
+
try:
|
426 |
+
r = requests.get(url)
|
427 |
+
r.raise_for_status()
|
428 |
+
if r.headers["content-type"].split("/")[0] == "image":
|
429 |
+
mime_type = r.headers["content-type"]
|
430 |
+
image_format = mimetypes.guess_extension(mime_type)
|
431 |
+
|
432 |
+
if not image_format:
|
433 |
+
raise ValueError("Could not determine image type from MIME type")
|
434 |
+
|
435 |
+
image_filename = f"{image_id}{image_format}"
|
436 |
+
|
437 |
+
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
|
438 |
+
with open(file_path, "wb") as image_file:
|
439 |
+
for chunk in r.iter_content(chunk_size=8192):
|
440 |
+
image_file.write(chunk)
|
441 |
+
return image_filename
|
442 |
+
else:
|
443 |
+
log.error("Url does not point to an image.")
|
444 |
+
return None
|
445 |
+
|
446 |
+
except Exception as e:
|
447 |
+
log.exception(f"Error saving image: {e}")
|
448 |
+
return None
|
449 |
+
|
450 |
+
|
451 |
+
@app.post("/generations")
|
452 |
+
async def image_generations(
|
453 |
+
form_data: GenerateImageForm,
|
454 |
+
user=Depends(get_verified_user),
|
455 |
+
):
|
456 |
+
width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
|
457 |
+
|
458 |
+
r = None
|
459 |
+
try:
|
460 |
+
if app.state.config.ENGINE == "openai":
|
461 |
+
headers = {}
|
462 |
+
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
|
463 |
+
headers["Content-Type"] = "application/json"
|
464 |
+
|
465 |
+
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
466 |
+
headers["X-OpenWebUI-User-Name"] = user.name
|
467 |
+
headers["X-OpenWebUI-User-Id"] = user.id
|
468 |
+
headers["X-OpenWebUI-User-Email"] = user.email
|
469 |
+
headers["X-OpenWebUI-User-Role"] = user.role
|
470 |
+
|
471 |
+
data = {
|
472 |
+
"model": (
|
473 |
+
app.state.config.MODEL
|
474 |
+
if app.state.config.MODEL != ""
|
475 |
+
else "dall-e-2"
|
476 |
+
),
|
477 |
+
"prompt": form_data.prompt,
|
478 |
+
"n": form_data.n,
|
479 |
+
"size": (
|
480 |
+
form_data.size if form_data.size else app.state.config.IMAGE_SIZE
|
481 |
+
),
|
482 |
+
"response_format": "b64_json",
|
483 |
+
}
|
484 |
+
|
485 |
+
# Use asyncio.to_thread for the requests.post call
|
486 |
+
r = await asyncio.to_thread(
|
487 |
+
requests.post,
|
488 |
+
url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
|
489 |
+
json=data,
|
490 |
+
headers=headers,
|
491 |
+
)
|
492 |
+
|
493 |
+
r.raise_for_status()
|
494 |
+
res = r.json()
|
495 |
+
|
496 |
+
images = []
|
497 |
+
|
498 |
+
for image in res["data"]:
|
499 |
+
image_filename = save_b64_image(image["b64_json"])
|
500 |
+
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
501 |
+
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
502 |
+
|
503 |
+
with open(file_body_path, "w") as f:
|
504 |
+
json.dump(data, f)
|
505 |
+
|
506 |
+
return images
|
507 |
+
|
508 |
+
elif app.state.config.ENGINE == "comfyui":
|
509 |
+
data = {
|
510 |
+
"prompt": form_data.prompt,
|
511 |
+
"width": width,
|
512 |
+
"height": height,
|
513 |
+
"n": form_data.n,
|
514 |
+
}
|
515 |
+
|
516 |
+
if app.state.config.IMAGE_STEPS is not None:
|
517 |
+
data["steps"] = app.state.config.IMAGE_STEPS
|
518 |
+
|
519 |
+
if form_data.negative_prompt is not None:
|
520 |
+
data["negative_prompt"] = form_data.negative_prompt
|
521 |
+
|
522 |
+
form_data = ComfyUIGenerateImageForm(
|
523 |
+
**{
|
524 |
+
"workflow": ComfyUIWorkflow(
|
525 |
+
**{
|
526 |
+
"workflow": app.state.config.COMFYUI_WORKFLOW,
|
527 |
+
"nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
|
528 |
+
}
|
529 |
+
),
|
530 |
+
**data,
|
531 |
+
}
|
532 |
+
)
|
533 |
+
res = await comfyui_generate_image(
|
534 |
+
app.state.config.MODEL,
|
535 |
+
form_data,
|
536 |
+
user.id,
|
537 |
+
app.state.config.COMFYUI_BASE_URL,
|
538 |
+
)
|
539 |
+
log.debug(f"res: {res}")
|
540 |
+
|
541 |
+
images = []
|
542 |
+
|
543 |
+
for image in res["data"]:
|
544 |
+
image_filename = save_url_image(image["url"])
|
545 |
+
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
546 |
+
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
547 |
+
|
548 |
+
with open(file_body_path, "w") as f:
|
549 |
+
json.dump(form_data.model_dump(exclude_none=True), f)
|
550 |
+
|
551 |
+
log.debug(f"images: {images}")
|
552 |
+
return images
|
553 |
+
elif (
|
554 |
+
app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
|
555 |
+
):
|
556 |
+
if form_data.model:
|
557 |
+
set_image_model(form_data.model)
|
558 |
+
|
559 |
+
data = {
|
560 |
+
"prompt": form_data.prompt,
|
561 |
+
"batch_size": form_data.n,
|
562 |
+
"width": width,
|
563 |
+
"height": height,
|
564 |
+
}
|
565 |
+
|
566 |
+
if app.state.config.IMAGE_STEPS is not None:
|
567 |
+
data["steps"] = app.state.config.IMAGE_STEPS
|
568 |
+
|
569 |
+
if form_data.negative_prompt is not None:
|
570 |
+
data["negative_prompt"] = form_data.negative_prompt
|
571 |
+
|
572 |
+
if app.state.config.AUTOMATIC1111_CFG_SCALE:
|
573 |
+
data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
|
574 |
+
|
575 |
+
if app.state.config.AUTOMATIC1111_SAMPLER:
|
576 |
+
data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
|
577 |
+
|
578 |
+
if app.state.config.AUTOMATIC1111_SCHEDULER:
|
579 |
+
data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
|
580 |
+
|
581 |
+
# Use asyncio.to_thread for the requests.post call
|
582 |
+
r = await asyncio.to_thread(
|
583 |
+
requests.post,
|
584 |
+
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
|
585 |
+
json=data,
|
586 |
+
headers={"authorization": get_automatic1111_api_auth()},
|
587 |
+
)
|
588 |
+
|
589 |
+
res = r.json()
|
590 |
+
log.debug(f"res: {res}")
|
591 |
+
|
592 |
+
images = []
|
593 |
+
|
594 |
+
for image in res["images"]:
|
595 |
+
image_filename = save_b64_image(image)
|
596 |
+
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
597 |
+
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
598 |
+
|
599 |
+
with open(file_body_path, "w") as f:
|
600 |
+
json.dump({**data, "info": res["info"]}, f)
|
601 |
+
|
602 |
+
return images
|
603 |
+
except Exception as e:
|
604 |
+
error = e
|
605 |
+
if r != None:
|
606 |
+
data = r.json()
|
607 |
+
if "error" in data:
|
608 |
+
error = data["error"]["message"]
|
609 |
+
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
|
backend/open_webui/apps/images/utils/comfyui.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import random
|
5 |
+
import urllib.parse
|
6 |
+
import urllib.request
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
10 |
+
from open_webui.env import SRC_LOG_LEVELS
|
11 |
+
from pydantic import BaseModel
|
12 |
+
|
13 |
+
log = logging.getLogger(__name__)
|
14 |
+
log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
|
15 |
+
|
16 |
+
default_headers = {"User-Agent": "Mozilla/5.0"}
|
17 |
+
|
18 |
+
|
19 |
+
def queue_prompt(prompt, client_id, base_url):
|
20 |
+
log.info("queue_prompt")
|
21 |
+
p = {"prompt": prompt, "client_id": client_id}
|
22 |
+
data = json.dumps(p).encode("utf-8")
|
23 |
+
log.debug(f"queue_prompt data: {data}")
|
24 |
+
try:
|
25 |
+
req = urllib.request.Request(
|
26 |
+
f"{base_url}/prompt", data=data, headers=default_headers
|
27 |
+
)
|
28 |
+
response = urllib.request.urlopen(req).read()
|
29 |
+
return json.loads(response)
|
30 |
+
except Exception as e:
|
31 |
+
log.exception(f"Error while queuing prompt: {e}")
|
32 |
+
raise e
|
33 |
+
|
34 |
+
|
35 |
+
def get_image(filename, subfolder, folder_type, base_url):
|
36 |
+
log.info("get_image")
|
37 |
+
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
38 |
+
url_values = urllib.parse.urlencode(data)
|
39 |
+
req = urllib.request.Request(
|
40 |
+
f"{base_url}/view?{url_values}", headers=default_headers
|
41 |
+
)
|
42 |
+
with urllib.request.urlopen(req) as response:
|
43 |
+
return response.read()
|
44 |
+
|
45 |
+
|
46 |
+
def get_image_url(filename, subfolder, folder_type, base_url):
|
47 |
+
log.info("get_image")
|
48 |
+
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
49 |
+
url_values = urllib.parse.urlencode(data)
|
50 |
+
return f"{base_url}/view?{url_values}"
|
51 |
+
|
52 |
+
|
53 |
+
def get_history(prompt_id, base_url):
|
54 |
+
log.info("get_history")
|
55 |
+
|
56 |
+
req = urllib.request.Request(
|
57 |
+
f"{base_url}/history/{prompt_id}", headers=default_headers
|
58 |
+
)
|
59 |
+
with urllib.request.urlopen(req) as response:
|
60 |
+
return json.loads(response.read())
|
61 |
+
|
62 |
+
|
63 |
+
def get_images(ws, prompt, client_id, base_url):
|
64 |
+
prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
|
65 |
+
output_images = []
|
66 |
+
while True:
|
67 |
+
out = ws.recv()
|
68 |
+
if isinstance(out, str):
|
69 |
+
message = json.loads(out)
|
70 |
+
if message["type"] == "executing":
|
71 |
+
data = message["data"]
|
72 |
+
if data["node"] is None and data["prompt_id"] == prompt_id:
|
73 |
+
break # Execution is done
|
74 |
+
else:
|
75 |
+
continue # previews are binary data
|
76 |
+
|
77 |
+
history = get_history(prompt_id, base_url)[prompt_id]
|
78 |
+
for o in history["outputs"]:
|
79 |
+
for node_id in history["outputs"]:
|
80 |
+
node_output = history["outputs"][node_id]
|
81 |
+
if "images" in node_output:
|
82 |
+
for image in node_output["images"]:
|
83 |
+
url = get_image_url(
|
84 |
+
image["filename"], image["subfolder"], image["type"], base_url
|
85 |
+
)
|
86 |
+
output_images.append({"url": url})
|
87 |
+
return {"data": output_images}
|
88 |
+
|
89 |
+
|
90 |
+
class ComfyUINodeInput(BaseModel):
|
91 |
+
type: Optional[str] = None
|
92 |
+
node_ids: list[str] = []
|
93 |
+
key: Optional[str] = "text"
|
94 |
+
value: Optional[str] = None
|
95 |
+
|
96 |
+
|
97 |
+
class ComfyUIWorkflow(BaseModel):
|
98 |
+
workflow: str
|
99 |
+
nodes: list[ComfyUINodeInput]
|
100 |
+
|
101 |
+
|
102 |
+
class ComfyUIGenerateImageForm(BaseModel):
|
103 |
+
workflow: ComfyUIWorkflow
|
104 |
+
|
105 |
+
prompt: str
|
106 |
+
negative_prompt: Optional[str] = None
|
107 |
+
width: int
|
108 |
+
height: int
|
109 |
+
n: int = 1
|
110 |
+
|
111 |
+
steps: Optional[int] = None
|
112 |
+
seed: Optional[int] = None
|
113 |
+
|
114 |
+
|
115 |
+
async def comfyui_generate_image(
|
116 |
+
model: str, payload: ComfyUIGenerateImageForm, client_id, base_url
|
117 |
+
):
|
118 |
+
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
|
119 |
+
workflow = json.loads(payload.workflow.workflow)
|
120 |
+
|
121 |
+
for node in payload.workflow.nodes:
|
122 |
+
if node.type:
|
123 |
+
if node.type == "model":
|
124 |
+
for node_id in node.node_ids:
|
125 |
+
workflow[node_id]["inputs"][node.key] = model
|
126 |
+
elif node.type == "prompt":
|
127 |
+
for node_id in node.node_ids:
|
128 |
+
workflow[node_id]["inputs"][
|
129 |
+
node.key if node.key else "text"
|
130 |
+
] = payload.prompt
|
131 |
+
elif node.type == "negative_prompt":
|
132 |
+
for node_id in node.node_ids:
|
133 |
+
workflow[node_id]["inputs"][
|
134 |
+
node.key if node.key else "text"
|
135 |
+
] = payload.negative_prompt
|
136 |
+
elif node.type == "width":
|
137 |
+
for node_id in node.node_ids:
|
138 |
+
workflow[node_id]["inputs"][
|
139 |
+
node.key if node.key else "width"
|
140 |
+
] = payload.width
|
141 |
+
elif node.type == "height":
|
142 |
+
for node_id in node.node_ids:
|
143 |
+
workflow[node_id]["inputs"][
|
144 |
+
node.key if node.key else "height"
|
145 |
+
] = payload.height
|
146 |
+
elif node.type == "n":
|
147 |
+
for node_id in node.node_ids:
|
148 |
+
workflow[node_id]["inputs"][
|
149 |
+
node.key if node.key else "batch_size"
|
150 |
+
] = payload.n
|
151 |
+
elif node.type == "steps":
|
152 |
+
for node_id in node.node_ids:
|
153 |
+
workflow[node_id]["inputs"][
|
154 |
+
node.key if node.key else "steps"
|
155 |
+
] = payload.steps
|
156 |
+
elif node.type == "seed":
|
157 |
+
seed = (
|
158 |
+
payload.seed
|
159 |
+
if payload.seed
|
160 |
+
else random.randint(0, 18446744073709551614)
|
161 |
+
)
|
162 |
+
for node_id in node.node_ids:
|
163 |
+
workflow[node_id]["inputs"][node.key] = seed
|
164 |
+
else:
|
165 |
+
for node_id in node.node_ids:
|
166 |
+
workflow[node_id]["inputs"][node.key] = node.value
|
167 |
+
|
168 |
+
try:
|
169 |
+
ws = websocket.WebSocket()
|
170 |
+
ws.connect(f"{ws_url}/ws?clientId={client_id}")
|
171 |
+
log.info("WebSocket connection established.")
|
172 |
+
except Exception as e:
|
173 |
+
log.exception(f"Failed to connect to WebSocket server: {e}")
|
174 |
+
return None
|
175 |
+
|
176 |
+
try:
|
177 |
+
log.info("Sending workflow to WebSocket server.")
|
178 |
+
log.info(f"Workflow: {workflow}")
|
179 |
+
images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url)
|
180 |
+
except Exception as e:
|
181 |
+
log.exception(f"Error while receiving images: {e}")
|
182 |
+
images = None
|
183 |
+
|
184 |
+
ws.close()
|
185 |
+
|
186 |
+
return images
|
backend/open_webui/apps/ollama/main.py
ADDED
@@ -0,0 +1,1351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import re
|
7 |
+
import time
|
8 |
+
from typing import Optional, Union
|
9 |
+
from urllib.parse import urlparse
|
10 |
+
|
11 |
+
import aiohttp
|
12 |
+
from aiocache import cached
|
13 |
+
|
14 |
+
import requests
|
15 |
+
from open_webui.apps.webui.models.models import Models
|
16 |
+
from open_webui.config import (
|
17 |
+
CORS_ALLOW_ORIGIN,
|
18 |
+
ENABLE_OLLAMA_API,
|
19 |
+
OLLAMA_BASE_URLS,
|
20 |
+
OLLAMA_API_CONFIGS,
|
21 |
+
UPLOAD_DIR,
|
22 |
+
AppConfig,
|
23 |
+
)
|
24 |
+
from open_webui.env import (
|
25 |
+
AIOHTTP_CLIENT_TIMEOUT,
|
26 |
+
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
27 |
+
BYPASS_MODEL_ACCESS_CONTROL,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
from open_webui.constants import ERROR_MESSAGES
|
32 |
+
from open_webui.env import ENV, SRC_LOG_LEVELS
|
33 |
+
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
|
34 |
+
from fastapi.middleware.cors import CORSMiddleware
|
35 |
+
from fastapi.responses import StreamingResponse
|
36 |
+
from pydantic import BaseModel, ConfigDict
|
37 |
+
from starlette.background import BackgroundTask
|
38 |
+
|
39 |
+
|
40 |
+
from open_webui.utils.misc import (
|
41 |
+
calculate_sha256,
|
42 |
+
)
|
43 |
+
from open_webui.utils.payload import (
|
44 |
+
apply_model_params_to_body_ollama,
|
45 |
+
apply_model_params_to_body_openai,
|
46 |
+
apply_model_system_prompt_to_body,
|
47 |
+
)
|
48 |
+
from open_webui.utils.utils import get_admin_user, get_verified_user
|
49 |
+
from open_webui.utils.access_control import has_access
|
50 |
+
|
51 |
+
log = logging.getLogger(__name__)
|
52 |
+
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
53 |
+
|
54 |
+
|
55 |
+
app = FastAPI(
|
56 |
+
docs_url="/docs" if ENV == "dev" else None,
|
57 |
+
openapi_url="/openapi.json" if ENV == "dev" else None,
|
58 |
+
redoc_url=None,
|
59 |
+
)
|
60 |
+
|
61 |
+
app.add_middleware(
|
62 |
+
CORSMiddleware,
|
63 |
+
allow_origins=CORS_ALLOW_ORIGIN,
|
64 |
+
allow_credentials=True,
|
65 |
+
allow_methods=["*"],
|
66 |
+
allow_headers=["*"],
|
67 |
+
)
|
68 |
+
|
69 |
+
app.state.config = AppConfig()
|
70 |
+
|
71 |
+
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
72 |
+
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
73 |
+
app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
|
74 |
+
|
75 |
+
|
76 |
+
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
|
77 |
+
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
|
78 |
+
# least connections, or least response time for better resource utilization and performance optimization.
|
79 |
+
|
80 |
+
|
81 |
+
@app.head("/")
|
82 |
+
@app.get("/")
|
83 |
+
async def get_status():
|
84 |
+
return {"status": True}
|
85 |
+
|
86 |
+
|
87 |
+
class ConnectionVerificationForm(BaseModel):
|
88 |
+
url: str
|
89 |
+
key: Optional[str] = None
|
90 |
+
|
91 |
+
|
92 |
+
@app.post("/verify")
|
93 |
+
async def verify_connection(
|
94 |
+
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
|
95 |
+
):
|
96 |
+
url = form_data.url
|
97 |
+
key = form_data.key
|
98 |
+
|
99 |
+
headers = {}
|
100 |
+
if key:
|
101 |
+
headers["Authorization"] = f"Bearer {key}"
|
102 |
+
|
103 |
+
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
104 |
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
105 |
+
try:
|
106 |
+
async with session.get(f"{url}/api/version", headers=headers) as r:
|
107 |
+
if r.status != 200:
|
108 |
+
# Extract response error details if available
|
109 |
+
error_detail = f"HTTP Error: {r.status}"
|
110 |
+
res = await r.json()
|
111 |
+
if "error" in res:
|
112 |
+
error_detail = f"External Error: {res['error']}"
|
113 |
+
raise Exception(error_detail)
|
114 |
+
|
115 |
+
response_data = await r.json()
|
116 |
+
return response_data
|
117 |
+
|
118 |
+
except aiohttp.ClientError as e:
|
119 |
+
# ClientError covers all aiohttp requests issues
|
120 |
+
log.exception(f"Client error: {str(e)}")
|
121 |
+
# Handle aiohttp-specific connection issues, timeout etc.
|
122 |
+
raise HTTPException(
|
123 |
+
status_code=500, detail="Open WebUI: Server Connection Error"
|
124 |
+
)
|
125 |
+
except Exception as e:
|
126 |
+
log.exception(f"Unexpected error: {e}")
|
127 |
+
# Generic error handler in case parsing JSON or other steps fail
|
128 |
+
error_detail = f"Unexpected error: {str(e)}"
|
129 |
+
raise HTTPException(status_code=500, detail=error_detail)
|
130 |
+
|
131 |
+
|
132 |
+
@app.get("/config")
|
133 |
+
async def get_config(user=Depends(get_admin_user)):
|
134 |
+
return {
|
135 |
+
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
|
136 |
+
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
|
137 |
+
"OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
|
138 |
+
}
|
139 |
+
|
140 |
+
|
141 |
+
class OllamaConfigForm(BaseModel):
|
142 |
+
ENABLE_OLLAMA_API: Optional[bool] = None
|
143 |
+
OLLAMA_BASE_URLS: list[str]
|
144 |
+
OLLAMA_API_CONFIGS: dict
|
145 |
+
|
146 |
+
|
147 |
+
@app.post("/config/update")
|
148 |
+
async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
|
149 |
+
app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
|
150 |
+
app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
|
151 |
+
|
152 |
+
app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
|
153 |
+
|
154 |
+
# Remove any extra configs
|
155 |
+
config_urls = app.state.config.OLLAMA_API_CONFIGS.keys()
|
156 |
+
for url in list(app.state.config.OLLAMA_BASE_URLS):
|
157 |
+
if url not in config_urls:
|
158 |
+
app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
|
159 |
+
|
160 |
+
return {
|
161 |
+
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
|
162 |
+
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
|
163 |
+
"OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
|
164 |
+
}
|
165 |
+
|
166 |
+
|
167 |
+
async def aiohttp_get(url, key=None):
|
168 |
+
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
169 |
+
try:
|
170 |
+
headers = {"Authorization": f"Bearer {key}"} if key else {}
|
171 |
+
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
172 |
+
async with session.get(url, headers=headers) as response:
|
173 |
+
return await response.json()
|
174 |
+
except Exception as e:
|
175 |
+
# Handle connection error here
|
176 |
+
log.error(f"Connection error: {e}")
|
177 |
+
return None
|
178 |
+
|
179 |
+
|
180 |
+
async def cleanup_response(
|
181 |
+
response: Optional[aiohttp.ClientResponse],
|
182 |
+
session: Optional[aiohttp.ClientSession],
|
183 |
+
):
|
184 |
+
if response:
|
185 |
+
response.close()
|
186 |
+
if session:
|
187 |
+
await session.close()
|
188 |
+
|
189 |
+
|
190 |
+
async def post_streaming_url(
|
191 |
+
url: str, payload: Union[str, bytes], stream: bool = True, content_type=None
|
192 |
+
):
|
193 |
+
r = None
|
194 |
+
try:
|
195 |
+
session = aiohttp.ClientSession(
|
196 |
+
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
197 |
+
)
|
198 |
+
|
199 |
+
parsed_url = urlparse(url)
|
200 |
+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
201 |
+
|
202 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
203 |
+
key = api_config.get("key", None)
|
204 |
+
|
205 |
+
headers = {"Content-Type": "application/json"}
|
206 |
+
if key:
|
207 |
+
headers["Authorization"] = f"Bearer {key}"
|
208 |
+
|
209 |
+
r = await session.post(
|
210 |
+
url,
|
211 |
+
data=payload,
|
212 |
+
headers=headers,
|
213 |
+
)
|
214 |
+
r.raise_for_status()
|
215 |
+
|
216 |
+
if stream:
|
217 |
+
response_headers = dict(r.headers)
|
218 |
+
if content_type:
|
219 |
+
response_headers["Content-Type"] = content_type
|
220 |
+
return StreamingResponse(
|
221 |
+
r.content,
|
222 |
+
status_code=r.status,
|
223 |
+
headers=response_headers,
|
224 |
+
background=BackgroundTask(
|
225 |
+
cleanup_response, response=r, session=session
|
226 |
+
),
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
res = await r.json()
|
230 |
+
await cleanup_response(r, session)
|
231 |
+
return res
|
232 |
+
|
233 |
+
except Exception as e:
|
234 |
+
error_detail = "Open WebUI: Server Connection Error"
|
235 |
+
if r is not None:
|
236 |
+
try:
|
237 |
+
res = await r.json()
|
238 |
+
if "error" in res:
|
239 |
+
error_detail = f"Ollama: {res['error']}"
|
240 |
+
except Exception:
|
241 |
+
error_detail = f"Ollama: {e}"
|
242 |
+
|
243 |
+
raise HTTPException(
|
244 |
+
status_code=r.status if r else 500,
|
245 |
+
detail=error_detail,
|
246 |
+
)
|
247 |
+
|
248 |
+
|
249 |
+
def merge_models_lists(model_lists):
|
250 |
+
merged_models = {}
|
251 |
+
|
252 |
+
for idx, model_list in enumerate(model_lists):
|
253 |
+
if model_list is not None:
|
254 |
+
for model in model_list:
|
255 |
+
id = model["model"]
|
256 |
+
if id not in merged_models:
|
257 |
+
model["urls"] = [idx]
|
258 |
+
merged_models[id] = model
|
259 |
+
else:
|
260 |
+
merged_models[id]["urls"].append(idx)
|
261 |
+
|
262 |
+
return list(merged_models.values())
|
263 |
+
|
264 |
+
|
265 |
+
@cached(ttl=3)
|
266 |
+
async def get_all_models():
|
267 |
+
log.info("get_all_models()")
|
268 |
+
if app.state.config.ENABLE_OLLAMA_API:
|
269 |
+
tasks = []
|
270 |
+
for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS):
|
271 |
+
if url not in app.state.config.OLLAMA_API_CONFIGS:
|
272 |
+
tasks.append(aiohttp_get(f"{url}/api/tags"))
|
273 |
+
else:
|
274 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
275 |
+
enable = api_config.get("enable", True)
|
276 |
+
key = api_config.get("key", None)
|
277 |
+
|
278 |
+
if enable:
|
279 |
+
tasks.append(aiohttp_get(f"{url}/api/tags", key))
|
280 |
+
else:
|
281 |
+
tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
282 |
+
|
283 |
+
responses = await asyncio.gather(*tasks)
|
284 |
+
|
285 |
+
for idx, response in enumerate(responses):
|
286 |
+
if response:
|
287 |
+
url = app.state.config.OLLAMA_BASE_URLS[idx]
|
288 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
289 |
+
|
290 |
+
prefix_id = api_config.get("prefix_id", None)
|
291 |
+
model_ids = api_config.get("model_ids", [])
|
292 |
+
|
293 |
+
if len(model_ids) != 0 and "models" in response:
|
294 |
+
response["models"] = list(
|
295 |
+
filter(
|
296 |
+
lambda model: model["model"] in model_ids,
|
297 |
+
response["models"],
|
298 |
+
)
|
299 |
+
)
|
300 |
+
|
301 |
+
if prefix_id:
|
302 |
+
for model in response.get("models", []):
|
303 |
+
model["model"] = f"{prefix_id}.{model['model']}"
|
304 |
+
|
305 |
+
models = {
|
306 |
+
"models": merge_models_lists(
|
307 |
+
map(
|
308 |
+
lambda response: response.get("models", []) if response else None,
|
309 |
+
responses,
|
310 |
+
)
|
311 |
+
)
|
312 |
+
}
|
313 |
+
|
314 |
+
else:
|
315 |
+
models = {"models": []}
|
316 |
+
|
317 |
+
return models
|
318 |
+
|
319 |
+
|
320 |
+
@app.get("/api/tags")
|
321 |
+
@app.get("/api/tags/{url_idx}")
|
322 |
+
async def get_ollama_tags(
|
323 |
+
url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
324 |
+
):
|
325 |
+
models = []
|
326 |
+
if url_idx is None:
|
327 |
+
models = await get_all_models()
|
328 |
+
else:
|
329 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
330 |
+
|
331 |
+
parsed_url = urlparse(url)
|
332 |
+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
333 |
+
|
334 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
335 |
+
key = api_config.get("key", None)
|
336 |
+
|
337 |
+
headers = {}
|
338 |
+
if key:
|
339 |
+
headers["Authorization"] = f"Bearer {key}"
|
340 |
+
|
341 |
+
r = None
|
342 |
+
try:
|
343 |
+
r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
|
344 |
+
r.raise_for_status()
|
345 |
+
|
346 |
+
models = r.json()
|
347 |
+
except Exception as e:
|
348 |
+
log.exception(e)
|
349 |
+
error_detail = "Open WebUI: Server Connection Error"
|
350 |
+
if r is not None:
|
351 |
+
try:
|
352 |
+
res = r.json()
|
353 |
+
if "error" in res:
|
354 |
+
error_detail = f"Ollama: {res['error']}"
|
355 |
+
except Exception:
|
356 |
+
error_detail = f"Ollama: {e}"
|
357 |
+
|
358 |
+
raise HTTPException(
|
359 |
+
status_code=r.status_code if r else 500,
|
360 |
+
detail=error_detail,
|
361 |
+
)
|
362 |
+
|
363 |
+
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
364 |
+
# Filter models based on user access control
|
365 |
+
filtered_models = []
|
366 |
+
for model in models.get("models", []):
|
367 |
+
model_info = Models.get_model_by_id(model["model"])
|
368 |
+
if model_info:
|
369 |
+
if user.id == model_info.user_id or has_access(
|
370 |
+
user.id, type="read", access_control=model_info.access_control
|
371 |
+
):
|
372 |
+
filtered_models.append(model)
|
373 |
+
models["models"] = filtered_models
|
374 |
+
|
375 |
+
return models
|
376 |
+
|
377 |
+
|
378 |
+
@app.get("/api/version")
|
379 |
+
@app.get("/api/version/{url_idx}")
|
380 |
+
async def get_ollama_versions(url_idx: Optional[int] = None):
|
381 |
+
if app.state.config.ENABLE_OLLAMA_API:
|
382 |
+
if url_idx is None:
|
383 |
+
# returns lowest version
|
384 |
+
tasks = [
|
385 |
+
aiohttp_get(
|
386 |
+
f"{url}/api/version",
|
387 |
+
app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
|
388 |
+
)
|
389 |
+
for url in app.state.config.OLLAMA_BASE_URLS
|
390 |
+
]
|
391 |
+
responses = await asyncio.gather(*tasks)
|
392 |
+
responses = list(filter(lambda x: x is not None, responses))
|
393 |
+
|
394 |
+
if len(responses) > 0:
|
395 |
+
lowest_version = min(
|
396 |
+
responses,
|
397 |
+
key=lambda x: tuple(
|
398 |
+
map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
|
399 |
+
),
|
400 |
+
)
|
401 |
+
|
402 |
+
return {"version": lowest_version["version"]}
|
403 |
+
else:
|
404 |
+
raise HTTPException(
|
405 |
+
status_code=500,
|
406 |
+
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
|
407 |
+
)
|
408 |
+
else:
|
409 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
410 |
+
|
411 |
+
r = None
|
412 |
+
try:
|
413 |
+
r = requests.request(method="GET", url=f"{url}/api/version")
|
414 |
+
r.raise_for_status()
|
415 |
+
|
416 |
+
return r.json()
|
417 |
+
except Exception as e:
|
418 |
+
log.exception(e)
|
419 |
+
error_detail = "Open WebUI: Server Connection Error"
|
420 |
+
if r is not None:
|
421 |
+
try:
|
422 |
+
res = r.json()
|
423 |
+
if "error" in res:
|
424 |
+
error_detail = f"Ollama: {res['error']}"
|
425 |
+
except Exception:
|
426 |
+
error_detail = f"Ollama: {e}"
|
427 |
+
|
428 |
+
raise HTTPException(
|
429 |
+
status_code=r.status_code if r else 500,
|
430 |
+
detail=error_detail,
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
return {"version": False}
|
434 |
+
|
435 |
+
|
436 |
+
class ModelNameForm(BaseModel):
|
437 |
+
name: str
|
438 |
+
|
439 |
+
|
440 |
+
@app.post("/api/pull")
|
441 |
+
@app.post("/api/pull/{url_idx}")
|
442 |
+
async def pull_model(
|
443 |
+
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
|
444 |
+
):
|
445 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
446 |
+
log.info(f"url: {url}")
|
447 |
+
|
448 |
+
# Admin should be able to pull models from any source
|
449 |
+
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
|
450 |
+
|
451 |
+
return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
|
452 |
+
|
453 |
+
|
454 |
+
class PushModelForm(BaseModel):
|
455 |
+
name: str
|
456 |
+
insecure: Optional[bool] = None
|
457 |
+
stream: Optional[bool] = None
|
458 |
+
|
459 |
+
|
460 |
+
@app.delete("/api/push")
|
461 |
+
@app.delete("/api/push/{url_idx}")
|
462 |
+
async def push_model(
|
463 |
+
form_data: PushModelForm,
|
464 |
+
url_idx: Optional[int] = None,
|
465 |
+
user=Depends(get_admin_user),
|
466 |
+
):
|
467 |
+
if url_idx is None:
|
468 |
+
model_list = await get_all_models()
|
469 |
+
models = {model["model"]: model for model in model_list["models"]}
|
470 |
+
|
471 |
+
if form_data.name in models:
|
472 |
+
url_idx = models[form_data.name]["urls"][0]
|
473 |
+
else:
|
474 |
+
raise HTTPException(
|
475 |
+
status_code=400,
|
476 |
+
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
477 |
+
)
|
478 |
+
|
479 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
480 |
+
log.debug(f"url: {url}")
|
481 |
+
|
482 |
+
return await post_streaming_url(
|
483 |
+
f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
|
484 |
+
)
|
485 |
+
|
486 |
+
|
487 |
+
class CreateModelForm(BaseModel):
|
488 |
+
name: str
|
489 |
+
modelfile: Optional[str] = None
|
490 |
+
stream: Optional[bool] = None
|
491 |
+
path: Optional[str] = None
|
492 |
+
|
493 |
+
|
494 |
+
@app.post("/api/create")
|
495 |
+
@app.post("/api/create/{url_idx}")
|
496 |
+
async def create_model(
|
497 |
+
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
|
498 |
+
):
|
499 |
+
log.debug(f"form_data: {form_data}")
|
500 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
501 |
+
log.info(f"url: {url}")
|
502 |
+
|
503 |
+
return await post_streaming_url(
|
504 |
+
f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
|
505 |
+
)
|
506 |
+
|
507 |
+
|
508 |
+
class CopyModelForm(BaseModel):
|
509 |
+
source: str
|
510 |
+
destination: str
|
511 |
+
|
512 |
+
|
513 |
+
@app.post("/api/copy")
|
514 |
+
@app.post("/api/copy/{url_idx}")
|
515 |
+
async def copy_model(
|
516 |
+
form_data: CopyModelForm,
|
517 |
+
url_idx: Optional[int] = None,
|
518 |
+
user=Depends(get_admin_user),
|
519 |
+
):
|
520 |
+
if url_idx is None:
|
521 |
+
model_list = await get_all_models()
|
522 |
+
models = {model["model"]: model for model in model_list["models"]}
|
523 |
+
|
524 |
+
if form_data.source in models:
|
525 |
+
url_idx = models[form_data.source]["urls"][0]
|
526 |
+
else:
|
527 |
+
raise HTTPException(
|
528 |
+
status_code=400,
|
529 |
+
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
|
530 |
+
)
|
531 |
+
|
532 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
533 |
+
log.info(f"url: {url}")
|
534 |
+
|
535 |
+
parsed_url = urlparse(url)
|
536 |
+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
537 |
+
|
538 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
539 |
+
key = api_config.get("key", None)
|
540 |
+
|
541 |
+
headers = {"Content-Type": "application/json"}
|
542 |
+
if key:
|
543 |
+
headers["Authorization"] = f"Bearer {key}"
|
544 |
+
|
545 |
+
r = requests.request(
|
546 |
+
method="POST",
|
547 |
+
url=f"{url}/api/copy",
|
548 |
+
headers=headers,
|
549 |
+
data=form_data.model_dump_json(exclude_none=True).encode(),
|
550 |
+
)
|
551 |
+
|
552 |
+
try:
|
553 |
+
r.raise_for_status()
|
554 |
+
|
555 |
+
log.debug(f"r.text: {r.text}")
|
556 |
+
|
557 |
+
return True
|
558 |
+
except Exception as e:
|
559 |
+
log.exception(e)
|
560 |
+
error_detail = "Open WebUI: Server Connection Error"
|
561 |
+
if r is not None:
|
562 |
+
try:
|
563 |
+
res = r.json()
|
564 |
+
if "error" in res:
|
565 |
+
error_detail = f"Ollama: {res['error']}"
|
566 |
+
except Exception:
|
567 |
+
error_detail = f"Ollama: {e}"
|
568 |
+
|
569 |
+
raise HTTPException(
|
570 |
+
status_code=r.status_code if r else 500,
|
571 |
+
detail=error_detail,
|
572 |
+
)
|
573 |
+
|
574 |
+
|
575 |
+
@app.delete("/api/delete")
|
576 |
+
@app.delete("/api/delete/{url_idx}")
|
577 |
+
async def delete_model(
|
578 |
+
form_data: ModelNameForm,
|
579 |
+
url_idx: Optional[int] = None,
|
580 |
+
user=Depends(get_admin_user),
|
581 |
+
):
|
582 |
+
if url_idx is None:
|
583 |
+
model_list = await get_all_models()
|
584 |
+
models = {model["model"]: model for model in model_list["models"]}
|
585 |
+
|
586 |
+
if form_data.name in models:
|
587 |
+
url_idx = models[form_data.name]["urls"][0]
|
588 |
+
else:
|
589 |
+
raise HTTPException(
|
590 |
+
status_code=400,
|
591 |
+
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
592 |
+
)
|
593 |
+
|
594 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
595 |
+
log.info(f"url: {url}")
|
596 |
+
|
597 |
+
parsed_url = urlparse(url)
|
598 |
+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
599 |
+
|
600 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
601 |
+
key = api_config.get("key", None)
|
602 |
+
|
603 |
+
headers = {"Content-Type": "application/json"}
|
604 |
+
if key:
|
605 |
+
headers["Authorization"] = f"Bearer {key}"
|
606 |
+
|
607 |
+
r = requests.request(
|
608 |
+
method="DELETE",
|
609 |
+
url=f"{url}/api/delete",
|
610 |
+
data=form_data.model_dump_json(exclude_none=True).encode(),
|
611 |
+
headers=headers,
|
612 |
+
)
|
613 |
+
try:
|
614 |
+
r.raise_for_status()
|
615 |
+
|
616 |
+
log.debug(f"r.text: {r.text}")
|
617 |
+
|
618 |
+
return True
|
619 |
+
except Exception as e:
|
620 |
+
log.exception(e)
|
621 |
+
error_detail = "Open WebUI: Server Connection Error"
|
622 |
+
if r is not None:
|
623 |
+
try:
|
624 |
+
res = r.json()
|
625 |
+
if "error" in res:
|
626 |
+
error_detail = f"Ollama: {res['error']}"
|
627 |
+
except Exception:
|
628 |
+
error_detail = f"Ollama: {e}"
|
629 |
+
|
630 |
+
raise HTTPException(
|
631 |
+
status_code=r.status_code if r else 500,
|
632 |
+
detail=error_detail,
|
633 |
+
)
|
634 |
+
|
635 |
+
|
636 |
+
@app.post("/api/show")
|
637 |
+
async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
|
638 |
+
model_list = await get_all_models()
|
639 |
+
models = {model["model"]: model for model in model_list["models"]}
|
640 |
+
|
641 |
+
if form_data.name not in models:
|
642 |
+
raise HTTPException(
|
643 |
+
status_code=400,
|
644 |
+
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
645 |
+
)
|
646 |
+
|
647 |
+
url_idx = random.choice(models[form_data.name]["urls"])
|
648 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
649 |
+
log.info(f"url: {url}")
|
650 |
+
|
651 |
+
parsed_url = urlparse(url)
|
652 |
+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
653 |
+
|
654 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
655 |
+
key = api_config.get("key", None)
|
656 |
+
|
657 |
+
headers = {"Content-Type": "application/json"}
|
658 |
+
if key:
|
659 |
+
headers["Authorization"] = f"Bearer {key}"
|
660 |
+
|
661 |
+
r = requests.request(
|
662 |
+
method="POST",
|
663 |
+
url=f"{url}/api/show",
|
664 |
+
headers=headers,
|
665 |
+
data=form_data.model_dump_json(exclude_none=True).encode(),
|
666 |
+
)
|
667 |
+
try:
|
668 |
+
r.raise_for_status()
|
669 |
+
|
670 |
+
return r.json()
|
671 |
+
except Exception as e:
|
672 |
+
log.exception(e)
|
673 |
+
error_detail = "Open WebUI: Server Connection Error"
|
674 |
+
if r is not None:
|
675 |
+
try:
|
676 |
+
res = r.json()
|
677 |
+
if "error" in res:
|
678 |
+
error_detail = f"Ollama: {res['error']}"
|
679 |
+
except Exception:
|
680 |
+
error_detail = f"Ollama: {e}"
|
681 |
+
|
682 |
+
raise HTTPException(
|
683 |
+
status_code=r.status_code if r else 500,
|
684 |
+
detail=error_detail,
|
685 |
+
)
|
686 |
+
|
687 |
+
|
688 |
+
class GenerateEmbeddingsForm(BaseModel):
|
689 |
+
model: str
|
690 |
+
prompt: str
|
691 |
+
options: Optional[dict] = None
|
692 |
+
keep_alive: Optional[Union[int, str]] = None
|
693 |
+
|
694 |
+
|
695 |
+
class GenerateEmbedForm(BaseModel):
|
696 |
+
model: str
|
697 |
+
input: list[str] | str
|
698 |
+
truncate: Optional[bool] = None
|
699 |
+
options: Optional[dict] = None
|
700 |
+
keep_alive: Optional[Union[int, str]] = None
|
701 |
+
|
702 |
+
|
703 |
+
@app.post("/api/embed")
|
704 |
+
@app.post("/api/embed/{url_idx}")
|
705 |
+
async def generate_embeddings(
|
706 |
+
form_data: GenerateEmbedForm,
|
707 |
+
url_idx: Optional[int] = None,
|
708 |
+
user=Depends(get_verified_user),
|
709 |
+
):
|
710 |
+
return await generate_ollama_batch_embeddings(form_data, url_idx)
|
711 |
+
|
712 |
+
|
713 |
+
@app.post("/api/embeddings")
|
714 |
+
@app.post("/api/embeddings/{url_idx}")
|
715 |
+
async def generate_embeddings(
|
716 |
+
form_data: GenerateEmbeddingsForm,
|
717 |
+
url_idx: Optional[int] = None,
|
718 |
+
user=Depends(get_verified_user),
|
719 |
+
):
|
720 |
+
return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
|
721 |
+
|
722 |
+
|
723 |
+
async def generate_ollama_embeddings(
|
724 |
+
form_data: GenerateEmbeddingsForm,
|
725 |
+
url_idx: Optional[int] = None,
|
726 |
+
):
|
727 |
+
log.info(f"generate_ollama_embeddings {form_data}")
|
728 |
+
|
729 |
+
if url_idx is None:
|
730 |
+
model_list = await get_all_models()
|
731 |
+
models = {model["model"]: model for model in model_list["models"]}
|
732 |
+
|
733 |
+
model = form_data.model
|
734 |
+
|
735 |
+
if ":" not in model:
|
736 |
+
model = f"{model}:latest"
|
737 |
+
|
738 |
+
if model in models:
|
739 |
+
url_idx = random.choice(models[model]["urls"])
|
740 |
+
else:
|
741 |
+
raise HTTPException(
|
742 |
+
status_code=400,
|
743 |
+
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
744 |
+
)
|
745 |
+
|
746 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
747 |
+
log.info(f"url: {url}")
|
748 |
+
|
749 |
+
parsed_url = urlparse(url)
|
750 |
+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
751 |
+
|
752 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
753 |
+
key = api_config.get("key", None)
|
754 |
+
|
755 |
+
headers = {"Content-Type": "application/json"}
|
756 |
+
if key:
|
757 |
+
headers["Authorization"] = f"Bearer {key}"
|
758 |
+
|
759 |
+
r = requests.request(
|
760 |
+
method="POST",
|
761 |
+
url=f"{url}/api/embeddings",
|
762 |
+
headers=headers,
|
763 |
+
data=form_data.model_dump_json(exclude_none=True).encode(),
|
764 |
+
)
|
765 |
+
try:
|
766 |
+
r.raise_for_status()
|
767 |
+
|
768 |
+
data = r.json()
|
769 |
+
|
770 |
+
log.info(f"generate_ollama_embeddings {data}")
|
771 |
+
|
772 |
+
if "embedding" in data:
|
773 |
+
return data
|
774 |
+
else:
|
775 |
+
raise Exception("Something went wrong :/")
|
776 |
+
except Exception as e:
|
777 |
+
log.exception(e)
|
778 |
+
error_detail = "Open WebUI: Server Connection Error"
|
779 |
+
if r is not None:
|
780 |
+
try:
|
781 |
+
res = r.json()
|
782 |
+
if "error" in res:
|
783 |
+
error_detail = f"Ollama: {res['error']}"
|
784 |
+
except Exception:
|
785 |
+
error_detail = f"Ollama: {e}"
|
786 |
+
|
787 |
+
raise HTTPException(
|
788 |
+
status_code=r.status_code if r else 500,
|
789 |
+
detail=error_detail,
|
790 |
+
)
|
791 |
+
|
792 |
+
|
793 |
+
async def generate_ollama_batch_embeddings(
|
794 |
+
form_data: GenerateEmbedForm,
|
795 |
+
url_idx: Optional[int] = None,
|
796 |
+
):
|
797 |
+
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
798 |
+
|
799 |
+
if url_idx is None:
|
800 |
+
model_list = await get_all_models()
|
801 |
+
models = {model["model"]: model for model in model_list["models"]}
|
802 |
+
|
803 |
+
model = form_data.model
|
804 |
+
|
805 |
+
if ":" not in model:
|
806 |
+
model = f"{model}:latest"
|
807 |
+
|
808 |
+
if model in models:
|
809 |
+
url_idx = random.choice(models[model]["urls"])
|
810 |
+
else:
|
811 |
+
raise HTTPException(
|
812 |
+
status_code=400,
|
813 |
+
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
814 |
+
)
|
815 |
+
|
816 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
817 |
+
log.info(f"url: {url}")
|
818 |
+
|
819 |
+
parsed_url = urlparse(url)
|
820 |
+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
821 |
+
|
822 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
823 |
+
key = api_config.get("key", None)
|
824 |
+
|
825 |
+
headers = {"Content-Type": "application/json"}
|
826 |
+
if key:
|
827 |
+
headers["Authorization"] = f"Bearer {key}"
|
828 |
+
|
829 |
+
r = requests.request(
|
830 |
+
method="POST",
|
831 |
+
url=f"{url}/api/embed",
|
832 |
+
headers=headers,
|
833 |
+
data=form_data.model_dump_json(exclude_none=True).encode(),
|
834 |
+
)
|
835 |
+
try:
|
836 |
+
r.raise_for_status()
|
837 |
+
|
838 |
+
data = r.json()
|
839 |
+
|
840 |
+
log.info(f"generate_ollama_batch_embeddings {data}")
|
841 |
+
|
842 |
+
if "embeddings" in data:
|
843 |
+
return data
|
844 |
+
else:
|
845 |
+
raise Exception("Something went wrong :/")
|
846 |
+
except Exception as e:
|
847 |
+
log.exception(e)
|
848 |
+
error_detail = "Open WebUI: Server Connection Error"
|
849 |
+
if r is not None:
|
850 |
+
try:
|
851 |
+
res = r.json()
|
852 |
+
if "error" in res:
|
853 |
+
error_detail = f"Ollama: {res['error']}"
|
854 |
+
except Exception:
|
855 |
+
error_detail = f"Ollama: {e}"
|
856 |
+
|
857 |
+
raise Exception(error_detail)
|
858 |
+
|
859 |
+
|
860 |
+
class GenerateCompletionForm(BaseModel):
|
861 |
+
model: str
|
862 |
+
prompt: str
|
863 |
+
suffix: Optional[str] = None
|
864 |
+
images: Optional[list[str]] = None
|
865 |
+
format: Optional[str] = None
|
866 |
+
options: Optional[dict] = None
|
867 |
+
system: Optional[str] = None
|
868 |
+
template: Optional[str] = None
|
869 |
+
context: Optional[list[int]] = None
|
870 |
+
stream: Optional[bool] = True
|
871 |
+
raw: Optional[bool] = None
|
872 |
+
keep_alive: Optional[Union[int, str]] = None
|
873 |
+
|
874 |
+
|
875 |
+
@app.post("/api/generate")
|
876 |
+
@app.post("/api/generate/{url_idx}")
|
877 |
+
async def generate_completion(
|
878 |
+
form_data: GenerateCompletionForm,
|
879 |
+
url_idx: Optional[int] = None,
|
880 |
+
user=Depends(get_verified_user),
|
881 |
+
):
|
882 |
+
if url_idx is None:
|
883 |
+
model_list = await get_all_models()
|
884 |
+
models = {model["model"]: model for model in model_list["models"]}
|
885 |
+
|
886 |
+
model = form_data.model
|
887 |
+
|
888 |
+
if ":" not in model:
|
889 |
+
model = f"{model}:latest"
|
890 |
+
|
891 |
+
if model in models:
|
892 |
+
url_idx = random.choice(models[model]["urls"])
|
893 |
+
else:
|
894 |
+
raise HTTPException(
|
895 |
+
status_code=400,
|
896 |
+
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
897 |
+
)
|
898 |
+
|
899 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
900 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
901 |
+
prefix_id = api_config.get("prefix_id", None)
|
902 |
+
if prefix_id:
|
903 |
+
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
904 |
+
log.info(f"url: {url}")
|
905 |
+
|
906 |
+
return await post_streaming_url(
|
907 |
+
f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
|
908 |
+
)
|
909 |
+
|
910 |
+
|
911 |
+
class ChatMessage(BaseModel):
|
912 |
+
role: str
|
913 |
+
content: str
|
914 |
+
images: Optional[list[str]] = None
|
915 |
+
|
916 |
+
|
917 |
+
class GenerateChatCompletionForm(BaseModel):
|
918 |
+
model: str
|
919 |
+
messages: list[ChatMessage]
|
920 |
+
format: Optional[str] = None
|
921 |
+
options: Optional[dict] = None
|
922 |
+
template: Optional[str] = None
|
923 |
+
stream: Optional[bool] = True
|
924 |
+
keep_alive: Optional[Union[int, str]] = None
|
925 |
+
|
926 |
+
|
927 |
+
async def get_ollama_url(url_idx: Optional[int], model: str):
|
928 |
+
if url_idx is None:
|
929 |
+
model_list = await get_all_models()
|
930 |
+
models = {model["model"]: model for model in model_list["models"]}
|
931 |
+
|
932 |
+
if model not in models:
|
933 |
+
raise HTTPException(
|
934 |
+
status_code=400,
|
935 |
+
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
|
936 |
+
)
|
937 |
+
url_idx = random.choice(models[model]["urls"])
|
938 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
939 |
+
return url
|
940 |
+
|
941 |
+
|
942 |
+
@app.post("/api/chat")
|
943 |
+
@app.post("/api/chat/{url_idx}")
|
944 |
+
async def generate_chat_completion(
|
945 |
+
form_data: GenerateChatCompletionForm,
|
946 |
+
url_idx: Optional[int] = None,
|
947 |
+
user=Depends(get_verified_user),
|
948 |
+
bypass_filter: Optional[bool] = False,
|
949 |
+
):
|
950 |
+
payload = {**form_data.model_dump(exclude_none=True)}
|
951 |
+
log.debug(f"generate_chat_completion() - 1.payload = {payload}")
|
952 |
+
if "metadata" in payload:
|
953 |
+
del payload["metadata"]
|
954 |
+
|
955 |
+
model_id = payload["model"]
|
956 |
+
model_info = Models.get_model_by_id(model_id)
|
957 |
+
|
958 |
+
if model_info:
|
959 |
+
if model_info.base_model_id:
|
960 |
+
payload["model"] = model_info.base_model_id
|
961 |
+
|
962 |
+
params = model_info.params.model_dump()
|
963 |
+
|
964 |
+
if params:
|
965 |
+
if payload.get("options") is None:
|
966 |
+
payload["options"] = {}
|
967 |
+
|
968 |
+
payload["options"] = apply_model_params_to_body_ollama(
|
969 |
+
params, payload["options"]
|
970 |
+
)
|
971 |
+
payload = apply_model_system_prompt_to_body(params, payload, user)
|
972 |
+
|
973 |
+
# Check if user has access to the model
|
974 |
+
if not bypass_filter and user.role == "user":
|
975 |
+
if not (
|
976 |
+
user.id == model_info.user_id
|
977 |
+
or has_access(
|
978 |
+
user.id, type="read", access_control=model_info.access_control
|
979 |
+
)
|
980 |
+
):
|
981 |
+
raise HTTPException(
|
982 |
+
status_code=403,
|
983 |
+
detail="Model not found",
|
984 |
+
)
|
985 |
+
elif not bypass_filter:
|
986 |
+
if user.role != "admin":
|
987 |
+
raise HTTPException(
|
988 |
+
status_code=403,
|
989 |
+
detail="Model not found",
|
990 |
+
)
|
991 |
+
|
992 |
+
if ":" not in payload["model"]:
|
993 |
+
payload["model"] = f"{payload['model']}:latest"
|
994 |
+
|
995 |
+
url = await get_ollama_url(url_idx, payload["model"])
|
996 |
+
log.info(f"url: {url}")
|
997 |
+
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
|
998 |
+
|
999 |
+
parsed_url = urlparse(url)
|
1000 |
+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
1001 |
+
|
1002 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
1003 |
+
prefix_id = api_config.get("prefix_id", None)
|
1004 |
+
if prefix_id:
|
1005 |
+
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
1006 |
+
|
1007 |
+
return await post_streaming_url(
|
1008 |
+
f"{url}/api/chat",
|
1009 |
+
json.dumps(payload),
|
1010 |
+
stream=form_data.stream,
|
1011 |
+
content_type="application/x-ndjson",
|
1012 |
+
)
|
1013 |
+
|
1014 |
+
|
1015 |
+
# TODO: we should update this part once Ollama supports other types
|
1016 |
+
class OpenAIChatMessageContent(BaseModel):
|
1017 |
+
type: str
|
1018 |
+
model_config = ConfigDict(extra="allow")
|
1019 |
+
|
1020 |
+
|
1021 |
+
class OpenAIChatMessage(BaseModel):
|
1022 |
+
role: str
|
1023 |
+
content: Union[str, list[OpenAIChatMessageContent]]
|
1024 |
+
|
1025 |
+
model_config = ConfigDict(extra="allow")
|
1026 |
+
|
1027 |
+
|
1028 |
+
class OpenAIChatCompletionForm(BaseModel):
|
1029 |
+
model: str
|
1030 |
+
messages: list[OpenAIChatMessage]
|
1031 |
+
|
1032 |
+
model_config = ConfigDict(extra="allow")
|
1033 |
+
|
1034 |
+
|
1035 |
+
@app.post("/v1/chat/completions")
|
1036 |
+
@app.post("/v1/chat/completions/{url_idx}")
|
1037 |
+
async def generate_openai_chat_completion(
|
1038 |
+
form_data: dict,
|
1039 |
+
url_idx: Optional[int] = None,
|
1040 |
+
user=Depends(get_verified_user),
|
1041 |
+
):
|
1042 |
+
try:
|
1043 |
+
completion_form = OpenAIChatCompletionForm(**form_data)
|
1044 |
+
except Exception as e:
|
1045 |
+
log.exception(e)
|
1046 |
+
raise HTTPException(
|
1047 |
+
status_code=400,
|
1048 |
+
detail=str(e),
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
|
1052 |
+
if "metadata" in payload:
|
1053 |
+
del payload["metadata"]
|
1054 |
+
|
1055 |
+
model_id = completion_form.model
|
1056 |
+
if ":" not in model_id:
|
1057 |
+
model_id = f"{model_id}:latest"
|
1058 |
+
|
1059 |
+
model_info = Models.get_model_by_id(model_id)
|
1060 |
+
if model_info:
|
1061 |
+
if model_info.base_model_id:
|
1062 |
+
payload["model"] = model_info.base_model_id
|
1063 |
+
|
1064 |
+
params = model_info.params.model_dump()
|
1065 |
+
|
1066 |
+
if params:
|
1067 |
+
payload = apply_model_params_to_body_openai(params, payload)
|
1068 |
+
payload = apply_model_system_prompt_to_body(params, payload, user)
|
1069 |
+
|
1070 |
+
# Check if user has access to the model
|
1071 |
+
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
1072 |
+
if not (
|
1073 |
+
user.id == model_info.user_id
|
1074 |
+
or has_access(
|
1075 |
+
user.id, type="read", access_control=model_info.access_control
|
1076 |
+
)
|
1077 |
+
):
|
1078 |
+
raise HTTPException(
|
1079 |
+
status_code=403,
|
1080 |
+
detail="Model not found",
|
1081 |
+
)
|
1082 |
+
else:
|
1083 |
+
if user.role != "admin":
|
1084 |
+
raise HTTPException(
|
1085 |
+
status_code=403,
|
1086 |
+
detail="Model not found",
|
1087 |
+
)
|
1088 |
+
|
1089 |
+
if ":" not in payload["model"]:
|
1090 |
+
payload["model"] = f"{payload['model']}:latest"
|
1091 |
+
|
1092 |
+
url = await get_ollama_url(url_idx, payload["model"])
|
1093 |
+
log.info(f"url: {url}")
|
1094 |
+
|
1095 |
+
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
1096 |
+
prefix_id = api_config.get("prefix_id", None)
|
1097 |
+
if prefix_id:
|
1098 |
+
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
1099 |
+
|
1100 |
+
return await post_streaming_url(
|
1101 |
+
f"{url}/v1/chat/completions",
|
1102 |
+
json.dumps(payload),
|
1103 |
+
stream=payload.get("stream", False),
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
|
1107 |
+
@app.get("/v1/models")
|
1108 |
+
@app.get("/v1/models/{url_idx}")
|
1109 |
+
async def get_openai_models(
|
1110 |
+
url_idx: Optional[int] = None,
|
1111 |
+
user=Depends(get_verified_user),
|
1112 |
+
):
|
1113 |
+
|
1114 |
+
models = []
|
1115 |
+
if url_idx is None:
|
1116 |
+
model_list = await get_all_models()
|
1117 |
+
models = [
|
1118 |
+
{
|
1119 |
+
"id": model["model"],
|
1120 |
+
"object": "model",
|
1121 |
+
"created": int(time.time()),
|
1122 |
+
"owned_by": "openai",
|
1123 |
+
}
|
1124 |
+
for model in model_list["models"]
|
1125 |
+
]
|
1126 |
+
|
1127 |
+
else:
|
1128 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
1129 |
+
try:
|
1130 |
+
r = requests.request(method="GET", url=f"{url}/api/tags")
|
1131 |
+
r.raise_for_status()
|
1132 |
+
|
1133 |
+
model_list = r.json()
|
1134 |
+
|
1135 |
+
models = [
|
1136 |
+
{
|
1137 |
+
"id": model["model"],
|
1138 |
+
"object": "model",
|
1139 |
+
"created": int(time.time()),
|
1140 |
+
"owned_by": "openai",
|
1141 |
+
}
|
1142 |
+
for model in models["models"]
|
1143 |
+
]
|
1144 |
+
except Exception as e:
|
1145 |
+
log.exception(e)
|
1146 |
+
error_detail = "Open WebUI: Server Connection Error"
|
1147 |
+
if r is not None:
|
1148 |
+
try:
|
1149 |
+
res = r.json()
|
1150 |
+
if "error" in res:
|
1151 |
+
error_detail = f"Ollama: {res['error']}"
|
1152 |
+
except Exception:
|
1153 |
+
error_detail = f"Ollama: {e}"
|
1154 |
+
|
1155 |
+
raise HTTPException(
|
1156 |
+
status_code=r.status_code if r else 500,
|
1157 |
+
detail=error_detail,
|
1158 |
+
)
|
1159 |
+
|
1160 |
+
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
1161 |
+
# Filter models based on user access control
|
1162 |
+
filtered_models = []
|
1163 |
+
for model in models:
|
1164 |
+
model_info = Models.get_model_by_id(model["id"])
|
1165 |
+
if model_info:
|
1166 |
+
if user.id == model_info.user_id or has_access(
|
1167 |
+
user.id, type="read", access_control=model_info.access_control
|
1168 |
+
):
|
1169 |
+
filtered_models.append(model)
|
1170 |
+
models = filtered_models
|
1171 |
+
|
1172 |
+
return {
|
1173 |
+
"data": models,
|
1174 |
+
"object": "list",
|
1175 |
+
}
|
1176 |
+
|
1177 |
+
|
1178 |
+
class UrlForm(BaseModel):
|
1179 |
+
url: str
|
1180 |
+
|
1181 |
+
|
1182 |
+
class UploadBlobForm(BaseModel):
|
1183 |
+
filename: str
|
1184 |
+
|
1185 |
+
|
1186 |
+
def parse_huggingface_url(hf_url):
|
1187 |
+
try:
|
1188 |
+
# Parse the URL
|
1189 |
+
parsed_url = urlparse(hf_url)
|
1190 |
+
|
1191 |
+
# Get the path and split it into components
|
1192 |
+
path_components = parsed_url.path.split("/")
|
1193 |
+
|
1194 |
+
# Extract the desired output
|
1195 |
+
model_file = path_components[-1]
|
1196 |
+
|
1197 |
+
return model_file
|
1198 |
+
except ValueError:
|
1199 |
+
return None
|
1200 |
+
|
1201 |
+
|
1202 |
+
async def download_file_stream(
|
1203 |
+
ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
|
1204 |
+
):
|
1205 |
+
done = False
|
1206 |
+
|
1207 |
+
if os.path.exists(file_path):
|
1208 |
+
current_size = os.path.getsize(file_path)
|
1209 |
+
else:
|
1210 |
+
current_size = 0
|
1211 |
+
|
1212 |
+
headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
|
1213 |
+
|
1214 |
+
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
|
1215 |
+
|
1216 |
+
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
1217 |
+
async with session.get(file_url, headers=headers) as response:
|
1218 |
+
total_size = int(response.headers.get("content-length", 0)) + current_size
|
1219 |
+
|
1220 |
+
with open(file_path, "ab+") as file:
|
1221 |
+
async for data in response.content.iter_chunked(chunk_size):
|
1222 |
+
current_size += len(data)
|
1223 |
+
file.write(data)
|
1224 |
+
|
1225 |
+
done = current_size == total_size
|
1226 |
+
progress = round((current_size / total_size) * 100, 2)
|
1227 |
+
|
1228 |
+
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
|
1229 |
+
|
1230 |
+
if done:
|
1231 |
+
file.seek(0)
|
1232 |
+
hashed = calculate_sha256(file)
|
1233 |
+
file.seek(0)
|
1234 |
+
|
1235 |
+
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
|
1236 |
+
response = requests.post(url, data=file)
|
1237 |
+
|
1238 |
+
if response.ok:
|
1239 |
+
res = {
|
1240 |
+
"done": done,
|
1241 |
+
"blob": f"sha256:{hashed}",
|
1242 |
+
"name": file_name,
|
1243 |
+
}
|
1244 |
+
os.remove(file_path)
|
1245 |
+
|
1246 |
+
yield f"data: {json.dumps(res)}\n\n"
|
1247 |
+
else:
|
1248 |
+
raise "Ollama: Could not create blob, Please try again."
|
1249 |
+
|
1250 |
+
|
1251 |
+
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
|
1252 |
+
@app.post("/models/download")
|
1253 |
+
@app.post("/models/download/{url_idx}")
|
1254 |
+
async def download_model(
|
1255 |
+
form_data: UrlForm,
|
1256 |
+
url_idx: Optional[int] = None,
|
1257 |
+
user=Depends(get_admin_user),
|
1258 |
+
):
|
1259 |
+
allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
|
1260 |
+
|
1261 |
+
if not any(form_data.url.startswith(host) for host in allowed_hosts):
|
1262 |
+
raise HTTPException(
|
1263 |
+
status_code=400,
|
1264 |
+
detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
|
1265 |
+
)
|
1266 |
+
|
1267 |
+
if url_idx is None:
|
1268 |
+
url_idx = 0
|
1269 |
+
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
1270 |
+
|
1271 |
+
file_name = parse_huggingface_url(form_data.url)
|
1272 |
+
|
1273 |
+
if file_name:
|
1274 |
+
file_path = f"{UPLOAD_DIR}/{file_name}"
|
1275 |
+
|
1276 |
+
return StreamingResponse(
|
1277 |
+
download_file_stream(url, form_data.url, file_path, file_name),
|
1278 |
+
)
|
1279 |
+
else:
|
1280 |
+
return None
|
1281 |
+
|
1282 |
+
|
1283 |
+
@app.post("/models/upload")
|
1284 |
+
@app.post("/models/upload/{url_idx}")
|
1285 |
+
def upload_model(
|
1286 |
+
file: UploadFile = File(...),
|
1287 |
+
url_idx: Optional[int] = None,
|
1288 |
+
user=Depends(get_admin_user),
|
1289 |
+
):
|
1290 |
+
if url_idx is None:
|
1291 |
+
url_idx = 0
|
1292 |
+
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
1293 |
+
|
1294 |
+
file_path = f"{UPLOAD_DIR}/{file.filename}"
|
1295 |
+
|
1296 |
+
# Save file in chunks
|
1297 |
+
with open(file_path, "wb+") as f:
|
1298 |
+
for chunk in file.file:
|
1299 |
+
f.write(chunk)
|
1300 |
+
|
1301 |
+
def file_process_stream():
|
1302 |
+
nonlocal ollama_url
|
1303 |
+
total_size = os.path.getsize(file_path)
|
1304 |
+
chunk_size = 1024 * 1024
|
1305 |
+
try:
|
1306 |
+
with open(file_path, "rb") as f:
|
1307 |
+
total = 0
|
1308 |
+
done = False
|
1309 |
+
|
1310 |
+
while not done:
|
1311 |
+
chunk = f.read(chunk_size)
|
1312 |
+
if not chunk:
|
1313 |
+
done = True
|
1314 |
+
continue
|
1315 |
+
|
1316 |
+
total += len(chunk)
|
1317 |
+
progress = round((total / total_size) * 100, 2)
|
1318 |
+
|
1319 |
+
res = {
|
1320 |
+
"progress": progress,
|
1321 |
+
"total": total_size,
|
1322 |
+
"completed": total,
|
1323 |
+
}
|
1324 |
+
yield f"data: {json.dumps(res)}\n\n"
|
1325 |
+
|
1326 |
+
if done:
|
1327 |
+
f.seek(0)
|
1328 |
+
hashed = calculate_sha256(f)
|
1329 |
+
f.seek(0)
|
1330 |
+
|
1331 |
+
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
|
1332 |
+
response = requests.post(url, data=f)
|
1333 |
+
|
1334 |
+
if response.ok:
|
1335 |
+
res = {
|
1336 |
+
"done": done,
|
1337 |
+
"blob": f"sha256:{hashed}",
|
1338 |
+
"name": file.filename,
|
1339 |
+
}
|
1340 |
+
os.remove(file_path)
|
1341 |
+
yield f"data: {json.dumps(res)}\n\n"
|
1342 |
+
else:
|
1343 |
+
raise Exception(
|
1344 |
+
"Ollama: Could not create blob, Please try again."
|
1345 |
+
)
|
1346 |
+
|
1347 |
+
except Exception as e:
|
1348 |
+
res = {"error": str(e)}
|
1349 |
+
yield f"data: {json.dumps(res)}\n\n"
|
1350 |
+
|
1351 |
+
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
|
backend/open_webui/apps/openai/main.py
ADDED
@@ -0,0 +1,719 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import hashlib
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Literal, Optional, overload
|
7 |
+
|
8 |
+
import aiohttp
|
9 |
+
from aiocache import cached
|
10 |
+
import requests
|
11 |
+
|
12 |
+
|
13 |
+
from open_webui.apps.webui.models.models import Models
|
14 |
+
from open_webui.config import (
|
15 |
+
CACHE_DIR,
|
16 |
+
CORS_ALLOW_ORIGIN,
|
17 |
+
ENABLE_OPENAI_API,
|
18 |
+
OPENAI_API_BASE_URLS,
|
19 |
+
OPENAI_API_KEYS,
|
20 |
+
OPENAI_API_CONFIGS,
|
21 |
+
AppConfig,
|
22 |
+
)
|
23 |
+
from open_webui.env import (
|
24 |
+
AIOHTTP_CLIENT_TIMEOUT,
|
25 |
+
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
26 |
+
ENABLE_FORWARD_USER_INFO_HEADERS,
|
27 |
+
BYPASS_MODEL_ACCESS_CONTROL,
|
28 |
+
)
|
29 |
+
|
30 |
+
from open_webui.constants import ERROR_MESSAGES
|
31 |
+
from open_webui.env import ENV, SRC_LOG_LEVELS
|
32 |
+
from fastapi import Depends, FastAPI, HTTPException, Request
|
33 |
+
from fastapi.middleware.cors import CORSMiddleware
|
34 |
+
from fastapi.responses import FileResponse, StreamingResponse
|
35 |
+
from pydantic import BaseModel
|
36 |
+
from starlette.background import BackgroundTask
|
37 |
+
|
38 |
+
from open_webui.utils.payload import (
|
39 |
+
apply_model_params_to_body_openai,
|
40 |
+
apply_model_system_prompt_to_body,
|
41 |
+
)
|
42 |
+
|
43 |
+
from open_webui.utils.utils import get_admin_user, get_verified_user
|
44 |
+
from open_webui.utils.access_control import has_access
|
45 |
+
|
46 |
+
|
47 |
+
log = logging.getLogger(__name__)
|
48 |
+
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
49 |
+
|
50 |
+
|
51 |
+
app = FastAPI(
|
52 |
+
docs_url="/docs" if ENV == "dev" else None,
|
53 |
+
openapi_url="/openapi.json" if ENV == "dev" else None,
|
54 |
+
redoc_url=None,
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
app.add_middleware(
|
59 |
+
CORSMiddleware,
|
60 |
+
allow_origins=CORS_ALLOW_ORIGIN,
|
61 |
+
allow_credentials=True,
|
62 |
+
allow_methods=["*"],
|
63 |
+
allow_headers=["*"],
|
64 |
+
)
|
65 |
+
|
66 |
+
app.state.config = AppConfig()
|
67 |
+
|
68 |
+
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
|
69 |
+
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
70 |
+
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
|
71 |
+
app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
|
72 |
+
|
73 |
+
|
74 |
+
@app.get("/config")
|
75 |
+
async def get_config(user=Depends(get_admin_user)):
|
76 |
+
return {
|
77 |
+
"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
|
78 |
+
"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
|
79 |
+
"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
|
80 |
+
"OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
class OpenAIConfigForm(BaseModel):
|
85 |
+
ENABLE_OPENAI_API: Optional[bool] = None
|
86 |
+
OPENAI_API_BASE_URLS: list[str]
|
87 |
+
OPENAI_API_KEYS: list[str]
|
88 |
+
OPENAI_API_CONFIGS: dict
|
89 |
+
|
90 |
+
|
91 |
+
@app.post("/config/update")
|
92 |
+
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
|
93 |
+
app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
|
94 |
+
|
95 |
+
app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
|
96 |
+
app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
|
97 |
+
|
98 |
+
# Check if API KEYS length is same than API URLS length
|
99 |
+
if len(app.state.config.OPENAI_API_KEYS) != len(
|
100 |
+
app.state.config.OPENAI_API_BASE_URLS
|
101 |
+
):
|
102 |
+
if len(app.state.config.OPENAI_API_KEYS) > len(
|
103 |
+
app.state.config.OPENAI_API_BASE_URLS
|
104 |
+
):
|
105 |
+
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
|
106 |
+
: len(app.state.config.OPENAI_API_BASE_URLS)
|
107 |
+
]
|
108 |
+
else:
|
109 |
+
app.state.config.OPENAI_API_KEYS += [""] * (
|
110 |
+
len(app.state.config.OPENAI_API_BASE_URLS)
|
111 |
+
- len(app.state.config.OPENAI_API_KEYS)
|
112 |
+
)
|
113 |
+
|
114 |
+
app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
|
115 |
+
|
116 |
+
# Remove any extra configs
|
117 |
+
config_urls = app.state.config.OPENAI_API_CONFIGS.keys()
|
118 |
+
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
|
119 |
+
if url not in config_urls:
|
120 |
+
app.state.config.OPENAI_API_CONFIGS.pop(url, None)
|
121 |
+
|
122 |
+
return {
|
123 |
+
"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
|
124 |
+
"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
|
125 |
+
"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
|
126 |
+
"OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
|
127 |
+
}
|
128 |
+
|
129 |
+
|
130 |
+
@app.post("/audio/speech")
|
131 |
+
async def speech(request: Request, user=Depends(get_verified_user)):
|
132 |
+
idx = None
|
133 |
+
try:
|
134 |
+
idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
|
135 |
+
body = await request.body()
|
136 |
+
name = hashlib.sha256(body).hexdigest()
|
137 |
+
|
138 |
+
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
139 |
+
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
140 |
+
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
141 |
+
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
142 |
+
|
143 |
+
# Check if the file already exists in the cache
|
144 |
+
if file_path.is_file():
|
145 |
+
return FileResponse(file_path)
|
146 |
+
|
147 |
+
headers = {}
|
148 |
+
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
|
149 |
+
headers["Content-Type"] = "application/json"
|
150 |
+
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
|
151 |
+
headers["HTTP-Referer"] = "https://openwebui.com/"
|
152 |
+
headers["X-Title"] = "Open WebUI"
|
153 |
+
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
154 |
+
headers["X-OpenWebUI-User-Name"] = user.name
|
155 |
+
headers["X-OpenWebUI-User-Id"] = user.id
|
156 |
+
headers["X-OpenWebUI-User-Email"] = user.email
|
157 |
+
headers["X-OpenWebUI-User-Role"] = user.role
|
158 |
+
r = None
|
159 |
+
try:
|
160 |
+
r = requests.post(
|
161 |
+
url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
|
162 |
+
data=body,
|
163 |
+
headers=headers,
|
164 |
+
stream=True,
|
165 |
+
)
|
166 |
+
|
167 |
+
r.raise_for_status()
|
168 |
+
|
169 |
+
# Save the streaming content to a file
|
170 |
+
with open(file_path, "wb") as f:
|
171 |
+
for chunk in r.iter_content(chunk_size=8192):
|
172 |
+
f.write(chunk)
|
173 |
+
|
174 |
+
with open(file_body_path, "w") as f:
|
175 |
+
json.dump(json.loads(body.decode("utf-8")), f)
|
176 |
+
|
177 |
+
# Return the saved file
|
178 |
+
return FileResponse(file_path)
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
log.exception(e)
|
182 |
+
error_detail = "Open WebUI: Server Connection Error"
|
183 |
+
if r is not None:
|
184 |
+
try:
|
185 |
+
res = r.json()
|
186 |
+
if "error" in res:
|
187 |
+
error_detail = f"External: {res['error']}"
|
188 |
+
except Exception:
|
189 |
+
error_detail = f"External: {e}"
|
190 |
+
|
191 |
+
raise HTTPException(
|
192 |
+
status_code=r.status_code if r else 500, detail=error_detail
|
193 |
+
)
|
194 |
+
|
195 |
+
except ValueError:
|
196 |
+
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
197 |
+
|
198 |
+
|
199 |
+
async def aiohttp_get(url, key=None):
|
200 |
+
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
201 |
+
try:
|
202 |
+
headers = {"Authorization": f"Bearer {key}"} if key else {}
|
203 |
+
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
204 |
+
async with session.get(url, headers=headers) as response:
|
205 |
+
return await response.json()
|
206 |
+
except Exception as e:
|
207 |
+
# Handle connection error here
|
208 |
+
log.error(f"Connection error: {e}")
|
209 |
+
return None
|
210 |
+
|
211 |
+
|
212 |
+
async def cleanup_response(
|
213 |
+
response: Optional[aiohttp.ClientResponse],
|
214 |
+
session: Optional[aiohttp.ClientSession],
|
215 |
+
):
|
216 |
+
if response:
|
217 |
+
response.close()
|
218 |
+
if session:
|
219 |
+
await session.close()
|
220 |
+
|
221 |
+
|
222 |
+
def merge_models_lists(model_lists):
|
223 |
+
log.debug(f"merge_models_lists {model_lists}")
|
224 |
+
merged_list = []
|
225 |
+
|
226 |
+
for idx, models in enumerate(model_lists):
|
227 |
+
if models is not None and "error" not in models:
|
228 |
+
merged_list.extend(
|
229 |
+
[
|
230 |
+
{
|
231 |
+
**model,
|
232 |
+
"name": model.get("name", model["id"]),
|
233 |
+
"owned_by": "openai",
|
234 |
+
"openai": model,
|
235 |
+
"urlIdx": idx,
|
236 |
+
}
|
237 |
+
for model in models
|
238 |
+
if "api.openai.com"
|
239 |
+
not in app.state.config.OPENAI_API_BASE_URLS[idx]
|
240 |
+
or not any(
|
241 |
+
name in model["id"]
|
242 |
+
for name in [
|
243 |
+
"babbage",
|
244 |
+
"dall-e",
|
245 |
+
"davinci",
|
246 |
+
"embedding",
|
247 |
+
"tts",
|
248 |
+
"whisper",
|
249 |
+
]
|
250 |
+
)
|
251 |
+
]
|
252 |
+
)
|
253 |
+
|
254 |
+
return merged_list
|
255 |
+
|
256 |
+
|
257 |
+
async def get_all_models_responses() -> list:
|
258 |
+
if not app.state.config.ENABLE_OPENAI_API:
|
259 |
+
return []
|
260 |
+
|
261 |
+
# Check if API KEYS length is same than API URLS length
|
262 |
+
num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
|
263 |
+
num_keys = len(app.state.config.OPENAI_API_KEYS)
|
264 |
+
|
265 |
+
if num_keys != num_urls:
|
266 |
+
# if there are more keys than urls, remove the extra keys
|
267 |
+
if num_keys > num_urls:
|
268 |
+
new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
|
269 |
+
app.state.config.OPENAI_API_KEYS = new_keys
|
270 |
+
# if there are more urls than keys, add empty keys
|
271 |
+
else:
|
272 |
+
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
|
273 |
+
|
274 |
+
tasks = []
|
275 |
+
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
|
276 |
+
if url not in app.state.config.OPENAI_API_CONFIGS:
|
277 |
+
tasks.append(
|
278 |
+
aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
|
282 |
+
|
283 |
+
enable = api_config.get("enable", True)
|
284 |
+
model_ids = api_config.get("model_ids", [])
|
285 |
+
|
286 |
+
if enable:
|
287 |
+
if len(model_ids) == 0:
|
288 |
+
tasks.append(
|
289 |
+
aiohttp_get(
|
290 |
+
f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]
|
291 |
+
)
|
292 |
+
)
|
293 |
+
else:
|
294 |
+
model_list = {
|
295 |
+
"object": "list",
|
296 |
+
"data": [
|
297 |
+
{
|
298 |
+
"id": model_id,
|
299 |
+
"name": model_id,
|
300 |
+
"owned_by": "openai",
|
301 |
+
"openai": {"id": model_id},
|
302 |
+
"urlIdx": idx,
|
303 |
+
}
|
304 |
+
for model_id in model_ids
|
305 |
+
],
|
306 |
+
}
|
307 |
+
|
308 |
+
tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list)))
|
309 |
+
else:
|
310 |
+
tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
311 |
+
|
312 |
+
responses = await asyncio.gather(*tasks)
|
313 |
+
|
314 |
+
for idx, response in enumerate(responses):
|
315 |
+
if response:
|
316 |
+
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
317 |
+
api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
|
318 |
+
|
319 |
+
prefix_id = api_config.get("prefix_id", None)
|
320 |
+
|
321 |
+
if prefix_id:
|
322 |
+
for model in (
|
323 |
+
response if isinstance(response, list) else response.get("data", [])
|
324 |
+
):
|
325 |
+
model["id"] = f"{prefix_id}.{model['id']}"
|
326 |
+
|
327 |
+
log.debug(f"get_all_models:responses() {responses}")
|
328 |
+
|
329 |
+
return responses
|
330 |
+
|
331 |
+
|
332 |
+
@cached(ttl=3)
|
333 |
+
async def get_all_models() -> dict[str, list]:
|
334 |
+
log.info("get_all_models()")
|
335 |
+
|
336 |
+
if not app.state.config.ENABLE_OPENAI_API:
|
337 |
+
return {"data": []}
|
338 |
+
|
339 |
+
responses = await get_all_models_responses()
|
340 |
+
|
341 |
+
def extract_data(response):
|
342 |
+
if response and "data" in response:
|
343 |
+
return response["data"]
|
344 |
+
if isinstance(response, list):
|
345 |
+
return response
|
346 |
+
return None
|
347 |
+
|
348 |
+
models = {"data": merge_models_lists(map(extract_data, responses))}
|
349 |
+
log.debug(f"models: {models}")
|
350 |
+
|
351 |
+
return models
|
352 |
+
|
353 |
+
|
354 |
+
@app.get("/models")
|
355 |
+
@app.get("/models/{url_idx}")
|
356 |
+
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
|
357 |
+
models = {
|
358 |
+
"data": [],
|
359 |
+
}
|
360 |
+
|
361 |
+
if url_idx is None:
|
362 |
+
models = await get_all_models()
|
363 |
+
else:
|
364 |
+
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
365 |
+
key = app.state.config.OPENAI_API_KEYS[url_idx]
|
366 |
+
|
367 |
+
headers = {}
|
368 |
+
headers["Authorization"] = f"Bearer {key}"
|
369 |
+
headers["Content-Type"] = "application/json"
|
370 |
+
|
371 |
+
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
372 |
+
headers["X-OpenWebUI-User-Name"] = user.name
|
373 |
+
headers["X-OpenWebUI-User-Id"] = user.id
|
374 |
+
headers["X-OpenWebUI-User-Email"] = user.email
|
375 |
+
headers["X-OpenWebUI-User-Role"] = user.role
|
376 |
+
|
377 |
+
r = None
|
378 |
+
|
379 |
+
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
380 |
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
381 |
+
try:
|
382 |
+
async with session.get(f"{url}/models", headers=headers) as r:
|
383 |
+
if r.status != 200:
|
384 |
+
# Extract response error details if available
|
385 |
+
error_detail = f"HTTP Error: {r.status}"
|
386 |
+
res = await r.json()
|
387 |
+
if "error" in res:
|
388 |
+
error_detail = f"External Error: {res['error']}"
|
389 |
+
raise Exception(error_detail)
|
390 |
+
|
391 |
+
response_data = await r.json()
|
392 |
+
|
393 |
+
# Check if we're calling OpenAI API based on the URL
|
394 |
+
if "api.openai.com" in url:
|
395 |
+
# Filter models according to the specified conditions
|
396 |
+
response_data["data"] = [
|
397 |
+
model
|
398 |
+
for model in response_data.get("data", [])
|
399 |
+
if not any(
|
400 |
+
name in model["id"]
|
401 |
+
for name in [
|
402 |
+
"babbage",
|
403 |
+
"dall-e",
|
404 |
+
"davinci",
|
405 |
+
"embedding",
|
406 |
+
"tts",
|
407 |
+
"whisper",
|
408 |
+
]
|
409 |
+
)
|
410 |
+
]
|
411 |
+
|
412 |
+
models = response_data
|
413 |
+
except aiohttp.ClientError as e:
|
414 |
+
# ClientError covers all aiohttp requests issues
|
415 |
+
log.exception(f"Client error: {str(e)}")
|
416 |
+
# Handle aiohttp-specific connection issues, timeout etc.
|
417 |
+
raise HTTPException(
|
418 |
+
status_code=500, detail="Open WebUI: Server Connection Error"
|
419 |
+
)
|
420 |
+
except Exception as e:
|
421 |
+
log.exception(f"Unexpected error: {e}")
|
422 |
+
# Generic error handler in case parsing JSON or other steps fail
|
423 |
+
error_detail = f"Unexpected error: {str(e)}"
|
424 |
+
raise HTTPException(status_code=500, detail=error_detail)
|
425 |
+
|
426 |
+
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
427 |
+
# Filter models based on user access control
|
428 |
+
filtered_models = []
|
429 |
+
for model in models.get("data", []):
|
430 |
+
model_info = Models.get_model_by_id(model["id"])
|
431 |
+
if model_info:
|
432 |
+
if user.id == model_info.user_id or has_access(
|
433 |
+
user.id, type="read", access_control=model_info.access_control
|
434 |
+
):
|
435 |
+
filtered_models.append(model)
|
436 |
+
models["data"] = filtered_models
|
437 |
+
|
438 |
+
return models
|
439 |
+
|
440 |
+
|
441 |
+
class ConnectionVerificationForm(BaseModel):
|
442 |
+
url: str
|
443 |
+
key: str
|
444 |
+
|
445 |
+
|
446 |
+
@app.post("/verify")
|
447 |
+
async def verify_connection(
|
448 |
+
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
|
449 |
+
):
|
450 |
+
url = form_data.url
|
451 |
+
key = form_data.key
|
452 |
+
|
453 |
+
headers = {}
|
454 |
+
headers["Authorization"] = f"Bearer {key}"
|
455 |
+
headers["Content-Type"] = "application/json"
|
456 |
+
|
457 |
+
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
458 |
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
459 |
+
try:
|
460 |
+
async with session.get(f"{url}/models", headers=headers) as r:
|
461 |
+
if r.status != 200:
|
462 |
+
# Extract response error details if available
|
463 |
+
error_detail = f"HTTP Error: {r.status}"
|
464 |
+
res = await r.json()
|
465 |
+
if "error" in res:
|
466 |
+
error_detail = f"External Error: {res['error']}"
|
467 |
+
raise Exception(error_detail)
|
468 |
+
|
469 |
+
response_data = await r.json()
|
470 |
+
return response_data
|
471 |
+
|
472 |
+
except aiohttp.ClientError as e:
|
473 |
+
# ClientError covers all aiohttp requests issues
|
474 |
+
log.exception(f"Client error: {str(e)}")
|
475 |
+
# Handle aiohttp-specific connection issues, timeout etc.
|
476 |
+
raise HTTPException(
|
477 |
+
status_code=500, detail="Open WebUI: Server Connection Error"
|
478 |
+
)
|
479 |
+
except Exception as e:
|
480 |
+
log.exception(f"Unexpected error: {e}")
|
481 |
+
# Generic error handler in case parsing JSON or other steps fail
|
482 |
+
error_detail = f"Unexpected error: {str(e)}"
|
483 |
+
raise HTTPException(status_code=500, detail=error_detail)
|
484 |
+
|
485 |
+
|
486 |
+
@app.post("/chat/completions")
|
487 |
+
async def generate_chat_completion(
|
488 |
+
form_data: dict,
|
489 |
+
user=Depends(get_verified_user),
|
490 |
+
bypass_filter: Optional[bool] = False,
|
491 |
+
):
|
492 |
+
idx = 0
|
493 |
+
payload = {**form_data}
|
494 |
+
|
495 |
+
if "metadata" in payload:
|
496 |
+
del payload["metadata"]
|
497 |
+
|
498 |
+
model_id = form_data.get("model")
|
499 |
+
model_info = Models.get_model_by_id(model_id)
|
500 |
+
|
501 |
+
# Check model info and override the payload
|
502 |
+
if model_info:
|
503 |
+
if model_info.base_model_id:
|
504 |
+
payload["model"] = model_info.base_model_id
|
505 |
+
|
506 |
+
params = model_info.params.model_dump()
|
507 |
+
payload = apply_model_params_to_body_openai(params, payload)
|
508 |
+
payload = apply_model_system_prompt_to_body(params, payload, user)
|
509 |
+
|
510 |
+
# Check if user has access to the model
|
511 |
+
if not bypass_filter and user.role == "user":
|
512 |
+
if not (
|
513 |
+
user.id == model_info.user_id
|
514 |
+
or has_access(
|
515 |
+
user.id, type="read", access_control=model_info.access_control
|
516 |
+
)
|
517 |
+
):
|
518 |
+
raise HTTPException(
|
519 |
+
status_code=403,
|
520 |
+
detail="Model not found",
|
521 |
+
)
|
522 |
+
elif not bypass_filter:
|
523 |
+
if user.role != "admin":
|
524 |
+
raise HTTPException(
|
525 |
+
status_code=403,
|
526 |
+
detail="Model not found",
|
527 |
+
)
|
528 |
+
|
529 |
+
# Attemp to get urlIdx from the model
|
530 |
+
models = await get_all_models()
|
531 |
+
|
532 |
+
# Find the model from the list
|
533 |
+
model = next(
|
534 |
+
(model for model in models["data"] if model["id"] == payload.get("model")),
|
535 |
+
None,
|
536 |
+
)
|
537 |
+
|
538 |
+
if model:
|
539 |
+
idx = model["urlIdx"]
|
540 |
+
else:
|
541 |
+
raise HTTPException(
|
542 |
+
status_code=404,
|
543 |
+
detail="Model not found",
|
544 |
+
)
|
545 |
+
|
546 |
+
# Get the API config for the model
|
547 |
+
api_config = app.state.config.OPENAI_API_CONFIGS.get(
|
548 |
+
app.state.config.OPENAI_API_BASE_URLS[idx], {}
|
549 |
+
)
|
550 |
+
prefix_id = api_config.get("prefix_id", None)
|
551 |
+
|
552 |
+
if prefix_id:
|
553 |
+
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
554 |
+
|
555 |
+
# Add user info to the payload if the model is a pipeline
|
556 |
+
if "pipeline" in model and model.get("pipeline"):
|
557 |
+
payload["user"] = {
|
558 |
+
"name": user.name,
|
559 |
+
"id": user.id,
|
560 |
+
"email": user.email,
|
561 |
+
"role": user.role,
|
562 |
+
}
|
563 |
+
|
564 |
+
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
565 |
+
key = app.state.config.OPENAI_API_KEYS[idx]
|
566 |
+
|
567 |
+
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
568 |
+
is_o1 = payload["model"].lower().startswith("o1-")
|
569 |
+
# Change max_completion_tokens to max_tokens (Backward compatible)
|
570 |
+
if "api.openai.com" not in url and not is_o1:
|
571 |
+
if "max_completion_tokens" in payload:
|
572 |
+
# Remove "max_completion_tokens" from the payload
|
573 |
+
payload["max_tokens"] = payload["max_completion_tokens"]
|
574 |
+
del payload["max_completion_tokens"]
|
575 |
+
else:
|
576 |
+
if is_o1 and "max_tokens" in payload:
|
577 |
+
payload["max_completion_tokens"] = payload["max_tokens"]
|
578 |
+
del payload["max_tokens"]
|
579 |
+
if "max_tokens" in payload and "max_completion_tokens" in payload:
|
580 |
+
del payload["max_tokens"]
|
581 |
+
|
582 |
+
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
|
583 |
+
if is_o1 and payload["messages"][0]["role"] == "system":
|
584 |
+
payload["messages"][0]["role"] = "user"
|
585 |
+
|
586 |
+
# Convert the modified body back to JSON
|
587 |
+
payload = json.dumps(payload)
|
588 |
+
|
589 |
+
headers = {}
|
590 |
+
headers["Authorization"] = f"Bearer {key}"
|
591 |
+
headers["Content-Type"] = "application/json"
|
592 |
+
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
|
593 |
+
headers["HTTP-Referer"] = "https://openwebui.com/"
|
594 |
+
headers["X-Title"] = "Open WebUI"
|
595 |
+
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
596 |
+
headers["X-OpenWebUI-User-Name"] = user.name
|
597 |
+
headers["X-OpenWebUI-User-Id"] = user.id
|
598 |
+
headers["X-OpenWebUI-User-Email"] = user.email
|
599 |
+
headers["X-OpenWebUI-User-Role"] = user.role
|
600 |
+
|
601 |
+
r = None
|
602 |
+
session = None
|
603 |
+
streaming = False
|
604 |
+
response = None
|
605 |
+
|
606 |
+
try:
|
607 |
+
session = aiohttp.ClientSession(
|
608 |
+
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
609 |
+
)
|
610 |
+
r = await session.request(
|
611 |
+
method="POST",
|
612 |
+
url=f"{url}/chat/completions",
|
613 |
+
data=payload,
|
614 |
+
headers=headers,
|
615 |
+
)
|
616 |
+
|
617 |
+
# Check if response is SSE
|
618 |
+
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
619 |
+
streaming = True
|
620 |
+
return StreamingResponse(
|
621 |
+
r.content,
|
622 |
+
status_code=r.status,
|
623 |
+
headers=dict(r.headers),
|
624 |
+
background=BackgroundTask(
|
625 |
+
cleanup_response, response=r, session=session
|
626 |
+
),
|
627 |
+
)
|
628 |
+
else:
|
629 |
+
try:
|
630 |
+
response = await r.json()
|
631 |
+
except Exception as e:
|
632 |
+
log.error(e)
|
633 |
+
response = await r.text()
|
634 |
+
|
635 |
+
r.raise_for_status()
|
636 |
+
return response
|
637 |
+
except Exception as e:
|
638 |
+
log.exception(e)
|
639 |
+
error_detail = "Open WebUI: Server Connection Error"
|
640 |
+
if isinstance(response, dict):
|
641 |
+
if "error" in response:
|
642 |
+
error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
|
643 |
+
elif isinstance(response, str):
|
644 |
+
error_detail = response
|
645 |
+
|
646 |
+
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
647 |
+
finally:
|
648 |
+
if not streaming and session:
|
649 |
+
if r:
|
650 |
+
r.close()
|
651 |
+
await session.close()
|
652 |
+
|
653 |
+
|
654 |
+
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
655 |
+
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
656 |
+
idx = 0
|
657 |
+
|
658 |
+
body = await request.body()
|
659 |
+
|
660 |
+
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
661 |
+
key = app.state.config.OPENAI_API_KEYS[idx]
|
662 |
+
|
663 |
+
target_url = f"{url}/{path}"
|
664 |
+
|
665 |
+
headers = {}
|
666 |
+
headers["Authorization"] = f"Bearer {key}"
|
667 |
+
headers["Content-Type"] = "application/json"
|
668 |
+
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
669 |
+
headers["X-OpenWebUI-User-Name"] = user.name
|
670 |
+
headers["X-OpenWebUI-User-Id"] = user.id
|
671 |
+
headers["X-OpenWebUI-User-Email"] = user.email
|
672 |
+
headers["X-OpenWebUI-User-Role"] = user.role
|
673 |
+
|
674 |
+
r = None
|
675 |
+
session = None
|
676 |
+
streaming = False
|
677 |
+
|
678 |
+
try:
|
679 |
+
session = aiohttp.ClientSession(trust_env=True)
|
680 |
+
r = await session.request(
|
681 |
+
method=request.method,
|
682 |
+
url=target_url,
|
683 |
+
data=body,
|
684 |
+
headers=headers,
|
685 |
+
)
|
686 |
+
|
687 |
+
r.raise_for_status()
|
688 |
+
|
689 |
+
# Check if response is SSE
|
690 |
+
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
691 |
+
streaming = True
|
692 |
+
return StreamingResponse(
|
693 |
+
r.content,
|
694 |
+
status_code=r.status,
|
695 |
+
headers=dict(r.headers),
|
696 |
+
background=BackgroundTask(
|
697 |
+
cleanup_response, response=r, session=session
|
698 |
+
),
|
699 |
+
)
|
700 |
+
else:
|
701 |
+
response_data = await r.json()
|
702 |
+
return response_data
|
703 |
+
except Exception as e:
|
704 |
+
log.exception(e)
|
705 |
+
error_detail = "Open WebUI: Server Connection Error"
|
706 |
+
if r is not None:
|
707 |
+
try:
|
708 |
+
res = await r.json()
|
709 |
+
print(res)
|
710 |
+
if "error" in res:
|
711 |
+
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
712 |
+
except Exception:
|
713 |
+
error_detail = f"External: {e}"
|
714 |
+
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
715 |
+
finally:
|
716 |
+
if not streaming and session:
|
717 |
+
if r:
|
718 |
+
r.close()
|
719 |
+
await session.close()
|
backend/open_webui/apps/retrieval/loaders/main.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import logging
|
3 |
+
import ftfy
|
4 |
+
|
5 |
+
from langchain_community.document_loaders import (
|
6 |
+
BSHTMLLoader,
|
7 |
+
CSVLoader,
|
8 |
+
Docx2txtLoader,
|
9 |
+
OutlookMessageLoader,
|
10 |
+
PyPDFLoader,
|
11 |
+
TextLoader,
|
12 |
+
UnstructuredEPubLoader,
|
13 |
+
UnstructuredExcelLoader,
|
14 |
+
UnstructuredMarkdownLoader,
|
15 |
+
UnstructuredPowerPointLoader,
|
16 |
+
UnstructuredRSTLoader,
|
17 |
+
UnstructuredXMLLoader,
|
18 |
+
YoutubeLoader,
|
19 |
+
)
|
20 |
+
from langchain_core.documents import Document
|
21 |
+
from open_webui.env import SRC_LOG_LEVELS
|
22 |
+
|
23 |
+
log = logging.getLogger(__name__)
|
24 |
+
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
25 |
+
|
26 |
+
known_source_ext = [
|
27 |
+
"go",
|
28 |
+
"py",
|
29 |
+
"java",
|
30 |
+
"sh",
|
31 |
+
"bat",
|
32 |
+
"ps1",
|
33 |
+
"cmd",
|
34 |
+
"js",
|
35 |
+
"ts",
|
36 |
+
"css",
|
37 |
+
"cpp",
|
38 |
+
"hpp",
|
39 |
+
"h",
|
40 |
+
"c",
|
41 |
+
"cs",
|
42 |
+
"sql",
|
43 |
+
"log",
|
44 |
+
"ini",
|
45 |
+
"pl",
|
46 |
+
"pm",
|
47 |
+
"r",
|
48 |
+
"dart",
|
49 |
+
"dockerfile",
|
50 |
+
"env",
|
51 |
+
"php",
|
52 |
+
"hs",
|
53 |
+
"hsc",
|
54 |
+
"lua",
|
55 |
+
"nginxconf",
|
56 |
+
"conf",
|
57 |
+
"m",
|
58 |
+
"mm",
|
59 |
+
"plsql",
|
60 |
+
"perl",
|
61 |
+
"rb",
|
62 |
+
"rs",
|
63 |
+
"db2",
|
64 |
+
"scala",
|
65 |
+
"bash",
|
66 |
+
"swift",
|
67 |
+
"vue",
|
68 |
+
"svelte",
|
69 |
+
"msg",
|
70 |
+
"ex",
|
71 |
+
"exs",
|
72 |
+
"erl",
|
73 |
+
"tsx",
|
74 |
+
"jsx",
|
75 |
+
"hs",
|
76 |
+
"lhs",
|
77 |
+
]
|
78 |
+
|
79 |
+
|
80 |
+
class TikaLoader:
|
81 |
+
def __init__(self, url, file_path, mime_type=None):
|
82 |
+
self.url = url
|
83 |
+
self.file_path = file_path
|
84 |
+
self.mime_type = mime_type
|
85 |
+
|
86 |
+
def load(self) -> list[Document]:
|
87 |
+
with open(self.file_path, "rb") as f:
|
88 |
+
data = f.read()
|
89 |
+
|
90 |
+
if self.mime_type is not None:
|
91 |
+
headers = {"Content-Type": self.mime_type}
|
92 |
+
else:
|
93 |
+
headers = {}
|
94 |
+
|
95 |
+
endpoint = self.url
|
96 |
+
if not endpoint.endswith("/"):
|
97 |
+
endpoint += "/"
|
98 |
+
endpoint += "tika/text"
|
99 |
+
|
100 |
+
r = requests.put(endpoint, data=data, headers=headers)
|
101 |
+
|
102 |
+
if r.ok:
|
103 |
+
raw_metadata = r.json()
|
104 |
+
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
|
105 |
+
|
106 |
+
if "Content-Type" in raw_metadata:
|
107 |
+
headers["Content-Type"] = raw_metadata["Content-Type"]
|
108 |
+
|
109 |
+
log.info("Tika extracted text: %s", text)
|
110 |
+
|
111 |
+
return [Document(page_content=text, metadata=headers)]
|
112 |
+
else:
|
113 |
+
raise Exception(f"Error calling Tika: {r.reason}")
|
114 |
+
|
115 |
+
|
116 |
+
class Loader:
|
117 |
+
def __init__(self, engine: str = "", **kwargs):
|
118 |
+
self.engine = engine
|
119 |
+
self.kwargs = kwargs
|
120 |
+
|
121 |
+
def load(
|
122 |
+
self, filename: str, file_content_type: str, file_path: str
|
123 |
+
) -> list[Document]:
|
124 |
+
loader = self._get_loader(filename, file_content_type, file_path)
|
125 |
+
docs = loader.load()
|
126 |
+
|
127 |
+
return [
|
128 |
+
Document(
|
129 |
+
page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
|
130 |
+
)
|
131 |
+
for doc in docs
|
132 |
+
]
|
133 |
+
|
134 |
+
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
|
135 |
+
file_ext = filename.split(".")[-1].lower()
|
136 |
+
|
137 |
+
if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
|
138 |
+
if file_ext in known_source_ext or (
|
139 |
+
file_content_type and file_content_type.find("text/") >= 0
|
140 |
+
):
|
141 |
+
loader = TextLoader(file_path, autodetect_encoding=True)
|
142 |
+
else:
|
143 |
+
loader = TikaLoader(
|
144 |
+
url=self.kwargs.get("TIKA_SERVER_URL"),
|
145 |
+
file_path=file_path,
|
146 |
+
mime_type=file_content_type,
|
147 |
+
)
|
148 |
+
else:
|
149 |
+
if file_ext == "pdf":
|
150 |
+
loader = PyPDFLoader(
|
151 |
+
file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES")
|
152 |
+
)
|
153 |
+
elif file_ext == "csv":
|
154 |
+
loader = CSVLoader(file_path)
|
155 |
+
elif file_ext == "rst":
|
156 |
+
loader = UnstructuredRSTLoader(file_path, mode="elements")
|
157 |
+
elif file_ext == "xml":
|
158 |
+
loader = UnstructuredXMLLoader(file_path)
|
159 |
+
elif file_ext in ["htm", "html"]:
|
160 |
+
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
|
161 |
+
elif file_ext == "md":
|
162 |
+
loader = TextLoader(file_path, autodetect_encoding=True)
|
163 |
+
elif file_content_type == "application/epub+zip":
|
164 |
+
loader = UnstructuredEPubLoader(file_path)
|
165 |
+
elif (
|
166 |
+
file_content_type
|
167 |
+
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
168 |
+
or file_ext == "docx"
|
169 |
+
):
|
170 |
+
loader = Docx2txtLoader(file_path)
|
171 |
+
elif file_content_type in [
|
172 |
+
"application/vnd.ms-excel",
|
173 |
+
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
174 |
+
] or file_ext in ["xls", "xlsx"]:
|
175 |
+
loader = UnstructuredExcelLoader(file_path)
|
176 |
+
elif file_content_type in [
|
177 |
+
"application/vnd.ms-powerpoint",
|
178 |
+
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
179 |
+
] or file_ext in ["ppt", "pptx"]:
|
180 |
+
loader = UnstructuredPowerPointLoader(file_path)
|
181 |
+
elif file_ext == "msg":
|
182 |
+
loader = OutlookMessageLoader(file_path)
|
183 |
+
elif file_ext in known_source_ext or (
|
184 |
+
file_content_type and file_content_type.find("text/") >= 0
|
185 |
+
):
|
186 |
+
loader = TextLoader(file_path, autodetect_encoding=True)
|
187 |
+
else:
|
188 |
+
loader = TextLoader(file_path, autodetect_encoding=True)
|
189 |
+
|
190 |
+
return loader
|
backend/open_webui/apps/retrieval/loaders/youtube.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
4 |
+
from urllib.parse import parse_qs, urlparse
|
5 |
+
from langchain_core.documents import Document
|
6 |
+
from open_webui.env import SRC_LOG_LEVELS
|
7 |
+
|
8 |
+
log = logging.getLogger(__name__)
|
9 |
+
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
10 |
+
|
11 |
+
ALLOWED_SCHEMES = {"http", "https"}
|
12 |
+
ALLOWED_NETLOCS = {
|
13 |
+
"youtu.be",
|
14 |
+
"m.youtube.com",
|
15 |
+
"youtube.com",
|
16 |
+
"www.youtube.com",
|
17 |
+
"www.youtube-nocookie.com",
|
18 |
+
"vid.plus",
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def _parse_video_id(url: str) -> Optional[str]:
|
23 |
+
"""Parse a YouTube URL and return the video ID if valid, otherwise None."""
|
24 |
+
parsed_url = urlparse(url)
|
25 |
+
|
26 |
+
if parsed_url.scheme not in ALLOWED_SCHEMES:
|
27 |
+
return None
|
28 |
+
|
29 |
+
if parsed_url.netloc not in ALLOWED_NETLOCS:
|
30 |
+
return None
|
31 |
+
|
32 |
+
path = parsed_url.path
|
33 |
+
|
34 |
+
if path.endswith("/watch"):
|
35 |
+
query = parsed_url.query
|
36 |
+
parsed_query = parse_qs(query)
|
37 |
+
if "v" in parsed_query:
|
38 |
+
ids = parsed_query["v"]
|
39 |
+
video_id = ids if isinstance(ids, str) else ids[0]
|
40 |
+
else:
|
41 |
+
return None
|
42 |
+
else:
|
43 |
+
path = parsed_url.path.lstrip("/")
|
44 |
+
video_id = path.split("/")[-1]
|
45 |
+
|
46 |
+
if len(video_id) != 11: # Video IDs are 11 characters long
|
47 |
+
return None
|
48 |
+
|
49 |
+
return video_id
|
50 |
+
|
51 |
+
|
52 |
+
class YoutubeLoader:
|
53 |
+
"""Load `YouTube` video transcripts."""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
video_id: str,
|
58 |
+
language: Union[str, Sequence[str]] = "en",
|
59 |
+
proxy_url: Optional[str] = None,
|
60 |
+
):
|
61 |
+
"""Initialize with YouTube video ID."""
|
62 |
+
_video_id = _parse_video_id(video_id)
|
63 |
+
self.video_id = _video_id if _video_id is not None else video_id
|
64 |
+
self._metadata = {"source": video_id}
|
65 |
+
self.language = language
|
66 |
+
self.proxy_url = proxy_url
|
67 |
+
if isinstance(language, str):
|
68 |
+
self.language = [language]
|
69 |
+
else:
|
70 |
+
self.language = language
|
71 |
+
|
72 |
+
def load(self) -> List[Document]:
|
73 |
+
"""Load YouTube transcripts into `Document` objects."""
|
74 |
+
try:
|
75 |
+
from youtube_transcript_api import (
|
76 |
+
NoTranscriptFound,
|
77 |
+
TranscriptsDisabled,
|
78 |
+
YouTubeTranscriptApi,
|
79 |
+
)
|
80 |
+
except ImportError:
|
81 |
+
raise ImportError(
|
82 |
+
'Could not import "youtube_transcript_api" Python package. '
|
83 |
+
"Please install it with `pip install youtube-transcript-api`."
|
84 |
+
)
|
85 |
+
|
86 |
+
if self.proxy_url:
|
87 |
+
youtube_proxies = {
|
88 |
+
"http": self.proxy_url,
|
89 |
+
"https": self.proxy_url,
|
90 |
+
}
|
91 |
+
# Don't log complete URL because it might contain secrets
|
92 |
+
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
|
93 |
+
else:
|
94 |
+
youtube_proxies = None
|
95 |
+
|
96 |
+
try:
|
97 |
+
transcript_list = YouTubeTranscriptApi.list_transcripts(
|
98 |
+
self.video_id, proxies=youtube_proxies
|
99 |
+
)
|
100 |
+
except Exception as e:
|
101 |
+
log.exception("Loading YouTube transcript failed")
|
102 |
+
return []
|
103 |
+
|
104 |
+
try:
|
105 |
+
transcript = transcript_list.find_transcript(self.language)
|
106 |
+
except NoTranscriptFound:
|
107 |
+
transcript = transcript_list.find_transcript(["en"])
|
108 |
+
|
109 |
+
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
|
110 |
+
|
111 |
+
transcript = " ".join(
|
112 |
+
map(
|
113 |
+
lambda transcript_piece: transcript_piece["text"].strip(" "),
|
114 |
+
transcript_pieces,
|
115 |
+
)
|
116 |
+
)
|
117 |
+
return [Document(page_content=transcript, metadata=self._metadata)]
|
backend/open_webui/apps/retrieval/main.py
ADDED
@@ -0,0 +1,1494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO: Merge this with the webui_app and make it a single app
|
2 |
+
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import mimetypes
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
+
|
9 |
+
import uuid
|
10 |
+
from datetime import datetime
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import Iterator, Optional, Sequence, Union
|
13 |
+
|
14 |
+
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
|
15 |
+
from fastapi.middleware.cors import CORSMiddleware
|
16 |
+
from pydantic import BaseModel
|
17 |
+
import tiktoken
|
18 |
+
|
19 |
+
|
20 |
+
from open_webui.storage.provider import Storage
|
21 |
+
from open_webui.apps.webui.models.knowledge import Knowledges
|
22 |
+
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
23 |
+
|
24 |
+
# Document loaders
|
25 |
+
from open_webui.apps.retrieval.loaders.main import Loader
|
26 |
+
from open_webui.apps.retrieval.loaders.youtube import YoutubeLoader
|
27 |
+
|
28 |
+
# Web search engines
|
29 |
+
from open_webui.apps.retrieval.web.main import SearchResult
|
30 |
+
from open_webui.apps.retrieval.web.utils import get_web_loader
|
31 |
+
from open_webui.apps.retrieval.web.brave import search_brave
|
32 |
+
from open_webui.apps.retrieval.web.mojeek import search_mojeek
|
33 |
+
from open_webui.apps.retrieval.web.duckduckgo import search_duckduckgo
|
34 |
+
from open_webui.apps.retrieval.web.google_pse import search_google_pse
|
35 |
+
from open_webui.apps.retrieval.web.jina_search import search_jina
|
36 |
+
from open_webui.apps.retrieval.web.searchapi import search_searchapi
|
37 |
+
from open_webui.apps.retrieval.web.searxng import search_searxng
|
38 |
+
from open_webui.apps.retrieval.web.serper import search_serper
|
39 |
+
from open_webui.apps.retrieval.web.serply import search_serply
|
40 |
+
from open_webui.apps.retrieval.web.serpstack import search_serpstack
|
41 |
+
from open_webui.apps.retrieval.web.tavily import search_tavily
|
42 |
+
from open_webui.apps.retrieval.web.bing import search_bing
|
43 |
+
|
44 |
+
|
45 |
+
from open_webui.apps.retrieval.utils import (
|
46 |
+
get_embedding_function,
|
47 |
+
get_model_path,
|
48 |
+
query_collection,
|
49 |
+
query_collection_with_hybrid_search,
|
50 |
+
query_doc,
|
51 |
+
query_doc_with_hybrid_search,
|
52 |
+
)
|
53 |
+
|
54 |
+
from open_webui.apps.webui.models.files import Files
|
55 |
+
from open_webui.config import (
|
56 |
+
BRAVE_SEARCH_API_KEY,
|
57 |
+
MOJEEK_SEARCH_API_KEY,
|
58 |
+
TIKTOKEN_ENCODING_NAME,
|
59 |
+
RAG_TEXT_SPLITTER,
|
60 |
+
CHUNK_OVERLAP,
|
61 |
+
CHUNK_SIZE,
|
62 |
+
CONTENT_EXTRACTION_ENGINE,
|
63 |
+
CORS_ALLOW_ORIGIN,
|
64 |
+
ENABLE_RAG_HYBRID_SEARCH,
|
65 |
+
ENABLE_RAG_LOCAL_WEB_FETCH,
|
66 |
+
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
67 |
+
ENABLE_RAG_WEB_SEARCH,
|
68 |
+
ENV,
|
69 |
+
GOOGLE_PSE_API_KEY,
|
70 |
+
GOOGLE_PSE_ENGINE_ID,
|
71 |
+
PDF_EXTRACT_IMAGES,
|
72 |
+
RAG_EMBEDDING_ENGINE,
|
73 |
+
RAG_EMBEDDING_MODEL,
|
74 |
+
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
75 |
+
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
76 |
+
RAG_EMBEDDING_BATCH_SIZE,
|
77 |
+
RAG_FILE_MAX_COUNT,
|
78 |
+
RAG_FILE_MAX_SIZE,
|
79 |
+
RAG_OPENAI_API_BASE_URL,
|
80 |
+
RAG_OPENAI_API_KEY,
|
81 |
+
RAG_OLLAMA_BASE_URL,
|
82 |
+
RAG_OLLAMA_API_KEY,
|
83 |
+
RAG_RELEVANCE_THRESHOLD,
|
84 |
+
RAG_RERANKING_MODEL,
|
85 |
+
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
86 |
+
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
87 |
+
DEFAULT_RAG_TEMPLATE,
|
88 |
+
RAG_TEMPLATE,
|
89 |
+
RAG_TOP_K,
|
90 |
+
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
91 |
+
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
92 |
+
RAG_WEB_SEARCH_ENGINE,
|
93 |
+
RAG_WEB_SEARCH_RESULT_COUNT,
|
94 |
+
JINA_API_KEY,
|
95 |
+
SEARCHAPI_API_KEY,
|
96 |
+
SEARCHAPI_ENGINE,
|
97 |
+
SEARXNG_QUERY_URL,
|
98 |
+
SERPER_API_KEY,
|
99 |
+
SERPLY_API_KEY,
|
100 |
+
SERPSTACK_API_KEY,
|
101 |
+
SERPSTACK_HTTPS,
|
102 |
+
TAVILY_API_KEY,
|
103 |
+
BING_SEARCH_V7_ENDPOINT,
|
104 |
+
BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
105 |
+
TIKA_SERVER_URL,
|
106 |
+
UPLOAD_DIR,
|
107 |
+
YOUTUBE_LOADER_LANGUAGE,
|
108 |
+
YOUTUBE_LOADER_PROXY_URL,
|
109 |
+
DEFAULT_LOCALE,
|
110 |
+
AppConfig,
|
111 |
+
)
|
112 |
+
from open_webui.constants import ERROR_MESSAGES
|
113 |
+
from open_webui.env import (
|
114 |
+
SRC_LOG_LEVELS,
|
115 |
+
DEVICE_TYPE,
|
116 |
+
DOCKER,
|
117 |
+
)
|
118 |
+
from open_webui.utils.misc import (
|
119 |
+
calculate_sha256,
|
120 |
+
calculate_sha256_string,
|
121 |
+
extract_folders_after_data_docs,
|
122 |
+
sanitize_filename,
|
123 |
+
)
|
124 |
+
from open_webui.utils.utils import get_admin_user, get_verified_user
|
125 |
+
|
126 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
|
127 |
+
from langchain_core.documents import Document
|
128 |
+
|
129 |
+
|
130 |
+
log = logging.getLogger(__name__)
|
131 |
+
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
132 |
+
|
133 |
+
app = FastAPI(
|
134 |
+
docs_url="/docs" if ENV == "dev" else None,
|
135 |
+
openapi_url="/openapi.json" if ENV == "dev" else None,
|
136 |
+
redoc_url=None,
|
137 |
+
)
|
138 |
+
|
139 |
+
app.state.config = AppConfig()
|
140 |
+
|
141 |
+
app.state.config.TOP_K = RAG_TOP_K
|
142 |
+
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
143 |
+
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
|
144 |
+
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
|
145 |
+
|
146 |
+
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
147 |
+
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
148 |
+
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
149 |
+
)
|
150 |
+
|
151 |
+
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
152 |
+
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
153 |
+
|
154 |
+
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
|
155 |
+
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
|
156 |
+
|
157 |
+
app.state.config.CHUNK_SIZE = CHUNK_SIZE
|
158 |
+
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
159 |
+
|
160 |
+
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
161 |
+
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
162 |
+
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
|
163 |
+
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
164 |
+
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
165 |
+
|
166 |
+
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
167 |
+
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
168 |
+
|
169 |
+
app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
|
170 |
+
app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
|
171 |
+
|
172 |
+
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
|
173 |
+
|
174 |
+
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
|
175 |
+
app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
|
176 |
+
app.state.YOUTUBE_LOADER_TRANSLATION = None
|
177 |
+
|
178 |
+
|
179 |
+
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
|
180 |
+
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
|
181 |
+
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
|
182 |
+
|
183 |
+
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
|
184 |
+
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
|
185 |
+
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
|
186 |
+
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
|
187 |
+
app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
|
188 |
+
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
|
189 |
+
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
|
190 |
+
app.state.config.SERPER_API_KEY = SERPER_API_KEY
|
191 |
+
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
|
192 |
+
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
|
193 |
+
app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
|
194 |
+
app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
|
195 |
+
app.state.config.JINA_API_KEY = JINA_API_KEY
|
196 |
+
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
|
197 |
+
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
|
198 |
+
|
199 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
200 |
+
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
201 |
+
|
202 |
+
|
203 |
+
def update_embedding_model(
|
204 |
+
embedding_model: str,
|
205 |
+
auto_update: bool = False,
|
206 |
+
):
|
207 |
+
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
|
208 |
+
from sentence_transformers import SentenceTransformer
|
209 |
+
|
210 |
+
try:
|
211 |
+
app.state.sentence_transformer_ef = SentenceTransformer(
|
212 |
+
get_model_path(embedding_model, auto_update),
|
213 |
+
device=DEVICE_TYPE,
|
214 |
+
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
215 |
+
)
|
216 |
+
except Exception as e:
|
217 |
+
log.debug(f"Error loading SentenceTransformer: {e}")
|
218 |
+
app.state.sentence_transformer_ef = None
|
219 |
+
else:
|
220 |
+
app.state.sentence_transformer_ef = None
|
221 |
+
|
222 |
+
|
223 |
+
def update_reranking_model(
|
224 |
+
reranking_model: str,
|
225 |
+
auto_update: bool = False,
|
226 |
+
):
|
227 |
+
if reranking_model:
|
228 |
+
if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
|
229 |
+
try:
|
230 |
+
from open_webui.apps.retrieval.models.colbert import ColBERT
|
231 |
+
|
232 |
+
app.state.sentence_transformer_rf = ColBERT(
|
233 |
+
get_model_path(reranking_model, auto_update),
|
234 |
+
env="docker" if DOCKER else None,
|
235 |
+
)
|
236 |
+
except Exception as e:
|
237 |
+
log.error(f"ColBERT: {e}")
|
238 |
+
app.state.sentence_transformer_rf = None
|
239 |
+
app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
240 |
+
else:
|
241 |
+
import sentence_transformers
|
242 |
+
|
243 |
+
try:
|
244 |
+
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
245 |
+
get_model_path(reranking_model, auto_update),
|
246 |
+
device=DEVICE_TYPE,
|
247 |
+
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
248 |
+
)
|
249 |
+
except:
|
250 |
+
log.error("CrossEncoder error")
|
251 |
+
app.state.sentence_transformer_rf = None
|
252 |
+
app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
253 |
+
else:
|
254 |
+
app.state.sentence_transformer_rf = None
|
255 |
+
|
256 |
+
|
257 |
+
update_embedding_model(
|
258 |
+
app.state.config.RAG_EMBEDDING_MODEL,
|
259 |
+
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
260 |
+
)
|
261 |
+
|
262 |
+
update_reranking_model(
|
263 |
+
app.state.config.RAG_RERANKING_MODEL,
|
264 |
+
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
265 |
+
)
|
266 |
+
|
267 |
+
|
268 |
+
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
269 |
+
app.state.config.RAG_EMBEDDING_ENGINE,
|
270 |
+
app.state.config.RAG_EMBEDDING_MODEL,
|
271 |
+
app.state.sentence_transformer_ef,
|
272 |
+
(
|
273 |
+
app.state.config.OPENAI_API_BASE_URL
|
274 |
+
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
275 |
+
else app.state.config.OLLAMA_BASE_URL
|
276 |
+
),
|
277 |
+
(
|
278 |
+
app.state.config.OPENAI_API_KEY
|
279 |
+
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
280 |
+
else app.state.config.OLLAMA_API_KEY
|
281 |
+
),
|
282 |
+
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
283 |
+
)
|
284 |
+
|
285 |
+
app.add_middleware(
|
286 |
+
CORSMiddleware,
|
287 |
+
allow_origins=CORS_ALLOW_ORIGIN,
|
288 |
+
allow_credentials=True,
|
289 |
+
allow_methods=["*"],
|
290 |
+
allow_headers=["*"],
|
291 |
+
)
|
292 |
+
|
293 |
+
|
294 |
+
class CollectionNameForm(BaseModel):
|
295 |
+
collection_name: Optional[str] = None
|
296 |
+
|
297 |
+
|
298 |
+
class ProcessUrlForm(CollectionNameForm):
|
299 |
+
url: str
|
300 |
+
|
301 |
+
|
302 |
+
class SearchForm(CollectionNameForm):
|
303 |
+
query: str
|
304 |
+
|
305 |
+
|
306 |
+
@app.get("/")
|
307 |
+
async def get_status():
|
308 |
+
return {
|
309 |
+
"status": True,
|
310 |
+
"chunk_size": app.state.config.CHUNK_SIZE,
|
311 |
+
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
|
312 |
+
"template": app.state.config.RAG_TEMPLATE,
|
313 |
+
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
314 |
+
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
315 |
+
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
|
316 |
+
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
317 |
+
}
|
318 |
+
|
319 |
+
|
320 |
+
@app.get("/embedding")
|
321 |
+
async def get_embedding_config(user=Depends(get_admin_user)):
|
322 |
+
return {
|
323 |
+
"status": True,
|
324 |
+
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
325 |
+
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
326 |
+
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
327 |
+
"openai_config": {
|
328 |
+
"url": app.state.config.OPENAI_API_BASE_URL,
|
329 |
+
"key": app.state.config.OPENAI_API_KEY,
|
330 |
+
},
|
331 |
+
"ollama_config": {
|
332 |
+
"url": app.state.config.OLLAMA_BASE_URL,
|
333 |
+
"key": app.state.config.OLLAMA_API_KEY,
|
334 |
+
},
|
335 |
+
}
|
336 |
+
|
337 |
+
|
338 |
+
@app.get("/reranking")
|
339 |
+
async def get_reraanking_config(user=Depends(get_admin_user)):
|
340 |
+
return {
|
341 |
+
"status": True,
|
342 |
+
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
|
343 |
+
}
|
344 |
+
|
345 |
+
|
346 |
+
class OpenAIConfigForm(BaseModel):
|
347 |
+
url: str
|
348 |
+
key: str
|
349 |
+
|
350 |
+
|
351 |
+
class OllamaConfigForm(BaseModel):
|
352 |
+
url: str
|
353 |
+
key: str
|
354 |
+
|
355 |
+
|
356 |
+
class EmbeddingModelUpdateForm(BaseModel):
|
357 |
+
openai_config: Optional[OpenAIConfigForm] = None
|
358 |
+
ollama_config: Optional[OllamaConfigForm] = None
|
359 |
+
embedding_engine: str
|
360 |
+
embedding_model: str
|
361 |
+
embedding_batch_size: Optional[int] = 1
|
362 |
+
|
363 |
+
|
364 |
+
@app.post("/embedding/update")
|
365 |
+
async def update_embedding_config(
|
366 |
+
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
367 |
+
):
|
368 |
+
log.info(
|
369 |
+
f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
370 |
+
)
|
371 |
+
try:
|
372 |
+
app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
373 |
+
app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
374 |
+
|
375 |
+
if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
376 |
+
if form_data.openai_config is not None:
|
377 |
+
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
378 |
+
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
379 |
+
|
380 |
+
if form_data.ollama_config is not None:
|
381 |
+
app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
|
382 |
+
app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
|
383 |
+
|
384 |
+
app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
|
385 |
+
|
386 |
+
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
387 |
+
|
388 |
+
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
389 |
+
app.state.config.RAG_EMBEDDING_ENGINE,
|
390 |
+
app.state.config.RAG_EMBEDDING_MODEL,
|
391 |
+
app.state.sentence_transformer_ef,
|
392 |
+
(
|
393 |
+
app.state.config.OPENAI_API_BASE_URL
|
394 |
+
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
395 |
+
else app.state.config.OLLAMA_BASE_URL
|
396 |
+
),
|
397 |
+
(
|
398 |
+
app.state.config.OPENAI_API_KEY
|
399 |
+
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
400 |
+
else app.state.config.OLLAMA_API_KEY
|
401 |
+
),
|
402 |
+
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
403 |
+
)
|
404 |
+
|
405 |
+
return {
|
406 |
+
"status": True,
|
407 |
+
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
408 |
+
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
409 |
+
"embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
410 |
+
"openai_config": {
|
411 |
+
"url": app.state.config.OPENAI_API_BASE_URL,
|
412 |
+
"key": app.state.config.OPENAI_API_KEY,
|
413 |
+
},
|
414 |
+
"ollama_config": {
|
415 |
+
"url": app.state.config.OLLAMA_BASE_URL,
|
416 |
+
"key": app.state.config.OLLAMA_API_KEY,
|
417 |
+
},
|
418 |
+
}
|
419 |
+
except Exception as e:
|
420 |
+
log.exception(f"Problem updating embedding model: {e}")
|
421 |
+
raise HTTPException(
|
422 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
423 |
+
detail=ERROR_MESSAGES.DEFAULT(e),
|
424 |
+
)
|
425 |
+
|
426 |
+
|
427 |
+
class RerankingModelUpdateForm(BaseModel):
|
428 |
+
reranking_model: str
|
429 |
+
|
430 |
+
|
431 |
+
@app.post("/reranking/update")
|
432 |
+
async def update_reranking_config(
|
433 |
+
form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
|
434 |
+
):
|
435 |
+
log.info(
|
436 |
+
f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
|
437 |
+
)
|
438 |
+
try:
|
439 |
+
app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
|
440 |
+
|
441 |
+
update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True)
|
442 |
+
|
443 |
+
return {
|
444 |
+
"status": True,
|
445 |
+
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
|
446 |
+
}
|
447 |
+
except Exception as e:
|
448 |
+
log.exception(f"Problem updating reranking model: {e}")
|
449 |
+
raise HTTPException(
|
450 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
451 |
+
detail=ERROR_MESSAGES.DEFAULT(e),
|
452 |
+
)
|
453 |
+
|
454 |
+
|
455 |
+
@app.get("/config")
|
456 |
+
async def get_rag_config(user=Depends(get_admin_user)):
|
457 |
+
return {
|
458 |
+
"status": True,
|
459 |
+
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
|
460 |
+
"content_extraction": {
|
461 |
+
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
|
462 |
+
"tika_server_url": app.state.config.TIKA_SERVER_URL,
|
463 |
+
},
|
464 |
+
"chunk": {
|
465 |
+
"text_splitter": app.state.config.TEXT_SPLITTER,
|
466 |
+
"chunk_size": app.state.config.CHUNK_SIZE,
|
467 |
+
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
|
468 |
+
},
|
469 |
+
"file": {
|
470 |
+
"max_size": app.state.config.FILE_MAX_SIZE,
|
471 |
+
"max_count": app.state.config.FILE_MAX_COUNT,
|
472 |
+
},
|
473 |
+
"youtube": {
|
474 |
+
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
475 |
+
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
476 |
+
"proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
477 |
+
},
|
478 |
+
"web": {
|
479 |
+
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
480 |
+
"search": {
|
481 |
+
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
482 |
+
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
|
483 |
+
"searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
|
484 |
+
"google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
|
485 |
+
"google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
|
486 |
+
"brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
|
487 |
+
"mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY,
|
488 |
+
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
|
489 |
+
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
|
490 |
+
"serper_api_key": app.state.config.SERPER_API_KEY,
|
491 |
+
"serply_api_key": app.state.config.SERPLY_API_KEY,
|
492 |
+
"tavily_api_key": app.state.config.TAVILY_API_KEY,
|
493 |
+
"searchapi_api_key": app.state.config.SEARCHAPI_API_KEY,
|
494 |
+
"seaarchapi_engine": app.state.config.SEARCHAPI_ENGINE,
|
495 |
+
"jina_api_key": app.state.config.JINA_API_KEY,
|
496 |
+
"bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT,
|
497 |
+
"bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
498 |
+
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
499 |
+
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
500 |
+
},
|
501 |
+
},
|
502 |
+
}
|
503 |
+
|
504 |
+
|
505 |
+
class FileConfig(BaseModel):
|
506 |
+
max_size: Optional[int] = None
|
507 |
+
max_count: Optional[int] = None
|
508 |
+
|
509 |
+
|
510 |
+
class ContentExtractionConfig(BaseModel):
|
511 |
+
engine: str = ""
|
512 |
+
tika_server_url: Optional[str] = None
|
513 |
+
|
514 |
+
|
515 |
+
class ChunkParamUpdateForm(BaseModel):
|
516 |
+
text_splitter: Optional[str] = None
|
517 |
+
chunk_size: int
|
518 |
+
chunk_overlap: int
|
519 |
+
|
520 |
+
|
521 |
+
class YoutubeLoaderConfig(BaseModel):
|
522 |
+
language: list[str]
|
523 |
+
translation: Optional[str] = None
|
524 |
+
proxy_url: str = ""
|
525 |
+
|
526 |
+
|
527 |
+
class WebSearchConfig(BaseModel):
|
528 |
+
enabled: bool
|
529 |
+
engine: Optional[str] = None
|
530 |
+
searxng_query_url: Optional[str] = None
|
531 |
+
google_pse_api_key: Optional[str] = None
|
532 |
+
google_pse_engine_id: Optional[str] = None
|
533 |
+
brave_search_api_key: Optional[str] = None
|
534 |
+
mojeek_search_api_key: Optional[str] = None
|
535 |
+
serpstack_api_key: Optional[str] = None
|
536 |
+
serpstack_https: Optional[bool] = None
|
537 |
+
serper_api_key: Optional[str] = None
|
538 |
+
serply_api_key: Optional[str] = None
|
539 |
+
tavily_api_key: Optional[str] = None
|
540 |
+
searchapi_api_key: Optional[str] = None
|
541 |
+
searchapi_engine: Optional[str] = None
|
542 |
+
jina_api_key: Optional[str] = None
|
543 |
+
bing_search_v7_endpoint: Optional[str] = None
|
544 |
+
bing_search_v7_subscription_key: Optional[str] = None
|
545 |
+
result_count: Optional[int] = None
|
546 |
+
concurrent_requests: Optional[int] = None
|
547 |
+
|
548 |
+
|
549 |
+
class WebConfig(BaseModel):
|
550 |
+
search: WebSearchConfig
|
551 |
+
web_loader_ssl_verification: Optional[bool] = None
|
552 |
+
|
553 |
+
|
554 |
+
class ConfigUpdateForm(BaseModel):
|
555 |
+
pdf_extract_images: Optional[bool] = None
|
556 |
+
file: Optional[FileConfig] = None
|
557 |
+
content_extraction: Optional[ContentExtractionConfig] = None
|
558 |
+
chunk: Optional[ChunkParamUpdateForm] = None
|
559 |
+
youtube: Optional[YoutubeLoaderConfig] = None
|
560 |
+
web: Optional[WebConfig] = None
|
561 |
+
|
562 |
+
|
563 |
+
@app.post("/config/update")
|
564 |
+
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
|
565 |
+
app.state.config.PDF_EXTRACT_IMAGES = (
|
566 |
+
form_data.pdf_extract_images
|
567 |
+
if form_data.pdf_extract_images is not None
|
568 |
+
else app.state.config.PDF_EXTRACT_IMAGES
|
569 |
+
)
|
570 |
+
|
571 |
+
if form_data.file is not None:
|
572 |
+
app.state.config.FILE_MAX_SIZE = form_data.file.max_size
|
573 |
+
app.state.config.FILE_MAX_COUNT = form_data.file.max_count
|
574 |
+
|
575 |
+
if form_data.content_extraction is not None:
|
576 |
+
log.info(f"Updating text settings: {form_data.content_extraction}")
|
577 |
+
app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine
|
578 |
+
app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
|
579 |
+
|
580 |
+
if form_data.chunk is not None:
|
581 |
+
app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
|
582 |
+
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
|
583 |
+
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
|
584 |
+
|
585 |
+
if form_data.youtube is not None:
|
586 |
+
app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
|
587 |
+
app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url
|
588 |
+
app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
|
589 |
+
|
590 |
+
if form_data.web is not None:
|
591 |
+
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
592 |
+
# Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
|
593 |
+
form_data.web.web_loader_ssl_verification
|
594 |
+
)
|
595 |
+
|
596 |
+
app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
|
597 |
+
app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
|
598 |
+
app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
|
599 |
+
app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
|
600 |
+
app.state.config.GOOGLE_PSE_ENGINE_ID = (
|
601 |
+
form_data.web.search.google_pse_engine_id
|
602 |
+
)
|
603 |
+
app.state.config.BRAVE_SEARCH_API_KEY = (
|
604 |
+
form_data.web.search.brave_search_api_key
|
605 |
+
)
|
606 |
+
app.state.config.MOJEEK_SEARCH_API_KEY = (
|
607 |
+
form_data.web.search.mojeek_search_api_key
|
608 |
+
)
|
609 |
+
app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
|
610 |
+
app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
|
611 |
+
app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
|
612 |
+
app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
|
613 |
+
app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
|
614 |
+
app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
|
615 |
+
app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
|
616 |
+
|
617 |
+
app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
|
618 |
+
app.state.config.BING_SEARCH_V7_ENDPOINT = (
|
619 |
+
form_data.web.search.bing_search_v7_endpoint
|
620 |
+
)
|
621 |
+
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = (
|
622 |
+
form_data.web.search.bing_search_v7_subscription_key
|
623 |
+
)
|
624 |
+
|
625 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
|
626 |
+
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
627 |
+
form_data.web.search.concurrent_requests
|
628 |
+
)
|
629 |
+
|
630 |
+
return {
|
631 |
+
"status": True,
|
632 |
+
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
|
633 |
+
"file": {
|
634 |
+
"max_size": app.state.config.FILE_MAX_SIZE,
|
635 |
+
"max_count": app.state.config.FILE_MAX_COUNT,
|
636 |
+
},
|
637 |
+
"content_extraction": {
|
638 |
+
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
|
639 |
+
"tika_server_url": app.state.config.TIKA_SERVER_URL,
|
640 |
+
},
|
641 |
+
"chunk": {
|
642 |
+
"text_splitter": app.state.config.TEXT_SPLITTER,
|
643 |
+
"chunk_size": app.state.config.CHUNK_SIZE,
|
644 |
+
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
|
645 |
+
},
|
646 |
+
"youtube": {
|
647 |
+
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
648 |
+
"proxy_url": app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
649 |
+
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
650 |
+
},
|
651 |
+
"web": {
|
652 |
+
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
653 |
+
"search": {
|
654 |
+
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
655 |
+
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
|
656 |
+
"searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
|
657 |
+
"google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
|
658 |
+
"google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
|
659 |
+
"brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
|
660 |
+
"mojeek_search_api_key": app.state.config.MOJEEK_SEARCH_API_KEY,
|
661 |
+
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
|
662 |
+
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
|
663 |
+
"serper_api_key": app.state.config.SERPER_API_KEY,
|
664 |
+
"serply_api_key": app.state.config.SERPLY_API_KEY,
|
665 |
+
"serachapi_api_key": app.state.config.SEARCHAPI_API_KEY,
|
666 |
+
"searchapi_engine": app.state.config.SEARCHAPI_ENGINE,
|
667 |
+
"tavily_api_key": app.state.config.TAVILY_API_KEY,
|
668 |
+
"jina_api_key": app.state.config.JINA_API_KEY,
|
669 |
+
"bing_search_v7_endpoint": app.state.config.BING_SEARCH_V7_ENDPOINT,
|
670 |
+
"bing_search_v7_subscription_key": app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
671 |
+
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
672 |
+
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
673 |
+
},
|
674 |
+
},
|
675 |
+
}
|
676 |
+
|
677 |
+
|
678 |
+
@app.get("/template")
|
679 |
+
async def get_rag_template(user=Depends(get_verified_user)):
|
680 |
+
return {
|
681 |
+
"status": True,
|
682 |
+
"template": app.state.config.RAG_TEMPLATE,
|
683 |
+
}
|
684 |
+
|
685 |
+
|
686 |
+
@app.get("/query/settings")
|
687 |
+
async def get_query_settings(user=Depends(get_admin_user)):
|
688 |
+
return {
|
689 |
+
"status": True,
|
690 |
+
"template": app.state.config.RAG_TEMPLATE,
|
691 |
+
"k": app.state.config.TOP_K,
|
692 |
+
"r": app.state.config.RELEVANCE_THRESHOLD,
|
693 |
+
"hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
694 |
+
}
|
695 |
+
|
696 |
+
|
697 |
+
class QuerySettingsForm(BaseModel):
|
698 |
+
k: Optional[int] = None
|
699 |
+
r: Optional[float] = None
|
700 |
+
template: Optional[str] = None
|
701 |
+
hybrid: Optional[bool] = None
|
702 |
+
|
703 |
+
|
704 |
+
@app.post("/query/settings/update")
|
705 |
+
async def update_query_settings(
|
706 |
+
form_data: QuerySettingsForm, user=Depends(get_admin_user)
|
707 |
+
):
|
708 |
+
app.state.config.RAG_TEMPLATE = form_data.template
|
709 |
+
app.state.config.TOP_K = form_data.k if form_data.k else 4
|
710 |
+
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
|
711 |
+
|
712 |
+
app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
|
713 |
+
form_data.hybrid if form_data.hybrid else False
|
714 |
+
)
|
715 |
+
|
716 |
+
return {
|
717 |
+
"status": True,
|
718 |
+
"template": app.state.config.RAG_TEMPLATE,
|
719 |
+
"k": app.state.config.TOP_K,
|
720 |
+
"r": app.state.config.RELEVANCE_THRESHOLD,
|
721 |
+
"hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
722 |
+
}
|
723 |
+
|
724 |
+
|
725 |
+
####################################
|
726 |
+
#
|
727 |
+
# Document process and retrieval
|
728 |
+
#
|
729 |
+
####################################
|
730 |
+
|
731 |
+
|
732 |
+
def _get_docs_info(docs: list[Document]) -> str:
|
733 |
+
docs_info = set()
|
734 |
+
|
735 |
+
# Trying to select relevant metadata identifying the document.
|
736 |
+
for doc in docs:
|
737 |
+
metadata = getattr(doc, "metadata", {})
|
738 |
+
doc_name = metadata.get("name", "")
|
739 |
+
if not doc_name:
|
740 |
+
doc_name = metadata.get("title", "")
|
741 |
+
if not doc_name:
|
742 |
+
doc_name = metadata.get("source", "")
|
743 |
+
if doc_name:
|
744 |
+
docs_info.add(doc_name)
|
745 |
+
|
746 |
+
return ", ".join(docs_info)
|
747 |
+
|
748 |
+
|
749 |
+
def save_docs_to_vector_db(
|
750 |
+
docs,
|
751 |
+
collection_name,
|
752 |
+
metadata: Optional[dict] = None,
|
753 |
+
overwrite: bool = False,
|
754 |
+
split: bool = True,
|
755 |
+
add: bool = False,
|
756 |
+
) -> bool:
|
757 |
+
log.info(
|
758 |
+
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
|
759 |
+
)
|
760 |
+
|
761 |
+
# Check if entries with the same hash (metadata.hash) already exist
|
762 |
+
if metadata and "hash" in metadata:
|
763 |
+
result = VECTOR_DB_CLIENT.query(
|
764 |
+
collection_name=collection_name,
|
765 |
+
filter={"hash": metadata["hash"]},
|
766 |
+
)
|
767 |
+
|
768 |
+
if result is not None:
|
769 |
+
existing_doc_ids = result.ids[0]
|
770 |
+
if existing_doc_ids:
|
771 |
+
log.info(f"Document with hash {metadata['hash']} already exists")
|
772 |
+
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
|
773 |
+
|
774 |
+
if split:
|
775 |
+
if app.state.config.TEXT_SPLITTER in ["", "character"]:
|
776 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
777 |
+
chunk_size=app.state.config.CHUNK_SIZE,
|
778 |
+
chunk_overlap=app.state.config.CHUNK_OVERLAP,
|
779 |
+
add_start_index=True,
|
780 |
+
)
|
781 |
+
elif app.state.config.TEXT_SPLITTER == "token":
|
782 |
+
log.info(
|
783 |
+
f"Using token text splitter: {app.state.config.TIKTOKEN_ENCODING_NAME}"
|
784 |
+
)
|
785 |
+
|
786 |
+
tiktoken.get_encoding(str(app.state.config.TIKTOKEN_ENCODING_NAME))
|
787 |
+
text_splitter = TokenTextSplitter(
|
788 |
+
encoding_name=str(app.state.config.TIKTOKEN_ENCODING_NAME),
|
789 |
+
chunk_size=app.state.config.CHUNK_SIZE,
|
790 |
+
chunk_overlap=app.state.config.CHUNK_OVERLAP,
|
791 |
+
add_start_index=True,
|
792 |
+
)
|
793 |
+
else:
|
794 |
+
raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
|
795 |
+
|
796 |
+
docs = text_splitter.split_documents(docs)
|
797 |
+
|
798 |
+
if len(docs) == 0:
|
799 |
+
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
|
800 |
+
|
801 |
+
texts = [doc.page_content for doc in docs]
|
802 |
+
metadatas = [
|
803 |
+
{
|
804 |
+
**doc.metadata,
|
805 |
+
**(metadata if metadata else {}),
|
806 |
+
"embedding_config": json.dumps(
|
807 |
+
{
|
808 |
+
"engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
809 |
+
"model": app.state.config.RAG_EMBEDDING_MODEL,
|
810 |
+
}
|
811 |
+
),
|
812 |
+
}
|
813 |
+
for doc in docs
|
814 |
+
]
|
815 |
+
|
816 |
+
# ChromaDB does not like datetime formats
|
817 |
+
# for meta-data so convert them to string.
|
818 |
+
for metadata in metadatas:
|
819 |
+
for key, value in metadata.items():
|
820 |
+
if isinstance(value, datetime):
|
821 |
+
metadata[key] = str(value)
|
822 |
+
|
823 |
+
try:
|
824 |
+
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
|
825 |
+
log.info(f"collection {collection_name} already exists")
|
826 |
+
|
827 |
+
if overwrite:
|
828 |
+
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
|
829 |
+
log.info(f"deleting existing collection {collection_name}")
|
830 |
+
elif add is False:
|
831 |
+
log.info(
|
832 |
+
f"collection {collection_name} already exists, overwrite is False and add is False"
|
833 |
+
)
|
834 |
+
return True
|
835 |
+
|
836 |
+
log.info(f"adding to collection {collection_name}")
|
837 |
+
embedding_function = get_embedding_function(
|
838 |
+
app.state.config.RAG_EMBEDDING_ENGINE,
|
839 |
+
app.state.config.RAG_EMBEDDING_MODEL,
|
840 |
+
app.state.sentence_transformer_ef,
|
841 |
+
(
|
842 |
+
app.state.config.OPENAI_API_BASE_URL
|
843 |
+
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
844 |
+
else app.state.config.OLLAMA_BASE_URL
|
845 |
+
),
|
846 |
+
(
|
847 |
+
app.state.config.OPENAI_API_KEY
|
848 |
+
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
849 |
+
else app.state.config.OLLAMA_API_KEY
|
850 |
+
),
|
851 |
+
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
852 |
+
)
|
853 |
+
|
854 |
+
embeddings = embedding_function(
|
855 |
+
list(map(lambda x: x.replace("\n", " "), texts))
|
856 |
+
)
|
857 |
+
|
858 |
+
items = [
|
859 |
+
{
|
860 |
+
"id": str(uuid.uuid4()),
|
861 |
+
"text": text,
|
862 |
+
"vector": embeddings[idx],
|
863 |
+
"metadata": metadatas[idx],
|
864 |
+
}
|
865 |
+
for idx, text in enumerate(texts)
|
866 |
+
]
|
867 |
+
|
868 |
+
VECTOR_DB_CLIENT.insert(
|
869 |
+
collection_name=collection_name,
|
870 |
+
items=items,
|
871 |
+
)
|
872 |
+
|
873 |
+
return True
|
874 |
+
except Exception as e:
|
875 |
+
log.exception(e)
|
876 |
+
raise e
|
877 |
+
|
878 |
+
|
879 |
+
class ProcessFileForm(BaseModel):
|
880 |
+
file_id: str
|
881 |
+
content: Optional[str] = None
|
882 |
+
collection_name: Optional[str] = None
|
883 |
+
|
884 |
+
|
885 |
+
@app.post("/process/file")
|
886 |
+
def process_file(
|
887 |
+
form_data: ProcessFileForm,
|
888 |
+
user=Depends(get_verified_user),
|
889 |
+
):
|
890 |
+
try:
|
891 |
+
file = Files.get_file_by_id(form_data.file_id)
|
892 |
+
|
893 |
+
collection_name = form_data.collection_name
|
894 |
+
|
895 |
+
if collection_name is None:
|
896 |
+
collection_name = f"file-{file.id}"
|
897 |
+
|
898 |
+
if form_data.content:
|
899 |
+
# Update the content in the file
|
900 |
+
# Usage: /files/{file_id}/data/content/update
|
901 |
+
|
902 |
+
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
|
903 |
+
|
904 |
+
docs = [
|
905 |
+
Document(
|
906 |
+
page_content=form_data.content.replace("<br/>", "\n"),
|
907 |
+
metadata={
|
908 |
+
**file.meta,
|
909 |
+
"name": file.filename,
|
910 |
+
"created_by": file.user_id,
|
911 |
+
"file_id": file.id,
|
912 |
+
"source": file.filename,
|
913 |
+
},
|
914 |
+
)
|
915 |
+
]
|
916 |
+
|
917 |
+
text_content = form_data.content
|
918 |
+
elif form_data.collection_name:
|
919 |
+
# Check if the file has already been processed and save the content
|
920 |
+
# Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
|
921 |
+
|
922 |
+
result = VECTOR_DB_CLIENT.query(
|
923 |
+
collection_name=f"file-{file.id}", filter={"file_id": file.id}
|
924 |
+
)
|
925 |
+
|
926 |
+
if result is not None and len(result.ids[0]) > 0:
|
927 |
+
docs = [
|
928 |
+
Document(
|
929 |
+
page_content=result.documents[0][idx],
|
930 |
+
metadata=result.metadatas[0][idx],
|
931 |
+
)
|
932 |
+
for idx, id in enumerate(result.ids[0])
|
933 |
+
]
|
934 |
+
else:
|
935 |
+
docs = [
|
936 |
+
Document(
|
937 |
+
page_content=file.data.get("content", ""),
|
938 |
+
metadata={
|
939 |
+
**file.meta,
|
940 |
+
"name": file.filename,
|
941 |
+
"created_by": file.user_id,
|
942 |
+
"file_id": file.id,
|
943 |
+
"source": file.filename,
|
944 |
+
},
|
945 |
+
)
|
946 |
+
]
|
947 |
+
|
948 |
+
text_content = file.data.get("content", "")
|
949 |
+
else:
|
950 |
+
# Process the file and save the content
|
951 |
+
# Usage: /files/
|
952 |
+
file_path = file.path
|
953 |
+
if file_path:
|
954 |
+
file_path = Storage.get_file(file_path)
|
955 |
+
loader = Loader(
|
956 |
+
engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
|
957 |
+
TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
|
958 |
+
PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
|
959 |
+
)
|
960 |
+
docs = loader.load(
|
961 |
+
file.filename, file.meta.get("content_type"), file_path
|
962 |
+
)
|
963 |
+
|
964 |
+
docs = [
|
965 |
+
Document(
|
966 |
+
page_content=doc.page_content,
|
967 |
+
metadata={
|
968 |
+
**doc.metadata,
|
969 |
+
"name": file.filename,
|
970 |
+
"created_by": file.user_id,
|
971 |
+
"file_id": file.id,
|
972 |
+
"source": file.filename,
|
973 |
+
},
|
974 |
+
)
|
975 |
+
for doc in docs
|
976 |
+
]
|
977 |
+
else:
|
978 |
+
docs = [
|
979 |
+
Document(
|
980 |
+
page_content=file.data.get("content", ""),
|
981 |
+
metadata={
|
982 |
+
**file.meta,
|
983 |
+
"name": file.filename,
|
984 |
+
"created_by": file.user_id,
|
985 |
+
"file_id": file.id,
|
986 |
+
"source": file.filename,
|
987 |
+
},
|
988 |
+
)
|
989 |
+
]
|
990 |
+
text_content = " ".join([doc.page_content for doc in docs])
|
991 |
+
|
992 |
+
log.debug(f"text_content: {text_content}")
|
993 |
+
Files.update_file_data_by_id(
|
994 |
+
file.id,
|
995 |
+
{"content": text_content},
|
996 |
+
)
|
997 |
+
|
998 |
+
hash = calculate_sha256_string(text_content)
|
999 |
+
Files.update_file_hash_by_id(file.id, hash)
|
1000 |
+
|
1001 |
+
try:
|
1002 |
+
result = save_docs_to_vector_db(
|
1003 |
+
docs=docs,
|
1004 |
+
collection_name=collection_name,
|
1005 |
+
metadata={
|
1006 |
+
"file_id": file.id,
|
1007 |
+
"name": file.filename,
|
1008 |
+
"hash": hash,
|
1009 |
+
},
|
1010 |
+
add=(True if form_data.collection_name else False),
|
1011 |
+
)
|
1012 |
+
|
1013 |
+
if result:
|
1014 |
+
Files.update_file_metadata_by_id(
|
1015 |
+
file.id,
|
1016 |
+
{
|
1017 |
+
"collection_name": collection_name,
|
1018 |
+
},
|
1019 |
+
)
|
1020 |
+
|
1021 |
+
return {
|
1022 |
+
"status": True,
|
1023 |
+
"collection_name": collection_name,
|
1024 |
+
"filename": file.filename,
|
1025 |
+
"content": text_content,
|
1026 |
+
}
|
1027 |
+
except Exception as e:
|
1028 |
+
raise e
|
1029 |
+
except Exception as e:
|
1030 |
+
log.exception(e)
|
1031 |
+
if "No pandoc was found" in str(e):
|
1032 |
+
raise HTTPException(
|
1033 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
1034 |
+
detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
|
1035 |
+
)
|
1036 |
+
else:
|
1037 |
+
raise HTTPException(
|
1038 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
1039 |
+
detail=str(e),
|
1040 |
+
)
|
1041 |
+
|
1042 |
+
|
1043 |
+
class ProcessTextForm(BaseModel):
|
1044 |
+
name: str
|
1045 |
+
content: str
|
1046 |
+
collection_name: Optional[str] = None
|
1047 |
+
|
1048 |
+
|
1049 |
+
@app.post("/process/text")
|
1050 |
+
def process_text(
|
1051 |
+
form_data: ProcessTextForm,
|
1052 |
+
user=Depends(get_verified_user),
|
1053 |
+
):
|
1054 |
+
collection_name = form_data.collection_name
|
1055 |
+
if collection_name is None:
|
1056 |
+
collection_name = calculate_sha256_string(form_data.content)
|
1057 |
+
|
1058 |
+
docs = [
|
1059 |
+
Document(
|
1060 |
+
page_content=form_data.content,
|
1061 |
+
metadata={"name": form_data.name, "created_by": user.id},
|
1062 |
+
)
|
1063 |
+
]
|
1064 |
+
text_content = form_data.content
|
1065 |
+
log.debug(f"text_content: {text_content}")
|
1066 |
+
|
1067 |
+
result = save_docs_to_vector_db(docs, collection_name)
|
1068 |
+
|
1069 |
+
if result:
|
1070 |
+
return {
|
1071 |
+
"status": True,
|
1072 |
+
"collection_name": collection_name,
|
1073 |
+
"content": text_content,
|
1074 |
+
}
|
1075 |
+
else:
|
1076 |
+
raise HTTPException(
|
1077 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
1078 |
+
detail=ERROR_MESSAGES.DEFAULT(),
|
1079 |
+
)
|
1080 |
+
|
1081 |
+
|
1082 |
+
@app.post("/process/youtube")
|
1083 |
+
def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
|
1084 |
+
try:
|
1085 |
+
collection_name = form_data.collection_name
|
1086 |
+
if not collection_name:
|
1087 |
+
collection_name = calculate_sha256_string(form_data.url)[:63]
|
1088 |
+
|
1089 |
+
loader = YoutubeLoader(
|
1090 |
+
form_data.url,
|
1091 |
+
language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
1092 |
+
proxy_url=app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
1093 |
+
)
|
1094 |
+
|
1095 |
+
docs = loader.load()
|
1096 |
+
content = " ".join([doc.page_content for doc in docs])
|
1097 |
+
log.debug(f"text_content: {content}")
|
1098 |
+
save_docs_to_vector_db(docs, collection_name, overwrite=True)
|
1099 |
+
|
1100 |
+
return {
|
1101 |
+
"status": True,
|
1102 |
+
"collection_name": collection_name,
|
1103 |
+
"filename": form_data.url,
|
1104 |
+
"file": {
|
1105 |
+
"data": {
|
1106 |
+
"content": content,
|
1107 |
+
},
|
1108 |
+
"meta": {
|
1109 |
+
"name": form_data.url,
|
1110 |
+
},
|
1111 |
+
},
|
1112 |
+
}
|
1113 |
+
except Exception as e:
|
1114 |
+
log.exception(e)
|
1115 |
+
raise HTTPException(
|
1116 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
1117 |
+
detail=ERROR_MESSAGES.DEFAULT(e),
|
1118 |
+
)
|
1119 |
+
|
1120 |
+
|
1121 |
+
@app.post("/process/web")
|
1122 |
+
def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
|
1123 |
+
try:
|
1124 |
+
collection_name = form_data.collection_name
|
1125 |
+
if not collection_name:
|
1126 |
+
collection_name = calculate_sha256_string(form_data.url)[:63]
|
1127 |
+
|
1128 |
+
loader = get_web_loader(
|
1129 |
+
form_data.url,
|
1130 |
+
verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
1131 |
+
requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
1132 |
+
)
|
1133 |
+
docs = loader.load()
|
1134 |
+
content = " ".join([doc.page_content for doc in docs])
|
1135 |
+
log.debug(f"text_content: {content}")
|
1136 |
+
save_docs_to_vector_db(docs, collection_name, overwrite=True)
|
1137 |
+
|
1138 |
+
return {
|
1139 |
+
"status": True,
|
1140 |
+
"collection_name": collection_name,
|
1141 |
+
"filename": form_data.url,
|
1142 |
+
"file": {
|
1143 |
+
"data": {
|
1144 |
+
"content": content,
|
1145 |
+
},
|
1146 |
+
"meta": {
|
1147 |
+
"name": form_data.url,
|
1148 |
+
},
|
1149 |
+
},
|
1150 |
+
}
|
1151 |
+
except Exception as e:
|
1152 |
+
log.exception(e)
|
1153 |
+
raise HTTPException(
|
1154 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
1155 |
+
detail=ERROR_MESSAGES.DEFAULT(e),
|
1156 |
+
)
|
1157 |
+
|
1158 |
+
|
1159 |
+
def search_web(engine: str, query: str) -> list[SearchResult]:
|
1160 |
+
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
1161 |
+
Will look for a search engine API key in environment variables in the following order:
|
1162 |
+
- SEARXNG_QUERY_URL
|
1163 |
+
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
|
1164 |
+
- BRAVE_SEARCH_API_KEY
|
1165 |
+
- MOJEEK_SEARCH_API_KEY
|
1166 |
+
- SERPSTACK_API_KEY
|
1167 |
+
- SERPER_API_KEY
|
1168 |
+
- SERPLY_API_KEY
|
1169 |
+
- TAVILY_API_KEY
|
1170 |
+
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
|
1171 |
+
Args:
|
1172 |
+
query (str): The query to search for
|
1173 |
+
"""
|
1174 |
+
|
1175 |
+
# TODO: add playwright to search the web
|
1176 |
+
if engine == "searxng":
|
1177 |
+
if app.state.config.SEARXNG_QUERY_URL:
|
1178 |
+
return search_searxng(
|
1179 |
+
app.state.config.SEARXNG_QUERY_URL,
|
1180 |
+
query,
|
1181 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1182 |
+
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
1183 |
+
)
|
1184 |
+
else:
|
1185 |
+
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
|
1186 |
+
elif engine == "google_pse":
|
1187 |
+
if (
|
1188 |
+
app.state.config.GOOGLE_PSE_API_KEY
|
1189 |
+
and app.state.config.GOOGLE_PSE_ENGINE_ID
|
1190 |
+
):
|
1191 |
+
return search_google_pse(
|
1192 |
+
app.state.config.GOOGLE_PSE_API_KEY,
|
1193 |
+
app.state.config.GOOGLE_PSE_ENGINE_ID,
|
1194 |
+
query,
|
1195 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1196 |
+
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
1197 |
+
)
|
1198 |
+
else:
|
1199 |
+
raise Exception(
|
1200 |
+
"No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
|
1201 |
+
)
|
1202 |
+
elif engine == "brave":
|
1203 |
+
if app.state.config.BRAVE_SEARCH_API_KEY:
|
1204 |
+
return search_brave(
|
1205 |
+
app.state.config.BRAVE_SEARCH_API_KEY,
|
1206 |
+
query,
|
1207 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1208 |
+
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
1209 |
+
)
|
1210 |
+
else:
|
1211 |
+
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
|
1212 |
+
elif engine == "mojeek":
|
1213 |
+
if app.state.config.MOJEEK_SEARCH_API_KEY:
|
1214 |
+
return search_mojeek(
|
1215 |
+
app.state.config.MOJEEK_SEARCH_API_KEY,
|
1216 |
+
query,
|
1217 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1218 |
+
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
1219 |
+
)
|
1220 |
+
else:
|
1221 |
+
raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
|
1222 |
+
elif engine == "serpstack":
|
1223 |
+
if app.state.config.SERPSTACK_API_KEY:
|
1224 |
+
return search_serpstack(
|
1225 |
+
app.state.config.SERPSTACK_API_KEY,
|
1226 |
+
query,
|
1227 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1228 |
+
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
1229 |
+
https_enabled=app.state.config.SERPSTACK_HTTPS,
|
1230 |
+
)
|
1231 |
+
else:
|
1232 |
+
raise Exception("No SERPSTACK_API_KEY found in environment variables")
|
1233 |
+
elif engine == "serper":
|
1234 |
+
if app.state.config.SERPER_API_KEY:
|
1235 |
+
return search_serper(
|
1236 |
+
app.state.config.SERPER_API_KEY,
|
1237 |
+
query,
|
1238 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1239 |
+
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
1240 |
+
)
|
1241 |
+
else:
|
1242 |
+
raise Exception("No SERPER_API_KEY found in environment variables")
|
1243 |
+
elif engine == "serply":
|
1244 |
+
if app.state.config.SERPLY_API_KEY:
|
1245 |
+
return search_serply(
|
1246 |
+
app.state.config.SERPLY_API_KEY,
|
1247 |
+
query,
|
1248 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1249 |
+
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
1250 |
+
)
|
1251 |
+
else:
|
1252 |
+
raise Exception("No SERPLY_API_KEY found in environment variables")
|
1253 |
+
elif engine == "duckduckgo":
|
1254 |
+
return search_duckduckgo(
|
1255 |
+
query,
|
1256 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1257 |
+
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
1258 |
+
)
|
1259 |
+
elif engine == "tavily":
|
1260 |
+
if app.state.config.TAVILY_API_KEY:
|
1261 |
+
return search_tavily(
|
1262 |
+
app.state.config.TAVILY_API_KEY,
|
1263 |
+
query,
|
1264 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1265 |
+
)
|
1266 |
+
else:
|
1267 |
+
raise Exception("No TAVILY_API_KEY found in environment variables")
|
1268 |
+
elif engine == "searchapi":
|
1269 |
+
if app.state.config.SEARCHAPI_API_KEY:
|
1270 |
+
return search_searchapi(
|
1271 |
+
app.state.config.SEARCHAPI_API_KEY,
|
1272 |
+
app.state.config.SEARCHAPI_ENGINE,
|
1273 |
+
query,
|
1274 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1275 |
+
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
1276 |
+
)
|
1277 |
+
else:
|
1278 |
+
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
|
1279 |
+
elif engine == "jina":
|
1280 |
+
return search_jina(
|
1281 |
+
app.state.config.JINA_API_KEY,
|
1282 |
+
query,
|
1283 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1284 |
+
)
|
1285 |
+
elif engine == "bing":
|
1286 |
+
return search_bing(
|
1287 |
+
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
1288 |
+
app.state.config.BING_SEARCH_V7_ENDPOINT,
|
1289 |
+
str(DEFAULT_LOCALE),
|
1290 |
+
query,
|
1291 |
+
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
1292 |
+
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
1293 |
+
)
|
1294 |
+
else:
|
1295 |
+
raise Exception("No search engine API key found in environment variables")
|
1296 |
+
|
1297 |
+
|
1298 |
+
@app.post("/process/web/search")
|
1299 |
+
def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
|
1300 |
+
try:
|
1301 |
+
logging.info(
|
1302 |
+
f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
|
1303 |
+
)
|
1304 |
+
web_results = search_web(
|
1305 |
+
app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
|
1306 |
+
)
|
1307 |
+
except Exception as e:
|
1308 |
+
log.exception(e)
|
1309 |
+
|
1310 |
+
print(e)
|
1311 |
+
raise HTTPException(
|
1312 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
1313 |
+
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
|
1314 |
+
)
|
1315 |
+
|
1316 |
+
try:
|
1317 |
+
collection_name = form_data.collection_name
|
1318 |
+
if collection_name == "":
|
1319 |
+
collection_name = calculate_sha256_string(form_data.query)[:63]
|
1320 |
+
|
1321 |
+
urls = [result.link for result in web_results]
|
1322 |
+
|
1323 |
+
loader = get_web_loader(
|
1324 |
+
urls,
|
1325 |
+
verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
1326 |
+
requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
1327 |
+
)
|
1328 |
+
docs = loader.aload()
|
1329 |
+
|
1330 |
+
save_docs_to_vector_db(docs, collection_name, overwrite=True)
|
1331 |
+
|
1332 |
+
return {
|
1333 |
+
"status": True,
|
1334 |
+
"collection_name": collection_name,
|
1335 |
+
"filenames": urls,
|
1336 |
+
}
|
1337 |
+
except Exception as e:
|
1338 |
+
log.exception(e)
|
1339 |
+
raise HTTPException(
|
1340 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
1341 |
+
detail=ERROR_MESSAGES.DEFAULT(e),
|
1342 |
+
)
|
1343 |
+
|
1344 |
+
|
1345 |
+
class QueryDocForm(BaseModel):
|
1346 |
+
collection_name: str
|
1347 |
+
query: str
|
1348 |
+
k: Optional[int] = None
|
1349 |
+
r: Optional[float] = None
|
1350 |
+
hybrid: Optional[bool] = None
|
1351 |
+
|
1352 |
+
|
1353 |
+
@app.post("/query/doc")
|
1354 |
+
def query_doc_handler(
|
1355 |
+
form_data: QueryDocForm,
|
1356 |
+
user=Depends(get_verified_user),
|
1357 |
+
):
|
1358 |
+
try:
|
1359 |
+
if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
1360 |
+
return query_doc_with_hybrid_search(
|
1361 |
+
collection_name=form_data.collection_name,
|
1362 |
+
query=form_data.query,
|
1363 |
+
embedding_function=app.state.EMBEDDING_FUNCTION,
|
1364 |
+
k=form_data.k if form_data.k else app.state.config.TOP_K,
|
1365 |
+
reranking_function=app.state.sentence_transformer_rf,
|
1366 |
+
r=(
|
1367 |
+
form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
|
1368 |
+
),
|
1369 |
+
)
|
1370 |
+
else:
|
1371 |
+
return query_doc(
|
1372 |
+
collection_name=form_data.collection_name,
|
1373 |
+
query=form_data.query,
|
1374 |
+
embedding_function=app.state.EMBEDDING_FUNCTION,
|
1375 |
+
k=form_data.k if form_data.k else app.state.config.TOP_K,
|
1376 |
+
)
|
1377 |
+
except Exception as e:
|
1378 |
+
log.exception(e)
|
1379 |
+
raise HTTPException(
|
1380 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
1381 |
+
detail=ERROR_MESSAGES.DEFAULT(e),
|
1382 |
+
)
|
1383 |
+
|
1384 |
+
|
1385 |
+
class QueryCollectionsForm(BaseModel):
|
1386 |
+
collection_names: list[str]
|
1387 |
+
query: str
|
1388 |
+
k: Optional[int] = None
|
1389 |
+
r: Optional[float] = None
|
1390 |
+
hybrid: Optional[bool] = None
|
1391 |
+
|
1392 |
+
|
1393 |
+
@app.post("/query/collection")
|
1394 |
+
def query_collection_handler(
|
1395 |
+
form_data: QueryCollectionsForm,
|
1396 |
+
user=Depends(get_verified_user),
|
1397 |
+
):
|
1398 |
+
try:
|
1399 |
+
if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
1400 |
+
return query_collection_with_hybrid_search(
|
1401 |
+
collection_names=form_data.collection_names,
|
1402 |
+
queries=[form_data.query],
|
1403 |
+
embedding_function=app.state.EMBEDDING_FUNCTION,
|
1404 |
+
k=form_data.k if form_data.k else app.state.config.TOP_K,
|
1405 |
+
reranking_function=app.state.sentence_transformer_rf,
|
1406 |
+
r=(
|
1407 |
+
form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
|
1408 |
+
),
|
1409 |
+
)
|
1410 |
+
else:
|
1411 |
+
return query_collection(
|
1412 |
+
collection_names=form_data.collection_names,
|
1413 |
+
queries=[form_data.query],
|
1414 |
+
embedding_function=app.state.EMBEDDING_FUNCTION,
|
1415 |
+
k=form_data.k if form_data.k else app.state.config.TOP_K,
|
1416 |
+
)
|
1417 |
+
|
1418 |
+
except Exception as e:
|
1419 |
+
log.exception(e)
|
1420 |
+
raise HTTPException(
|
1421 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
1422 |
+
detail=ERROR_MESSAGES.DEFAULT(e),
|
1423 |
+
)
|
1424 |
+
|
1425 |
+
|
1426 |
+
####################################
|
1427 |
+
#
|
1428 |
+
# Vector DB operations
|
1429 |
+
#
|
1430 |
+
####################################
|
1431 |
+
|
1432 |
+
|
1433 |
+
class DeleteForm(BaseModel):
|
1434 |
+
collection_name: str
|
1435 |
+
file_id: str
|
1436 |
+
|
1437 |
+
|
1438 |
+
@app.post("/delete")
|
1439 |
+
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
|
1440 |
+
try:
|
1441 |
+
if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
|
1442 |
+
file = Files.get_file_by_id(form_data.file_id)
|
1443 |
+
hash = file.hash
|
1444 |
+
|
1445 |
+
VECTOR_DB_CLIENT.delete(
|
1446 |
+
collection_name=form_data.collection_name,
|
1447 |
+
metadata={"hash": hash},
|
1448 |
+
)
|
1449 |
+
return {"status": True}
|
1450 |
+
else:
|
1451 |
+
return {"status": False}
|
1452 |
+
except Exception as e:
|
1453 |
+
log.exception(e)
|
1454 |
+
return {"status": False}
|
1455 |
+
|
1456 |
+
|
1457 |
+
@app.post("/reset/db")
|
1458 |
+
def reset_vector_db(user=Depends(get_admin_user)):
|
1459 |
+
VECTOR_DB_CLIENT.reset()
|
1460 |
+
Knowledges.delete_all_knowledge()
|
1461 |
+
|
1462 |
+
|
1463 |
+
@app.post("/reset/uploads")
|
1464 |
+
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
|
1465 |
+
folder = f"{UPLOAD_DIR}"
|
1466 |
+
try:
|
1467 |
+
# Check if the directory exists
|
1468 |
+
if os.path.exists(folder):
|
1469 |
+
# Iterate over all the files and directories in the specified directory
|
1470 |
+
for filename in os.listdir(folder):
|
1471 |
+
file_path = os.path.join(folder, filename)
|
1472 |
+
try:
|
1473 |
+
if os.path.isfile(file_path) or os.path.islink(file_path):
|
1474 |
+
os.unlink(file_path) # Remove the file or link
|
1475 |
+
elif os.path.isdir(file_path):
|
1476 |
+
shutil.rmtree(file_path) # Remove the directory
|
1477 |
+
except Exception as e:
|
1478 |
+
print(f"Failed to delete {file_path}. Reason: {e}")
|
1479 |
+
else:
|
1480 |
+
print(f"The directory {folder} does not exist")
|
1481 |
+
except Exception as e:
|
1482 |
+
print(f"Failed to process the directory {folder}. Reason: {e}")
|
1483 |
+
return True
|
1484 |
+
|
1485 |
+
|
1486 |
+
if ENV == "dev":
|
1487 |
+
|
1488 |
+
@app.get("/ef")
|
1489 |
+
async def get_embeddings():
|
1490 |
+
return {"result": app.state.EMBEDDING_FUNCTION("hello world")}
|
1491 |
+
|
1492 |
+
@app.get("/ef/{text}")
|
1493 |
+
async def get_embeddings_text(text: str):
|
1494 |
+
return {"result": app.state.EMBEDDING_FUNCTION(text)}
|
backend/open_webui/apps/retrieval/models/colbert.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from colbert.infra import ColBERTConfig
|
5 |
+
from colbert.modeling.checkpoint import Checkpoint
|
6 |
+
|
7 |
+
|
8 |
+
class ColBERT:
|
9 |
+
def __init__(self, name, **kwargs) -> None:
|
10 |
+
print("ColBERT: Loading model", name)
|
11 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
|
13 |
+
DOCKER = kwargs.get("env") == "docker"
|
14 |
+
if DOCKER:
|
15 |
+
# This is a workaround for the issue with the docker container
|
16 |
+
# where the torch extension is not loaded properly
|
17 |
+
# and the following error is thrown:
|
18 |
+
# /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
|
19 |
+
|
20 |
+
lock_file = (
|
21 |
+
"/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
|
22 |
+
)
|
23 |
+
if os.path.exists(lock_file):
|
24 |
+
os.remove(lock_file)
|
25 |
+
|
26 |
+
self.ckpt = Checkpoint(
|
27 |
+
name,
|
28 |
+
colbert_config=ColBERTConfig(model_name=name),
|
29 |
+
).to(self.device)
|
30 |
+
pass
|
31 |
+
|
32 |
+
def calculate_similarity_scores(self, query_embeddings, document_embeddings):
|
33 |
+
|
34 |
+
query_embeddings = query_embeddings.to(self.device)
|
35 |
+
document_embeddings = document_embeddings.to(self.device)
|
36 |
+
|
37 |
+
# Validate dimensions to ensure compatibility
|
38 |
+
if query_embeddings.dim() != 3:
|
39 |
+
raise ValueError(
|
40 |
+
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
|
41 |
+
)
|
42 |
+
if document_embeddings.dim() != 3:
|
43 |
+
raise ValueError(
|
44 |
+
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
|
45 |
+
)
|
46 |
+
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
|
47 |
+
raise ValueError(
|
48 |
+
"There should be either one query or queries equal to the number of documents."
|
49 |
+
)
|
50 |
+
|
51 |
+
# Transpose the query embeddings to align for matrix multiplication
|
52 |
+
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
|
53 |
+
# Compute similarity scores using batch matrix multiplication
|
54 |
+
computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings)
|
55 |
+
# Apply max pooling to extract the highest semantic similarity across each document's sequence
|
56 |
+
maximum_scores = torch.max(computed_scores, dim=1).values
|
57 |
+
|
58 |
+
# Sum up the maximum scores across features to get the overall document relevance scores
|
59 |
+
final_scores = maximum_scores.sum(dim=1)
|
60 |
+
|
61 |
+
normalized_scores = torch.softmax(final_scores, dim=0)
|
62 |
+
|
63 |
+
return normalized_scores.detach().cpu().numpy().astype(np.float32)
|
64 |
+
|
65 |
+
def predict(self, sentences):
|
66 |
+
|
67 |
+
query = sentences[0][0]
|
68 |
+
docs = [i[1] for i in sentences]
|
69 |
+
|
70 |
+
# Embedding the documents
|
71 |
+
embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
|
72 |
+
# Embedding the queries
|
73 |
+
embedded_queries = self.ckpt.queryFromText([query], bsize=32)
|
74 |
+
embedded_query = embedded_queries[0]
|
75 |
+
|
76 |
+
# Calculate retrieval scores for the query against all documents
|
77 |
+
scores = self.calculate_similarity_scores(
|
78 |
+
embedded_query.unsqueeze(0), embedded_docs
|
79 |
+
)
|
80 |
+
|
81 |
+
return scores
|
backend/open_webui/apps/retrieval/utils.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import uuid
|
4 |
+
from typing import Optional, Union
|
5 |
+
|
6 |
+
import asyncio
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
11 |
+
from langchain_community.retrievers import BM25Retriever
|
12 |
+
from langchain_core.documents import Document
|
13 |
+
|
14 |
+
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
15 |
+
from open_webui.utils.misc import get_last_user_message
|
16 |
+
|
17 |
+
from open_webui.env import SRC_LOG_LEVELS
|
18 |
+
|
19 |
+
log = logging.getLogger(__name__)
|
20 |
+
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
21 |
+
|
22 |
+
|
23 |
+
from typing import Any
|
24 |
+
|
25 |
+
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
26 |
+
from langchain_core.retrievers import BaseRetriever
|
27 |
+
|
28 |
+
|
29 |
+
class VectorSearchRetriever(BaseRetriever):
|
30 |
+
collection_name: Any
|
31 |
+
embedding_function: Any
|
32 |
+
top_k: int
|
33 |
+
|
34 |
+
def _get_relevant_documents(
|
35 |
+
self,
|
36 |
+
query: str,
|
37 |
+
*,
|
38 |
+
run_manager: CallbackManagerForRetrieverRun,
|
39 |
+
) -> list[Document]:
|
40 |
+
result = VECTOR_DB_CLIENT.search(
|
41 |
+
collection_name=self.collection_name,
|
42 |
+
vectors=[self.embedding_function(query)],
|
43 |
+
limit=self.top_k,
|
44 |
+
)
|
45 |
+
|
46 |
+
ids = result.ids[0]
|
47 |
+
metadatas = result.metadatas[0]
|
48 |
+
documents = result.documents[0]
|
49 |
+
|
50 |
+
results = []
|
51 |
+
for idx in range(len(ids)):
|
52 |
+
results.append(
|
53 |
+
Document(
|
54 |
+
metadata=metadatas[idx],
|
55 |
+
page_content=documents[idx],
|
56 |
+
)
|
57 |
+
)
|
58 |
+
return results
|
59 |
+
|
60 |
+
|
61 |
+
def query_doc(
|
62 |
+
collection_name: str,
|
63 |
+
query_embedding: list[float],
|
64 |
+
k: int,
|
65 |
+
):
|
66 |
+
try:
|
67 |
+
result = VECTOR_DB_CLIENT.search(
|
68 |
+
collection_name=collection_name,
|
69 |
+
vectors=[query_embedding],
|
70 |
+
limit=k,
|
71 |
+
)
|
72 |
+
|
73 |
+
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
74 |
+
return result
|
75 |
+
except Exception as e:
|
76 |
+
print(e)
|
77 |
+
raise e
|
78 |
+
|
79 |
+
|
80 |
+
def query_doc_with_hybrid_search(
|
81 |
+
collection_name: str,
|
82 |
+
query: str,
|
83 |
+
embedding_function,
|
84 |
+
k: int,
|
85 |
+
reranking_function,
|
86 |
+
r: float,
|
87 |
+
) -> dict:
|
88 |
+
try:
|
89 |
+
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
90 |
+
|
91 |
+
bm25_retriever = BM25Retriever.from_texts(
|
92 |
+
texts=result.documents[0],
|
93 |
+
metadatas=result.metadatas[0],
|
94 |
+
)
|
95 |
+
bm25_retriever.k = k
|
96 |
+
|
97 |
+
vector_search_retriever = VectorSearchRetriever(
|
98 |
+
collection_name=collection_name,
|
99 |
+
embedding_function=embedding_function,
|
100 |
+
top_k=k,
|
101 |
+
)
|
102 |
+
|
103 |
+
ensemble_retriever = EnsembleRetriever(
|
104 |
+
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
|
105 |
+
)
|
106 |
+
compressor = RerankCompressor(
|
107 |
+
embedding_function=embedding_function,
|
108 |
+
top_n=k,
|
109 |
+
reranking_function=reranking_function,
|
110 |
+
r_score=r,
|
111 |
+
)
|
112 |
+
|
113 |
+
compression_retriever = ContextualCompressionRetriever(
|
114 |
+
base_compressor=compressor, base_retriever=ensemble_retriever
|
115 |
+
)
|
116 |
+
|
117 |
+
result = compression_retriever.invoke(query)
|
118 |
+
result = {
|
119 |
+
"distances": [[d.metadata.get("score") for d in result]],
|
120 |
+
"documents": [[d.page_content for d in result]],
|
121 |
+
"metadatas": [[d.metadata for d in result]],
|
122 |
+
}
|
123 |
+
|
124 |
+
log.info(
|
125 |
+
"query_doc_with_hybrid_search:result "
|
126 |
+
+ f'{result["metadatas"]} {result["distances"]}'
|
127 |
+
)
|
128 |
+
return result
|
129 |
+
except Exception as e:
|
130 |
+
raise e
|
131 |
+
|
132 |
+
|
133 |
+
def merge_and_sort_query_results(
|
134 |
+
query_results: list[dict], k: int, reverse: bool = False
|
135 |
+
) -> list[dict]:
|
136 |
+
# Initialize lists to store combined data
|
137 |
+
combined_distances = []
|
138 |
+
combined_documents = []
|
139 |
+
combined_metadatas = []
|
140 |
+
|
141 |
+
for data in query_results:
|
142 |
+
combined_distances.extend(data["distances"][0])
|
143 |
+
combined_documents.extend(data["documents"][0])
|
144 |
+
combined_metadatas.extend(data["metadatas"][0])
|
145 |
+
|
146 |
+
# Create a list of tuples (distance, document, metadata)
|
147 |
+
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
|
148 |
+
|
149 |
+
# Sort the list based on distances
|
150 |
+
combined.sort(key=lambda x: x[0], reverse=reverse)
|
151 |
+
|
152 |
+
# We don't have anything :-(
|
153 |
+
if not combined:
|
154 |
+
sorted_distances = []
|
155 |
+
sorted_documents = []
|
156 |
+
sorted_metadatas = []
|
157 |
+
else:
|
158 |
+
# Unzip the sorted list
|
159 |
+
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
|
160 |
+
|
161 |
+
# Slicing the lists to include only k elements
|
162 |
+
sorted_distances = list(sorted_distances)[:k]
|
163 |
+
sorted_documents = list(sorted_documents)[:k]
|
164 |
+
sorted_metadatas = list(sorted_metadatas)[:k]
|
165 |
+
|
166 |
+
# Create the output dictionary
|
167 |
+
result = {
|
168 |
+
"distances": [sorted_distances],
|
169 |
+
"documents": [sorted_documents],
|
170 |
+
"metadatas": [sorted_metadatas],
|
171 |
+
}
|
172 |
+
|
173 |
+
return result
|
174 |
+
|
175 |
+
|
176 |
+
def query_collection(
|
177 |
+
collection_names: list[str],
|
178 |
+
queries: list[str],
|
179 |
+
embedding_function,
|
180 |
+
k: int,
|
181 |
+
) -> dict:
|
182 |
+
results = []
|
183 |
+
for query in queries:
|
184 |
+
query_embedding = embedding_function(query)
|
185 |
+
for collection_name in collection_names:
|
186 |
+
if collection_name:
|
187 |
+
try:
|
188 |
+
result = query_doc(
|
189 |
+
collection_name=collection_name,
|
190 |
+
k=k,
|
191 |
+
query_embedding=query_embedding,
|
192 |
+
)
|
193 |
+
if result is not None:
|
194 |
+
results.append(result.model_dump())
|
195 |
+
except Exception as e:
|
196 |
+
log.exception(f"Error when querying the collection: {e}")
|
197 |
+
else:
|
198 |
+
pass
|
199 |
+
|
200 |
+
return merge_and_sort_query_results(results, k=k)
|
201 |
+
|
202 |
+
|
203 |
+
def query_collection_with_hybrid_search(
|
204 |
+
collection_names: list[str],
|
205 |
+
queries: list[str],
|
206 |
+
embedding_function,
|
207 |
+
k: int,
|
208 |
+
reranking_function,
|
209 |
+
r: float,
|
210 |
+
) -> dict:
|
211 |
+
results = []
|
212 |
+
error = False
|
213 |
+
for collection_name in collection_names:
|
214 |
+
try:
|
215 |
+
for query in queries:
|
216 |
+
result = query_doc_with_hybrid_search(
|
217 |
+
collection_name=collection_name,
|
218 |
+
query=query,
|
219 |
+
embedding_function=embedding_function,
|
220 |
+
k=k,
|
221 |
+
reranking_function=reranking_function,
|
222 |
+
r=r,
|
223 |
+
)
|
224 |
+
results.append(result)
|
225 |
+
except Exception as e:
|
226 |
+
log.exception(
|
227 |
+
"Error when querying the collection with " f"hybrid_search: {e}"
|
228 |
+
)
|
229 |
+
error = True
|
230 |
+
|
231 |
+
if error:
|
232 |
+
raise Exception(
|
233 |
+
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
|
234 |
+
)
|
235 |
+
|
236 |
+
return merge_and_sort_query_results(results, k=k, reverse=True)
|
237 |
+
|
238 |
+
|
239 |
+
def get_embedding_function(
|
240 |
+
embedding_engine,
|
241 |
+
embedding_model,
|
242 |
+
embedding_function,
|
243 |
+
url,
|
244 |
+
key,
|
245 |
+
embedding_batch_size,
|
246 |
+
):
|
247 |
+
if embedding_engine == "":
|
248 |
+
return lambda query: embedding_function.encode(query).tolist()
|
249 |
+
elif embedding_engine in ["ollama", "openai"]:
|
250 |
+
func = lambda query: generate_embeddings(
|
251 |
+
engine=embedding_engine,
|
252 |
+
model=embedding_model,
|
253 |
+
text=query,
|
254 |
+
url=url,
|
255 |
+
key=key,
|
256 |
+
)
|
257 |
+
|
258 |
+
def generate_multiple(query, func):
|
259 |
+
if isinstance(query, list):
|
260 |
+
embeddings = []
|
261 |
+
for i in range(0, len(query), embedding_batch_size):
|
262 |
+
embeddings.extend(func(query[i : i + embedding_batch_size]))
|
263 |
+
return embeddings
|
264 |
+
else:
|
265 |
+
return func(query)
|
266 |
+
|
267 |
+
return lambda query: generate_multiple(query, func)
|
268 |
+
|
269 |
+
|
270 |
+
def get_sources_from_files(
|
271 |
+
files,
|
272 |
+
queries,
|
273 |
+
embedding_function,
|
274 |
+
k,
|
275 |
+
reranking_function,
|
276 |
+
r,
|
277 |
+
hybrid_search,
|
278 |
+
):
|
279 |
+
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
|
280 |
+
|
281 |
+
extracted_collections = []
|
282 |
+
relevant_contexts = []
|
283 |
+
|
284 |
+
for file in files:
|
285 |
+
if file.get("context") == "full":
|
286 |
+
context = {
|
287 |
+
"documents": [[file.get("file").get("data", {}).get("content")]],
|
288 |
+
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
|
289 |
+
}
|
290 |
+
else:
|
291 |
+
context = None
|
292 |
+
|
293 |
+
collection_names = []
|
294 |
+
if file.get("type") == "collection":
|
295 |
+
if file.get("legacy"):
|
296 |
+
collection_names = file.get("collection_names", [])
|
297 |
+
else:
|
298 |
+
collection_names.append(file["id"])
|
299 |
+
elif file.get("collection_name"):
|
300 |
+
collection_names.append(file["collection_name"])
|
301 |
+
elif file.get("id"):
|
302 |
+
if file.get("legacy"):
|
303 |
+
collection_names.append(f"{file['id']}")
|
304 |
+
else:
|
305 |
+
collection_names.append(f"file-{file['id']}")
|
306 |
+
|
307 |
+
collection_names = set(collection_names).difference(extracted_collections)
|
308 |
+
if not collection_names:
|
309 |
+
log.debug(f"skipping {file} as it has already been extracted")
|
310 |
+
continue
|
311 |
+
|
312 |
+
try:
|
313 |
+
context = None
|
314 |
+
if file.get("type") == "text":
|
315 |
+
context = file["content"]
|
316 |
+
else:
|
317 |
+
if hybrid_search:
|
318 |
+
try:
|
319 |
+
context = query_collection_with_hybrid_search(
|
320 |
+
collection_names=collection_names,
|
321 |
+
queries=queries,
|
322 |
+
embedding_function=embedding_function,
|
323 |
+
k=k,
|
324 |
+
reranking_function=reranking_function,
|
325 |
+
r=r,
|
326 |
+
)
|
327 |
+
except Exception as e:
|
328 |
+
log.debug(
|
329 |
+
"Error when using hybrid search, using"
|
330 |
+
" non hybrid search as fallback."
|
331 |
+
)
|
332 |
+
|
333 |
+
if (not hybrid_search) or (context is None):
|
334 |
+
context = query_collection(
|
335 |
+
collection_names=collection_names,
|
336 |
+
queries=queries,
|
337 |
+
embedding_function=embedding_function,
|
338 |
+
k=k,
|
339 |
+
)
|
340 |
+
except Exception as e:
|
341 |
+
log.exception(e)
|
342 |
+
|
343 |
+
extracted_collections.extend(collection_names)
|
344 |
+
|
345 |
+
if context:
|
346 |
+
if "data" in file:
|
347 |
+
del file["data"]
|
348 |
+
relevant_contexts.append({**context, "file": file})
|
349 |
+
|
350 |
+
sources = []
|
351 |
+
for context in relevant_contexts:
|
352 |
+
try:
|
353 |
+
if "documents" in context:
|
354 |
+
if "metadatas" in context:
|
355 |
+
source = {
|
356 |
+
"source": context["file"],
|
357 |
+
"document": context["documents"][0],
|
358 |
+
"metadata": context["metadatas"][0],
|
359 |
+
}
|
360 |
+
if "distances" in context and context["distances"]:
|
361 |
+
source["distances"] = context["distances"][0]
|
362 |
+
|
363 |
+
sources.append(source)
|
364 |
+
except Exception as e:
|
365 |
+
log.exception(e)
|
366 |
+
|
367 |
+
return sources
|
368 |
+
|
369 |
+
|
370 |
+
def get_model_path(model: str, update_model: bool = False):
|
371 |
+
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
|
372 |
+
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
|
373 |
+
|
374 |
+
local_files_only = not update_model
|
375 |
+
|
376 |
+
snapshot_kwargs = {
|
377 |
+
"cache_dir": cache_dir,
|
378 |
+
"local_files_only": local_files_only,
|
379 |
+
}
|
380 |
+
|
381 |
+
log.debug(f"model: {model}")
|
382 |
+
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
|
383 |
+
|
384 |
+
# Inspiration from upstream sentence_transformers
|
385 |
+
if (
|
386 |
+
os.path.exists(model)
|
387 |
+
or ("\\" in model or model.count("/") > 1)
|
388 |
+
and local_files_only
|
389 |
+
):
|
390 |
+
# If fully qualified path exists, return input, else set repo_id
|
391 |
+
return model
|
392 |
+
elif "/" not in model:
|
393 |
+
# Set valid repo_id for model short-name
|
394 |
+
model = "sentence-transformers" + "/" + model
|
395 |
+
|
396 |
+
snapshot_kwargs["repo_id"] = model
|
397 |
+
|
398 |
+
# Attempt to query the huggingface_hub library to determine the local path and/or to update
|
399 |
+
try:
|
400 |
+
model_repo_path = snapshot_download(**snapshot_kwargs)
|
401 |
+
log.debug(f"model_repo_path: {model_repo_path}")
|
402 |
+
return model_repo_path
|
403 |
+
except Exception as e:
|
404 |
+
log.exception(f"Cannot determine model snapshot path: {e}")
|
405 |
+
return model
|
406 |
+
|
407 |
+
|
408 |
+
def generate_openai_batch_embeddings(
|
409 |
+
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
|
410 |
+
) -> Optional[list[list[float]]]:
|
411 |
+
try:
|
412 |
+
r = requests.post(
|
413 |
+
f"{url}/embeddings",
|
414 |
+
headers={
|
415 |
+
"Content-Type": "application/json",
|
416 |
+
"Authorization": f"Bearer {key}",
|
417 |
+
},
|
418 |
+
json={"input": texts, "model": model},
|
419 |
+
)
|
420 |
+
r.raise_for_status()
|
421 |
+
data = r.json()
|
422 |
+
if "data" in data:
|
423 |
+
return [elem["embedding"] for elem in data["data"]]
|
424 |
+
else:
|
425 |
+
raise "Something went wrong :/"
|
426 |
+
except Exception as e:
|
427 |
+
print(e)
|
428 |
+
return None
|
429 |
+
|
430 |
+
|
431 |
+
def generate_ollama_batch_embeddings(
|
432 |
+
model: str, texts: list[str], url: str, key: str = ""
|
433 |
+
) -> Optional[list[list[float]]]:
|
434 |
+
try:
|
435 |
+
r = requests.post(
|
436 |
+
f"{url}/api/embed",
|
437 |
+
headers={
|
438 |
+
"Content-Type": "application/json",
|
439 |
+
"Authorization": f"Bearer {key}",
|
440 |
+
},
|
441 |
+
json={"input": texts, "model": model},
|
442 |
+
)
|
443 |
+
r.raise_for_status()
|
444 |
+
data = r.json()
|
445 |
+
|
446 |
+
if "embeddings" in data:
|
447 |
+
return data["embeddings"]
|
448 |
+
else:
|
449 |
+
raise "Something went wrong :/"
|
450 |
+
except Exception as e:
|
451 |
+
print(e)
|
452 |
+
return None
|
453 |
+
|
454 |
+
|
455 |
+
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
456 |
+
url = kwargs.get("url", "")
|
457 |
+
key = kwargs.get("key", "")
|
458 |
+
|
459 |
+
if engine == "ollama":
|
460 |
+
if isinstance(text, list):
|
461 |
+
embeddings = generate_ollama_batch_embeddings(
|
462 |
+
**{"model": model, "texts": text, "url": url, "key": key}
|
463 |
+
)
|
464 |
+
else:
|
465 |
+
embeddings = generate_ollama_batch_embeddings(
|
466 |
+
**{"model": model, "texts": [text], "url": url, "key": key}
|
467 |
+
)
|
468 |
+
return embeddings[0] if isinstance(text, str) else embeddings
|
469 |
+
elif engine == "openai":
|
470 |
+
if isinstance(text, list):
|
471 |
+
embeddings = generate_openai_batch_embeddings(model, text, url, key)
|
472 |
+
else:
|
473 |
+
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
|
474 |
+
|
475 |
+
return embeddings[0] if isinstance(text, str) else embeddings
|
476 |
+
|
477 |
+
|
478 |
+
import operator
|
479 |
+
from typing import Optional, Sequence
|
480 |
+
|
481 |
+
from langchain_core.callbacks import Callbacks
|
482 |
+
from langchain_core.documents import BaseDocumentCompressor, Document
|
483 |
+
|
484 |
+
|
485 |
+
class RerankCompressor(BaseDocumentCompressor):
|
486 |
+
embedding_function: Any
|
487 |
+
top_n: int
|
488 |
+
reranking_function: Any
|
489 |
+
r_score: float
|
490 |
+
|
491 |
+
class Config:
|
492 |
+
extra = "forbid"
|
493 |
+
arbitrary_types_allowed = True
|
494 |
+
|
495 |
+
def compress_documents(
|
496 |
+
self,
|
497 |
+
documents: Sequence[Document],
|
498 |
+
query: str,
|
499 |
+
callbacks: Optional[Callbacks] = None,
|
500 |
+
) -> Sequence[Document]:
|
501 |
+
reranking = self.reranking_function is not None
|
502 |
+
|
503 |
+
if reranking:
|
504 |
+
scores = self.reranking_function.predict(
|
505 |
+
[(query, doc.page_content) for doc in documents]
|
506 |
+
)
|
507 |
+
else:
|
508 |
+
from sentence_transformers import util
|
509 |
+
|
510 |
+
query_embedding = self.embedding_function(query)
|
511 |
+
document_embedding = self.embedding_function(
|
512 |
+
[doc.page_content for doc in documents]
|
513 |
+
)
|
514 |
+
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
515 |
+
|
516 |
+
docs_with_scores = list(zip(documents, scores.tolist()))
|
517 |
+
if self.r_score:
|
518 |
+
docs_with_scores = [
|
519 |
+
(d, s) for d, s in docs_with_scores if s >= self.r_score
|
520 |
+
]
|
521 |
+
|
522 |
+
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
523 |
+
final_results = []
|
524 |
+
for doc, doc_score in result[: self.top_n]:
|
525 |
+
metadata = doc.metadata
|
526 |
+
metadata["score"] = doc_score
|
527 |
+
doc = Document(
|
528 |
+
page_content=doc.page_content,
|
529 |
+
metadata=metadata,
|
530 |
+
)
|
531 |
+
final_results.append(doc)
|
532 |
+
return final_results
|
backend/open_webui/apps/retrieval/vector/connector.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from open_webui.config import VECTOR_DB
|
2 |
+
|
3 |
+
if VECTOR_DB == "milvus":
|
4 |
+
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
|
5 |
+
|
6 |
+
VECTOR_DB_CLIENT = MilvusClient()
|
7 |
+
elif VECTOR_DB == "qdrant":
|
8 |
+
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
|
9 |
+
|
10 |
+
VECTOR_DB_CLIENT = QdrantClient()
|
11 |
+
elif VECTOR_DB == "opensearch":
|
12 |
+
from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient
|
13 |
+
|
14 |
+
VECTOR_DB_CLIENT = OpenSearchClient()
|
15 |
+
elif VECTOR_DB == "pgvector":
|
16 |
+
from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient
|
17 |
+
|
18 |
+
VECTOR_DB_CLIENT = PgvectorClient()
|
19 |
+
else:
|
20 |
+
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
|
21 |
+
|
22 |
+
VECTOR_DB_CLIENT = ChromaClient()
|
backend/open_webui/apps/retrieval/vector/dbs/chroma.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chromadb
|
2 |
+
from chromadb import Settings
|
3 |
+
from chromadb.utils.batch_utils import create_batches
|
4 |
+
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
8 |
+
from open_webui.config import (
|
9 |
+
CHROMA_DATA_PATH,
|
10 |
+
CHROMA_HTTP_HOST,
|
11 |
+
CHROMA_HTTP_PORT,
|
12 |
+
CHROMA_HTTP_HEADERS,
|
13 |
+
CHROMA_HTTP_SSL,
|
14 |
+
CHROMA_TENANT,
|
15 |
+
CHROMA_DATABASE,
|
16 |
+
CHROMA_CLIENT_AUTH_PROVIDER,
|
17 |
+
CHROMA_CLIENT_AUTH_CREDENTIALS,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
class ChromaClient:
|
22 |
+
def __init__(self):
|
23 |
+
settings_dict = {
|
24 |
+
"allow_reset": True,
|
25 |
+
"anonymized_telemetry": False,
|
26 |
+
}
|
27 |
+
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
|
28 |
+
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
|
29 |
+
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
|
30 |
+
settings_dict["chroma_client_auth_credentials"] = (
|
31 |
+
CHROMA_CLIENT_AUTH_CREDENTIALS
|
32 |
+
)
|
33 |
+
|
34 |
+
if CHROMA_HTTP_HOST != "":
|
35 |
+
self.client = chromadb.HttpClient(
|
36 |
+
host=CHROMA_HTTP_HOST,
|
37 |
+
port=CHROMA_HTTP_PORT,
|
38 |
+
headers=CHROMA_HTTP_HEADERS,
|
39 |
+
ssl=CHROMA_HTTP_SSL,
|
40 |
+
tenant=CHROMA_TENANT,
|
41 |
+
database=CHROMA_DATABASE,
|
42 |
+
settings=Settings(**settings_dict),
|
43 |
+
)
|
44 |
+
else:
|
45 |
+
self.client = chromadb.PersistentClient(
|
46 |
+
path=CHROMA_DATA_PATH,
|
47 |
+
settings=Settings(**settings_dict),
|
48 |
+
tenant=CHROMA_TENANT,
|
49 |
+
database=CHROMA_DATABASE,
|
50 |
+
)
|
51 |
+
|
52 |
+
def has_collection(self, collection_name: str) -> bool:
|
53 |
+
# Check if the collection exists based on the collection name.
|
54 |
+
collections = self.client.list_collections()
|
55 |
+
return collection_name in [collection.name for collection in collections]
|
56 |
+
|
57 |
+
def delete_collection(self, collection_name: str):
|
58 |
+
# Delete the collection based on the collection name.
|
59 |
+
return self.client.delete_collection(name=collection_name)
|
60 |
+
|
61 |
+
def search(
|
62 |
+
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
63 |
+
) -> Optional[SearchResult]:
|
64 |
+
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
65 |
+
try:
|
66 |
+
collection = self.client.get_collection(name=collection_name)
|
67 |
+
if collection:
|
68 |
+
result = collection.query(
|
69 |
+
query_embeddings=vectors,
|
70 |
+
n_results=limit,
|
71 |
+
)
|
72 |
+
|
73 |
+
return SearchResult(
|
74 |
+
**{
|
75 |
+
"ids": result["ids"],
|
76 |
+
"distances": result["distances"],
|
77 |
+
"documents": result["documents"],
|
78 |
+
"metadatas": result["metadatas"],
|
79 |
+
}
|
80 |
+
)
|
81 |
+
return None
|
82 |
+
except Exception as e:
|
83 |
+
return None
|
84 |
+
|
85 |
+
def query(
|
86 |
+
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
87 |
+
) -> Optional[GetResult]:
|
88 |
+
# Query the items from the collection based on the filter.
|
89 |
+
try:
|
90 |
+
collection = self.client.get_collection(name=collection_name)
|
91 |
+
if collection:
|
92 |
+
result = collection.get(
|
93 |
+
where=filter,
|
94 |
+
limit=limit,
|
95 |
+
)
|
96 |
+
|
97 |
+
return GetResult(
|
98 |
+
**{
|
99 |
+
"ids": [result["ids"]],
|
100 |
+
"documents": [result["documents"]],
|
101 |
+
"metadatas": [result["metadatas"]],
|
102 |
+
}
|
103 |
+
)
|
104 |
+
return None
|
105 |
+
except Exception as e:
|
106 |
+
print(e)
|
107 |
+
return None
|
108 |
+
|
109 |
+
def get(self, collection_name: str) -> Optional[GetResult]:
|
110 |
+
# Get all the items in the collection.
|
111 |
+
collection = self.client.get_collection(name=collection_name)
|
112 |
+
if collection:
|
113 |
+
result = collection.get()
|
114 |
+
return GetResult(
|
115 |
+
**{
|
116 |
+
"ids": [result["ids"]],
|
117 |
+
"documents": [result["documents"]],
|
118 |
+
"metadatas": [result["metadatas"]],
|
119 |
+
}
|
120 |
+
)
|
121 |
+
return None
|
122 |
+
|
123 |
+
def insert(self, collection_name: str, items: list[VectorItem]):
|
124 |
+
# Insert the items into the collection, if the collection does not exist, it will be created.
|
125 |
+
collection = self.client.get_or_create_collection(
|
126 |
+
name=collection_name, metadata={"hnsw:space": "cosine"}
|
127 |
+
)
|
128 |
+
|
129 |
+
ids = [item["id"] for item in items]
|
130 |
+
documents = [item["text"] for item in items]
|
131 |
+
embeddings = [item["vector"] for item in items]
|
132 |
+
metadatas = [item["metadata"] for item in items]
|
133 |
+
|
134 |
+
for batch in create_batches(
|
135 |
+
api=self.client,
|
136 |
+
documents=documents,
|
137 |
+
embeddings=embeddings,
|
138 |
+
ids=ids,
|
139 |
+
metadatas=metadatas,
|
140 |
+
):
|
141 |
+
collection.add(*batch)
|
142 |
+
|
143 |
+
def upsert(self, collection_name: str, items: list[VectorItem]):
|
144 |
+
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
145 |
+
collection = self.client.get_or_create_collection(
|
146 |
+
name=collection_name, metadata={"hnsw:space": "cosine"}
|
147 |
+
)
|
148 |
+
|
149 |
+
ids = [item["id"] for item in items]
|
150 |
+
documents = [item["text"] for item in items]
|
151 |
+
embeddings = [item["vector"] for item in items]
|
152 |
+
metadatas = [item["metadata"] for item in items]
|
153 |
+
|
154 |
+
collection.upsert(
|
155 |
+
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
|
156 |
+
)
|
157 |
+
|
158 |
+
def delete(
|
159 |
+
self,
|
160 |
+
collection_name: str,
|
161 |
+
ids: Optional[list[str]] = None,
|
162 |
+
filter: Optional[dict] = None,
|
163 |
+
):
|
164 |
+
# Delete the items from the collection based on the ids.
|
165 |
+
collection = self.client.get_collection(name=collection_name)
|
166 |
+
if collection:
|
167 |
+
if ids:
|
168 |
+
collection.delete(ids=ids)
|
169 |
+
elif filter:
|
170 |
+
collection.delete(where=filter)
|
171 |
+
|
172 |
+
def reset(self):
|
173 |
+
# Resets the database. This will delete all collections and item entries.
|
174 |
+
return self.client.reset()
|