fangxia commited on
Commit
df5cb06
1 Parent(s): 32bcd21

first release

Browse files
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import gradio as gr
4
+ from client1 import inference
5
+
6
+ title = "Dreamoving-Phantom"
7
+ img_urls=['https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/public/dashscope/show3.png']
8
+ description = f""" Gradio demo for [Dreamoving-Phantom](https://github.com/dreamoving/Phantom)
9
+
10
+ DreaMoving-Phantom is a general and automatic image enhancement and super resolution framework.
11
+
12
+ **No need to adjust parameters or select models, just run with one click.** The demo can be adapted to a variety of scenarios.
13
+
14
+ 🔥 New feature: We added text super-resolution module so that the demo can better handle text scenes. This module will still be updated iteratively.
15
+
16
+ 🧭 Instructions: Input resolution: 64<=short_side<=2160, long_side<=3840, aspect ratio<=4; best input resolution: no larger than 1080p.
17
+ """
18
+
19
+
20
+ examples=[['examples/3.png'],['examples/4.png'],['examples/5.png'],['examples/6.png'],
21
+ ['examples/7.png'],['examples/8.png'],['examples/1.png'],['examples/9.jpg'],['examples/10.png'],
22
+ ['examples/12.png'],['examples/13.png'],['examples/14.png']]
23
+ # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2)
24
+
25
+ with gr.Blocks(css="style.css") as demo:
26
+
27
+ gr.Markdown(f"<h1 style='text-align: center; font-size: 2em;'>{title}</h1>")
28
+ gr.Markdown(description, elem_id='description')
29
+
30
+ with gr.Row():
31
+ with gr.Column(scale=0.67):
32
+ input_image = gr.Image(type="pil", label="Input Image", image_mode="RGBA")
33
+ upsample_scale = gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1, elem_id='slider')
34
+
35
+ btn = gr.Button("Run", elem_id='button_param')
36
+
37
+ with gr.Column():
38
+ output_gallery = gr.Gallery(label="Output", elem_id="output_gallery")
39
+ # output_gallery = gr.Image(type="pil", label="Output", elem_id="output_gallery", image_mode="RGBA")
40
+ output_text = gr.Textbox(label="Log", elem_id="output_text")
41
+
42
+ btn.click(
43
+ inference,
44
+ inputs=[input_image, upsample_scale],
45
+ outputs=[output_gallery, output_text]
46
+ )
47
+
48
+ with gr.Row():
49
+ with gr.Column(scale=2):
50
+ gr.Markdown('**Examples**', elem_id='example')
51
+ gr.Examples(label='', examples=examples, inputs=[input_image, upsample_scale], outputs=[output_gallery, output_text], examples_per_page=15, elem_id='examples')
52
+
53
+ with gr.Column(scale=3):
54
+ additional_text = "**Gallery**"
55
+ gr.Markdown(additional_text, elem_id='additional_text')
56
+ gr.HTML(
57
+ f"""
58
+ <div style='text-align: center;'>
59
+ <img src='{img_urls[0]}' alt='gallery' style='max-width: 100%; height: auto; margin-top: 5px;'>
60
+ </div>
61
+ """,
62
+ elem_id='html_image'
63
+ )
64
+
65
+
66
+ demo.queue(api_open=False, concurrency_count=100).launch(
67
+ server_name="0.0.0.0" if os.getenv('GRADIO_LISTEN', '') != '' else "127.0.0.1",
68
+ share=False,
69
+ root_path=f"/{os.getenv('GRADIO_PROXY_PATH')}" if os.getenv('GRADIO_PROXY_PATH') else ""
70
+ )
client1.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from oss_utils import ossService
3
+ import requests
4
+ import random
5
+ import json
6
+ import time
7
+ from diffusers.utils import load_image
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import numpy as np
10
+
11
+ # oss config
12
+ BUCKET = os.environ.get('BUCKET', '')
13
+ ENDPOINT = os.environ.get('ENDPOINT', '')
14
+ PREFIX = os.environ.get('OSS_DIR_PREFIX', '')
15
+ AK = os.environ.get('AK', '')
16
+ SK = os.environ.get('SK', '')
17
+ DASHONE_SERVICE_ID = os.environ.get('SERVICE_ID', '')
18
+ URL = os.environ.get('URL', '')
19
+ GET_URL = os.environ.get('GET_URL', '')
20
+
21
+ OUTPUT_PATH = './output'
22
+
23
+ os.makedirs(OUTPUT_PATH, exist_ok=True)
24
+ oss_service = ossService(AK, SK, ENDPOINT, BUCKET, PREFIX)
25
+
26
+ def async_request_and_query(data, request_id):
27
+ url = URL
28
+ headers = {'Content-Type': 'application/json'}
29
+
30
+ # 1.发起一个异步请求
31
+ print('Start sending request')
32
+ response = requests.post(url, headers=headers, data=json.dumps(data))
33
+ if response.status_code != requests.codes.ok:
34
+ response.raise_for_status()
35
+ response_json = json.loads(response.content.decode("utf-8"))
36
+ print('Finish sending request')
37
+
38
+ # 2.异步查询结果
39
+ is_running = True
40
+ running_print_count = 0
41
+ sign_oss_path = None
42
+
43
+ task_id = response_json['header']['task_id']
44
+ get_url = GET_URL
45
+ get_header = headers
46
+ get_data = {"header": {"request_id":request_id,"service_id":DASHONE_SERVICE_ID,"task_id":f"{task_id}"}}
47
+
48
+ print('Start querying results')
49
+ while is_running:
50
+ response2 = requests.post(get_url, headers=get_header, data=json.dumps(get_data))
51
+ if response2.status_code != requests.codes.ok:
52
+ response2.raise_for_status()
53
+ response2_json = json.loads(response2.content.decode("utf-8"))
54
+
55
+ task_status = response2_json['header']['task_status']
56
+
57
+ if task_status == 'SUCCESS':
58
+ sign_oss_path = response2_json['payload']['output']['res']
59
+ break
60
+ elif task_status in ['FAILED', 'ERROR'] or running_print_count >= 120:
61
+ raise ValueError(f'Task Failed')
62
+ else:
63
+ time.sleep(1)
64
+ running_print_count += 1
65
+ continue
66
+
67
+ print('Task succeeded!')
68
+ return sign_oss_path
69
+
70
+ # def add_transparent_watermark(pil_image, watermark_text, position, opacity, font_path, font_size):
71
+ # # 加载字体
72
+ # font = ImageFont.truetype(font_path, font_size)
73
+
74
+ # # 创建一个半透明的水印图层
75
+ # watermark_layer = Image.new("RGBA", pil_image.size)
76
+ # draw = ImageDraw.Draw(watermark_layer)
77
+
78
+ # # 文本颜色和透明度
79
+ # text_color = (255, 255, 255, opacity) # 白色文本
80
+ # outline_color = (0, 0, 0, opacity) # 黑色轮廓
81
+
82
+ # # 获取文本尺寸
83
+ # text_width = draw.textlength(watermark_text, font=font)
84
+ # text_height = text_width // 5
85
+
86
+ # # 计算水印位置
87
+ # img_width, img_height = pil_image.size
88
+ # x = img_width - text_width - position[0]
89
+ # y = img_height - text_height - position[1]
90
+
91
+ # outline_range = 1 # 轮廓的粗细
92
+ # for adj in range(-outline_range, outline_range+1):
93
+ # for ord in range(-outline_range, outline_range+1):
94
+ # if adj != 0 or ord != 0: # 避免中心位置,那是真正的文本
95
+ # draw.text((x+adj, y+ord), watermark_text, font=font, fill=outline_color)
96
+
97
+ # # 将文本绘制到水印层上
98
+ # draw.text((x, y), watermark_text, font=font, fill=text_color)
99
+
100
+ # # 将水印层叠加到原始图像上
101
+ # pil_image_with_watermark = Image.alpha_composite(pil_image.convert("RGBA"), watermark_layer)
102
+
103
+ # # 返回添加了水印的图像
104
+ # return pil_image_with_watermark
105
+
106
+ def inference(input_image, upscale):
107
+ # process alpha channel
108
+ alpha_channel = input_image.split()[-1]
109
+ input_image = input_image.convert('RGB')
110
+
111
+ local_save_path = os.path.join(OUTPUT_PATH, 'tmp.png')
112
+ local_output_save_path = os.path.join(OUTPUT_PATH, 'out.png')
113
+ input_image.save(local_save_path)
114
+
115
+ # generate image url
116
+ oss_key = os.path.join(PREFIX, 'tmp.png')
117
+ _, image_url = oss_service.uploadOssFile(oss_key, local_save_path)
118
+
119
+ # rm local file
120
+ if os.path.isfile(local_save_path):
121
+ os.remove(local_save_path)
122
+
123
+ data = {}
124
+ data_header = {}
125
+ data_payload = {}
126
+ data_input = {}
127
+ data_para = {}
128
+
129
+ data_header['request_id'] = "".join(random.sample("0123456789abcdefghijklmnopqrstuvwxyz", 10))
130
+ data_header['service_id'] = DASHONE_SERVICE_ID
131
+ data_input['image_url'] = image_url
132
+ data_para['upscale'] = upscale
133
+ data_para['platform'] = 'modelscope'
134
+
135
+ data_payload['input'] = data_input
136
+ data_payload['parameters'] = data_para
137
+
138
+ data['header'] = data_header
139
+ data['payload'] = data_payload
140
+
141
+ try:
142
+ output_url = async_request_and_query(data=data, request_id=data_header['request_id'])
143
+ download_status = oss_service.downloadFile(output_url, local_output_save_path)
144
+ if not download_status:
145
+ raise ValueError(f'Download output image failed')
146
+ except:
147
+ output_image = load_image('./error.png')
148
+ return [output_image], 'The input image format or resolution does not meet the requirements. Please change the image or resize it and try again.'
149
+
150
+
151
+ output_image = load_image(local_output_save_path)
152
+
153
+ if os.path.isfile(local_output_save_path):
154
+ os.remove(local_output_save_path)
155
+
156
+ out_width, out_height = output_image.size
157
+
158
+ # add alpha channel
159
+ output_alpha_channel = alpha_channel.resize(output_image.size, resample=Image.LANCZOS)
160
+
161
+ # merge
162
+ output_image_alpha = Image.merge("RGBA", (*output_image.split(), output_alpha_channel))
163
+
164
+ # new_width = out_width // 4
165
+ # new_height = out_height // 4
166
+
167
+ # resized_image = np.array(input_image.resize((new_width, new_height)))
168
+ # np_output_image = np.array(output_image)
169
+ # np_output_image[0:new_height, 0:new_width] = resized_image.copy()
170
+
171
+ # new_output_image = Image.fromarray(np_output_image)
172
+
173
+ # new_output_image = add_transparent_watermark(
174
+ # pil_image=new_output_image,
175
+ # watermark_text="追影-放大镜",
176
+ # position=(5, 5),
177
+ # opacity=200,
178
+ # font_path="AlibabaPuHuiTi-3-45-Light.ttf", # 例如:"Arial", "Helvetica", "Times New Roman"
179
+ # font_size= out_width // 30
180
+ # )
181
+
182
+ org_width, org_height = input_image.size
183
+ if max(org_width, org_height) > 1920 or min(org_width, org_height) > 1080:
184
+ msg = 'The input image size has exceeded the optimal range. You can consider scaling the input image for better generation effect'
185
+ else:
186
+ msg = 'Task succeeded'
187
+ return [output_image_alpha], msg
error.png ADDED
examples/1.png ADDED
examples/10.png ADDED
examples/12.png ADDED
examples/13.png ADDED
examples/14.png ADDED
examples/2.png ADDED
examples/3.png ADDED
examples/4.png ADDED
examples/5.png ADDED
examples/6.png ADDED
examples/7.png ADDED
examples/8.png ADDED
examples/9.jpg ADDED
examples/oss_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import sys
4
+ import string
5
+ # import platform
6
+ import time
7
+ import datetime
8
+ import json
9
+ # import numpy as np
10
+ # import threading
11
+ # import cv2
12
+ # import PIL.Image as Image
13
+ # import ffmpeg
14
+ from io import BytesIO
15
+ # from queue import Queue
16
+ # import glob
17
+ import oss2
18
+ import random
19
+ import requests
20
+ import shutil
21
+ # import torch
22
+ # import ctypes
23
+
24
+ use_internal_network = False
25
+
26
+ # OSSAccessKeyId = os.getenv('OSSAccessKeyId', "")
27
+ # OSSAccessKeySecret = os.getenv('OSSAccessKeySecret', "")
28
+
29
+ def get_random_string():
30
+ now = datetime.datetime.now()
31
+ date = now.strftime('%Y%m%d')
32
+ time = now.strftime('%H%M%S')
33
+ microsecond = now.strftime('%f')
34
+ microsecond = microsecond[:6] # 取前6位,即微秒
35
+
36
+ rand_num = ''.join([str(random.randint(0, 9)) for _ in range(6)])
37
+ random_string = ''.join(random.choices(string.ascii_uppercase, k=6)) # ascii_lowercase
38
+ return date + "-" + time + "-" + microsecond + "-" + random_string
39
+
40
+ class ossService():
41
+ def __init__(self, OSSAccessKeyId, OSSAccessKeySecret, Endpoint, BucketName, ObjectName):
42
+ self.AccessKeyId = OSSAccessKeyId
43
+ self.AccessKeySecret = OSSAccessKeySecret
44
+ self.Endpoint = Endpoint
45
+ self.BucketName = BucketName # "vigen-video"
46
+ self.ObjectName = ObjectName # "VideoGeneration"
47
+ self.Prefix = "oss://" + self.BucketName
48
+
49
+ auth = oss2.Auth(self.AccessKeyId, self.AccessKeySecret)
50
+ self.bucket = oss2.Bucket(auth, self.Endpoint, self.BucketName)
51
+
52
+
53
+ # oss_url: eg: oss://BucketName/ObjectName/xxx.mp4
54
+ def sign(self, oss_url, timeout=86400):
55
+ try:
56
+ oss_path = oss_url[len("oss://" + self.BucketName + "/"):]
57
+ return 1, self.bucket.sign_url('GET', oss_path, timeout, slash_safe=True)
58
+ except Exception as e:
59
+ print("sign error: {}".format(e))
60
+ return 0, ""
61
+
62
+ def uploadOssFile(self, oss_full_path, local_full_path):
63
+ try:
64
+ self.bucket.put_object_from_file(oss_full_path, local_full_path)
65
+ return self.sign(self.Prefix+"/"+oss_full_path, timeout=86400)
66
+ except oss2.exceptions.OssError as e:
67
+ print("oss upload error: ", e)
68
+ return 0, ""
69
+
70
+ def downloadOssFile(self, oss_full_path, local_full_path):
71
+ status = 1
72
+ try:
73
+ self.bucket.get_object_to_file(oss_full_path, local_full_path)
74
+ except oss2.exceptions.OssError as e:
75
+ print("oss download error: ", e)
76
+ status = 0
77
+ return status
78
+
79
+
80
+ def downloadFile(self, file_full_url, local_full_path):
81
+ status = 1
82
+ response = requests.get(file_full_url)
83
+ if response.status_code == 200:
84
+ with open(local_full_path, "wb") as f:
85
+ f.write(response.content)
86
+ else:
87
+ print("oss download error. ")
88
+ status = 0
89
+ return status
oss_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import sys
4
+ import string
5
+ # import platform
6
+ import time
7
+ import datetime
8
+ import json
9
+ # import numpy as np
10
+ # import threading
11
+ # import cv2
12
+ # import PIL.Image as Image
13
+ # import ffmpeg
14
+ from io import BytesIO
15
+ # from queue import Queue
16
+ # import glob
17
+ import oss2
18
+ import random
19
+ import requests
20
+ import shutil
21
+ # import torch
22
+ # import ctypes
23
+
24
+ use_internal_network = False
25
+
26
+ # OSSAccessKeyId = os.getenv('OSSAccessKeyId', "")
27
+ # OSSAccessKeySecret = os.getenv('OSSAccessKeySecret', "")
28
+
29
+ def get_random_string():
30
+ now = datetime.datetime.now()
31
+ date = now.strftime('%Y%m%d')
32
+ time = now.strftime('%H%M%S')
33
+ microsecond = now.strftime('%f')
34
+ microsecond = microsecond[:6] # 取前6位,即微秒
35
+
36
+ rand_num = ''.join([str(random.randint(0, 9)) for _ in range(6)])
37
+ random_string = ''.join(random.choices(string.ascii_uppercase, k=6)) # ascii_lowercase
38
+ return date + "-" + time + "-" + microsecond + "-" + random_string
39
+
40
+ class ossService():
41
+ def __init__(self, OSSAccessKeyId, OSSAccessKeySecret, Endpoint, BucketName, ObjectName):
42
+ self.AccessKeyId = OSSAccessKeyId
43
+ self.AccessKeySecret = OSSAccessKeySecret
44
+ self.Endpoint = Endpoint
45
+ self.BucketName = BucketName # "vigen-video"
46
+ self.ObjectName = ObjectName # "VideoGeneration"
47
+ self.Prefix = "oss://" + self.BucketName
48
+
49
+ auth = oss2.Auth(self.AccessKeyId, self.AccessKeySecret)
50
+ self.bucket = oss2.Bucket(auth, self.Endpoint, self.BucketName)
51
+
52
+
53
+ # oss_url: eg: oss://BucketName/ObjectName/xxx.mp4
54
+ def sign(self, oss_url, timeout=86400):
55
+ try:
56
+ oss_path = oss_url[len("oss://" + self.BucketName + "/"):]
57
+ return 1, self.bucket.sign_url('GET', oss_path, timeout, slash_safe=True)
58
+ except Exception as e:
59
+ print("sign error: {}".format(e))
60
+ return 0, ""
61
+
62
+ def uploadOssFile(self, oss_full_path, local_full_path):
63
+ try:
64
+ self.bucket.put_object_from_file(oss_full_path, local_full_path)
65
+ return self.sign(self.Prefix+"/"+oss_full_path, timeout=86400)
66
+ except oss2.exceptions.OssError as e:
67
+ print("oss upload error: ", e)
68
+ return 0, ""
69
+
70
+ def downloadOssFile(self, oss_full_path, local_full_path):
71
+ status = 1
72
+ try:
73
+ self.bucket.get_object_to_file(oss_full_path, local_full_path)
74
+ except oss2.exceptions.OssError as e:
75
+ print("oss download error: ", e)
76
+ status = 0
77
+ return status
78
+
79
+
80
+ def downloadFile(self, file_full_url, local_full_path):
81
+ status = 1
82
+ response = requests.get(file_full_url)
83
+ if response.status_code == 200:
84
+ with open(local_full_path, "wb") as f:
85
+ f.write(response.content)
86
+ else:
87
+ print("oss download error. ")
88
+ status = 0
89
+ return status
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ oss2
style.css ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* style.css */
2
+
3
+ #button_param {
4
+ background: #5110d3 !important;
5
+ color: white !important;
6
+ border-radius: 3px !important;
7
+ }
8
+
9
+ #feature {
10
+ font-size: 24px;
11
+ margin-bottom: 24px;
12
+ }
13
+
14
+ #description {
15
+ margin-top: 10px;
16
+ }
17
+
18
+ #example {
19
+ margin-top: 20px;
20
+ }
21
+
22
+ #additional_text {
23
+ margin-top: 20px;
24
+ }
25
+
26
+ #html_image {
27
+ margin-top: 0px;
28
+ }