Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import torch | |
| import datetime | |
| import json | |
| import csv | |
| import gc | |
| import local_gemma | |
| from transformers import AutoTokenizer, TextStreamer | |
| from transformers import TextIteratorStreamer | |
| from transformers import BitsAndBytesConfig, GPTQConfig | |
| from threading import Thread | |
| tokenizer = None | |
| model = None | |
| cfg = { | |
| 'size': None, | |
| } | |
| default_args = { | |
| 'instruction': None, | |
| 'first_assistant': None, | |
| 'chat_template': None, | |
| 'max_new_tokens': 1024, | |
| 'temperature': 0.9, | |
| 'top_p': 0.95, | |
| 'top_k': 40, | |
| 'repetition_penalty': 1.2, | |
| } | |
| chat_past_key_values = {} | |
| chat_messages = {} | |
| def load_model(size = '9b'): | |
| global tokenizer, model, cfg | |
| if cfg['size'] == size: | |
| return | |
| del model | |
| del tokenizer | |
| model = None | |
| tokenizer = None | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| model_name = f"SillyTilly/google-gemma-2-{size}-it" | |
| model = local_gemma.LocalGemma2ForCausalLM.from_pretrained(model_name, preset="memory") | |
| model._supports_cache_class = True | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| cfg['size'] = size | |
| def set_config(size, instruction, first_assistant, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty): | |
| global default_args | |
| load_model(size) | |
| default_args.update({ | |
| 'instruction': instruction, | |
| 'first_assistant': first_assistant, | |
| 'chat_template': chat_template, | |
| 'max_new_tokens': int(max_new_tokens), | |
| 'temperature': float(temperature), | |
| 'top_p': float(top_p), | |
| 'top_k': int(top_k), | |
| 'repetition_penalty': float(repetition_penalty), | |
| }) | |
| return 'done.' | |
| def set_config_args(args): | |
| global default_args | |
| load_model(args['size']) | |
| default_args.update(args) | |
| return 'done.' | |
| def chatinterface_to_messages(history): | |
| messages = [] | |
| for pair in history: | |
| [user, assistant] = pair | |
| if user: | |
| messages.append({'role': 'user', 'content': user}) | |
| if assistant: | |
| messages.append({'role': 'assistant', 'content': assistant}) | |
| return messages | |
| # わりとややこしい | |
| def tokenize(user_input, history = [], instruction = None, conversation_id = 'gradio', args = {}): | |
| global tokenizer, chat_messages | |
| # 先頭挿入用の形式づくり | |
| inst_messages = [] | |
| if instruction: | |
| if 'first_assistant' in args and args['first_assistant']: | |
| # Claude互換形式 | |
| # userとassistantは交互に存在しないといけない | |
| inst_messages = [ | |
| {'role': 'user', 'content': instruction}, | |
| {'role': 'assistant', 'content': args['first_assistant']}, | |
| ] | |
| else: | |
| # OpenAI互換形式 | |
| inst_messages = [{'role': 'system', 'content': instruction}] | |
| # messagesがあるときは全部上書きする | |
| if conversation_id and 'messages' in args: | |
| chat_messages[conversation_id] = inst_messages + args['messages'] | |
| # cacheがあるならmessages形式で送る | |
| # instructionは既にcacheされているので不要(途中変更不可) | |
| if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]: | |
| # user_inputを追加する | |
| chat_messages[conversation_id] += [{'role': 'user', 'content': user_input}] | |
| tokenized_chat = tokenizer.apply_chat_template( | |
| chat_messages[conversation_id], tokenize=True, add_generation_prompt=True, return_tensors="pt" | |
| ) | |
| else: | |
| # instructionがあれば適用する(inputは任意) | |
| if instruction: | |
| user_input = instruction.format(input=user_input) | |
| # どっちも無いとさすがにエラー | |
| if not user_input: | |
| raise ValueError('require input or instruction.') | |
| tokenized_chat = tokenizer(user_input, return_tensors="pt").input_ids | |
| return tokenized_chat | |
| def chat(message, history = [], instruction = None, conversation_id = 'gradio', args = {}): | |
| global tokenizer, model, chat_past_key_values, chat_messages | |
| for k, v in default_args.items(): | |
| args.setdefault(k, v) | |
| cache = None | |
| # conversation_idがあるときはcacheを読む | |
| if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]: | |
| # clearが指定されてるなら最初に消す | |
| if 'clear' in args and args['clear']: | |
| chat_past_key_values[conversation_id] = None | |
| chat_messages[conversation_id] = None | |
| else: | |
| cache = chat_past_key_values[conversation_id] | |
| # chat_templateがあれば適用する | |
| if args['chat_template']: | |
| tokenizer.chat_template = args['chat_template'] | |
| device = local_gemma.utils.config.infer_device(None) | |
| generation_kwargs = local_gemma.utils.config.get_generation_kwargs('chat') | |
| # tokenizeする | |
| tokenized_chat = tokenize(message, history, instruction, conversation_id, args).to(device) | |
| streamer = TextStreamer(tokenizer, skip_prompt=True, **{"skip_special_tokens": True}) | |
| generation_kwargs.update( | |
| { | |
| "streamer": streamer, | |
| "assistant_model": None, | |
| "return_dict_in_generate": True, | |
| "past_key_values": cache, | |
| } | |
| ) | |
| for k in [ | |
| 'max_new_tokens', | |
| 'temperature', | |
| 'top_p', | |
| 'top_k', | |
| 'repetition_penalty' | |
| ]: | |
| if args[k]: | |
| generation_kwargs[k] = args[k] | |
| # TODO(joao): this if shouldn't be needed, fix in transformers | |
| if cache is not None: | |
| generation_kwargs["cache_implementation"] = None | |
| if args['max_new_tokens'] is not None: | |
| input_ids_len = tokenized_chat.shape[-1] | |
| max_cache_len = args['max_new_tokens'] + input_ids_len | |
| if cache is not None and cache.max_cache_len < max_cache_len: | |
| # reset the cache | |
| generation_kwargs.pop("past_key_values") | |
| generation_kwargs["cache_implementation"] = "hybrid" | |
| else: | |
| generation_kwargs["max_length"] = model.config.max_position_embeddings | |
| gen_out = model.generate(input_ids=tokenized_chat, **generation_kwargs) | |
| model_tokens = gen_out.sequences[0, tokenized_chat.shape[1]:] | |
| model_output_text = tokenizer.decode(model_tokens, skip_special_tokens=True) | |
| if conversation_id and conversation_id in chat_messages and chat_messages[conversation_id]: | |
| # Store the cache for the next generation round; Pull the model output into the chat history. | |
| chat_past_key_values[conversation_id] = gen_out.past_key_values | |
| chat_messages[conversation_id] += [{"role": "user", "content": message},] | |
| chat_messages[conversation_id] += [{"role": "assistant", "content": model_output_text},] | |
| # Sanity check: EOS was removed, ends in "<end_of_turn>\n" | |
| tokenized_chat = tokenizer.apply_chat_template( | |
| chat_messages[conversation_id], tokenize=True, add_generation_prompt=False, return_tensors="pt" | |
| ).tolist()[0] | |
| assert tokenized_chat[0] == 2 | |
| assert tokenized_chat[-1] == 108 | |
| assert tokenized_chat[-2] == 107 | |
| # TODO: stream対応 | |
| return model_output_text | |
| # 非streamで返す | |
| def infer(message, history = [], instruction = None, conversation_id = 'gradio', args = {}): | |
| return chat(message, history, instruction, conversation_id, args) | |
| def numel(message, history = [], instruction = None, conversation_id = 'gradio', args = {}): | |
| global tokenizer, chat_messages | |
| tokenized_chat = tokenize(message, history, instruction, conversation_id, args) | |
| return torch.numel(tokenized_chat) | |