|
import os |
|
import tempfile |
|
|
|
import fal_client |
|
import gradio as gr |
|
import numpy as np |
|
import requests |
|
from dotenv import load_dotenv |
|
|
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
load_dotenv() |
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
MAX_IMAGE_SIZE = 2048 |
|
TOKEN = None |
|
FAL_KEY = None |
|
|
|
|
|
def download_locally(url: str, local_path: str = "downloaded_file.png") -> str: |
|
"""Download an image or a video from a URL to a local path. |
|
Args: |
|
url (str): The URL of the image to download. Must be an http(s) URL. |
|
local_path (str, optional): The path (including filename) where the file should be saved. Defaults to "downloaded_file.png". |
|
Returns: |
|
str: The filesystem path of the saved file – suitable for returning to a **gr.File** output, or as an MCP tool response. |
|
""" |
|
if local_path == "": |
|
local_path = "downloaded_file.png" |
|
response = requests.get(url, timeout=30) |
|
response.raise_for_status() |
|
|
|
if os.path.dirname(local_path) == "": |
|
tmp_dir = tempfile.gettempdir() |
|
local_path = os.path.join(tmp_dir, local_path) |
|
with open(local_path, "wb") as f: |
|
f.write(response.content) |
|
return local_path |
|
|
|
|
|
def login_hf(oauth_token: gr.OAuthToken | None): |
|
""" |
|
Login to Hugging Face and check initial key statuses. |
|
Args: |
|
oauth_token (gr.OAuthToken | None): The OAuth token from Hugging Face. |
|
""" |
|
global TOKEN |
|
if oauth_token and oauth_token.token: |
|
print("Received OAuth token, logging in...") |
|
TOKEN = oauth_token.token |
|
else: |
|
print("No OAuth token provided, using environment variable HF_TOKEN.") |
|
TOKEN = os.environ.get("HF_TOKEN") |
|
print("TOKEN: ", TOKEN) |
|
|
|
|
|
def login_fal(fal_key_from_ui: str | None): |
|
""" |
|
Sets the FAL API key from the UI. |
|
Args: |
|
fal_key_from_ui (str | None): The FAL key from the UI textbox. |
|
""" |
|
global FAL_KEY |
|
if fal_key_from_ui and fal_key_from_ui.strip(): |
|
FAL_KEY = fal_key_from_ui.strip() |
|
os.environ["FAL_KEY"] = FAL_KEY |
|
print("FAL_KEY has been set from UI input.") |
|
else: |
|
FAL_KEY = os.environ.get("FAL_KEY") |
|
print("FAL_KEY is configured from environment variable.") |
|
print("FAL_KEY: ", FAL_KEY) |
|
|
|
|
|
def generate_image(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024, num_inference_steps: int = 25): |
|
""" |
|
Generate an image from a prompt. |
|
Args: |
|
prompt (str): |
|
The prompt to generate an image from. |
|
seed (int, default=42): |
|
Seed for the random number generator. |
|
height (int, default=1024): |
|
The height in pixels of the output image |
|
width (int, default=1024): |
|
The width in pixels of the output image |
|
num_inference_steps (int, default=25): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
""" |
|
client = InferenceClient(provider="fal-ai", token=TOKEN) |
|
image = client.text_to_image( |
|
prompt=prompt, |
|
width=width, |
|
height=height, |
|
num_inference_steps=num_inference_steps, |
|
seed=seed, |
|
model="black-forest-labs/FLUX.1-dev", |
|
) |
|
return image, seed |
|
|
|
|
|
def generate_video_from_image( |
|
image_filepath: str, |
|
video_prompt: str, |
|
duration: str, |
|
aspect_ratio: str, |
|
video_negative_prompt: str, |
|
cfg_scale_video: float, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
""" |
|
Generates a video from an image using fal-ai/kling-video API. |
|
""" |
|
if not FAL_KEY: |
|
gr.Error("FAL_KEY is not set. Cannot generate video.") |
|
return None |
|
if not image_filepath: |
|
gr.Warning("No image provided to generate video from.") |
|
return None |
|
if not os.path.exists(image_filepath): |
|
gr.Error(f"Image file not found at: {image_filepath}") |
|
return None |
|
|
|
print(f"Video generation started for image: {image_filepath}") |
|
progress(0, desc="Preparing for video generation...") |
|
|
|
try: |
|
progress(0.1, desc="Uploading image...") |
|
print("Uploading image to fal.ai storage...") |
|
print("FAL_KEY: ", os.environ.get("FAL_KEY")) |
|
image_url = fal_client.upload_file(image_filepath) |
|
print(f"Image uploaded, URL: {image_url}") |
|
progress(0.3, desc="Image uploaded. Submitting video request...") |
|
|
|
def on_queue_update(update): |
|
if isinstance(update, fal_client.InProgress): |
|
if update.logs: |
|
for log in update.logs: |
|
print(f"[fal-ai log] {log['message']}") |
|
|
|
|
|
|
|
print("Subscribing to fal-ai/kling-video/v2.1/master/image-to-video...") |
|
api_result = fal_client.subscribe( |
|
"fal-ai/kling-video/v2.1/master/image-to-video", |
|
arguments={ |
|
"prompt": video_prompt, |
|
"image_url": image_url, |
|
"duration": duration, |
|
"aspect_ratio": aspect_ratio, |
|
"negative_prompt": video_negative_prompt, |
|
"cfg_scale": cfg_scale_video, |
|
}, |
|
with_logs=True, |
|
on_queue_update=on_queue_update, |
|
) |
|
|
|
progress(0.9, desc="Video processing complete.") |
|
video_output_url = api_result.get("video", {}).get("url") |
|
|
|
if video_output_url: |
|
print(f"Video generated successfully: {video_output_url}") |
|
progress(1, desc="Video ready!") |
|
return video_output_url |
|
else: |
|
print(f"Video generation failed or no URL in response. API Result: {api_result}") |
|
gr.Error("Video generation failed or no video URL returned.") |
|
return None |
|
|
|
except Exception as e: |
|
print(f"Error during video generation: {e}") |
|
gr.Error(f"An error occurred: {str(e)}") |
|
return None |
|
|
|
|
|
examples = [ |
|
"a tiny astronaut hatching from an egg on the moon", |
|
"a cat holding a sign that says hello world", |
|
"an anime illustration of a wiener schnitzel", |
|
] |
|
|
|
css = """ |
|
#col-container { |
|
margin: 0 auto; |
|
max-width: 520px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
demo.load(login_hf, inputs=None, outputs=None) |
|
demo.load(login_fal, inputs=None, outputs=None) |
|
with gr.Sidebar(): |
|
gr.Markdown("# Authentication") |
|
gr.Markdown( |
|
"Sign in with Hugging Face for image generation. Separately, set your fal.ai API Key for image to video generation." |
|
) |
|
|
|
gr.Markdown("### Hugging Face Login") |
|
hf_login_button = gr.LoginButton("Sign in with Hugging Face") |
|
|
|
hf_login_button.click(fn=login_hf, inputs=[hf_login_button], outputs=None) |
|
|
|
gr.Markdown("### FAL Login (for Image to Video)") |
|
fal_key_input = gr.Textbox( |
|
label="FAL API Key", |
|
placeholder="Enter your FAL API Key here", |
|
type="password", |
|
value=os.environ.get("FAL_KEY", ""), |
|
) |
|
set_fal_key_button = gr.Button("Set FAL Key") |
|
set_fal_key_button.click(fn=login_fal, inputs=[fal_key_input], outputs=None) |
|
|
|
with gr.Column(elem_id="col-container"): |
|
gr.Markdown( |
|
"""# Text to Image to Video with fal‑ai through HF Inference Providers ⚡\nLearn more about HF Inference Providers [here](https://huggingface.co/docs/inference-providers/index)""" |
|
"""## Text to Image uses [FLUX.1 [dev]](https://fal.ai/models/fal-ai/flux/dev) with fal‑ai through HF Inference Providers""" |
|
"""## Image to Vide uses [kling-video v2.1](https://fal.ai/models/fal-ai/kling-video/v2.1/master/image-to-video/playground) with fal‑ai directly (you will need to set your `FAL_KEY`).""" |
|
) |
|
|
|
with gr.Row(): |
|
prompt = gr.Text( |
|
label="Prompt", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Enter your prompt", |
|
container=False, |
|
) |
|
run_button = gr.Button("Run", scale=0) |
|
|
|
result = gr.Image(label="Generated Image", show_label=False, format="png", type="filepath") |
|
download_btn = gr.DownloadButton( |
|
label="Download result image", |
|
visible=False, |
|
value=None, |
|
variant="primary", |
|
) |
|
|
|
seed_number = gr.Number(label="Seed", precision=0, value=42, interactive=False) |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
seed_slider = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=1, |
|
value=42, |
|
) |
|
with gr.Row(): |
|
width_slider = gr.Slider( |
|
label="Width", |
|
minimum=256, |
|
maximum=MAX_IMAGE_SIZE, |
|
step=32, |
|
value=1024, |
|
) |
|
height_slider = gr.Slider( |
|
label="Height", |
|
minimum=256, |
|
maximum=MAX_IMAGE_SIZE, |
|
step=32, |
|
value=1024, |
|
) |
|
steps_slider = gr.Slider( |
|
label="Number of inference steps", |
|
minimum=1, |
|
maximum=50, |
|
step=1, |
|
value=25, |
|
) |
|
|
|
gr.Examples( |
|
examples=examples, |
|
fn=generate_image, |
|
inputs=[prompt], |
|
outputs=[result, seed_number], |
|
cache_examples="lazy", |
|
) |
|
|
|
def update_image_outputs(image_pil, seed_val): |
|
return { |
|
result: image_pil, |
|
seed_number: seed_val, |
|
download_btn: gr.DownloadButton(value=image_pil, visible=True) |
|
if image_pil |
|
else gr.DownloadButton(visible=False), |
|
} |
|
|
|
run_button.click( |
|
fn=generate_image, |
|
inputs=[prompt, seed_slider, width_slider, height_slider, steps_slider], |
|
outputs=[result, seed_number], |
|
).then( |
|
lambda img_path, vid_accordion, vid_btn: { |
|
vid_accordion: gr.Accordion(open=True), |
|
vid_btn: gr.Button(interactive=True), |
|
}, |
|
inputs=[result], |
|
outputs=[], |
|
) |
|
|
|
video_result_output = gr.Video(label="Generated Video", show_label=False) |
|
|
|
with gr.Accordion("Video Generation from Image", open=False) as video_gen_accordion: |
|
video_prompt_input = gr.Text( |
|
label="Prompt for Video", |
|
placeholder="Describe the animation or changes for the video (e.g., 'camera zooms out slowly')", |
|
value="A gentle breeze rustles the leaves, subtle camera movement.", |
|
) |
|
with gr.Row(): |
|
video_duration_input = gr.Dropdown(label="Duration (seconds)", choices=["5", "10"], value="5") |
|
video_aspect_ratio_input = gr.Dropdown( |
|
label="Aspect Ratio", |
|
choices=["16:9", "9:16", "1:1"], |
|
value="16:9", |
|
) |
|
video_negative_prompt_input = gr.Text( |
|
label="Negative Prompt for Video", |
|
value="blur, distort, low quality", |
|
) |
|
video_cfg_scale_input = gr.Slider( |
|
label="CFG Scale for Video", |
|
minimum=0.0, |
|
maximum=10.0, |
|
value=0.5, |
|
step=0.1, |
|
) |
|
generate_video_btn = gr.Button("Generate Video", interactive=False) |
|
|
|
generate_video_btn.click( |
|
fn=generate_video_from_image, |
|
inputs=[ |
|
result, |
|
video_prompt_input, |
|
video_duration_input, |
|
video_aspect_ratio_input, |
|
video_negative_prompt_input, |
|
video_cfg_scale_input, |
|
], |
|
outputs=[video_result_output], |
|
) |
|
|
|
run_button.click( |
|
fn=generate_image, |
|
inputs=[prompt, seed_slider, width_slider, height_slider, steps_slider], |
|
outputs=[result, seed_number], |
|
).then( |
|
lambda image_filepath: { |
|
video_gen_accordion: gr.Accordion(open=True), |
|
generate_video_btn: gr.Button(interactive=True if image_filepath else False), |
|
download_btn: gr.DownloadButton(value=image_filepath, visible=True if image_filepath else False), |
|
}, |
|
inputs=[result], |
|
outputs=[video_gen_accordion, generate_video_btn, download_btn], |
|
) |
|
with gr.Accordion("Download Image from URL", open=False): |
|
image_url_input = gr.Text(label="Image URL", placeholder="Enter image URL (e.g., http://.../image.png)") |
|
filename_input = gr.Text( |
|
label="Filename (optional)", |
|
placeholder=" Filename", |
|
) |
|
download_from_url_btn = gr.DownloadButton(label="Download Image") |
|
|
|
download_from_url_btn.click( |
|
fn=download_locally, |
|
inputs=[image_url_input, filename_input], |
|
outputs=[download_from_url_btn], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(mcp_server=True) |
|
|