stenio123 commited on
Commit
7021da8
1 Parent(s): bb95a2e

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/inference-checkpoint.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers_neuronx.mixtral.model import MixtralForSampling
2
+ import torch
3
+
4
+ def model_fn(model_dir):
5
+ # Load the Neuron-compiled model from the directory
6
+ model = MixtralForSampling.from_pretrained(model_dir)
7
+ return model
8
+
9
+ def predict_fn(input_data, model):
10
+ # Implement prediction logic
11
+ with torch.no_grad():
12
+ outputs = model(**input_data)
13
+ return outputs
inference.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers_neuronx.mixtral.model import MixtralForSampling
2
+ import torch
3
+
4
+ def model_fn(model_dir):
5
+ # Load the Neuron-compiled model from the directory
6
+ model = MixtralForSampling.from_pretrained(model_dir)
7
+ return model
8
+
9
+ def predict_fn(input_data, model):
10
+ # Implement prediction logic
11
+ with torch.no_grad():
12
+ outputs = model(**input_data)
13
+ return outputs