KOMUChat / koalpaca.py
4n3mone's picture
Update koalpaca.py
162fbc7
raw
history blame
1.77 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig
from peft import PeftModel, PeftConfig
from model import Model
from accelerate import Accelerator
class KoAlpaca(Model):
def __init__(self):
peft_model_id = "4n3mone/Komuchat-koalpaca-polyglot-12.8B"
config = PeftConfig.from_pretrained(peft_model_id)
self.accelerator = Accelerator()
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
#self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map={"":0})
self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map='auto')
self.model = PeftModel.from_pretrained(self.model, peft_model_id)
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
self.gen_config = GenerationConfig.from_pretrained('./models/koalpaca', 'gen_config.json')
self.INPUT_FORMAT = "### 질문: <INPUT>\n\n### 답변:"
self.model.eval()
def generate(self, inputs):
inputs = self.INPUT_FORMAT.replace('<INPUT>', inputs)
output_ids = self.model.generate(
**self.tokenizer(
inputs,
return_tensors='pt',
return_token_type_ids=False
).to(self.accelerator.device),
generation_config=self.gen_config
)
outputs = self.tokenizer.decode(output_ids[0]).split("### 답변: ")[-1]
return outputs