import os import re import torch import datetime import json import csv import gc import local_gemma from transformers import AutoTokenizer, TextStreamer from transformers import TextIteratorStreamer from transformers import BitsAndBytesConfig, GPTQConfig from threading import Thread tokenizer = None model = None cfg = { 'size': None, } default_args = { 'instruction': None, 'first_assistant': None, 'chat_template': None, 'max_new_tokens': 1024, 'temperature': 0.9, 'top_p': 0.95, 'top_k': 40, 'repetition_penalty': 1.2, } chat_past_key_values = {} chat_messages = {} def load_model(size = '9b'): global tokenizer, model, cfg if cfg['size'] == size: return del model del tokenizer model = None tokenizer = None gc.collect() torch.cuda.empty_cache() model_name = f"SillyTilly/google-gemma-2-{size}-it" model = local_gemma.LocalGemma2ForCausalLM.from_pretrained(model_name, preset="memory") model._supports_cache_class = True tokenizer = AutoTokenizer.from_pretrained(model_name) cfg['size'] = size def set_config(size, instruction, first_assistant, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty): global default_args load_model(size) default_args.update({ 'instruction': instruction, 'first_assistant': first_assistant, 'chat_template': chat_template, 'max_new_tokens': int(max_new_tokens), 'temperature': float(temperature), 'top_p': float(top_p), 'top_k': int(top_k), 'repetition_penalty': float(repetition_penalty), }) return 'done.' def set_config_args(args): global default_args load_model(args['size']) default_args.update(args) return 'done.' def chatinterface_to_messages(history): messages = [] for pair in history: [user, assistant] = pair if user: messages.append({'role': 'user', 'content': user}) if assistant: messages.append({'role': 'assistant', 'content': assistant}) return messages # わりとややこしい def tokenize(user_input, history = [], instruction = None, conversation_id = 'gradio', args = {}): global tokenizer, chat_messages # 先頭挿入用の形式づくり inst_messages = [] if instruction: if 'first_assistant' in args and args['first_assistant']: # Claude互換形式 # userとassistantは交互に存在しないといけない inst_messages = [ {'role': 'user', 'content': instruction}, {'role': 'assistant', 'content': args['first_assistant']}, ] else: # OpenAI互換形式 inst_messages = [{'role': 'system', 'content': instruction}] # messagesがあるときは全部上書きする if conversation_id and 'messages' in args: chat_messages[conversation_id] = inst_messages + args['messages'] # cacheがあるならmessages形式で送る # instructionは既にcacheされているので不要(途中変更不可) if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]: # user_inputを追加する chat_messages[conversation_id] += [{'role': 'user', 'content': user_input}] tokenized_chat = tokenizer.apply_chat_template( chat_messages[conversation_id], tokenize=True, add_generation_prompt=True, return_tensors="pt" ) else: # instructionがあれば適用する(inputは任意) if instruction: user_input = instruction.format(input=user_input) # どっちも無いとさすがにエラー if not user_input: raise ValueError('require input or instruction.') tokenized_chat = tokenizer(user_input, return_tensors="pt").input_ids return tokenized_chat def chat(message, history = [], instruction = None, conversation_id = 'gradio', args = {}): global tokenizer, model, chat_past_key_values, chat_messages for k, v in default_args.items(): args.setdefault(k, v) cache = None # conversation_idがあるときはcacheを読む if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]: # clearが指定されてるなら最初に消す if 'clear' in args and args['clear']: chat_past_key_values[conversation_id] = None chat_messages[conversation_id] = None else: cache = chat_past_key_values[conversation_id] # chat_templateがあれば適用する if args['chat_template']: tokenizer.chat_template = args['chat_template'] device = local_gemma.utils.config.infer_device(None) generation_kwargs = local_gemma.utils.config.get_generation_kwargs('chat') # tokenizeする tokenized_chat = tokenize(message, history, instruction, conversation_id, args).to(device) streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True}) generation_kwargs.update( { "streamer": streamer, "assistant_model": None, "return_dict_in_generate": True, "past_key_values": cache, } ) for k in [ 'max_new_tokens', 'temperature', 'top_p', 'top_k', 'repetition_penalty' ]: if args[k]: generation_kwargs[k] = args[k] # TODO(joao): this if shouldn't be needed, fix in transformers if cache is not None: generation_kwargs["cache_implementation"] = None if args['max_new_tokens'] is not None: input_ids_len = tokenized_chat.shape[-1] max_cache_len = args['max_new_tokens'] + input_ids_len if cache is not None and cache.max_cache_len < max_cache_len: # reset the cache generation_kwargs.pop("past_key_values") generation_kwargs["cache_implementation"] = "hybrid" else: generation_kwargs["max_length"] = model.config.max_position_embeddings gen_out = model.generate(input_ids=tokenized_chat, **generation_kwargs) model_tokens = gen_out.sequences[0, tokenized_chat.shape[1]:] model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True) if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]: # Store the cache for the next generation round; Pull the model output into the chat history. chat_past_key_values[conversation_id] = gen_out.past_key_values chat_messages[conversation_id] += [{"role": "user", "content": message},] chat_messages[conversation_id] += [{"role": "assistant", "content": model_output_text},] # Sanity check: EOS was removed, ends in "\n" tokenized_chat = tokenizer.apply_chat_template( chat_messages[conversation_id], tokenize=True, add_generation_prompt=False, return_tensors="pt" ).tolist()[0] assert tokenized_chat[0] == 2 assert tokenized_chat[-1] == 108 assert tokenized_chat[-2] == 107 # TODO: stream対応 return model_output_text # 非streamで返す def infer(message, history = [], instruction = None, conversation_id = 'gradio', args = {}): return chat(message, history, instruction, conversation_id, args) def numel(message, history = [], instruction = None, conversation_id = 'gradio', args = {}): global tokenizer, chat_messages tokenized_chat = tokenize(message, history, instruction, conversation_id, args) return torch.numel(tokenized_chat)