File size: 4,384 Bytes
18fa92b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import onnxruntime as ort
import numpy as np
import torch
import torchaudio
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM
from typing import Union, List, Dict, Any
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class InferenceEngine:
    def __init__(self, model_path: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
        """
        Initialize the InferenceEngine.

        Args:
            model_path (str): Path to the ONNX model.
            device (str): Device to run the model on ("cuda" or "cpu").
        """
        self.device = device
        try:
            # Initialize ONNX runtime session
            self.session = ort.InferenceSession(
                model_path,
                providers=[
                    "TensorrtExecutionProvider",
                    "CUDAExecutionProvider",
                    "CPUExecutionProvider"
                ]
            )
            logger.info(f"ONNX model loaded successfully on device: {self.device}")
        except Exception as e:
            logger.error(f"Failed to load ONNX model: {e}")
            raise

    def run_text_inference(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: str, max_length: int = 200) -> str:
        """
        Run text inference using a causal language model.

        Args:
            model (AutoModelForCausalLM): Pre-trained causal language model.
            tokenizer (AutoTokenizer): Tokenizer for the model.
            prompt (str): Input text prompt.
            max_length (int): Maximum length of the generated text.

        Returns:
            str: Generated text.
        """
        try:
            inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
            outputs = model.generate(**inputs, max_length=max_length)
            return tokenizer.decode(outputs[0], skip_special_tokens=True)
        except Exception as e:
            logger.error(f"Text inference failed: {e}")
            raise

    def run_image_inference(self, clip_model: CLIPModel, processor: CLIPProcessor, image_path: str) -> np.ndarray:
        """
        Run image inference using a CLIP model.

        Args:
            clip_model (CLIPModel): Pre-trained CLIP model.
            processor (CLIPProcessor): Processor for the CLIP model.
            image_path (str): Path to the input image.

        Returns:
            np.ndarray: Image features as a numpy array.
        """
        try:
            image = Image.open(image_path).convert("RGB")
            inputs = processor(images=image, return_tensors="pt").to(self.device)
            outputs = clip_model.get_image_features(**inputs)
            return outputs.cpu().detach().numpy()
        except Exception as e:
            logger.error(f"Image inference failed: {e}")
            raise

    def run_audio_inference(self, whisper_model: Any, audio_file: str) -> str:
        """
        Run audio inference using a Whisper model.

        Args:
            whisper_model (Any): Pre-trained Whisper model.
            audio_file (str): Path to the input audio file.

        Returns:
            str: Transcribed text.
        """
        try:
            waveform, sample_rate = torchaudio.load(audio_file)
            waveform = waveform.to(self.device)
            return whisper_model.transcribe(waveform)["text"]
        except Exception as e:
            logger.error(f"Audio inference failed: {e}")
            raise

    def run_general_inference(self, input_data: Union[np.ndarray, List, Dict]) -> np.ndarray:
        """
        Run general inference using the ONNX model.

        Args:
            input_data (Union[np.ndarray, List, Dict]): Input data for the model.

        Returns:
            np.ndarray: Model output.
        """
        try:
            input_name = self.session.get_inputs()[0].name
            output_name = self.session.get_outputs()[0].name

            # Ensure input_data is a numpy array
            if not isinstance(input_data, np.ndarray):
                input_data = np.array(input_data, dtype=np.float32)

            return self.session.run([output_name], {input_name: input_data})[0]
        except Exception as e:
            logger.error(f"General inference failed: {e}")
            raise