File size: 387 Bytes
7021da8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
from transformers_neuronx.mixtral.model import MixtralForSampling
import torch
def model_fn(model_dir):
# Load the Neuron-compiled model from the directory
model = MixtralForSampling.from_pretrained(model_dir)
return model
def predict_fn(input_data, model):
# Implement prediction logic
with torch.no_grad():
outputs = model(**input_data)
return outputs |