Spaces:
Restarting
on
Zero
Restarting
on
Zero
import fal_client | |
from PIL import Image | |
import requests | |
import io | |
import os | |
class FalModel(): | |
def __init__(self, model_name, model_type): | |
self.model_name = model_name | |
self.model_type = model_type | |
os.environ['FAL_KEY'] = os.environ['FalAPI'] | |
def __call__(self, *args, **kwargs): | |
if self.model_type == "text2image": | |
assert "prompt" in kwargs, "prompt is required for text2image model" | |
handler = fal_client.submit( | |
f"fal-ai/{self.model_name}", | |
arguments={ | |
"prompt": kwargs["prompt"] | |
}, | |
) | |
for event in handler.iter_events(with_logs=True): | |
if isinstance(event, fal_client.InProgress): | |
print('Request in progress') | |
print(event.logs) | |
result = handler.get() | |
result_url = result['images'][0]['url'] | |
response = requests.get(result_url) | |
result = Image.open(io.BytesIO(response.content)) | |
return result | |
elif self.model_type == "image2image": | |
raise NotImplementedError("image2image model is not implemented yet") | |
# assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model" | |
# if "image" in kwargs: | |
# image_url = None | |
# pass | |
# handler = fal_client.submit( | |
# f"fal-ai/{self.model_name}", | |
# arguments={ | |
# "image_url": image_url | |
# }, | |
# ) | |
# | |
# for event in handler.iter_events(): | |
# if isinstance(event, fal_client.InProgress): | |
# print('Request in progress') | |
# print(event.logs) | |
# | |
# result = handler.get() | |
# return result | |
elif self.model_type == "text2video": | |
assert "prompt" in kwargs, "prompt is required for text2video model" | |
if self.model_name == 'AnimateDiff': | |
fal_model_name = 'fast-animatediff/text-to-video' | |
elif self.model_name == 'AnimateDiffTurbo': | |
fal_model_name = 'fast-animatediff/turbo/text-to-video' | |
else: | |
raise NotImplementedError(f"text2video model of {self.model_name} in fal is not implemented yet") | |
handler = fal_client.submit( | |
f"fal-ai/{fal_model_name}", | |
arguments={ | |
"prompt": kwargs["prompt"] | |
}, | |
) | |
for event in handler.iter_events(with_logs=True): | |
if isinstance(event, fal_client.InProgress): | |
print('Request in progress') | |
print(event.logs) | |
result = handler.get() | |
print("result video: ====") | |
print(result) | |
result_url = result['video']['url'] | |
return result_url | |
else: | |
raise ValueError("model_type must be text2image or image2image") | |
def load_fal_model(model_name, model_type): | |
return FalModel(model_name, model_type) |