mixtral-compiled / inference.py
stenio123's picture
Upload folder using huggingface_hub
7021da8 verified
raw
history blame contribute delete
No virus
387 Bytes
from transformers_neuronx.mixtral.model import MixtralForSampling
import torch
def model_fn(model_dir):
# Load the Neuron-compiled model from the directory
model = MixtralForSampling.from_pretrained(model_dir)
return model
def predict_fn(input_data, model):
# Implement prediction logic
with torch.no_grad():
outputs = model(**input_data)
return outputs