Spaces:
Runtime error
Runtime error
File size: 6,177 Bytes
5a7ab71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
"""
Chat with a model with command line interface.
Usage:
python3 -m fastchat.serve.cli --model ~/model_weights/llama-7b
"""
import argparse
import os
import re
from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.history import InMemoryHistory
from rich.console import Console
from rich.markdown import Markdown
from rich.live import Live
from fastchat.serve.inference import chat_loop, ChatIO
class SimpleChatIO(ChatIO):
def prompt_for_input(self, role) -> str:
return input(f"{role}: ")
def prompt_for_output(self, role: str):
print(f"{role}: ", end="", flush=True)
def stream_output(self, output_stream, skip_echo_len: int):
pre = 0
for outputs in output_stream:
outputs = outputs[skip_echo_len:].strip()
outputs = outputs.split(" ")
now = len(outputs) - 1
if now > pre:
print(" ".join(outputs[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(outputs[pre:]), flush=True)
return " ".join(outputs)
class RichChatIO(ChatIO):
def __init__(self):
self._prompt_session = PromptSession(history=InMemoryHistory())
self._completer = WordCompleter(
words=["!exit", "!reset"], pattern=re.compile("$")
)
self._console = Console()
def prompt_for_input(self, role) -> str:
self._console.print(f"[bold]{role}:")
# TODO(suquark): multiline input has some issues. fix it later.
prompt_input = self._prompt_session.prompt(
completer=self._completer,
multiline=False,
auto_suggest=AutoSuggestFromHistory(),
key_bindings=None,
)
self._console.print()
return prompt_input
def prompt_for_output(self, role: str):
self._console.print(f"[bold]{role}:")
def stream_output(self, output_stream, skip_echo_len: int):
"""Stream output from a role."""
# TODO(suquark): the console flickers when there is a code block
# above it. We need to cut off "live" when a code block is done.
# Create a Live context for updating the console output
with Live(console=self._console, refresh_per_second=4) as live:
# Read lines from the stream
for outputs in output_stream:
accumulated_text = outputs[skip_echo_len:]
if not accumulated_text:
continue
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
# in rich. The chatbots output treat "\n" as a new line for
# better compatibility with real-world text. However, rendering
# in markdown would break the format. It is because standard markdown
# treat a single "\n" in normal text as a space.
# Our workaround is adding two spaces at the end of each line.
# This is not a perfect solution, as it would
# introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not
# care about trailing spaces.
lines = []
for line in accumulated_text.splitlines():
lines.append(line)
if line.startswith("```"):
# Code block marker - do not add trailing spaces, as it would
# break the syntax highlighting
lines.append("\n")
else:
lines.append(" \n")
markdown = Markdown("".join(lines))
# Update the Live console output
live.update(markdown)
self._console.print()
return outputs[skip_echo_len:]
def main(args):
if args.gpus:
if args.num_gpus and len(args.gpus.split(",")) < int(args.num_gpus):
raise ValueError(f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!")
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
if args.style == "simple":
chatio = SimpleChatIO()
elif args.style == "rich":
chatio = RichChatIO()
else:
raise ValueError(f"Invalid style for console: {args.style}")
try:
chat_loop(
args.model_path,
args.device,
args.num_gpus,
args.max_gpu_memory,
args.load_8bit,
args.conv_template,
args.temperature,
args.max_new_tokens,
chatio,
args.debug,
)
except KeyboardInterrupt:
print("exit...")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path",
type=str,
default="facebook/opt-350m",
help="The path to the weights",
)
parser.add_argument(
"--device", type=str, choices=["cpu", "cuda", "mps"], default="cuda"
)
parser.add_argument(
"--gpus",
type=str,
default=None,
help="A single GPU like 1 or multiple GPUs like 0,2"
)
parser.add_argument("--num-gpus", type=str, default="1")
parser.add_argument(
"--max-gpu-memory",
type=str,
help="The maximum memory per gpu. Use a string like '13Gib'",
)
parser.add_argument(
"--load-8bit", action="store_true", help="Use 8-bit quantization."
)
parser.add_argument(
"--conv-template", type=str, default=None, help="Conversation prompt template."
)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument(
"--style",
type=str,
default="simple",
choices=["simple", "rich"],
help="Display style.",
)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
main(args)
|