import os from typing import List from fastapi import APIRouter from src.libs.image import save_img, generate_img_index from src.libs.model import CganCols, get_model from src.libs.s3 import s3client from src.models.generate import GenerateResult, ImageResult from src.models.main import User, Method, UserIndex IMAGE_STORE_PATH = os.path.abspath("./src/store") BUCKET_NAME = "pimthaigans-image-container" # just make sure to have IMAGE_STORE_PATH folder created if not os.path.exists(IMAGE_STORE_PATH): os.makedirs(IMAGE_STORE_PATH) router = APIRouter( prefix="/generate", tags=["Generate"], responses={404: {"description": "Not found"}}, ) model = CganCols() @router.get("/") async def info(): return {"info": "This is the generate endpoint"} @router.get("/status") async def status(): return {"status": "OK"} @router.post("/") async def generate(user: UserIndex) -> GenerateResult: if user.method == Method.index: result: GenerateResult = await generate_index(user.user, user.index) return result result: GenerateResult = await generate_all(user.user) return result async def generate_index(user: User, index: int) -> GenerateResult: s3 = s3client() img_detail = s3uploadimage(user, s3, index) result: List[ImageResult] = [img_detail] s3.close() return GenerateResult(user=user, method=Method.index, result=result) async def generate_all(user: User): s3 = s3client() result: List[ImageResult] = [] for index in range(0, 88): img_detail = s3uploadimage(user, s3, index) result.append(img_detail) s3.close() return GenerateResult(user=user, method=Method.all, result=result) def s3uploadimage(user, s3, index): output_path = os.path.join( IMAGE_STORE_PATH, f"{user.uuid}-{str(index).zfill(2)}.png") used_model = model.model_cols[get_model(index)] image = generate_img_index(reloaded_model=used_model, index=index % 11) save_img(image, output_path) s3_path: str = f"{user.uuid}/{str(index).zfill(2)}.png" s3.upload_file(output_path, BUCKET_NAME, s3_path) image_url = f'https://{BUCKET_NAME}.s3.amazonaws.com/{s3_path}' img_detail = ImageResult(index=index, image_url=image_url) os.remove(output_path) return img_detail