Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| from typing import Type | |
| from api.aws import AWSBedrockAPI | |
| from api.baseline import BaselineAPI | |
| from api.fal import FalAPI | |
| from api.fireworks import FireworksAPI | |
| from api.flux import FluxAPI | |
| from api.pruna import PrunaAPI | |
| from api.pruna_dev import PrunaDevAPI | |
| from api.replicate import ReplicateAPI | |
| from api.together import TogetherAPI | |
| __all__ = [ | |
| "create_api", | |
| "FluxAPI", | |
| "BaselineAPI", | |
| "FireworksAPI", | |
| "PrunaAPI", | |
| "ReplicateAPI", | |
| "TogetherAPI", | |
| "FalAPI", | |
| "PrunaDevAPI", | |
| ] | |
| def create_api(api_type: str) -> FluxAPI: | |
| """ | |
| Factory function to create API instances. | |
| Args: | |
| api_type (str): The type of API to create. Must be one of: | |
| - "baseline" | |
| - "fireworks" | |
| - "pruna_speed_mode" (where speed_mode is the desired speed mode) | |
| - "replicate" | |
| - "together" | |
| - "fal" | |
| - "aws" | |
| Returns: | |
| FluxAPI: An instance of the requested API implementation | |
| Raises: | |
| ValueError: If an invalid API type is provided | |
| """ | |
| if api_type == "pruna_dev": | |
| return PrunaDevAPI() | |
| if api_type.startswith("pruna_"): | |
| speed_mode = api_type[6:] # Remove "pruna_" prefix | |
| return PrunaAPI(speed_mode) | |
| api_map: dict[str, Type[FluxAPI]] = { | |
| "baseline": BaselineAPI, | |
| "fireworks": FireworksAPI, | |
| "replicate": ReplicateAPI, | |
| "together": TogetherAPI, | |
| "fal": FalAPI, | |
| "aws": AWSBedrockAPI, | |
| } | |
| if api_type not in api_map: | |
| raise ValueError( | |
| f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'" | |
| ) | |
| return api_map[api_type]() | |