Spaces:
Runtime error
Runtime error
| #!python | |
| # -*- coding: utf-8 -*- | |
| # @author: Kun | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig | |
| max_token: int = 10000 # 10000 # 64 | |
| temperature: float = 0.75 | |
| top_p = 0.9 | |
| use_lora = False | |
| # model_name_or_path = "Hannes-Epoch/falcon-7b-instruct-8bit" # not work, miss file | |
| def load_model(opt="gptq"): | |
| if "pt" == opt: | |
| return load_pt_model() | |
| elif "gptq" == opt: | |
| return load_gptq_model() | |
| else: | |
| raise Exception("not supported opt: {}".format(opt)) | |
| ######################################################################################################## | |
| def load_gptq_model(): | |
| model_name_or_path = "TheBloke/falcon-7b-instruct-GPTQ" | |
| # You could also download the model locally, and access it there | |
| # model_name_or_path = "/path/to/TheBloke_falcon-7b-instruct-GPTQ" | |
| model_basename = "gptq_model-4bit-64g" | |
| use_triton = False | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name_or_path, use_fast=True) | |
| model = AutoGPTQForCausalLM.from_quantized(model_name_or_path, | |
| model_basename=model_basename, | |
| use_safetensors=True, | |
| trust_remote_code=True, | |
| device="cuda:0", | |
| use_triton=use_triton, | |
| quantize_config=None) | |
| return tokenizer, model | |
| ######################################################################################################## | |
| def load_pt_model(): | |
| model_name_or_path = "tiiuae/falcon-7b" | |
| # model_name_or_path = "tiiuae/falcon-7b-instruct" | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=True, | |
| device_map='auto', | |
| # load_in_8bit=True, # not working "RWForCausalLM.__init__() got an unexpected keyword argument 'load_in_8bit'" | |
| ) | |
| return tokenizer, model | |
| ######################################################################################################## |