Spaces:
Runtime error
Runtime error
import os | |
import time | |
import logging | |
import requests | |
logger = logging.getLogger() | |
logging.basicConfig(level=logging.INFO) | |
api_key = os.getenv("HEDRA_API_KEY") | |
class Session(requests.Session): | |
def __init__(self, api_key: str): | |
super().__init__() | |
self.base_url: str = "https://api.hedra.com/web-app/public" | |
self.headers["x-api-key"] = api_key | |
#@override | |
def prepare_request(self, request: requests.Request) -> requests.PreparedRequest: | |
request.url = f"{self.base_url}{request.url}" | |
return super().prepare_request(request) | |
def generate_video(audio, image, aspect_ratio, resolution, text_prompt, seed): | |
global api_key | |
# Load environment variables from .env file | |
if not api_key: | |
print("HEDRA_API_KEY not found in environment variables or .env file.") | |
return | |
# Initialize Hedra client | |
session = Session(api_key=api_key) | |
logger.info("testing against %s", session.base_url) | |
model_id = session.get("/models").json()[0]["id"] | |
logger.info("got model id %s", model_id) | |
image_response = session.post( | |
"/assets", | |
json={"name": os.path.basename(image), "type": "image"}, | |
) | |
if not image_response.ok: | |
logger.error( | |
"error creating image: %d %s", | |
image_response.status_code, | |
image_response.json(), | |
) | |
image_id = image_response.json()["id"] | |
with open(image, "rb") as f: | |
session.post(f"/assets/{image_id}/upload", files={"file": f}).raise_for_status() | |
logger.info("uploaded image %s", image_id) | |
audio_id = session.post( | |
"/assets", json={"name": os.path.basename(audio), "type": "audio"} | |
).json()["id"] | |
with open(audio, "rb") as f: | |
session.post(f"/assets/{audio_id}/upload", files={"file": f}).raise_for_status() | |
logger.info("uploaded audio %s", audio_id) | |
generation_request_data = { | |
"type": "video", | |
"ai_model_id": model_id, | |
"start_keyframe_id": image_id, | |
"audio_id": audio_id, | |
"generated_video_inputs": { | |
"text_prompt": text_prompt, | |
"resolution": resolution, | |
"aspect_ratio": aspect_ratio, | |
}, | |
} | |
# Add optional parameters if provided | |
if seed is not None: | |
generation_request_data["generated_video_inputs"]["seed"] = seed | |
generation_response = session.post( | |
"/generations", json=generation_request_data | |
).json() | |
logger.info(generation_response) | |
generation_id = generation_response["id"] | |
while True: | |
status_response = session.get(f"/generations/{generation_id}/status").json() | |
logger.info("status response %s", status_response) | |
status = status_response["status"] | |
# --- Check for completion or error to break the loop --- | |
if status in ["complete", "error"]: | |
break | |
time.sleep(5) | |
# --- Process final status (download or log error) --- | |
if status == "complete" and status_response.get("url"): | |
download_url = status_response["url"] | |
# Use asset_id for filename if available, otherwise use generation_id | |
output_filename_base = status_response.get("asset_id", generation_id) | |
output_filename = f"{output_filename_base}.mp4" | |
logger.info(f"Generation complete. Downloading video from {download_url} to {output_filename}") | |
try: | |
# Use a fresh requests get, not the session, as the URL is likely presigned S3 | |
with requests.get(download_url, stream=True) as r: | |
r.raise_for_status() # Check if the request was successful | |
with open(output_filename, 'wb') as f: | |
for chunk in r.iter_content(chunk_size=8192): | |
f.write(chunk) | |
logger.info(f"Successfully downloaded video to {output_filename}") | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Failed to download video: {e}") | |
except IOError as e: | |
logger.error(f"Failed to save video file: {e}") | |
elif status == "error": | |
logger.error(f"Video generation failed: {status_response.get('error_message', 'Unknown error')}") | |
else: | |
# This case might happen if loop breaks unexpectedly or API changes | |
logger.warning(f"Video generation finished with status '{status}' but no download URL was found.") | |
return output_filename if 'output_filename' in locals() else None | |