Spaces:
Sleeping
Sleeping
File size: 4,102 Bytes
581f159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import os
import re
import torch
import datetime
import json
import csv
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import TextStreamer, TextIteratorStreamer
from transformers import GenerationConfig, AutoConfig, GPTQConfig, AwqConfig
from models import models
tokenizer = None
model = None
loaded_model_name = None
loaded_dtype = None
def load_model(model_name, dtype = 'int4'):
global tokenizer, model, loaded_model_name, loaded_dtype
if loaded_model_name == model_name and loaded_dtype == dtype:
return
del model
del tokenizer
model = None
tokenizer = None
gc.collect()
torch.cuda.empty_cache()
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if dtype == 'int4':
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
),
)
elif dtype == 'int8':
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
quantization_config=BitsAndBytesConfig(
torch_dtype=torch.bfloat16,
load_in_8bit=True,
),
)
elif dtype == 'fp16':
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
)
elif dtype == 'bf16':
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
device_map="auto",
)
loaded_model_name = model_name
loaded_dtype = dtype
def infer(args: dict):
global tokenizer, model, loaded_model_name
if 'model' in args:
args['model_name'] = args['model']
if not tokenizer or 'model_name' in args and loaded_model_name != args['model_name']:
if 'dtype' in args:
load_model(args['model_name'], args['dtype'])
else:
load_model(args['model_name'])
config = {}
if args['model_name'] in models:
config = models[args['model_name']]
config.update(args)
if config['is_messages']:
messages = []
messages.append({"role": "system", "content": args['instruction']})
if args['input']:
messages.append({"role": "user", "content": args['input']})
tprompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
else:
tprompt = config['template'].format(bos_token=tokenizer.bos_token, instruction=args['instruction'], input=args['input'])
kwargs = config.copy()
for k in ['model_name', 'template', 'instruction', 'input', 'location', 'endpoint', 'model', 'dtype', 'is_messages']:
if k in kwargs:
del kwargs[k]
with torch.no_grad():
token_ids = tokenizer.encode(tprompt, add_special_tokens=False, return_tensors="pt")
if config['is_messages']:
output_ids = model.generate(
input_ids=token_ids.to(model.device),
do_sample=True,
**kwargs,
)
else:
output_ids = model.generate(
input_ids=token_ids.to(model.device),
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
**kwargs,
)
out = output_ids.tolist()[0][token_ids.size(1) :]
content = tokenizer.decode(out, skip_special_tokens=True)
return content
|