import fal class FalModel(): def __init__(self, model_name, model_type): self.model_name = model_name self.modle_type = model_type def __call__(self, *args, **kwargs): if self.model_type == "text2image": assert "prompt" in kwargs, "prompt is required for text2image model" handler = fal.apps.submit( f"fal-ai/{self.model_name}", arguments={ "prompt": kwargs["prompt"] }, ) for event in handler.iter_events(): if isinstance(event, fal.apps.InProgress): print('Request in progress') print(event.logs) result = handler.get() return result elif self.model_type == "image2image": assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model" if "image" in kwargs: image_url = None pass handler = fal.apps.submit( f"fal-ai/{self.model_name}", arguments={ "image_url": image_url }, ) for event in handler.iter_events(): if isinstance(event, fal.apps.InProgress): print('Request in progress') print(event.logs) result = handler.get() return result else: raise ValueError("model_type must be text2image or image2image") def load_fal_model(model_name, model_type): return FalModel(model_name, model_type)