Sanket17's picture
added all files
5fbd25d
"""function for call generate worker"""
from typing import List
from fastapi import Response
from fooocusapi.models.common.base import GenerateMaskRequest
from fooocusapi.models.common.requests import (
CommonRequest as Text2ImgRequest
)
from fooocusapi.models.common.response import (
AsyncJobResponse,
GeneratedImageResult
)
from fooocusapi.models.common.task import (
GenerationFinishReason,
ImageGenerationResult,
AsyncJobStage,
TaskType
)
from fooocusapi.utils.api_utils import (
req_to_params,
generate_async_output,
generate_streaming_output,
generate_image_result_output
)
from fooocusapi.models.requests_v1 import (
ImageEnhanceRequest, ImgUpscaleOrVaryRequest,
ImgPromptRequest,
ImgInpaintOrOutpaintRequest
)
from fooocusapi.models.requests_v2 import (
ImageEnhanceRequestJson, ImgInpaintOrOutpaintRequestJson,
ImgPromptRequestJson,
ImgUpscaleOrVaryRequestJson
)
from fooocusapi.utils.img_utils import narray_to_base64img, read_input_image
from fooocusapi.worker import worker_queue, blocking_get_task_result
from extras.inpaint_mask import generate_mask_from_image, SAMOptions
def get_task_type(req: Text2ImgRequest) -> TaskType:
"""return task type"""
if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)):
return TaskType.img_uov
if isinstance(req, (ImgPromptRequest, ImgPromptRequestJson)):
return TaskType.img_prompt
if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)):
return TaskType.img_inpaint_outpaint
if isinstance(req, (ImageEnhanceRequestJson, ImageEnhanceRequest)):
return TaskType.img_enhance
return TaskType.text_2_img
def call_worker(req: Text2ImgRequest, accept: str) -> Response | AsyncJobResponse | List[GeneratedImageResult]:
"""call generation worker"""
if accept == 'image/png':
streaming_output = True
# image_number auto set to 1 in streaming mode
req.image_number = 1
else:
streaming_output = False
task_type = get_task_type(req)
params = req_to_params(req)
async_task = worker_queue.add_task(task_type, params, req.webhook_url)
if async_task is None:
# add to worker queue failed
failure_results = [
ImageGenerationResult(
im=None,
seed='',
finish_reason=GenerationFinishReason.queue_is_full
)]
if streaming_output:
return generate_streaming_output(failure_results)
if req.async_process:
return AsyncJobResponse(
job_id='',
job_type=get_task_type(req),
job_stage=AsyncJobStage.error,
job_progress=0,
job_status=None,
job_step_preview=None,
job_result=[GeneratedImageResult(
base64=None,
url=None,
seed='',
finish_reason=GenerationFinishReason.queue_is_full
)])
return generate_image_result_output(failure_results, False)
if req.async_process:
# return async response directly
return generate_async_output(async_task)
# blocking get generation result
results = blocking_get_task_result(async_task.job_id)
if streaming_output:
return generate_streaming_output(results)
return generate_image_result_output(results, req.require_base64)
async def generate_mask(request: GenerateMaskRequest):
"""
Calls the worker with the given params.
:param request: The request object containing the params.
:return: The result of the task.
"""
extras = {}
sam_options = None
image = read_input_image(request.image)
if request.mask_model == 'u2net_cloth_seg':
extras['cloth_category'] = request.cloth_category
elif request.mask_model == 'sam':
sam_options = SAMOptions(
dino_prompt=request.dino_prompt_text,
dino_box_threshold=request.box_threshold,
dino_text_threshold=request.text_threshold,
dino_erode_or_dilate=request.dino_erode_or_dilate,
dino_debug=request.dino_debug,
max_detections=request.sam_max_detections,
model_type=request.sam_model
)
mask, _, _, _ = generate_mask_from_image(image, request.mask_model, extras, sam_options)
return narray_to_base64img(mask)