import asyncio import base64 import logging from io import BytesIO from pathlib import Path import uvicorn from config import Config from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from PIL import Image from pydantic import BaseModel from wrapper import StreamDiffusionWrapper logger = logging.getLogger("uvicorn") PROJECT_DIR = Path(__file__).parent.parent class PredictInputModel(BaseModel): """ The input model for the /predict endpoint. """ prompt: str class PredictResponseModel(BaseModel): """ The response model for the /predict endpoint. """ base64_images: list[str] class UpdatePromptResponseModel(BaseModel): """ The response model for the /update_prompt endpoint. """ prompt: str class Api: def __init__(self, config: Config) -> None: """ Initialize the API. Parameters ---------- config : Config The configuration. """ self.config = config self.stream_diffusion = StreamDiffusionWrapper( model_id=config.model_id, lcm_lora_id=config.lcm_lora_id, vae_id=config.vae_id, device=config.device, dtype=config.dtype, t_index_list=config.t_index_list, warmup=config.warmup, ) self.app = FastAPI() self.app.add_api_route( "/api/predict", self._predict, methods=["POST"], response_model=PredictResponseModel, ) self.app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) self.app.mount( "/", StaticFiles(directory="../view/build", html=True), name="public" ) self._predict_lock = asyncio.Lock() self._update_prompt_lock = asyncio.Lock() self.last_prompt: str = "" self.last_images: list[str] = [""] async def _predict(self, inp: PredictInputModel) -> PredictResponseModel: """ Predict an image and return. Parameters ---------- inp : PredictInputModel The input. Returns ------- PredictResponseModel The prediction result. """ async with self._predict_lock: if ( self._calc_levenstein_distance(inp.prompt, self.last_prompt) < self.config.levenstein_distance_threshold ): logger.info("Using cached images") return PredictResponseModel(base64_images=self.last_images) self.last_prompt = inp.prompt self.last_images = [self._pil_to_base64(image) for image in self.stream_diffusion(inp.prompt)] return PredictResponseModel(base64_images=self.last_images) def _pil_to_base64(self, image: Image.Image, format: str = "JPEG") -> bytes: """ Convert a PIL image to base64. Parameters ---------- image : Image.Image The PIL image. format : str The image format, by default "JPEG". Returns ------- bytes The base64 image. """ buffered = BytesIO() image.convert("RGB").save(buffered, format=format) return base64.b64encode(buffered.getvalue()).decode("ascii") def _base64_to_pil(self, base64_image: str) -> Image.Image: """ Convert a base64 image to PIL. Parameters ---------- base64_image : str The base64 image. Returns ------- Image.Image The PIL image. """ if "base64," in base64_image: base64_image = base64_image.split("base64,")[1] return Image.open(BytesIO(base64.b64decode(base64_image))).convert("RGB") def _calc_levenstein_distance(self, a: str, b: str) -> int: """ Calculate the Levenstein distance between two strings. Parameters ---------- a : str The first string. b : str The second string. Returns ------- int The Levenstein distance. """ if a == b: return 0 a_k = len(a) b_k = len(b) if a == "": return b_k if b == "": return a_k matrix = [[] for i in range(a_k + 1)] for i in range(a_k + 1): matrix[i] = [0 for j in range(b_k + 1)] for i in range(a_k + 1): matrix[i][0] = i for j in range(b_k + 1): matrix[0][j] = j for i in range(1, a_k + 1): ac = a[i - 1] for j in range(1, b_k + 1): bc = b[j - 1] cost = 0 if (ac == bc) else 1 matrix[i][j] = min( [ matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + cost, ] ) return matrix[a_k][b_k] if __name__ == "__main__": from config import Config config = Config() uvicorn.run( Api(config).app, host=config.host, port=config.port, workers=config.workers, )