File size: 4,161 Bytes
e6d2def
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# main.py (API com FastAPI)

import os
import uuid
import shutil
import subprocess
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import FileResponse
from fastapi.concurrency import run_in_threadpool

# Diretório base onde o código do SeedVR está
SEEDVR_DIR = "/app/SeedVR"

app = FastAPI()

def run_inference_blocking(input_video_path: str, output_dir: str, seed: int, res_h: int, res_w: int) -> str:
    """
    Função síncrona que executa o script torchrun.
    Ela bloqueia a execução, por isso deve ser chamada em um thread separado.
    """
    # O script de inferência espera ser executado de dentro do diretório SeedVR
    # e que os caminhos de entrada/saída sejam relativos a ele.
    
    # Constrói o caminho relativo para a pasta de entrada
    input_folder_relative = os.path.relpath(os.path.dirname(input_video_path), SEEDVR_DIR)
    
    # Constrói o caminho relativo para a pasta de saída
    output_folder_relative = os.path.relpath(output_dir, SEEDVR_DIR)
    
    command = [
        "torchrun",
        "--nproc-per-node=4",
        "projects/inference_seedvr2_3b.py",
        "--video_path", input_folder_relative,
        "--output_dir", output_folder_relative,
        "--seed", str(seed),
        "--res_h", str(res_h),
        "--res_w", str(res_w),
        "--sp_size", "1",  # Mantido fixo ou pode se tornar um parâmetro
    ]
    
    try:
        print(f"Executando comando: {' '.join(command)}")
        # Executa o subprocesso a partir do diretório do SeedVR
        subprocess.run(command, cwd=SEEDVR_DIR, check=True, capture_output=True, text=True)
    except subprocess.CalledProcessError as e:
        # Se o script falhar, captura o erro e o log para depuração
        print("Erro na execução do subprocesso!")
        print(f"Stdout: {e.stdout}")
        print(f"Stderr: {e.stderr}")
        raise HTTPException(status_code=500, detail=f"A inferência falhou: {e.stderr}")

    # Encontra o arquivo de saída gerado
    output_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.png'))]
    if not output_files:
        raise HTTPException(status_code=500, detail="A inferência foi concluída, mas nenhum arquivo de saída foi encontrado.")
        
    return os.path.join(output_dir, output_files[0])


@app.get("/")
async def root():
    return {"message": "API de Inferência SeedVR2 está online. Use o endpoint /infer/ para processar vídeos."}


@app.post("/infer/", response_class=FileResponse)
async def create_inference_job(
    video: UploadFile = File(...),
    seed: int = Form(666),
    res_h: int = Form(720),
    res_w: int = Form(1280),
):
    """
    Recebe um vídeo e parâmetros, executa a inferência e retorna o vídeo processado.
    """
    # Cria diretórios temporários únicos para esta requisição para evitar conflitos
    job_id = str(uuid.uuid4())
    input_dir = os.path.join("/app", "temp_inputs", job_id)
    output_dir = os.path.join("/app", "temp_outputs", job_id)
    os.makedirs(input_dir, exist_ok=True)
    os.makedirs(output_dir, exist_ok=True)
    
    input_video_path = os.path.join(input_dir, video.filename)
    
    try:
        # Salva o vídeo enviado para o disco
        with open(input_video_path, "wb") as buffer:
            shutil.copyfileobj(video.file, buffer)
        
        # Executa a função de inferência pesada em um thread separado
        # para não bloquear o servidor da API
        result_path = await run_in_threadpool(
            run_inference_blocking, 
            input_video_path=input_video_path, 
            output_dir=output_dir, 
            seed=seed,
            res_h=res_h,
            res_w=res_w
        )
        
        # Retorna o arquivo de vídeo como uma resposta para download
        return FileResponse(path=result_path, media_type='video/mp4', filename=os.path.basename(result_path))
        
    finally:
        # Limpa os diretórios temporários após a conclusão ou falha
        print("Limpando diretórios temporários...")
        shutil.rmtree(input_dir, ignore_errors=True)
        shutil.rmtree(output_dir, ignore_errors=True)