File size: 308 Bytes
05009dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
from transformers import pipeline
def model_fn(model_dir):
instruct_pipeline = pipeline(
model=model_dir,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
model_kwargs={"load_in_8bit": True},
)
return instruct_pipeline |