import os import re import torch import datetime import json import csv import gc from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import TextIteratorStreamer from transformers import BitsAndBytesConfig, GPTQConfig from threading import Thread tokenizer = None model = None default_cfg = { 'model_name': None, 'qtype': 'bnb', 'dtype': '4bit', 'instruction': None, 'inst_template': None, 'chat_template': None, 'max_new_tokens': 1024, 'temperature': 0.9, 'top_p': 0.95, 'top_k': 40, 'repetition_penalty': 1.2, } cfg = default_cfg.copy() def load_model(model_name, qtype = 'bnb', dtype = '4bit'): global tokenizer, model, cfg if cfg['model_name'] == model_name and cfg['qtype'] == qtype and cfg['dtype'] == dtype: return del model del tokenizer model = None tokenizer = None gc.collect() torch.cuda.empty_cache() tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) match qtype: case 'bnb': match dtype: case '4bit' | 'int4': kwargs = dict( quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, ), ) case '8bit' | 'int8': kwargs = dict( quantization_config=BitsAndBytesConfig( load_in_8bit=True, bnb_4bit_compute_dtype=torch.bfloat16, ), ) case 'fp16': kwargs = dict( torch_dtype=torch.float16, ) case 'bf16': kwargs = dict( torch_dtype=torch.bfloat16, ) case _: kwargs = dict() case 'gptq': match dtype: case '4bit' | 'int4': kwargs = dict( quantization_config=GPTQConfig( bits=4, tokenizer=tokenizer, ), ) case '8bit' | 'int8': kwargs = dict( quantization_config=GPTQConfig( bits=8, tokenizer=tokenizer, ), ) case 'gguf': kwargs = dict( gguf_file=qtype, ) case 'awq': match dtype: case 'fa2': kwargs = dict( use_flash_attention_2=True, ) case _: kwargs = dict() model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True, **kwargs, ) cfg['model_name'] = model_name cfg['qtype'] = qtype cfg['dtype'] = dtype def clear_config(): global cfg cfg = default_cfg.copy() def set_config(model_name, qtype, dtype, instruction, inst_template, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty): global cfg load_model(model_name, qtype, dtype) cfg.update({ 'instruction': instruction, 'inst_template': inst_template, 'chat_template': chat_template, 'max_new_tokens': int(max_new_tokens), 'temperature': float(temperature), 'top_p': float(top_p), 'top_k': int(top_k), 'repetition_penalty': float(repetition_penalty), }) return 'done.' def set_config_args(args): global cfg load_model(args['model_name'], args['qtype'], args['dtype']) cfg.update(args) return 'done.' def chatinterface_to_messages(message, history): global cfg messages = [] if cfg['instruction']: messages.append({'role': 'system', 'content': cfg['instruction']}) for pair in history: [user, assistant] = pair if user: messages.append({'role': 'user', 'content': user}) if assistant: messages.append({'role': 'assistant', 'content': assistant}) if message: messages.append({'role': 'user', 'content': message}) return messages def chat(message, history = [], instruction = None, args = {}): global tokenizer, model, cfg if instruction: cfg['instruction'] = instruction prompt = apply_template(message) else: messages = chatinterface_to_messages(message, history) prompt = apply_template(messages) model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device) streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True, ) generate_kwargs = dict( model_inputs, streamer=streamer, do_sample=True, num_beams=1, ) for k in [ 'max_new_tokens', 'temperature', 'top_p', 'top_k', 'repetition_penalty' ]: if cfg[k]: generate_kwargs[k] = cfg[k] t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() model_output = "" for new_text in streamer: model_output += new_text if 'fastapi' in args: # fastapiは差分だけを返して欲しい yield new_text else: # gradioは常に全文を返して欲しい yield model_output return model_output def infer(args: dict): global cfg if 'model_name' in args: load_model(args['model_name'], args['qtype'], args['dtype']) for k in [ 'instruction', 'inst_template', 'chat_template', 'max_new_tokens', 'temperature', 'top_p', 'top_k', 'repetition_penalty' ]: cfg[k] = args[k] if 'messages' in args: return chat(args['input'], args['messages']) if 'instruction' in args: return instruct(args['instruction'], args['input']) def apply_template(messages): global tokenizer, cfg if cfg['chat_template']: tokenizer.chat_template = cfg['chat_template'] if type(messages) is str: if cfg['inst_template']: return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages) return cfg['instruction'] if type(messages) is list: return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)