L4GM-demo / app.py
fffiloni's picture
Update app.py
b04e321 verified
raw
history blame
2.11 kB
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())
import os, subprocess
import uuid, tempfile
import gradio as gr
from huggingface_hub import snapshot_download
os.makedirs("pretrained", exist_ok=True)
snapshot_download(
repo_id = "jiawei011/L4GM",
local_dir = "./pretrained"
)
# Folder containing example images
examples_folder = "data_test"
# Retrieve all file paths in the folder
video_examples = [
os.path.join(examples_folder, file)
for file in os.listdir(examples_folder)
if os.path.isfile(os.path.join(examples_folder, file))
]
def generate(input_video):
#--test_path data_test/otter-on-surfboard_fg.mp4
workdir = "results"
pretrained_model = "pretrained/recon.safetensors"
num_frames = 1
test_path = input_video
try:
# Run the inference command
subprocess.run(
[
"python", "infer_3d.py", "big",
"--workspace", f"{workdir}",
"--resume", f"{pretrained_model}",
"--num_frames", f"{num_frames}",
"--test_path", f"{test_path}",
],
check=True
)
# Retrieve the file name without the extension
#removed_bg_file_name = os.path.splitext(os.path.basename(removed_bg_path))[0]
output_videos = glob(os.path.join(f"{workdir}", "*.mp4"))
return output_videos
except subprocess.CalledProcessError as e:
return f"Error during inference: {str(e)}"
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Input Video")
submit_btn = gr.Button("Submit")
with gr.Column():
output_result = gr.Video(label="Result")
gr.Examples(
examples = video_examples,
inputs = [input_video]
)
submit_btn.click(
fn = generate,
inputs = [input_video],
outputs = [output_result]
)
demo.queue().launch(show_api=False, show_error=True)