fuwangwang commited on
Commit
ffc95f0
1 Parent(s): 8251ef9

Create code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +9 -0
code/inference.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import torch
3
+
4
+ def model_fn(model_dir):
5
+ """
6
+ Overrides the default model load function in the HuggingFace Deep Learning Container
7
+ """
8
+ instruct_pipeline = pipeline(model="fuwangwang/mpt-7b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
9
+ return instruct_pipeline