khawir commited on
Commit
67e167f
·
verified ·
1 Parent(s): b3e0a01

Upload 4 files

Browse files
Files changed (3) hide show
  1. app/api/generate.py +236 -158
  2. app/api/prompt.py +45 -34
  3. app/api/user.py +69 -55
app/api/generate.py CHANGED
@@ -1,158 +1,236 @@
1
- from app.db import get_db
2
- from app.config import DEVICE
3
- from app.core import schemas, crud
4
- from app.security import get_current_user
5
- from app.core.schemas import TextImage, ImageImage, BackgroundRemoval, ImageVariations
6
-
7
- import base64
8
- from io import BytesIO
9
- from sqlalchemy.orm import Session
10
- from typing import Annotated, List
11
- from fastapi import APIRouter, Depends, HTTPException, Request
12
-
13
- import torch
14
- import numpy as np
15
- from PIL import Image
16
- import torch.nn.functional as F
17
- from torchvision.transforms.functional import normalize
18
-
19
-
20
- router = APIRouter()
21
-
22
-
23
- def decode_image(image):
24
- return Image.open(BytesIO(base64.b64decode(image))).convert("RGB")
25
-
26
-
27
- def encode_image(image):
28
- bytes = BytesIO()
29
- image.save(bytes, format="PNG")
30
- return base64.b64encode(bytes.getvalue())
31
-
32
-
33
- def create_prompt(subject, medium, style, artist, website, resolution, additional_details, color, lightning):
34
- if not subject:
35
- return None
36
- if medium:
37
- subject = f"{medium} of {subject}"
38
- if style:
39
- subject = f"{subject}, {style}"
40
- if artist:
41
- subject = f"{subject}, by {artist}"
42
- if website:
43
- subject = f"{subject}, {website}"
44
- if resolution:
45
- subject = f"{subject}, {resolution}"
46
- if additional_details:
47
- subject = f"{subject}, {additional_details}"
48
- if color:
49
- subject = f"{subject}, {color}"
50
- if lightning:
51
- subject = f"{subject}, {lightning}"
52
- return subject
53
-
54
-
55
- @router.post("/text-image/", response_model=str)
56
- def text_image(model: Request, request: TextImage, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
57
- if not current_user.is_active:
58
- raise HTTPException(status_code=403, detail="Forbidden")
59
-
60
- generator = torch.manual_seed(request.seed)
61
- prompt = create_prompt(request.prompt, medium=request.medium, style=request.style, artist=request.artist, website=request.website, resolution=request.resolution, additional_details=request.additional_details, color=request.color, lightning=request.lightning)
62
-
63
- crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
64
-
65
- image = model.state.ti_pipe(prompt, num_inference_steps=request.num_inference_steps,
66
- guidance_scale=request.guidance_scale, generator=generator, negative_prompt=request.negative_prompt).images[0]
67
-
68
- return encode_image(image)
69
-
70
-
71
- @router.post("/image-image/", response_model=str)
72
- def image_image(model: Request, request: ImageImage, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
73
- if not current_user.is_active:
74
- raise HTTPException(status_code=403, detail="Forbidden")
75
-
76
- generator = torch.manual_seed(request.seed)
77
- prompt = create_prompt(request.prompt, medium=request.medium, style=request.style, artist=request.artist, website=request.website, resolution=request.resolution, additional_details=request.additional_details, color=request.color, lightning=request.lightning)
78
- image = decode_image(request.image)
79
-
80
- crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
81
-
82
- image = model.state.ii_pipe(prompt, image=image, num_inference_steps=request.num_inference_steps, guidance_scale=request.guidance_scale,
83
- image_guidance_scale=request.image_guidance_scale, generator=generator, negative_prompt=request.negative_prompt).images[0]
84
-
85
- return encode_image(image)
86
-
87
-
88
- @router.post("/background-removal/", response_model=str)
89
- def background_removal(model: Request, request: BackgroundRemoval, current_user: Annotated[schemas.User, Depends(get_current_user)]):
90
- if not current_user.is_active:
91
- raise HTTPException(status_code=403, detail="Forbidden")
92
-
93
- image = decode_image(request.image)
94
-
95
- def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
96
- if len(im.shape) < 3:
97
- im = im[:, :, np.newaxis]
98
- # orig_im_size=im.shape[0:2]
99
- im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
100
- im_tensor = F.interpolate(torch.unsqueeze(
101
- im_tensor, 0), size=model_input_size, mode='bilinear')
102
- image = torch.divide(im_tensor, 255.0)
103
- image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
104
- return image
105
-
106
- def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
107
- result = torch.squeeze(F.interpolate(
108
- result, size=im_size, mode='bilinear'), 0)
109
- ma = torch.max(result)
110
- mi = torch.min(result)
111
- result = (result-mi)/(ma-mi)
112
- im_array = (result*255).permute(1, 2,
113
- 0).cpu().data.numpy().astype(np.uint8)
114
- im_array = np.squeeze(im_array)
115
- return im_array
116
-
117
- # prepare input
118
- model_input_size = [1024, 1024]
119
- orig_im = np.array(image)
120
- orig_im_size = orig_im.shape[0:2]
121
- image = preprocess_image(orig_im, model_input_size).to(DEVICE)
122
-
123
- # inference
124
- result = model.state.br_model(image)
125
-
126
- # post process
127
- result_image = postprocess_image(result[0][0], orig_im_size)
128
-
129
- # save result
130
- pil_im = Image.fromarray(result_image)
131
- no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
132
- orig_image = Image.fromarray(orig_im)
133
- no_bg_image.paste(orig_image, mask=pil_im)
134
-
135
- return encode_image(no_bg_image)
136
-
137
-
138
- @router.post("/image-variations/", response_model=List[str])
139
- def image_variations(model: Request, request: ImageVariations, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
140
- if not current_user.is_active:
141
- raise HTTPException(status_code=403, detail="Forbidden")
142
-
143
- prompt = create_prompt(request.prompt, medium=request.medium, style=request.style, artist=request.artist, website=request.website, resolution=request.resolution, additional_details=request.additional_details, color=request.color, lightning=request.lightning)
144
- image = decode_image(request.image)
145
- image.resize((256, 256))
146
-
147
- if prompt:
148
- prompt = f"best quality, high quality, {prompt}"
149
- crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
150
- else:
151
- request.scale = 1.0
152
-
153
- images = model.state.iv_model.generate(pil_image=image, num_samples=request.num_samples, num_inference_steps=request.num_inference_steps, seed=request.seed,
154
- prompt=prompt, negative_prompt=request.negative_prompt, scale=request.scale, guidance_scale=request.guidance_scale)
155
-
156
- images = [encode_image(image) for image in images]
157
-
158
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.db import get_db
2
+ from app.config import DEVICE
3
+ from app.core import schemas, crud
4
+ from app.security import get_current_user
5
+ from app.core.schemas import TextImage, ImageImage, BackgroundRemoval, ImageVariations
6
+
7
+ import base64
8
+ from io import BytesIO
9
+ from sqlalchemy.orm import Session
10
+ from typing import Annotated, List
11
+ from fastapi import APIRouter, Depends, HTTPException, Request
12
+
13
+ import torch
14
+ import numpy as np
15
+ from PIL import Image
16
+ import torch.nn.functional as F
17
+ from torchvision.transforms.functional import normalize
18
+
19
+
20
+ router = APIRouter()
21
+
22
+
23
+ def decode_image(image):
24
+ return Image.open(BytesIO(base64.b64decode(image))).convert("RGB")
25
+
26
+
27
+ def encode_image(image):
28
+ bytes = BytesIO()
29
+ image.save(bytes, format="PNG")
30
+ return base64.b64encode(bytes.getvalue())
31
+
32
+
33
+ def create_prompt(
34
+ subject,
35
+ medium,
36
+ style,
37
+ artist,
38
+ website,
39
+ resolution,
40
+ additional_details,
41
+ color,
42
+ lightning,
43
+ ):
44
+ if not subject:
45
+ return None
46
+ if medium:
47
+ subject = f"{medium} of {subject}"
48
+ if style:
49
+ subject = f"{subject}, {style}"
50
+ if artist:
51
+ subject = f"{subject}, by {artist}"
52
+ if website:
53
+ subject = f"{subject}, {website}"
54
+ if resolution:
55
+ subject = f"{subject}, {resolution}"
56
+ if additional_details:
57
+ subject = f"{subject}, {additional_details}"
58
+ if color:
59
+ subject = f"{subject}, {color}"
60
+ if lightning:
61
+ subject = f"{subject}, {lightning}"
62
+ return subject
63
+
64
+
65
+ @router.post("/text-image/", response_model=str)
66
+ def text_image(
67
+ model: Request,
68
+ request: TextImage,
69
+ db: Annotated[Session, Depends(get_db)],
70
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
71
+ ):
72
+ if not current_user.is_active:
73
+ raise HTTPException(status_code=403, detail="Forbidden")
74
+
75
+ generator = torch.manual_seed(request.seed)
76
+ prompt = create_prompt(
77
+ request.prompt,
78
+ medium=request.medium,
79
+ style=request.style,
80
+ artist=request.artist,
81
+ website=request.website,
82
+ resolution=request.resolution,
83
+ additional_details=request.additional_details,
84
+ color=request.color,
85
+ lightning=request.lightning,
86
+ )
87
+
88
+ crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
89
+
90
+ image = model.state.ti_pipe(
91
+ prompt,
92
+ num_inference_steps=request.num_inference_steps,
93
+ guidance_scale=request.guidance_scale,
94
+ generator=generator,
95
+ negative_prompt=request.negative_prompt,
96
+ ).images[0]
97
+
98
+ return encode_image(image)
99
+
100
+
101
+ @router.post("/image-image/", response_model=str)
102
+ def image_image(
103
+ model: Request,
104
+ request: ImageImage,
105
+ db: Annotated[Session, Depends(get_db)],
106
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
107
+ ):
108
+ if not current_user.is_active:
109
+ raise HTTPException(status_code=403, detail="Forbidden")
110
+
111
+ generator = torch.manual_seed(request.seed)
112
+ prompt = create_prompt(
113
+ request.prompt,
114
+ medium=request.medium,
115
+ style=request.style,
116
+ artist=request.artist,
117
+ website=request.website,
118
+ resolution=request.resolution,
119
+ additional_details=request.additional_details,
120
+ color=request.color,
121
+ lightning=request.lightning,
122
+ )
123
+ image = decode_image(request.image)
124
+
125
+ crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
126
+
127
+ image = model.state.ii_pipe(
128
+ prompt,
129
+ image=image,
130
+ num_inference_steps=request.num_inference_steps,
131
+ guidance_scale=request.guidance_scale,
132
+ image_guidance_scale=request.image_guidance_scale,
133
+ generator=generator,
134
+ negative_prompt=request.negative_prompt,
135
+ ).images[0]
136
+
137
+ return encode_image(image)
138
+
139
+
140
+ @router.post("/background-removal/", response_model=str)
141
+ def background_removal(
142
+ model: Request,
143
+ request: BackgroundRemoval,
144
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
145
+ ):
146
+ if not current_user.is_active:
147
+ raise HTTPException(status_code=403, detail="Forbidden")
148
+
149
+ image = decode_image(request.image)
150
+
151
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
152
+ if len(im.shape) < 3:
153
+ im = im[:, :, np.newaxis]
154
+ # orig_im_size=im.shape[0:2]
155
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
156
+ im_tensor = F.interpolate(
157
+ torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
158
+ )
159
+ image = torch.divide(im_tensor, 255.0)
160
+ image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
161
+ return image
162
+
163
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
164
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
165
+ ma = torch.max(result)
166
+ mi = torch.min(result)
167
+ result = (result - mi) / (ma - mi)
168
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
169
+ im_array = np.squeeze(im_array)
170
+ return im_array
171
+
172
+ # prepare input
173
+ model_input_size = [1024, 1024]
174
+ orig_im = np.array(image)
175
+ orig_im_size = orig_im.shape[0:2]
176
+ image = preprocess_image(orig_im, model_input_size).to(DEVICE)
177
+
178
+ # inference
179
+ result = model.state.br_model(image)
180
+
181
+ # post process
182
+ result_image = postprocess_image(result[0][0], orig_im_size)
183
+
184
+ # save result
185
+ pil_im = Image.fromarray(result_image)
186
+ no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
187
+ orig_image = Image.fromarray(orig_im)
188
+ no_bg_image.paste(orig_image, mask=pil_im)
189
+
190
+ return encode_image(no_bg_image)
191
+
192
+
193
+ @router.post("/image-variations/", response_model=List[str])
194
+ def image_variations(
195
+ model: Request,
196
+ request: ImageVariations,
197
+ db: Annotated[Session, Depends(get_db)],
198
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
199
+ ):
200
+ if not current_user.is_active:
201
+ raise HTTPException(status_code=403, detail="Forbidden")
202
+
203
+ prompt = create_prompt(
204
+ request.prompt,
205
+ medium=request.medium,
206
+ style=request.style,
207
+ artist=request.artist,
208
+ website=request.website,
209
+ resolution=request.resolution,
210
+ additional_details=request.additional_details,
211
+ color=request.color,
212
+ lightning=request.lightning,
213
+ )
214
+ image = decode_image(request.image)
215
+ image.resize((256, 256))
216
+
217
+ if prompt:
218
+ prompt = f"best quality, high quality, {prompt}"
219
+ crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
220
+ else:
221
+ request.scale = 1.0
222
+
223
+ images = model.state.iv_model.generate(
224
+ pil_image=image,
225
+ num_samples=request.num_samples,
226
+ num_inference_steps=request.num_inference_steps,
227
+ seed=request.seed,
228
+ prompt=prompt,
229
+ negative_prompt=request.negative_prompt,
230
+ scale=request.scale,
231
+ guidance_scale=request.guidance_scale,
232
+ )
233
+
234
+ images = [encode_image(image) for image in images]
235
+
236
+ return images
app/api/prompt.py CHANGED
@@ -1,34 +1,45 @@
1
- from typing import Annotated, List
2
- from sqlalchemy.orm import Session
3
- from fastapi import APIRouter, Depends, HTTPException
4
-
5
- from app.db import get_db
6
- from app.core import schemas, crud
7
- from app.security import get_current_user
8
-
9
-
10
- router = APIRouter()
11
-
12
-
13
- @router.post("/get-all-prompts/", response_model=List[schemas.Prompt])
14
- def get_all_prompts(db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
15
- if not current_user.is_superuser:
16
- raise HTTPException(status_code=403, detail="Forbidden")
17
-
18
- return crud.get_all_prompts(db=db)
19
-
20
-
21
- @ router.post("/get-prompt_by_user_id/{user_id}/", response_model=List[schemas.Prompt])
22
- def get_prompt_by_user_id(user_id: int, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
23
- if not current_user.is_superuser:
24
- raise HTTPException(status_code=403, detail="Forbidden")
25
-
26
- return crud.get_prompt_by_user_id(user_id=user_id, db=db)
27
-
28
-
29
- # @ router.post("/create-prompt/", response_model=schemas.Prompt)
30
- # def create_prompt(prompt: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
31
- # if not current_user.is_superuser:
32
- # raise HTTPException(status_code=403, detail="Forbidden")
33
-
34
- # return crud.create_prompt(prompt=prompt, db=db)
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, List
2
+ from sqlalchemy.orm import Session
3
+ from fastapi import APIRouter, Depends, HTTPException
4
+
5
+ from app.db import get_db
6
+ from app.core import schemas, crud
7
+ from app.security import get_current_user
8
+
9
+
10
+ router = APIRouter()
11
+
12
+
13
+ @router.get("/get-all-prompts/", response_model=List[schemas.Prompt])
14
+ def get_all_prompts(
15
+ db: Annotated[Session, Depends(get_db)],
16
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
17
+ ):
18
+ if not current_user.is_superuser:
19
+ raise HTTPException(status_code=403, detail="Forbidden")
20
+
21
+ return crud.get_all_prompts(db=db)
22
+
23
+
24
+ @router.get("/get-prompt_by_user_id/{user_id}/", response_model=List[schemas.Prompt])
25
+ def get_prompt_by_user_id(
26
+ user_id: int,
27
+ db: Annotated[Session, Depends(get_db)],
28
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
29
+ ):
30
+ if not current_user.is_superuser:
31
+ raise HTTPException(status_code=403, detail="Forbidden")
32
+
33
+ return crud.get_prompt_by_user_id(user_id=user_id, db=db)
34
+
35
+
36
+ # @router.post("/create-prompt/", response_model=schemas.Prompt)
37
+ # def create_prompt(
38
+ # prompt: str,
39
+ # db: Annotated[Session, Depends(get_db)],
40
+ # current_user: Annotated[schemas.User, Depends(get_current_user)],
41
+ # ):
42
+ # if not current_user.is_superuser:
43
+ # raise HTTPException(status_code=403, detail="Forbidden")
44
+
45
+ # return crud.create_prompt(prompt=prompt, db=db)
app/api/user.py CHANGED
@@ -1,55 +1,69 @@
1
- from typing import Annotated, List
2
- from sqlalchemy.orm import Session
3
- from fastapi import APIRouter, Depends, HTTPException
4
-
5
- from app.db import get_db
6
- from app.core import schemas, crud
7
- from app.security import get_current_user
8
-
9
-
10
- router = APIRouter()
11
-
12
-
13
- @router.post("/create-user/", response_model=schemas.User)
14
- def create_user(user: schemas.UserCreate, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
15
- if not current_user.is_superuser:
16
- raise HTTPException(status_code=403, detail="Forbidden")
17
-
18
- user_exists = crud.get_user_by_username(username=user.username, db=db)
19
- if user_exists:
20
- raise HTTPException(
21
- status_code=400, detail="Username already registered")
22
-
23
- return crud.create_user(user=user, db=db)
24
-
25
-
26
- @router.post("/update-user/", response_model=schemas.User)
27
- def update_user(user: schemas.UserUpdate, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
28
- if not current_user.is_superuser:
29
- raise HTTPException(status_code=403, detail="Forbidden")
30
-
31
- user_exists = crud.get_user_by_user_id(user_id=user.user_id, db=db)
32
- if not user_exists:
33
- raise HTTPException(status_code=404, detail="User not found")
34
-
35
- return crud.update_user(user=user, db=db)
36
-
37
-
38
- @router.post("/get-all-users/", response_model=List[schemas.User])
39
- def get_all_users(db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
40
- if not current_user.is_superuser:
41
- raise HTTPException(status_code=403, detail="Forbidden")
42
-
43
- return crud.get_all_users(db=db)
44
-
45
-
46
- @router.post("/get-user_by_user_id/{user_id}/", response_model=schemas.User)
47
- def get_user_by_user_id(user_id: int, db: Annotated[Session, Depends(get_db)], current_user: Annotated[schemas.User, Depends(get_current_user)]):
48
- if not current_user.is_superuser:
49
- raise HTTPException(status_code=403, detail="Forbidden")
50
-
51
- user = crud.get_user_by_user_id(user_id=user_id, db=db)
52
- if user is None:
53
- raise HTTPException(status_code=404, detail="User not found")
54
-
55
- return user
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, List
2
+ from sqlalchemy.orm import Session
3
+ from fastapi import APIRouter, Depends, HTTPException
4
+
5
+ from app.db import get_db
6
+ from app.core import schemas, crud
7
+ from app.security import get_current_user
8
+
9
+
10
+ router = APIRouter()
11
+
12
+
13
+ @router.post("/create-user/", response_model=schemas.User)
14
+ def create_user(
15
+ user: schemas.UserCreate,
16
+ db: Annotated[Session, Depends(get_db)],
17
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
18
+ ):
19
+ if not current_user.is_superuser:
20
+ raise HTTPException(status_code=403, detail="Forbidden")
21
+
22
+ user_exists = crud.get_user_by_username(username=user.username, db=db)
23
+ if user_exists:
24
+ raise HTTPException(status_code=400, detail="Username already registered")
25
+
26
+ return crud.create_user(user=user, db=db)
27
+
28
+
29
+ @router.put("/update-user/", response_model=schemas.User)
30
+ def update_user(
31
+ user: schemas.UserUpdate,
32
+ db: Annotated[Session, Depends(get_db)],
33
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
34
+ ):
35
+ if not current_user.is_superuser:
36
+ raise HTTPException(status_code=403, detail="Forbidden")
37
+
38
+ user_exists = crud.get_user_by_user_id(user_id=user.user_id, db=db)
39
+ if not user_exists:
40
+ raise HTTPException(status_code=404, detail="User not found")
41
+
42
+ return crud.update_user(user=user, db=db)
43
+
44
+
45
+ @router.get("/get-all-users/", response_model=List[schemas.User])
46
+ def get_all_users(
47
+ db: Annotated[Session, Depends(get_db)],
48
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
49
+ ):
50
+ if not current_user.is_superuser:
51
+ raise HTTPException(status_code=403, detail="Forbidden")
52
+
53
+ return crud.get_all_users(db=db)
54
+
55
+
56
+ @router.get("/get-user_by_user_id/{user_id}/", response_model=schemas.User)
57
+ def get_user_by_user_id(
58
+ user_id: int,
59
+ db: Annotated[Session, Depends(get_db)],
60
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
61
+ ):
62
+ if not current_user.is_superuser:
63
+ raise HTTPException(status_code=403, detail="Forbidden")
64
+
65
+ user = crud.get_user_by_user_id(user_id=user_id, db=db)
66
+ if user is None:
67
+ raise HTTPException(status_code=404, detail="User not found")
68
+
69
+ return user