trllm / fn.py
aka7774's picture
Upload 8 files
581f159 verified
import os
import re
import torch
import datetime
import json
import csv
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import TextStreamer, TextIteratorStreamer
from transformers import GenerationConfig, AutoConfig, GPTQConfig, AwqConfig
from models import models
tokenizer = None
model = None
loaded_model_name = None
loaded_dtype = None
def load_model(model_name, dtype = 'int4'):
global tokenizer, model, loaded_model_name, loaded_dtype
if loaded_model_name == model_name and loaded_dtype == dtype:
return
del model
del tokenizer
model = None
tokenizer = None
gc.collect()
torch.cuda.empty_cache()
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if dtype == 'int4':
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
),
)
elif dtype == 'int8':
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
quantization_config=BitsAndBytesConfig(
torch_dtype=torch.bfloat16,
load_in_8bit=True,
),
)
elif dtype == 'fp16':
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
)
elif dtype == 'bf16':
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
device_map="auto",
)
loaded_model_name = model_name
loaded_dtype = dtype
def infer(args: dict):
global tokenizer, model, loaded_model_name
if 'model' in args:
args['model_name'] = args['model']
if not tokenizer or 'model_name' in args and loaded_model_name != args['model_name']:
if 'dtype' in args:
load_model(args['model_name'], args['dtype'])
else:
load_model(args['model_name'])
config = {}
if args['model_name'] in models:
config = models[args['model_name']]
config.update(args)
if config['is_messages']:
messages = []
messages.append({"role": "system", "content": args['instruction']})
if args['input']:
messages.append({"role": "user", "content": args['input']})
tprompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
else:
tprompt = config['template'].format(bos_token=tokenizer.bos_token, instruction=args['instruction'], input=args['input'])
kwargs = config.copy()
for k in ['model_name', 'template', 'instruction', 'input', 'location', 'endpoint', 'model', 'dtype', 'is_messages']:
if k in kwargs:
del kwargs[k]
with torch.no_grad():
token_ids = tokenizer.encode(tprompt, add_special_tokens=False, return_tensors="pt")
if config['is_messages']:
output_ids = model.generate(
input_ids=token_ids.to(model.device),
do_sample=True,
**kwargs,
)
else:
output_ids = model.generate(
input_ids=token_ids.to(model.device),
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
**kwargs,
)
out = output_ids.tolist()[0][token_ids.size(1) :]
content = tokenizer.decode(out, skip_special_tokens=True)
return content