mixtral-compiled / .ipynb_checkpoints /inference-checkpoint.py
stenio123's picture
Upload folder using huggingface_hub
7021da8 verified
from transformers_neuronx.mixtral.model import MixtralForSampling
import torch
def model_fn(model_dir):
# Load the Neuron-compiled model from the directory
model = MixtralForSampling.from_pretrained(model_dir)
return model
def predict_fn(input_data, model):
# Implement prediction logic
with torch.no_grad():
outputs = model(**input_data)
return outputs