File size: 4,467 Bytes
e1079c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import time
import logging
import requests

logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)

api_key = os.getenv("HEDRA_API_KEY")

class Session(requests.Session):
    def __init__(self, api_key: str):
        super().__init__()

        self.base_url: str = "https://api.hedra.com/web-app/public"
        self.headers["x-api-key"] = api_key

    #@override
    def prepare_request(self, request: requests.Request) -> requests.PreparedRequest:
        request.url = f"{self.base_url}{request.url}"
        return super().prepare_request(request)


def generate_video(audio, image, aspect_ratio, resolution, text_prompt, seed):
    global api_key
    # Load environment variables from .env file 
    if not api_key:
        print("HEDRA_API_KEY not found in environment variables or .env file.")
        return


    # Initialize Hedra client
    session = Session(api_key=api_key)

    logger.info("testing against %s", session.base_url)
    model_id = session.get("/models").json()[0]["id"]
    logger.info("got model id %s", model_id)

    image_response = session.post(
        "/assets",
        json={"name": os.path.basename(image), "type": "image"},
    )
    if not image_response.ok:
        logger.error(
            "error creating image: %d %s",
            image_response.status_code,
            image_response.json(),
        )
    image_id = image_response.json()["id"]
    with open(image, "rb") as f:
        session.post(f"/assets/{image_id}/upload", files={"file": f}).raise_for_status()
    logger.info("uploaded image %s", image_id)

    audio_id = session.post(
        "/assets", json={"name": os.path.basename(audio), "type": "audio"}
    ).json()["id"]
    with open(audio, "rb") as f:
        session.post(f"/assets/{audio_id}/upload", files={"file": f}).raise_for_status()
    logger.info("uploaded audio %s", audio_id)

    generation_request_data = {
        "type": "video",
        "ai_model_id": model_id,
        "start_keyframe_id": image_id,
        "audio_id": audio_id,
        "generated_video_inputs": {
            "text_prompt": text_prompt,
            "resolution": resolution,
            "aspect_ratio": aspect_ratio,
        },
    }

    # Add optional parameters if provided
    if seed is not None:
        generation_request_data["generated_video_inputs"]["seed"] = seed

    generation_response = session.post(
        "/generations", json=generation_request_data
    ).json()
    logger.info(generation_response)
    generation_id = generation_response["id"]
    while True:
        status_response = session.get(f"/generations/{generation_id}/status").json()
        logger.info("status response %s", status_response)
        status = status_response["status"]

        # --- Check for completion or error to break the loop ---
        if status in ["complete", "error"]:
            break

        time.sleep(5)

    # --- Process final status (download or log error) ---
    if status == "complete" and status_response.get("url"):
        download_url = status_response["url"]
        # Use asset_id for filename if available, otherwise use generation_id
        output_filename_base = status_response.get("asset_id", generation_id)
        output_filename = f"{output_filename_base}.mp4"
        logger.info(f"Generation complete. Downloading video from {download_url} to {output_filename}")
        try:
            # Use a fresh requests get, not the session, as the URL is likely presigned S3
            with requests.get(download_url, stream=True) as r:
                r.raise_for_status() # Check if the request was successful
                with open(output_filename, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            logger.info(f"Successfully downloaded video to {output_filename}")
        except requests.exceptions.RequestException as e:
            logger.error(f"Failed to download video: {e}")
        except IOError as e:
            logger.error(f"Failed to save video file: {e}")
    elif status == "error":
        logger.error(f"Video generation failed: {status_response.get('error_message', 'Unknown error')}")
    else:
        # This case might happen if loop breaks unexpectedly or API changes
        logger.warning(f"Video generation finished with status '{status}' but no download URL was found.")

    return output_filename if 'output_filename' in locals() else None