red-pajama-3b-sagemaker / inference.py
mlashcorp's picture
Upload 2 files
05009dc
raw
history blame
308 Bytes
import torch
from transformers import pipeline
def model_fn(model_dir):
instruct_pipeline = pipeline(
model=model_dir,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
model_kwargs={"load_in_8bit": True},
)
return instruct_pipeline