UniPortrait / src /util.py
Junjie96's picture
Upload 38 files
dbac7c5 verified
raw
history blame
1.71 kB
import concurrent.futures
import io
import os
import oss2
import requests
from PIL import Image
from .log import logger
# oss
access_key_id = os.getenv("ACCESS_KEY_ID")
access_key_secret = os.getenv("ACCESS_KEY_SECRET")
bucket_name = os.getenv("BUCKET_NAME")
endpoint = os.getenv("ENDPOINT")
bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, bucket_name)
oss_path = os.getenv("OSS_PATH")
def download_img_pil(index, img_url):
r = requests.get(img_url, stream=True)
if r.status_code == 200:
img = Image.open(io.BytesIO(r.content))
return (index, img)
else:
logger.error(f"Fail to download: {img_url}")
def download_images(img_urls, batch_size):
imgs_pil = [None] * batch_size
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
to_do = []
for i, url in enumerate(img_urls):
future = executor.submit(download_img_pil, i, url)
to_do.append(future)
for future in concurrent.futures.as_completed(to_do):
ret = future.result()
index, img_pil = ret
imgs_pil[index] = img_pil
return imgs_pil
def upload_np_2_oss(input_image, name="cache.png"):
assert name.lower().endswith((".png", ".jpg")), name
imgByteArr = io.BytesIO()
if name.lower().endswith(".png"):
Image.fromarray(input_image).save(imgByteArr, format="PNG")
else:
Image.fromarray(input_image).save(imgByteArr, format="JPEG", quality=95)
imgByteArr = imgByteArr.getvalue()
bucket.put_object(oss_path + "/" + name, imgByteArr)
ret = bucket.sign_url('GET', oss_path + "/" + name, 60 * 60 * 24)
del imgByteArr
return ret