Spaces:
Configuration error
Configuration error
File size: 628 Bytes
05e6f93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch
from sparrow_parse.vllm.inference_base import ModelInference
class LocalGPUInference(ModelInference):
def __init__(self, model, device='cuda'):
self.model = model
self.device = device
self.model.to(self.device)
def inference(self, input_data, mode=None):
self.model.eval() # Set the model to evaluation mode
with torch.no_grad(): # No need to calculate gradients
input_tensor = torch.tensor(input_data).to(self.device)
output = self.model(input_tensor)
return output.cpu().numpy() # Convert the output back to NumPy if necessary
|