Spaces:
Runtime error
Runtime error
""" | |
Usage: | |
python3 -m fastchat.serve.huggingface_api --model ~/model_weights/vicuna-7b/ | |
""" | |
import argparse | |
import json | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from fastchat.conversation import get_default_conv_template, compute_skip_echo_len | |
from fastchat.serve.inference import load_model | |
def main(args): | |
model, tokenizer = load_model( | |
args.model_path, | |
args.device, | |
args.num_gpus, | |
args.max_gpu_memory, | |
args.load_8bit, | |
debug=args.debug, | |
) | |
msg = args.message | |
conv = get_default_conv_template(args.model_path).copy() | |
conv.append_message(conv.roles[0], msg) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
inputs = tokenizer([prompt]) | |
output_ids = model.generate( | |
torch.as_tensor(inputs.input_ids).cuda(), | |
do_sample=True, | |
temperature=0.7, | |
max_new_tokens=1024, | |
) | |
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] | |
skip_echo_len = compute_skip_echo_len(args.model_path, conv, prompt) | |
outputs = outputs[skip_echo_len:] | |
print(f"{conv.roles[0]}: {msg}") | |
print(f"{conv.roles[1]}: {outputs}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model-path", | |
type=str, | |
default="facebook/opt-350m", | |
help="The path to the weights", | |
) | |
parser.add_argument( | |
"--device", type=str, choices=["cpu", "cuda", "mps"], default="cuda" | |
) | |
parser.add_argument("--num-gpus", type=str, default="1") | |
parser.add_argument( | |
"--max-gpu-memory", | |
type=str, | |
help="The maximum memory per gpu. Use a string like '13Gib'", | |
) | |
parser.add_argument( | |
"--load-8bit", action="store_true", help="Use 8-bit quantization." | |
) | |
parser.add_argument( | |
"--conv-template", type=str, default=None, help="Conversation prompt template." | |
) | |
parser.add_argument("--temperature", type=float, default=0.7) | |
parser.add_argument("--max-new-tokens", type=int, default=512) | |
parser.add_argument("--debug", action="store_true") | |
parser.add_argument("--message", type=str, default="Hello! Who are you?") | |
args = parser.parse_args() | |
main(args) | |