Spaces:
Runtime error
Runtime error
| """ | |
| Usage: | |
| python3 -m fastchat.serve.huggingface_api --model ~/model_weights/vicuna-7b/ | |
| """ | |
| import argparse | |
| import json | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from fastchat.conversation import get_default_conv_template, compute_skip_echo_len | |
| from fastchat.serve.inference import load_model | |
| def main(args): | |
| model, tokenizer = load_model( | |
| args.model_path, | |
| args.device, | |
| args.num_gpus, | |
| args.max_gpu_memory, | |
| args.load_8bit, | |
| debug=args.debug, | |
| ) | |
| msg = args.message | |
| conv = get_default_conv_template(args.model_path).copy() | |
| conv.append_message(conv.roles[0], msg) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| inputs = tokenizer([prompt]) | |
| output_ids = model.generate( | |
| torch.as_tensor(inputs.input_ids).cuda(), | |
| do_sample=True, | |
| temperature=0.7, | |
| max_new_tokens=1024, | |
| ) | |
| outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] | |
| skip_echo_len = compute_skip_echo_len(args.model_path, conv, prompt) | |
| outputs = outputs[skip_echo_len:] | |
| print(f"{conv.roles[0]}: {msg}") | |
| print(f"{conv.roles[1]}: {outputs}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model-path", | |
| type=str, | |
| default="facebook/opt-350m", | |
| help="The path to the weights", | |
| ) | |
| parser.add_argument( | |
| "--device", type=str, choices=["cpu", "cuda", "mps"], default="cuda" | |
| ) | |
| parser.add_argument("--num-gpus", type=str, default="1") | |
| parser.add_argument( | |
| "--max-gpu-memory", | |
| type=str, | |
| help="The maximum memory per gpu. Use a string like '13Gib'", | |
| ) | |
| parser.add_argument( | |
| "--load-8bit", action="store_true", help="Use 8-bit quantization." | |
| ) | |
| parser.add_argument( | |
| "--conv-template", type=str, default=None, help="Conversation prompt template." | |
| ) | |
| parser.add_argument("--temperature", type=float, default=0.7) | |
| parser.add_argument("--max-new-tokens", type=int, default=512) | |
| parser.add_argument("--debug", action="store_true") | |
| parser.add_argument("--message", type=str, default="Hello! Who are you?") | |
| args = parser.parse_args() | |
| main(args) | |