File size: 628 Bytes
05e6f93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from sparrow_parse.vllm.inference_base import ModelInference


class LocalGPUInference(ModelInference):
    def __init__(self, model, device='cuda'):
        self.model = model
        self.device = device
        self.model.to(self.device)

    def inference(self, input_data, mode=None):
        self.model.eval()  # Set the model to evaluation mode
        with torch.no_grad():  # No need to calculate gradients
            input_tensor = torch.tensor(input_data).to(self.device)
            output = self.model(input_tensor)
        return output.cpu().numpy()  # Convert the output back to NumPy if necessary