File size: 4,485 Bytes
5fbd25d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""function for call generate worker"""
from typing import List
from fastapi import Response

from fooocusapi.models.common.base import GenerateMaskRequest
from fooocusapi.models.common.requests import (
    CommonRequest as Text2ImgRequest
)
from fooocusapi.models.common.response import (
    AsyncJobResponse,
    GeneratedImageResult
)
from fooocusapi.models.common.task import (
    GenerationFinishReason,
    ImageGenerationResult,
    AsyncJobStage,
    TaskType
)
from fooocusapi.utils.api_utils import (
    req_to_params,
    generate_async_output,
    generate_streaming_output,
    generate_image_result_output
)
from fooocusapi.models.requests_v1 import (
    ImageEnhanceRequest, ImgUpscaleOrVaryRequest,
    ImgPromptRequest,
    ImgInpaintOrOutpaintRequest
)
from fooocusapi.models.requests_v2 import (
    ImageEnhanceRequestJson, ImgInpaintOrOutpaintRequestJson,
    ImgPromptRequestJson,
    ImgUpscaleOrVaryRequestJson
)
from fooocusapi.utils.img_utils import narray_to_base64img, read_input_image
from fooocusapi.worker import worker_queue, blocking_get_task_result

from extras.inpaint_mask import generate_mask_from_image, SAMOptions


def get_task_type(req: Text2ImgRequest) -> TaskType:
    """return task type"""
    if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)):
        return TaskType.img_uov
    if isinstance(req, (ImgPromptRequest, ImgPromptRequestJson)):
        return TaskType.img_prompt
    if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)):
        return TaskType.img_inpaint_outpaint
    if isinstance(req, (ImageEnhanceRequestJson, ImageEnhanceRequest)):
        return TaskType.img_enhance
    return TaskType.text_2_img


def call_worker(req: Text2ImgRequest, accept: str) -> Response | AsyncJobResponse | List[GeneratedImageResult]:
    """call generation worker"""
    if accept == 'image/png':
        streaming_output = True
        # image_number auto set to 1 in streaming mode
        req.image_number = 1
    else:
        streaming_output = False

    task_type = get_task_type(req)
    params = req_to_params(req)
    async_task = worker_queue.add_task(task_type, params, req.webhook_url)

    if async_task is None:
        # add to worker queue failed
        failure_results = [
            ImageGenerationResult(
                im=None,
                seed='',
                finish_reason=GenerationFinishReason.queue_is_full
            )]

        if streaming_output:
            return generate_streaming_output(failure_results)
        if req.async_process:
            return AsyncJobResponse(
                job_id='',
                job_type=get_task_type(req),
                job_stage=AsyncJobStage.error,
                job_progress=0,
                job_status=None,
                job_step_preview=None,
                job_result=[GeneratedImageResult(
                    base64=None,
                    url=None,
                    seed='',
                    finish_reason=GenerationFinishReason.queue_is_full
                )])
        return generate_image_result_output(failure_results, False)

    if req.async_process:
        # return async response directly
        return generate_async_output(async_task)

    # blocking get generation result
    results = blocking_get_task_result(async_task.job_id)

    if streaming_output:
        return generate_streaming_output(results)
    return generate_image_result_output(results, req.require_base64)


async def generate_mask(request: GenerateMaskRequest):
    """
    Calls the worker with the given params.
    :param request: The request object containing the params.
    :return: The result of the task.
    """
    extras = {}
    sam_options = None
    image = read_input_image(request.image)
    if request.mask_model == 'u2net_cloth_seg':
        extras['cloth_category'] = request.cloth_category
    elif request.mask_model == 'sam':
        sam_options = SAMOptions(
            dino_prompt=request.dino_prompt_text,
            dino_box_threshold=request.box_threshold,
            dino_text_threshold=request.text_threshold,
            dino_erode_or_dilate=request.dino_erode_or_dilate,
            dino_debug=request.dino_debug,
            max_detections=request.sam_max_detections,
            model_type=request.sam_model
        )

    mask, _, _, _ = generate_mask_from_image(image, request.mask_model, extras, sam_options)
    return narray_to_base64img(mask)