Spaces:
Paused
Paused
Start Application
Browse files- .gitignore +160 -0
- Dockerfile +11 -0
- app/__init__.py +0 -0
- app/api/__init__.py +0 -0
- app/api/generate.py +159 -0
- app/api/prompt.py +34 -0
- app/api/user.py +55 -0
- app/config.py +10 -0
- app/core/__init__.py +0 -0
- app/core/crud.py +58 -0
- app/core/schemas.py +77 -0
- app/db.py +27 -0
- app/main.py +101 -0
- app/security.py +61 -0
- app/sql/__init__.py +0 -0
- app/sql/database.py +11 -0
- app/sql/models.py +29 -0
- requirements.txt +7 -0
- sql_app.db +0 -0
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
COPY ./requirements.txt /code/requirements.txt
|
6 |
+
|
7 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
8 |
+
|
9 |
+
COPY . .
|
10 |
+
|
11 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app/__init__.py
ADDED
File without changes
|
app/api/__init__.py
ADDED
File without changes
|
app/api/generate.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app.db import get_db
|
2 |
+
from app.config import DEVICE
|
3 |
+
from app.core import schemas, crud
|
4 |
+
from app.security import get_current_user
|
5 |
+
from app.core.schemas import TextImage, ImageImage, BackgroundRemoval, ImageVariations
|
6 |
+
|
7 |
+
import base64
|
8 |
+
from io import BytesIO
|
9 |
+
from sqlalchemy.orm import Session
|
10 |
+
from typing import Annotated, List
|
11 |
+
from fastapi import APIRouter, Depends, HTTPException, Request
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torchvision.transforms.functional import normalize
|
18 |
+
|
19 |
+
|
20 |
+
router = APIRouter()
|
21 |
+
|
22 |
+
|
23 |
+
def decode_image(image):
|
24 |
+
return Image.open(BytesIO(base64.b64decode(image))).convert("RGB")
|
25 |
+
|
26 |
+
|
27 |
+
def encode_image(image):
|
28 |
+
bytes = BytesIO()
|
29 |
+
image.save(bytes, format="PNG")
|
30 |
+
return base64.b64encode(bytes.getvalue())
|
31 |
+
|
32 |
+
|
33 |
+
def create_prompt(subject, medium, style, artist, website, resolution, additional_details, color, lightning):
|
34 |
+
if not subject:
|
35 |
+
return None
|
36 |
+
if medium:
|
37 |
+
subject = f"{medium} of {subject}"
|
38 |
+
if style:
|
39 |
+
subject = f"{subject}, {style}"
|
40 |
+
if artist:
|
41 |
+
subject = f"{subject}, by {artist}"
|
42 |
+
if website:
|
43 |
+
subject = f"{subject}, {website}"
|
44 |
+
if resolution:
|
45 |
+
subject = f"{subject}, {resolution}"
|
46 |
+
if additional_details:
|
47 |
+
subject = f"{subject}, {additional_details}"
|
48 |
+
if color:
|
49 |
+
subject = f"{subject}, {color}"
|
50 |
+
if lightning:
|
51 |
+
subject = f"{subject}, {lightning}"
|
52 |
+
return subject
|
53 |
+
|
54 |
+
|
55 |
+
@router.post("/text-image/", response_model=str)
|
56 |
+
def text_image(model: Request, request: TextImage, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
|
57 |
+
if not current_user.is_active:
|
58 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
59 |
+
|
60 |
+
generator = torch.manual_seed(request.seed)
|
61 |
+
prompt = create_prompt(request.prompt, medium=request.medium, style=request.style,
|
62 |
+
additional_details=request.additional_details, lightning=request.lightning)
|
63 |
+
|
64 |
+
crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
|
65 |
+
|
66 |
+
image = model.state.ti_pipe(prompt, num_inference_steps=request.num_inference_steps,
|
67 |
+
guidance_scale=request.guidance_scale, generator=generator, negative_prompt=request.negative_prompt).images[0]
|
68 |
+
|
69 |
+
return encode_image(image)
|
70 |
+
|
71 |
+
|
72 |
+
@router.post("/image-image/", response_model=str)
|
73 |
+
def image_image(model: Request, request: ImageImage, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
|
74 |
+
if not current_user.is_active:
|
75 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
76 |
+
|
77 |
+
generator = torch.manual_seed(request.seed)
|
78 |
+
prompt = create_prompt(request.prompt, medium=request.medium, style=request.style,
|
79 |
+
additional_details=request.additional_details, lightning=request.lightning)
|
80 |
+
image = decode_image(request.image)
|
81 |
+
|
82 |
+
crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
|
83 |
+
|
84 |
+
image = model.state.ii_pipe(prompt, image=image, num_inference_steps=request.num_inference_steps, guidance_scale=request.guidance_scale,
|
85 |
+
image_guidance_scale=request.image_guidance_scale, generator=generator, negative_prompt=request.negative_prompt).images[0]
|
86 |
+
|
87 |
+
return encode_image(image)
|
88 |
+
|
89 |
+
|
90 |
+
@router.post("/background-removal/", response_model=str)
|
91 |
+
def background_removal(model: Request, request: BackgroundRemoval, current_user: Annotated[schemas.User, Depends(get_current_user)]):
|
92 |
+
if not current_user.is_active:
|
93 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
94 |
+
|
95 |
+
image = decode_image(request.image)
|
96 |
+
|
97 |
+
def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
|
98 |
+
if len(im.shape) < 3:
|
99 |
+
im = im[:, :, np.newaxis]
|
100 |
+
# orig_im_size=im.shape[0:2]
|
101 |
+
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
|
102 |
+
im_tensor = F.interpolate(torch.unsqueeze(
|
103 |
+
im_tensor, 0), size=model_input_size, mode='bilinear')
|
104 |
+
image = torch.divide(im_tensor, 255.0)
|
105 |
+
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
|
106 |
+
return image
|
107 |
+
|
108 |
+
def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
|
109 |
+
result = torch.squeeze(F.interpolate(
|
110 |
+
result, size=im_size, mode='bilinear'), 0)
|
111 |
+
ma = torch.max(result)
|
112 |
+
mi = torch.min(result)
|
113 |
+
result = (result-mi)/(ma-mi)
|
114 |
+
im_array = (result*255).permute(1, 2,
|
115 |
+
0).cpu().data.numpy().astype(np.uint8)
|
116 |
+
im_array = np.squeeze(im_array)
|
117 |
+
return im_array
|
118 |
+
|
119 |
+
# prepare input
|
120 |
+
model_input_size = [1024, 1024]
|
121 |
+
orig_im = np.array(image)
|
122 |
+
orig_im_size = orig_im.shape[0:2]
|
123 |
+
image = preprocess_image(orig_im, model_input_size).to(DEVICE)
|
124 |
+
|
125 |
+
# inference
|
126 |
+
result = model.state.br_model(image)
|
127 |
+
|
128 |
+
# post process
|
129 |
+
result_image = postprocess_image(result[0][0], orig_im_size)
|
130 |
+
|
131 |
+
# save result
|
132 |
+
pil_im = Image.fromarray(result_image)
|
133 |
+
no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
|
134 |
+
orig_image = Image.fromarray(orig_im)
|
135 |
+
no_bg_image.paste(orig_image, mask=pil_im)
|
136 |
+
|
137 |
+
return encode_image(no_bg_image)
|
138 |
+
|
139 |
+
|
140 |
+
@router.post("/image-variations/", response_model=List[str])
|
141 |
+
def image_variations(model: Request, request: ImageVariations, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
|
142 |
+
if not current_user.is_active:
|
143 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
144 |
+
|
145 |
+
# prompt = create_prompt(request.prompt, medium=request.medium, style=request.style,
|
146 |
+
# additional_details=request.additional_details, lightning=request.lightning)
|
147 |
+
# image = decode_image(request.image)
|
148 |
+
# image.resize((512, 512))
|
149 |
+
|
150 |
+
# if prompt:
|
151 |
+
# crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
|
152 |
+
|
153 |
+
# images = model.state.iv_model.generate(pil_image=image, num_samples=request.num_samples, num_inference_steps=request.num_inference_steps,
|
154 |
+
# seed=request.seed, prompt=prompt, scale=request.scale, negative_prompt=request.negative_prompt)
|
155 |
+
|
156 |
+
# images = [encode_image(image) for image in images]
|
157 |
+
|
158 |
+
# return images
|
159 |
+
return ["Image Variations is not supported yet."]
|
app/api/prompt.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, List
|
2 |
+
from sqlalchemy.orm import Session
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException
|
4 |
+
|
5 |
+
from app.db import get_db
|
6 |
+
from app.core import schemas, crud
|
7 |
+
from app.security import get_current_user
|
8 |
+
|
9 |
+
|
10 |
+
router = APIRouter()
|
11 |
+
|
12 |
+
|
13 |
+
@router.post("/get-all-prompts/", response_model=List[schemas.Prompt])
|
14 |
+
def get_all_prompts(db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
|
15 |
+
if not current_user.is_superuser:
|
16 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
17 |
+
|
18 |
+
return crud.get_all_prompts(db=db)
|
19 |
+
|
20 |
+
|
21 |
+
@ router.post("/get-prompt_by_user_id/{user_id}/", response_model=List[schemas.Prompt])
|
22 |
+
def get_prompt_by_user_id(user_id: int, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
|
23 |
+
if not current_user.is_superuser:
|
24 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
25 |
+
|
26 |
+
return crud.get_prompt_by_user_id(user_id=user_id, db=db)
|
27 |
+
|
28 |
+
|
29 |
+
# @ router.post("/create-prompt/", response_model=schemas.Prompt)
|
30 |
+
# def create_prompt(prompt: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
|
31 |
+
# if not current_user.is_superuser:
|
32 |
+
# raise HTTPException(status_code=403, detail="Forbidden")
|
33 |
+
|
34 |
+
# return crud.create_prompt(prompt=prompt, db=db)
|
app/api/user.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, List
|
2 |
+
from sqlalchemy.orm import Session
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException
|
4 |
+
|
5 |
+
from app.db import get_db
|
6 |
+
from app.core import schemas, crud
|
7 |
+
from app.security import get_current_user
|
8 |
+
|
9 |
+
|
10 |
+
router = APIRouter()
|
11 |
+
|
12 |
+
|
13 |
+
@router.post("/create-user/", response_model=schemas.User)
|
14 |
+
def create_user(user: schemas.UserCreate, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
|
15 |
+
if not current_user.is_superuser:
|
16 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
17 |
+
|
18 |
+
user_exists = crud.get_user_by_username(username=user.username, db=db)
|
19 |
+
if user_exists:
|
20 |
+
raise HTTPException(
|
21 |
+
status_code=400, detail="Username already registered")
|
22 |
+
|
23 |
+
return crud.create_user(user=user, db=db)
|
24 |
+
|
25 |
+
|
26 |
+
@router.post("/update-user/", response_model=schemas.User)
|
27 |
+
def update_user(user: schemas.UserUpdate, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
|
28 |
+
if not current_user.is_superuser:
|
29 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
30 |
+
|
31 |
+
user_exists = crud.get_user_by_user_id(user_id=user.user_id, db=db)
|
32 |
+
if not user_exists:
|
33 |
+
raise HTTPException(status_code=404, detail="User not found")
|
34 |
+
|
35 |
+
return crud.update_user(user=user, db=db)
|
36 |
+
|
37 |
+
|
38 |
+
@router.post("/get-all-users/", response_model=List[schemas.User])
|
39 |
+
def get_all_users(db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
|
40 |
+
if not current_user.is_superuser:
|
41 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
42 |
+
|
43 |
+
return crud.get_all_users(db=db)
|
44 |
+
|
45 |
+
|
46 |
+
@router.post("/get-user_by_user_id/{user_id}/", response_model=schemas.User)
|
47 |
+
def get_user_by_user_id(user_id: int, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
|
48 |
+
if not current_user.is_superuser:
|
49 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
50 |
+
|
51 |
+
user = crud.get_user_by_user_id(user_id=user_id, db=db)
|
52 |
+
if user is None:
|
53 |
+
raise HTTPException(status_code=404, detail="User not found")
|
54 |
+
|
55 |
+
return user
|
app/config.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATABASE_URL = 'sqlite:///./sql_app.db'
|
2 |
+
|
3 |
+
SECRET_KEY = "20c7fc21a12dc8c19b06e603af10ebc11887a614f188f3d0878bb42bd784c315"
|
4 |
+
ALGORITHM = "HS256"
|
5 |
+
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
6 |
+
|
7 |
+
SUPERUSER_USERNAME = "admin"
|
8 |
+
SUPERUSER_PASSWORD = "admin"
|
9 |
+
|
10 |
+
DEVICE = "cuda"
|
app/core/__init__.py
ADDED
File without changes
|
app/core/crud.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import app.sql.models as models
|
2 |
+
import app.core.schemas as schemas
|
3 |
+
|
4 |
+
from sqlalchemy.orm import Session
|
5 |
+
from passlib.context import CryptContext
|
6 |
+
|
7 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
8 |
+
|
9 |
+
|
10 |
+
def hash_password(password: str):
|
11 |
+
return pwd_context.hash(password)
|
12 |
+
|
13 |
+
|
14 |
+
def create_prompt(prompt: str, user_id: int, db: Session):
|
15 |
+
db_prompt = models.Prompt(prompt=prompt, user_id=user_id)
|
16 |
+
db.add(db_prompt)
|
17 |
+
db.commit()
|
18 |
+
db.refresh(db_prompt)
|
19 |
+
return db_prompt
|
20 |
+
|
21 |
+
|
22 |
+
def get_all_prompts(db: Session):
|
23 |
+
return db.query(models.Prompt).all()
|
24 |
+
|
25 |
+
|
26 |
+
def get_prompt_by_user_id(user_id: int, db: Session):
|
27 |
+
return db.query(models.Prompt).filter(models.Prompt.user_id == user_id).all()
|
28 |
+
|
29 |
+
|
30 |
+
def create_user(user: schemas.UserCreate, db: Session):
|
31 |
+
hashed_password = hash_password(user.password)
|
32 |
+
db_user = models.User(username=user.username, password=hashed_password)
|
33 |
+
db.add(db_user)
|
34 |
+
db.commit()
|
35 |
+
db.refresh(db_user)
|
36 |
+
return db_user
|
37 |
+
|
38 |
+
|
39 |
+
def update_user(user: schemas.UserUpdate, db: Session):
|
40 |
+
db_user = db.query(models.User).filter(
|
41 |
+
models.User.user_id == user.user_id).first()
|
42 |
+
db_user.is_active = user.is_active
|
43 |
+
db_user.is_superuser = user.is_superuser
|
44 |
+
db.commit()
|
45 |
+
db.refresh(db_user)
|
46 |
+
return db_user
|
47 |
+
|
48 |
+
|
49 |
+
def get_all_users(db: Session):
|
50 |
+
return db.query(models.User).all()
|
51 |
+
|
52 |
+
|
53 |
+
def get_user_by_user_id(user_id: int, db: Session):
|
54 |
+
return db.query(models.User).filter(models.User.user_id == user_id).first()
|
55 |
+
|
56 |
+
|
57 |
+
def get_user_by_username(username: str, db: Session):
|
58 |
+
return db.query(models.User).filter(models.User.username == username).first()
|
app/core/schemas.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from pydantic import BaseModel
|
3 |
+
|
4 |
+
|
5 |
+
class PromptBase(BaseModel):
|
6 |
+
prompt: str
|
7 |
+
|
8 |
+
|
9 |
+
class Prompt(BaseModel):
|
10 |
+
id: int
|
11 |
+
created_at: datetime
|
12 |
+
|
13 |
+
class Config:
|
14 |
+
from_attributes = True
|
15 |
+
|
16 |
+
|
17 |
+
class UserBase(BaseModel):
|
18 |
+
pass
|
19 |
+
|
20 |
+
|
21 |
+
class UserCreate(UserBase):
|
22 |
+
username: str
|
23 |
+
password: str
|
24 |
+
|
25 |
+
|
26 |
+
class UserUpdate(UserBase):
|
27 |
+
user_id: int
|
28 |
+
is_active: bool = True
|
29 |
+
is_superuser: bool = False
|
30 |
+
|
31 |
+
|
32 |
+
class User(UserBase):
|
33 |
+
user_id: int
|
34 |
+
username: str
|
35 |
+
is_active: bool
|
36 |
+
is_superuser: bool
|
37 |
+
created_at: datetime
|
38 |
+
updated_at: datetime
|
39 |
+
prompts: list[Prompt] = []
|
40 |
+
|
41 |
+
class Config:
|
42 |
+
from_attributes = True
|
43 |
+
|
44 |
+
|
45 |
+
class Generate(BaseModel):
|
46 |
+
seed: int | None = None
|
47 |
+
negative_prompt : str = "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face"
|
48 |
+
medium: str | None = None
|
49 |
+
style: str | None = None
|
50 |
+
lightning: str | None = None
|
51 |
+
additional_details: str | None = None
|
52 |
+
|
53 |
+
|
54 |
+
class TextImage(Generate):
|
55 |
+
prompt: str
|
56 |
+
num_inference_steps: int = 4
|
57 |
+
guidance_scale: float = 2.0
|
58 |
+
|
59 |
+
|
60 |
+
class ImageImage(Generate):
|
61 |
+
prompt: str
|
62 |
+
image: str
|
63 |
+
num_inference_steps: int = 10
|
64 |
+
guidance_scale: float = 7.5
|
65 |
+
image_guidance_scale: float = 1.5
|
66 |
+
|
67 |
+
|
68 |
+
class BackgroundRemoval(BaseModel):
|
69 |
+
image: str
|
70 |
+
|
71 |
+
|
72 |
+
class ImageVariations(Generate):
|
73 |
+
image: str
|
74 |
+
num_samples: int = 2
|
75 |
+
num_inference_steps: int = 30
|
76 |
+
prompt: str | None = None
|
77 |
+
scale: float = 0.5
|
app/db.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app.sql.models import Base
|
2 |
+
from app.sql.models import User
|
3 |
+
from app.core.crud import hash_password
|
4 |
+
from app.sql.database import SessionLocal, engine
|
5 |
+
from app.config import SUPERUSER_USERNAME, SUPERUSER_PASSWORD
|
6 |
+
|
7 |
+
|
8 |
+
def get_db():
|
9 |
+
db = SessionLocal()
|
10 |
+
try:
|
11 |
+
yield db
|
12 |
+
finally:
|
13 |
+
db.close()
|
14 |
+
|
15 |
+
|
16 |
+
def init_db():
|
17 |
+
db = SessionLocal()
|
18 |
+
Base.metadata.create_all(bind=engine)
|
19 |
+
|
20 |
+
hashed_superuser_password = hash_password(SUPERUSER_PASSWORD)
|
21 |
+
superuser = User(username=SUPERUSER_USERNAME,
|
22 |
+
password=hashed_superuser_password, is_superuser=True, is_active=True)
|
23 |
+
|
24 |
+
if db.query(User).filter(User.username == SUPERUSER_USERNAME).first() is None:
|
25 |
+
db.add(superuser)
|
26 |
+
db.commit()
|
27 |
+
db.refresh(superuser)
|
app/main.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated
|
2 |
+
from sqlalchemy.orm import Session
|
3 |
+
from contextlib import asynccontextmanager
|
4 |
+
from starlette.middleware.cors import CORSMiddleware
|
5 |
+
from fastapi.security import OAuth2PasswordRequestForm
|
6 |
+
from fastapi import APIRouter, FastAPI, HTTPException, Depends
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from ip_adapter import IPAdapterXL
|
10 |
+
from transformers import AutoModelForImageSegmentation
|
11 |
+
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler, StableDiffusionXLPipeline
|
12 |
+
|
13 |
+
from app.db import get_db, init_db
|
14 |
+
from app.sql import models
|
15 |
+
from app.api import user
|
16 |
+
from app.api import prompt
|
17 |
+
from app.sql.database import engine
|
18 |
+
from app.api import generate
|
19 |
+
from app.config import ACCESS_TOKEN_EXPIRE_MINUTES, DEVICE
|
20 |
+
from app.security import authenticate_user, create_access_token, timedelta
|
21 |
+
|
22 |
+
|
23 |
+
@asynccontextmanager
|
24 |
+
async def lifespan(app: FastAPI):
|
25 |
+
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
26 |
+
image_encoder_path = "sdxl_models/image_encoder"
|
27 |
+
ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
|
28 |
+
|
29 |
+
ti_pipe = AutoPipelineForText2Image.from_pretrained(
|
30 |
+
'lykon/dreamshaper-xl-v2-turbo', torch_dtype=torch.float16, variant="fp16")
|
31 |
+
ti_pipe.to(DEVICE)
|
32 |
+
ti_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
33 |
+
ti_pipe.scheduler.config)
|
34 |
+
|
35 |
+
ii_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
36 |
+
"timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None)
|
37 |
+
ii_pipe.to(DEVICE)
|
38 |
+
ii_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
39 |
+
ii_pipe.scheduler.config)
|
40 |
+
|
41 |
+
br_model = AutoModelForImageSegmentation.from_pretrained(
|
42 |
+
"briaai/RMBG-1.4", trust_remote_code=True)
|
43 |
+
br_model.to(DEVICE)
|
44 |
+
|
45 |
+
# sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
|
46 |
+
# base_model_path,
|
47 |
+
# torch_dtype=torch.float16,
|
48 |
+
# add_watermarker=False,
|
49 |
+
# )
|
50 |
+
# iv_model = IPAdapterXL(sdxl_pipe, image_encoder_path, ip_ckpt, DEVICE)
|
51 |
+
|
52 |
+
yield {'ti_pipe': ti_pipe, 'ii_pipe': ii_pipe, 'br_model': br_model} # , 'iv_model': iv_model
|
53 |
+
|
54 |
+
del ti_pipe
|
55 |
+
del ii_pipe
|
56 |
+
del br_model
|
57 |
+
# del sdxl_pipe
|
58 |
+
# del iv_model
|
59 |
+
|
60 |
+
|
61 |
+
app = FastAPI(lifespan=lifespan)
|
62 |
+
# app = FastAPI()
|
63 |
+
|
64 |
+
router = APIRouter()
|
65 |
+
|
66 |
+
origins = ["*"]
|
67 |
+
|
68 |
+
app.add_middleware(
|
69 |
+
CORSMiddleware,
|
70 |
+
allow_origins=origins,
|
71 |
+
allow_credentials=True,
|
72 |
+
allow_methods=["*"],
|
73 |
+
allow_headers=["*"],
|
74 |
+
)
|
75 |
+
|
76 |
+
init_db()
|
77 |
+
|
78 |
+
|
79 |
+
@app.get("/")
|
80 |
+
def read_root():
|
81 |
+
return {"Hello": "World"}
|
82 |
+
|
83 |
+
|
84 |
+
@app.post("/token")
|
85 |
+
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]):
|
86 |
+
user = authenticate_user(
|
87 |
+
db, form_data.username, form_data.password)
|
88 |
+
if not user:
|
89 |
+
raise HTTPException(
|
90 |
+
status_code=400, detail="Incorrect username or password")
|
91 |
+
access_token_expires = timedelta(
|
92 |
+
minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
93 |
+
access_token = create_access_token(
|
94 |
+
data={"sub": user.username}, expires_delta=access_token_expires)
|
95 |
+
return {"access_token": access_token, "token_type": "bearer"}
|
96 |
+
|
97 |
+
|
98 |
+
router.include_router(user.router, prefix="/users")
|
99 |
+
router.include_router(prompt.router, prefix="/prompts")
|
100 |
+
router.include_router(generate.router, prefix="/generate")
|
101 |
+
app.include_router(router)
|
app/security.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated
|
2 |
+
from datetime import datetime, timedelta, timezone
|
3 |
+
from fastapi import Depends, HTTPException
|
4 |
+
from fastapi.security import OAuth2PasswordBearer
|
5 |
+
from passlib.context import CryptContext
|
6 |
+
from sqlalchemy.orm import Session
|
7 |
+
from jose import JWTError, jwt
|
8 |
+
|
9 |
+
|
10 |
+
from app.config import ALGORITHM, SECRET_KEY
|
11 |
+
from app.db import get_db
|
12 |
+
from app.sql import models
|
13 |
+
|
14 |
+
|
15 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
16 |
+
|
17 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
18 |
+
|
19 |
+
|
20 |
+
def verify_password(plain_password, hashed_password):
|
21 |
+
return pwd_context.verify(plain_password, hashed_password)
|
22 |
+
|
23 |
+
|
24 |
+
def get_user(db: Session, username: str):
|
25 |
+
return db.query(models.User).filter(models.User.username == username).first()
|
26 |
+
|
27 |
+
|
28 |
+
def authenticate_user(db: Session, username: str, password: str):
|
29 |
+
user = get_user(db, username)
|
30 |
+
if not user:
|
31 |
+
return False
|
32 |
+
if not verify_password(password, user.password):
|
33 |
+
return False
|
34 |
+
return user
|
35 |
+
|
36 |
+
|
37 |
+
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
38 |
+
to_encode = data.copy()
|
39 |
+
if expires_delta:
|
40 |
+
expire = datetime.now(timezone.utc) + expires_delta
|
41 |
+
else:
|
42 |
+
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
|
43 |
+
to_encode.update({"exp": expire})
|
44 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
45 |
+
return encoded_jwt
|
46 |
+
|
47 |
+
|
48 |
+
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], db: Annotated[Session, Depends(get_db)]):
|
49 |
+
credentials_exception = HTTPException(
|
50 |
+
status_code=401, detail="Could not validate credentials")
|
51 |
+
try:
|
52 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
53 |
+
username: str = payload.get("sub")
|
54 |
+
if username is None:
|
55 |
+
raise credentials_exception
|
56 |
+
except JWTError:
|
57 |
+
raise credentials_exception
|
58 |
+
user = get_user(db, username)
|
59 |
+
if user is None:
|
60 |
+
raise credentials_exception
|
61 |
+
return user
|
app/sql/__init__.py
ADDED
File without changes
|
app/sql/database.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app.config import DATABASE_URL
|
2 |
+
from sqlalchemy import create_engine
|
3 |
+
from sqlalchemy.orm import sessionmaker
|
4 |
+
from sqlalchemy.ext.declarative import declarative_base
|
5 |
+
|
6 |
+
|
7 |
+
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
8 |
+
|
9 |
+
SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
|
10 |
+
|
11 |
+
Base = declarative_base()
|
app/sql/models.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from app.sql.database import Base
|
3 |
+
from sqlalchemy.orm import relationship
|
4 |
+
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String
|
5 |
+
|
6 |
+
|
7 |
+
class User(Base):
|
8 |
+
__tablename__ = 'users'
|
9 |
+
|
10 |
+
user_id = Column(Integer, primary_key=True, index=True)
|
11 |
+
username = Column(String, index=True, unique=True)
|
12 |
+
password = Column(String)
|
13 |
+
is_superuser = Column(Boolean, default=False)
|
14 |
+
is_active = Column(Boolean, default=True)
|
15 |
+
created_at = Column(DateTime, default=datetime.now)
|
16 |
+
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
17 |
+
|
18 |
+
prompts = relationship("Prompt", back_populates="user")
|
19 |
+
|
20 |
+
|
21 |
+
class Prompt(Base):
|
22 |
+
__tablename__ = 'prompts'
|
23 |
+
|
24 |
+
id = Column(Integer, primary_key=True)
|
25 |
+
prompt = Column(String, index=True)
|
26 |
+
created_at = Column(DateTime, default=datetime.now)
|
27 |
+
user_id = Column(Integer, ForeignKey('users.user_id'))
|
28 |
+
|
29 |
+
user = relationship("User", back_populates="prompts")
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn[standard]
|
3 |
+
sqlalchemy
|
4 |
+
psycopg2-binary
|
5 |
+
python-multipart
|
6 |
+
python-jose[cryptography]
|
7 |
+
passlib[bcrypt]
|
sql_app.db
ADDED
Binary file (24.6 kB). View file
|
|