trllm / fn.py
aka7774's picture
Upload 2 files
0be9b92 verified
raw
history blame
No virus
6.4 kB
import os
import re
import torch
import datetime
import json
import csv
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TextIteratorStreamer
from transformers import BitsAndBytesConfig, GPTQConfig
from threading import Thread
tokenizer = None
model = None
default_cfg = {
'model_name': None,
'qtype': 'bnb',
'dtype': '4bit',
'instruction': None,
'inst_template': None,
'chat_template': None,
'max_new_tokens': 1024,
'temperature': 0.9,
'top_p': 0.95,
'top_k': 40,
'repetition_penalty': 1.2,
}
cfg = default_cfg.copy()
def load_model(model_name, qtype = 'bnb', dtype = '4bit'):
global tokenizer, model, cfg
if cfg['model_name'] == model_name and cfg['qtype'] == qtype and cfg['dtype'] == dtype:
return
del model
del tokenizer
model = None
tokenizer = None
gc.collect()
torch.cuda.empty_cache()
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
match qtype:
case 'bnb':
match dtype:
case '4bit' | 'int4':
kwargs = dict(
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
),
)
case '8bit' | 'int8':
kwargs = dict(
quantization_config=BitsAndBytesConfig(
load_in_8bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
),
)
case 'fp16':
kwargs = dict(
torch_dtype=torch.float16,
)
case 'bf16':
kwargs = dict(
torch_dtype=torch.bfloat16,
)
case _:
kwargs = dict()
case 'gptq':
match dtype:
case '4bit' | 'int4':
kwargs = dict(
quantization_config=GPTQConfig(
bits=4,
tokenizer=tokenizer,
),
)
case '8bit' | 'int8':
kwargs = dict(
quantization_config=GPTQConfig(
bits=8,
tokenizer=tokenizer,
),
)
case 'gguf':
kwargs = dict(
gguf_file=qtype,
)
case 'awq':
match dtype:
case 'fa2':
kwargs = dict(
use_flash_attention_2=True,
)
case _:
kwargs = dict()
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
**kwargs,
)
cfg['model_name'] = model_name
cfg['qtype'] = qtype
cfg['dtype'] = dtype
def clear_config():
global cfg
cfg = default_cfg.copy()
def set_config(model_name, qtype, dtype, instruction, inst_template, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
global cfg
load_model(model_name, qtype, dtype)
cfg.update({
'instruction': instruction,
'inst_template': inst_template,
'chat_template': chat_template,
'max_new_tokens': int(max_new_tokens),
'temperature': float(temperature),
'top_p': float(top_p),
'top_k': int(top_k),
'repetition_penalty': float(repetition_penalty),
})
return 'done.'
def set_config_args(args):
global cfg
load_model(args['model_name'], args['qtype'], args['dtype'])
cfg.update(args)
return 'done.'
def chatinterface_to_messages(message, history):
global cfg
messages = []
if cfg['instruction']:
messages.append({'role': 'system', 'content': cfg['instruction']})
for pair in history:
[user, assistant] = pair
if user:
messages.append({'role': 'user', 'content': user})
if assistant:
messages.append({'role': 'assistant', 'content': assistant})
if message:
messages.append({'role': 'user', 'content': message})
return messages
def apply_template(messages):
global tokenizer, cfg
if cfg['chat_template']:
tokenizer.chat_template = cfg['chat_template']
if type(messages) is str:
if cfg['inst_template']:
return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
return cfg['instruction'].format(input=messages)
if type(messages) is list:
return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
def chat(message, history = [], instruction = None, args = {}):
global tokenizer, model, cfg
if instruction:
cfg['instruction'] = instruction
prompt = apply_template(message)
else:
messages = chatinterface_to_messages(message, history)
prompt = apply_template(messages)
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True,
)
generate_kwargs = dict(
model_inputs,
do_sample=True,
streamer=streamer,
num_beams=1,
)
for k in [
'max_new_tokens',
'temperature',
'top_p',
'top_k',
'repetition_penalty'
]:
if cfg[k]:
generate_kwargs[k] = cfg[k]
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
model_output = ""
for new_text in streamer:
model_output += new_text
if 'fastapi' in args:
# fastapiは差分だけを返して欲しい
yield new_text
else:
# gradioは常に全文を返して欲しい
yield model_output
def infer(message, history = [], instruction = None, args = {}):
content = ''
for s in chat(message, history, instruction, args):
content += s
return content