tykiww's picture
Create model.py
426e1d8 verified
raw
history blame
1.27 kB
from unsloth import FastLanguageModel
class InferencePipeline:
def __init__(self, conf, api_key):
self.conf = conf
self.token = api_key
self.model, self.tokenizer = self.get_model()
def get_model(self):
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = self.conf["model"]["model_name"],
max_seq_length = self.conf["model"]["max_seq_length"],
dtype = self.conf["model"]["dtype"],
load_in_4bit = self.conf["model"]["load_in_4bit"],
token = self.token
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
return model, tokenizer
def infer(self, prompt):
inputs = self.tokenizer([prompt], return_tensors = "pt").to("cuda")
outputs = model.generate(**inputs,
max_new_tokens = self.conf["model"]["max_new_tokens"],
use_cache = True)
outputs = tokenizer.batch_decode(outputs)
return outputs
#pipeline = InferencePipeline(conf,
# api_key=keys["huggingface"],
# prompt,
# context
# )
#
#pipeline.infer(prompt)