mlashcorp commited on
Commit
05009dc
1 Parent(s): 694b9d6

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference.py +14 -0
  2. requirements.txt +3 -0
inference.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import pipeline
3
+
4
+
5
+ def model_fn(model_dir):
6
+ instruct_pipeline = pipeline(
7
+ model=model_dir,
8
+ torch_dtype=torch.bfloat16,
9
+ trust_remote_code=True,
10
+ device_map="auto",
11
+ model_kwargs={"load_in_8bit": True},
12
+ )
13
+
14
+ return instruct_pipeline
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ accelerate==0.18.0
2
+ transformers==4.27.2
3
+ bitsandbytes==0.38.1