File size: 1,166 Bytes
05e6f93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from sparrow_parse.vllm.huggingface_inference import HuggingFaceInference
from sparrow_parse.vllm.local_gpu_inference import LocalGPUInference
from sparrow_parse.vllm.mlx_inference import MLXInference


class InferenceFactory:
    def __init__(self, config):
        self.config = config

    def get_inference_instance(self):
        if self.config["method"] == "huggingface":
            return HuggingFaceInference(hf_space=self.config["hf_space"], hf_token=self.config["hf_token"])
        elif self.config["method"] == "local_gpu":
            model = self._load_local_model()  # Replace with actual model loading logic
            return LocalGPUInference(model=model, device=self.config.get("device", "cuda"))
        elif self.config["method"] == "mlx":
            return MLXInference(model_name=self.config["model_name"])
        else:
            raise ValueError(f"Unknown method: {self.config['method']}")

    def _load_local_model(self):
        # Example: Load a PyTorch model (replace with actual loading code)
        # model = torch.load('model.pth')
        # return model
        raise NotImplementedError("Model loading logic not implemented")