Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
import os | |
import tarfile | |
import requests | |
import shutil | |
import tempfile | |
import gradio as gr | |
from PIL import Image | |
from rembg import remove | |
import sys | |
import subprocess | |
from glob import glob | |
import requests | |
from huggingface_hub import snapshot_download | |
# Define the URL and destination paths | |
onedrive_url = "https://hkustconnect-my.sharepoint.com/:u:/g/personal/plibp_connect_ust_hk/EZQphP-2y5BGhEIe8jb03i4BIcqiJ2mUW2JmGC5s0VKOdw?e=qVzBBD" | |
destination_tar = "smpl_related.tar.gz" | |
destination_folder = "smpl_related" | |
# Download the file | |
def download_file(url, destination): | |
print(f"Downloading {url} to {destination}...") | |
response = requests.get(url, stream=True) | |
if response.status_code == 200: | |
with open(destination, 'wb') as f: | |
f.write(response.content) | |
print(f"Downloaded file to {destination}") | |
else: | |
raise Exception(f"Failed to download file. Status code: {response.status_code}") | |
# Extract the tar.gz file | |
def extract_tar(file_path, extract_to): | |
print(f"Extracting {file_path} to {extract_to}...") | |
with tarfile.open(file_path, "r:gz") as tar: | |
tar.extractall(path=extract_to) | |
print(f"Extraction completed.") | |
# Ensure the folder exists | |
if not os.path.exists(destination_folder): | |
try: | |
# Step 1: Download the tar.gz file | |
download_file(onedrive_url, destination_tar) | |
# Step 2: Extract the tar.gz file | |
extract_tar(destination_tar, "./") | |
# Step 3: Clean up the tar.gz file after extraction | |
os.remove(destination_tar) | |
print(f"Cleaned up the tar file: {destination_tar}") | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
else: | |
print(f"Folder {destination_folder} already exists. Skipping download and extraction.") | |
# Download models | |
os.makedirs("ckpts", exist_ok=True) | |
snapshot_download( | |
repo_id = "pengHTYX/PSHuman_Unclip_768_6views", | |
local_dir = "./ckpts" | |
) | |
def remove_background(input_url): | |
# Create a temporary folder for downloaded and processed images | |
temp_dir = tempfile.mkdtemp() | |
# Download the image from the URL | |
image_path = os.path.join(temp_dir, 'input_image.png') | |
try: | |
image = Image.open(input_url).convert("RGBA") | |
image.save(image_path) | |
except Exception as e: | |
shutil.rmtree(temp_dir) | |
return f"Error downloading or saving the image: {str(e)}" | |
""" | |
# Run background removal | |
try: | |
removed_bg_path = os.path.join(temp_dir, 'output_image_rmbg.png') | |
img = Image.open(image_path) | |
result = remove(img) | |
result.save(removed_bg_path) | |
except Exception as e: | |
shutil.rmtree(temp_dir) | |
return f"Error removing background: {str(e)}" | |
return removed_bg_path, temp_dir | |
""" | |
return image_path, temp_dir | |
def run_inference(temp_dir): | |
# Define the inference configuration | |
inference_config = "configs/inference-768-6view.yaml" | |
pretrained_model = "./ckpts" | |
crop_size = 740 | |
seed = 600 | |
num_views = 7 | |
save_mode = "rgb" | |
try: | |
# Run the inference command | |
subprocess.run( | |
[ | |
"python", "inference.py", | |
"--config", inference_config, | |
f"pretrained_model_name_or_path={pretrained_model}", | |
f"validation_dataset.crop_size={crop_size}", | |
f"with_smpl=false", | |
f"validation_dataset.root_dir={temp_dir}", | |
f"seed={seed}", | |
f"num_views={num_views}", | |
f"save_mode={save_mode}" | |
], | |
check=True | |
) | |
# Collect the output images | |
output_images = glob(os.path.join(temp_dir, "*.png")) | |
return output_images | |
except subprocess.CalledProcessError as e: | |
return f"Error during inference: {str(e)}" | |
def process_image(input_url): | |
# Remove background | |
result = remove_background(input_url) | |
if isinstance(result, str) and result.startswith("Error"): | |
raise gr.Error(f"{result}") # Return the error message if something went wrong | |
removed_bg_path, temp_dir = result # Unpack only if successful | |
# Run inference | |
output_images = run_inference(temp_dir) | |
if isinstance(output_images, str) and output_images.startswith("Error"): | |
shutil.rmtree(temp_dir) | |
raise gr.Error(f"{output_images}") # Return the error message if inference failed | |
# Prepare outputs for display | |
results = [] | |
for img_path in output_images: | |
results.append((img_path, img_path)) | |
#shutil.rmtree(temp_dir) # Cleanup temporary folder | |
return results | |
def gradio_interface(): | |
with gr.Blocks() as app: | |
gr.Markdown("# Background Removal and Inference Pipeline") | |
with gr.Row(): | |
input_image = gr.Image(label="Image input", type="filepath") | |
submit_button = gr.Button("Process") | |
output_gallery = gr.Gallery(label="Output Images") | |
submit_button.click(process_image, inputs=[input_image], outputs=[output_gallery]) | |
return app | |
# Launch the Gradio app | |
app = gradio_interface() | |
app.launch() | |