khawir commited on
Commit
c053e7d
·
1 Parent(s): c444f2d

Start Application

Browse files
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app/__init__.py ADDED
File without changes
app/api/__init__.py ADDED
File without changes
app/api/generate.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.db import get_db
2
+ from app.config import DEVICE
3
+ from app.core import schemas, crud
4
+ from app.security import get_current_user
5
+ from app.core.schemas import TextImage, ImageImage, BackgroundRemoval, ImageVariations
6
+
7
+ import base64
8
+ from io import BytesIO
9
+ from sqlalchemy.orm import Session
10
+ from typing import Annotated, List
11
+ from fastapi import APIRouter, Depends, HTTPException, Request
12
+
13
+ import torch
14
+ import numpy as np
15
+ from PIL import Image
16
+ import torch.nn.functional as F
17
+ from torchvision.transforms.functional import normalize
18
+
19
+
20
+ router = APIRouter()
21
+
22
+
23
+ def decode_image(image):
24
+ return Image.open(BytesIO(base64.b64decode(image))).convert("RGB")
25
+
26
+
27
+ def encode_image(image):
28
+ bytes = BytesIO()
29
+ image.save(bytes, format="PNG")
30
+ return base64.b64encode(bytes.getvalue())
31
+
32
+
33
+ def create_prompt(subject, medium, style, artist, website, resolution, additional_details, color, lightning):
34
+ if not subject:
35
+ return None
36
+ if medium:
37
+ subject = f"{medium} of {subject}"
38
+ if style:
39
+ subject = f"{subject}, {style}"
40
+ if artist:
41
+ subject = f"{subject}, by {artist}"
42
+ if website:
43
+ subject = f"{subject}, {website}"
44
+ if resolution:
45
+ subject = f"{subject}, {resolution}"
46
+ if additional_details:
47
+ subject = f"{subject}, {additional_details}"
48
+ if color:
49
+ subject = f"{subject}, {color}"
50
+ if lightning:
51
+ subject = f"{subject}, {lightning}"
52
+ return subject
53
+
54
+
55
+ @router.post("/text-image/", response_model=str)
56
+ def text_image(model: Request, request: TextImage, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
57
+ if not current_user.is_active:
58
+ raise HTTPException(status_code=403, detail="Forbidden")
59
+
60
+ generator = torch.manual_seed(request.seed)
61
+ prompt = create_prompt(request.prompt, medium=request.medium, style=request.style,
62
+ additional_details=request.additional_details, lightning=request.lightning)
63
+
64
+ crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
65
+
66
+ image = model.state.ti_pipe(prompt, num_inference_steps=request.num_inference_steps,
67
+ guidance_scale=request.guidance_scale, generator=generator, negative_prompt=request.negative_prompt).images[0]
68
+
69
+ return encode_image(image)
70
+
71
+
72
+ @router.post("/image-image/", response_model=str)
73
+ def image_image(model: Request, request: ImageImage, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
74
+ if not current_user.is_active:
75
+ raise HTTPException(status_code=403, detail="Forbidden")
76
+
77
+ generator = torch.manual_seed(request.seed)
78
+ prompt = create_prompt(request.prompt, medium=request.medium, style=request.style,
79
+ additional_details=request.additional_details, lightning=request.lightning)
80
+ image = decode_image(request.image)
81
+
82
+ crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
83
+
84
+ image = model.state.ii_pipe(prompt, image=image, num_inference_steps=request.num_inference_steps, guidance_scale=request.guidance_scale,
85
+ image_guidance_scale=request.image_guidance_scale, generator=generator, negative_prompt=request.negative_prompt).images[0]
86
+
87
+ return encode_image(image)
88
+
89
+
90
+ @router.post("/background-removal/", response_model=str)
91
+ def background_removal(model: Request, request: BackgroundRemoval, current_user: Annotated[schemas.User, Depends(get_current_user)]):
92
+ if not current_user.is_active:
93
+ raise HTTPException(status_code=403, detail="Forbidden")
94
+
95
+ image = decode_image(request.image)
96
+
97
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
98
+ if len(im.shape) < 3:
99
+ im = im[:, :, np.newaxis]
100
+ # orig_im_size=im.shape[0:2]
101
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
102
+ im_tensor = F.interpolate(torch.unsqueeze(
103
+ im_tensor, 0), size=model_input_size, mode='bilinear')
104
+ image = torch.divide(im_tensor, 255.0)
105
+ image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
106
+ return image
107
+
108
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
109
+ result = torch.squeeze(F.interpolate(
110
+ result, size=im_size, mode='bilinear'), 0)
111
+ ma = torch.max(result)
112
+ mi = torch.min(result)
113
+ result = (result-mi)/(ma-mi)
114
+ im_array = (result*255).permute(1, 2,
115
+ 0).cpu().data.numpy().astype(np.uint8)
116
+ im_array = np.squeeze(im_array)
117
+ return im_array
118
+
119
+ # prepare input
120
+ model_input_size = [1024, 1024]
121
+ orig_im = np.array(image)
122
+ orig_im_size = orig_im.shape[0:2]
123
+ image = preprocess_image(orig_im, model_input_size).to(DEVICE)
124
+
125
+ # inference
126
+ result = model.state.br_model(image)
127
+
128
+ # post process
129
+ result_image = postprocess_image(result[0][0], orig_im_size)
130
+
131
+ # save result
132
+ pil_im = Image.fromarray(result_image)
133
+ no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
134
+ orig_image = Image.fromarray(orig_im)
135
+ no_bg_image.paste(orig_image, mask=pil_im)
136
+
137
+ return encode_image(no_bg_image)
138
+
139
+
140
+ @router.post("/image-variations/", response_model=List[str])
141
+ def image_variations(model: Request, request: ImageVariations, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
142
+ if not current_user.is_active:
143
+ raise HTTPException(status_code=403, detail="Forbidden")
144
+
145
+ # prompt = create_prompt(request.prompt, medium=request.medium, style=request.style,
146
+ # additional_details=request.additional_details, lightning=request.lightning)
147
+ # image = decode_image(request.image)
148
+ # image.resize((512, 512))
149
+
150
+ # if prompt:
151
+ # crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
152
+
153
+ # images = model.state.iv_model.generate(pil_image=image, num_samples=request.num_samples, num_inference_steps=request.num_inference_steps,
154
+ # seed=request.seed, prompt=prompt, scale=request.scale, negative_prompt=request.negative_prompt)
155
+
156
+ # images = [encode_image(image) for image in images]
157
+
158
+ # return images
159
+ return ["Image Variations is not supported yet."]
app/api/prompt.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, List
2
+ from sqlalchemy.orm import Session
3
+ from fastapi import APIRouter, Depends, HTTPException
4
+
5
+ from app.db import get_db
6
+ from app.core import schemas, crud
7
+ from app.security import get_current_user
8
+
9
+
10
+ router = APIRouter()
11
+
12
+
13
+ @router.post("/get-all-prompts/", response_model=List[schemas.Prompt])
14
+ def get_all_prompts(db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
15
+ if not current_user.is_superuser:
16
+ raise HTTPException(status_code=403, detail="Forbidden")
17
+
18
+ return crud.get_all_prompts(db=db)
19
+
20
+
21
+ @ router.post("/get-prompt_by_user_id/{user_id}/", response_model=List[schemas.Prompt])
22
+ def get_prompt_by_user_id(user_id: int, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
23
+ if not current_user.is_superuser:
24
+ raise HTTPException(status_code=403, detail="Forbidden")
25
+
26
+ return crud.get_prompt_by_user_id(user_id=user_id, db=db)
27
+
28
+
29
+ # @ router.post("/create-prompt/", response_model=schemas.Prompt)
30
+ # def create_prompt(prompt: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
31
+ # if not current_user.is_superuser:
32
+ # raise HTTPException(status_code=403, detail="Forbidden")
33
+
34
+ # return crud.create_prompt(prompt=prompt, db=db)
app/api/user.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, List
2
+ from sqlalchemy.orm import Session
3
+ from fastapi import APIRouter, Depends, HTTPException
4
+
5
+ from app.db import get_db
6
+ from app.core import schemas, crud
7
+ from app.security import get_current_user
8
+
9
+
10
+ router = APIRouter()
11
+
12
+
13
+ @router.post("/create-user/", response_model=schemas.User)
14
+ def create_user(user: schemas.UserCreate, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
15
+ if not current_user.is_superuser:
16
+ raise HTTPException(status_code=403, detail="Forbidden")
17
+
18
+ user_exists = crud.get_user_by_username(username=user.username, db=db)
19
+ if user_exists:
20
+ raise HTTPException(
21
+ status_code=400, detail="Username already registered")
22
+
23
+ return crud.create_user(user=user, db=db)
24
+
25
+
26
+ @router.post("/update-user/", response_model=schemas.User)
27
+ def update_user(user: schemas.UserUpdate, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
28
+ if not current_user.is_superuser:
29
+ raise HTTPException(status_code=403, detail="Forbidden")
30
+
31
+ user_exists = crud.get_user_by_user_id(user_id=user.user_id, db=db)
32
+ if not user_exists:
33
+ raise HTTPException(status_code=404, detail="User not found")
34
+
35
+ return crud.update_user(user=user, db=db)
36
+
37
+
38
+ @router.post("/get-all-users/", response_model=List[schemas.User])
39
+ def get_all_users(db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
40
+ if not current_user.is_superuser:
41
+ raise HTTPException(status_code=403, detail="Forbidden")
42
+
43
+ return crud.get_all_users(db=db)
44
+
45
+
46
+ @router.post("/get-user_by_user_id/{user_id}/", response_model=schemas.User)
47
+ def get_user_by_user_id(user_id: int, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
48
+ if not current_user.is_superuser:
49
+ raise HTTPException(status_code=403, detail="Forbidden")
50
+
51
+ user = crud.get_user_by_user_id(user_id=user_id, db=db)
52
+ if user is None:
53
+ raise HTTPException(status_code=404, detail="User not found")
54
+
55
+ return user
app/config.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ DATABASE_URL = 'sqlite:///./sql_app.db'
2
+
3
+ SECRET_KEY = "20c7fc21a12dc8c19b06e603af10ebc11887a614f188f3d0878bb42bd784c315"
4
+ ALGORITHM = "HS256"
5
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
6
+
7
+ SUPERUSER_USERNAME = "admin"
8
+ SUPERUSER_PASSWORD = "admin"
9
+
10
+ DEVICE = "cuda"
app/core/__init__.py ADDED
File without changes
app/core/crud.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import app.sql.models as models
2
+ import app.core.schemas as schemas
3
+
4
+ from sqlalchemy.orm import Session
5
+ from passlib.context import CryptContext
6
+
7
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
8
+
9
+
10
+ def hash_password(password: str):
11
+ return pwd_context.hash(password)
12
+
13
+
14
+ def create_prompt(prompt: str, user_id: int, db: Session):
15
+ db_prompt = models.Prompt(prompt=prompt, user_id=user_id)
16
+ db.add(db_prompt)
17
+ db.commit()
18
+ db.refresh(db_prompt)
19
+ return db_prompt
20
+
21
+
22
+ def get_all_prompts(db: Session):
23
+ return db.query(models.Prompt).all()
24
+
25
+
26
+ def get_prompt_by_user_id(user_id: int, db: Session):
27
+ return db.query(models.Prompt).filter(models.Prompt.user_id == user_id).all()
28
+
29
+
30
+ def create_user(user: schemas.UserCreate, db: Session):
31
+ hashed_password = hash_password(user.password)
32
+ db_user = models.User(username=user.username, password=hashed_password)
33
+ db.add(db_user)
34
+ db.commit()
35
+ db.refresh(db_user)
36
+ return db_user
37
+
38
+
39
+ def update_user(user: schemas.UserUpdate, db: Session):
40
+ db_user = db.query(models.User).filter(
41
+ models.User.user_id == user.user_id).first()
42
+ db_user.is_active = user.is_active
43
+ db_user.is_superuser = user.is_superuser
44
+ db.commit()
45
+ db.refresh(db_user)
46
+ return db_user
47
+
48
+
49
+ def get_all_users(db: Session):
50
+ return db.query(models.User).all()
51
+
52
+
53
+ def get_user_by_user_id(user_id: int, db: Session):
54
+ return db.query(models.User).filter(models.User.user_id == user_id).first()
55
+
56
+
57
+ def get_user_by_username(username: str, db: Session):
58
+ return db.query(models.User).filter(models.User.username == username).first()
app/core/schemas.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from pydantic import BaseModel
3
+
4
+
5
+ class PromptBase(BaseModel):
6
+ prompt: str
7
+
8
+
9
+ class Prompt(BaseModel):
10
+ id: int
11
+ created_at: datetime
12
+
13
+ class Config:
14
+ from_attributes = True
15
+
16
+
17
+ class UserBase(BaseModel):
18
+ pass
19
+
20
+
21
+ class UserCreate(UserBase):
22
+ username: str
23
+ password: str
24
+
25
+
26
+ class UserUpdate(UserBase):
27
+ user_id: int
28
+ is_active: bool = True
29
+ is_superuser: bool = False
30
+
31
+
32
+ class User(UserBase):
33
+ user_id: int
34
+ username: str
35
+ is_active: bool
36
+ is_superuser: bool
37
+ created_at: datetime
38
+ updated_at: datetime
39
+ prompts: list[Prompt] = []
40
+
41
+ class Config:
42
+ from_attributes = True
43
+
44
+
45
+ class Generate(BaseModel):
46
+ seed: int | None = None
47
+ negative_prompt : str = "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face"
48
+ medium: str | None = None
49
+ style: str | None = None
50
+ lightning: str | None = None
51
+ additional_details: str | None = None
52
+
53
+
54
+ class TextImage(Generate):
55
+ prompt: str
56
+ num_inference_steps: int = 4
57
+ guidance_scale: float = 2.0
58
+
59
+
60
+ class ImageImage(Generate):
61
+ prompt: str
62
+ image: str
63
+ num_inference_steps: int = 10
64
+ guidance_scale: float = 7.5
65
+ image_guidance_scale: float = 1.5
66
+
67
+
68
+ class BackgroundRemoval(BaseModel):
69
+ image: str
70
+
71
+
72
+ class ImageVariations(Generate):
73
+ image: str
74
+ num_samples: int = 2
75
+ num_inference_steps: int = 30
76
+ prompt: str | None = None
77
+ scale: float = 0.5
app/db.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.sql.models import Base
2
+ from app.sql.models import User
3
+ from app.core.crud import hash_password
4
+ from app.sql.database import SessionLocal, engine
5
+ from app.config import SUPERUSER_USERNAME, SUPERUSER_PASSWORD
6
+
7
+
8
+ def get_db():
9
+ db = SessionLocal()
10
+ try:
11
+ yield db
12
+ finally:
13
+ db.close()
14
+
15
+
16
+ def init_db():
17
+ db = SessionLocal()
18
+ Base.metadata.create_all(bind=engine)
19
+
20
+ hashed_superuser_password = hash_password(SUPERUSER_PASSWORD)
21
+ superuser = User(username=SUPERUSER_USERNAME,
22
+ password=hashed_superuser_password, is_superuser=True, is_active=True)
23
+
24
+ if db.query(User).filter(User.username == SUPERUSER_USERNAME).first() is None:
25
+ db.add(superuser)
26
+ db.commit()
27
+ db.refresh(superuser)
app/main.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated
2
+ from sqlalchemy.orm import Session
3
+ from contextlib import asynccontextmanager
4
+ from starlette.middleware.cors import CORSMiddleware
5
+ from fastapi.security import OAuth2PasswordRequestForm
6
+ from fastapi import APIRouter, FastAPI, HTTPException, Depends
7
+
8
+ import torch
9
+ from ip_adapter import IPAdapterXL
10
+ from transformers import AutoModelForImageSegmentation
11
+ from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler, StableDiffusionXLPipeline
12
+
13
+ from app.db import get_db, init_db
14
+ from app.sql import models
15
+ from app.api import user
16
+ from app.api import prompt
17
+ from app.sql.database import engine
18
+ from app.api import generate
19
+ from app.config import ACCESS_TOKEN_EXPIRE_MINUTES, DEVICE
20
+ from app.security import authenticate_user, create_access_token, timedelta
21
+
22
+
23
+ @asynccontextmanager
24
+ async def lifespan(app: FastAPI):
25
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
26
+ image_encoder_path = "sdxl_models/image_encoder"
27
+ ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
28
+
29
+ ti_pipe = AutoPipelineForText2Image.from_pretrained(
30
+ 'lykon/dreamshaper-xl-v2-turbo', torch_dtype=torch.float16, variant="fp16")
31
+ ti_pipe.to(DEVICE)
32
+ ti_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
33
+ ti_pipe.scheduler.config)
34
+
35
+ ii_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
36
+ "timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None)
37
+ ii_pipe.to(DEVICE)
38
+ ii_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
39
+ ii_pipe.scheduler.config)
40
+
41
+ br_model = AutoModelForImageSegmentation.from_pretrained(
42
+ "briaai/RMBG-1.4", trust_remote_code=True)
43
+ br_model.to(DEVICE)
44
+
45
+ # sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
46
+ # base_model_path,
47
+ # torch_dtype=torch.float16,
48
+ # add_watermarker=False,
49
+ # )
50
+ # iv_model = IPAdapterXL(sdxl_pipe, image_encoder_path, ip_ckpt, DEVICE)
51
+
52
+ yield {'ti_pipe': ti_pipe, 'ii_pipe': ii_pipe, 'br_model': br_model} # , 'iv_model': iv_model
53
+
54
+ del ti_pipe
55
+ del ii_pipe
56
+ del br_model
57
+ # del sdxl_pipe
58
+ # del iv_model
59
+
60
+
61
+ app = FastAPI(lifespan=lifespan)
62
+ # app = FastAPI()
63
+
64
+ router = APIRouter()
65
+
66
+ origins = ["*"]
67
+
68
+ app.add_middleware(
69
+ CORSMiddleware,
70
+ allow_origins=origins,
71
+ allow_credentials=True,
72
+ allow_methods=["*"],
73
+ allow_headers=["*"],
74
+ )
75
+
76
+ init_db()
77
+
78
+
79
+ @app.get("/")
80
+ def read_root():
81
+ return {"Hello": "World"}
82
+
83
+
84
+ @app.post("/token")
85
+ async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]):
86
+ user = authenticate_user(
87
+ db, form_data.username, form_data.password)
88
+ if not user:
89
+ raise HTTPException(
90
+ status_code=400, detail="Incorrect username or password")
91
+ access_token_expires = timedelta(
92
+ minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
93
+ access_token = create_access_token(
94
+ data={"sub": user.username}, expires_delta=access_token_expires)
95
+ return {"access_token": access_token, "token_type": "bearer"}
96
+
97
+
98
+ router.include_router(user.router, prefix="/users")
99
+ router.include_router(prompt.router, prefix="/prompts")
100
+ router.include_router(generate.router, prefix="/generate")
101
+ app.include_router(router)
app/security.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated
2
+ from datetime import datetime, timedelta, timezone
3
+ from fastapi import Depends, HTTPException
4
+ from fastapi.security import OAuth2PasswordBearer
5
+ from passlib.context import CryptContext
6
+ from sqlalchemy.orm import Session
7
+ from jose import JWTError, jwt
8
+
9
+
10
+ from app.config import ALGORITHM, SECRET_KEY
11
+ from app.db import get_db
12
+ from app.sql import models
13
+
14
+
15
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
16
+
17
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
18
+
19
+
20
+ def verify_password(plain_password, hashed_password):
21
+ return pwd_context.verify(plain_password, hashed_password)
22
+
23
+
24
+ def get_user(db: Session, username: str):
25
+ return db.query(models.User).filter(models.User.username == username).first()
26
+
27
+
28
+ def authenticate_user(db: Session, username: str, password: str):
29
+ user = get_user(db, username)
30
+ if not user:
31
+ return False
32
+ if not verify_password(password, user.password):
33
+ return False
34
+ return user
35
+
36
+
37
+ def create_access_token(data: dict, expires_delta: timedelta | None = None):
38
+ to_encode = data.copy()
39
+ if expires_delta:
40
+ expire = datetime.now(timezone.utc) + expires_delta
41
+ else:
42
+ expire = datetime.now(timezone.utc) + timedelta(minutes=15)
43
+ to_encode.update({"exp": expire})
44
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
45
+ return encoded_jwt
46
+
47
+
48
+ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], db: Annotated[Session, Depends(get_db)]):
49
+ credentials_exception = HTTPException(
50
+ status_code=401, detail="Could not validate credentials")
51
+ try:
52
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
53
+ username: str = payload.get("sub")
54
+ if username is None:
55
+ raise credentials_exception
56
+ except JWTError:
57
+ raise credentials_exception
58
+ user = get_user(db, username)
59
+ if user is None:
60
+ raise credentials_exception
61
+ return user
app/sql/__init__.py ADDED
File without changes
app/sql/database.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.config import DATABASE_URL
2
+ from sqlalchemy import create_engine
3
+ from sqlalchemy.orm import sessionmaker
4
+ from sqlalchemy.ext.declarative import declarative_base
5
+
6
+
7
+ engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
8
+
9
+ SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
10
+
11
+ Base = declarative_base()
app/sql/models.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from app.sql.database import Base
3
+ from sqlalchemy.orm import relationship
4
+ from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String
5
+
6
+
7
+ class User(Base):
8
+ __tablename__ = 'users'
9
+
10
+ user_id = Column(Integer, primary_key=True, index=True)
11
+ username = Column(String, index=True, unique=True)
12
+ password = Column(String)
13
+ is_superuser = Column(Boolean, default=False)
14
+ is_active = Column(Boolean, default=True)
15
+ created_at = Column(DateTime, default=datetime.now)
16
+ updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
17
+
18
+ prompts = relationship("Prompt", back_populates="user")
19
+
20
+
21
+ class Prompt(Base):
22
+ __tablename__ = 'prompts'
23
+
24
+ id = Column(Integer, primary_key=True)
25
+ prompt = Column(String, index=True)
26
+ created_at = Column(DateTime, default=datetime.now)
27
+ user_id = Column(Integer, ForeignKey('users.user_id'))
28
+
29
+ user = relationship("User", back_populates="prompts")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ sqlalchemy
4
+ psycopg2-binary
5
+ python-multipart
6
+ python-jose[cryptography]
7
+ passlib[bcrypt]
sql_app.db ADDED
Binary file (24.6 kB). View file