syde660summaryL3 / handler.py
4llengodlike's picture
Upload 2 files
8cc6b8b verified
raw
history blame contribute delete
794 Bytes
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class ModelHandler:
def __init__(self):
self.model = None
self.tokenizer = None
def load_model(self):
# 加载模型和分词器
self.model = AutoModelForCausalLM.from_pretrained("your-model-path")
self.tokenizer = AutoTokenizer.from_pretrained("your-model-path")
def predict(self, inputs):
# 将输入转换为模型可以处理的格式
inputs = self.tokenizer(inputs, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
return outputs
handler = ModelHandler()
handler.load_model()
def handler(event, context):
inputs = event["data"]
outputs = handler.predict(inputs)
return outputs