|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import os |
|
import subprocess |
|
|
|
|
|
|
|
def install(package): |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) |
|
|
|
try: |
|
import bitsandbytes |
|
except ImportError: |
|
install("bitsandbytes==0.39.1") |
|
|
|
try: |
|
import accelerate |
|
except ImportError: |
|
install("accelerate==0.20.0") |
|
|
|
class ModelHandler: |
|
def __init__(self): |
|
self.model = None |
|
self.tokenizer = None |
|
|
|
def load_model(self): |
|
|
|
model_id = "NiCETmtm/llama3_torch" |
|
token = os.getenv("HF_API_TOKEN") |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, trust_remote_code=True, from_tf=True) |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token, trust_remote_code=True) |
|
|
|
def predict(self, inputs): |
|
tokens = self.tokenizer(inputs, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = self.model.generate(**tokens) |
|
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
model_handler = ModelHandler() |
|
model_handler.load_model() |
|
|
|
def inference(event, context): |
|
inputs = event["data"] |
|
outputs = model_handler.predict(inputs) |
|
return {"predictions": outputs} |
|
|