File size: 3,521 Bytes
09321b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import re

import cv2
import dashscope
import json
from dashscope import ImageSynthesis
from ..output_wrapper import ImageWrapper

from modelscope.utils.constant import Tasks
from .pipeline_tool import ModelscopePipelineTool


class TextToImageTool(ModelscopePipelineTool):
    default_model = 'AI-ModelScope/stable-diffusion-xl-base-1.0'
    description = 'AI绘画(图像生成)服务,输入文本描述和图像分辨率,返回根据文本信息绘制的图片URL。'
    name = 'image_gen'
    parameters: list = [{
        'name': 'text',
        'description': '详细描述了希望生成的图像具有什么内容,例如人物、环境、动作等细节描述',
        'required': True,
        'schema': {
            'type': 'string'
        }
    }, {
        'name': 'resolution',
        'description':
        '格式是 数字*数字,表示希望生成的图像的分辨率大小,选项有[1024*1024, 720*1280, 1280*720]',
        'required': True,
        'schema': {
            'type': 'string'
        }
    }]
    model_revision = 'v1.0.0'
    task = Tasks.text_to_image_synthesis

    # def _remote_parse_input(self, *args, **kwargs):
    #     params = {
    #         'input': {
    #             'text': kwargs['text'],
    #             'resolution': kwargs['resolution']
    #         }
    #     }
    #     if kwargs.get('seed', None):
    #         params['input']['seed'] = kwargs['seed']
    #     return params

    def _remote_call(self, *args, **kwargs):

        if ('resolution' in kwargs) and (kwargs['resolution'] in [
                '1024*1024', '720*1280', '1280*720'
        ]):
            resolution = kwargs['resolution']
        else:
            resolution = '1280*720'

        prompt = kwargs['text']
        seed = kwargs.get('seed', None)
        if prompt is None:
            return None
        dashscope.api_key = os.getenv('DASHSCOPE_API_KEY')
        response = ImageSynthesis.call(
            model=ImageSynthesis.Models.wanx_v1,
            prompt=prompt,
            n=1,
            size=resolution,
            steps=10,
            seed=seed)
        final_result = self._parse_output(response, remote=True)
        return final_result

    def _local_parse_input(self, *args, **kwargs):

        text = kwargs.pop('text', '')

        parsed_args = ({'text': text}, )

        return parsed_args, {}

    def _parse_output(self, origin_result, remote=True):
        if not remote:
            image = cv2.cvtColor(origin_result['output_imgs'][0],
                                 cv2.COLOR_BGR2RGB)
        else:
            image = origin_result.output['results'][0]['url']

        return {'result': ImageWrapper(image)}

    def _handle_input_fallback(self, **kwargs):
        """
        an alternative method is to parse image is that get item between { and }
        for last try

        :param fallback_text:
        :return: language, cocde
        """

        text = kwargs.get('text', None)
        fallback = kwargs.get('fallback', None)

        if text:
            return text
        elif fallback:
            try:
                text = fallback
                json_block = re.search(r'\{([\s\S]+)\}', text)  # noqa W^05
                if json_block:
                    result = json_block.group(1)
                    result_json = json.loads('{' + result + '}')
                    return result_json['text']
            except ValueError:
                return text
        else:
            return text