zhangbofei
commited on
Commit
•
2238fe2
1
Parent(s):
8d7d353
fix: src
Browse files- src/serve/api_provider.py +1 -1
- src/serve/base_model_worker.py +3 -3
- src/serve/cli.py +7 -7
- src/serve/controller.py +2 -2
- src/serve/gradio_block_arena_named.py +5 -5
- src/serve/gradio_block_arena_vision_anony.py +8 -8
- src/serve/huggingface_api.py +1 -1
- src/serve/huggingface_api_worker.py +3 -3
- src/serve/inference.py +7 -7
- src/serve/lightllm_worker.py +2 -2
- src/serve/mlx_worker.py +3 -3
- src/serve/model_worker.py +8 -8
- src/serve/multi_model_worker.py +11 -11
- src/serve/openai_api_server.py +6 -6
src/serve/api_provider.py
CHANGED
@@ -9,7 +9,7 @@ import time
|
|
9 |
|
10 |
import requests
|
11 |
|
12 |
-
from
|
13 |
|
14 |
|
15 |
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
|
|
9 |
|
10 |
import requests
|
11 |
|
12 |
+
from src.utils import build_logger
|
13 |
|
14 |
|
15 |
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
src/serve/base_model_worker.py
CHANGED
@@ -7,9 +7,9 @@ from fastapi import FastAPI, Request, BackgroundTasks
|
|
7 |
from fastapi.responses import StreamingResponse, JSONResponse
|
8 |
import requests
|
9 |
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from
|
13 |
|
14 |
|
15 |
worker = None
|
|
|
7 |
from fastapi.responses import StreamingResponse, JSONResponse
|
8 |
import requests
|
9 |
|
10 |
+
from src.constants import WORKER_HEART_BEAT_INTERVAL
|
11 |
+
from src.conversation import Conversation
|
12 |
+
from src.utils import pretty_print_semaphore, build_logger
|
13 |
|
14 |
|
15 |
worker = None
|
src/serve/cli.py
CHANGED
@@ -28,13 +28,13 @@ from rich.live import Live
|
|
28 |
from rich.markdown import Markdown
|
29 |
import torch
|
30 |
|
31 |
-
from
|
32 |
-
from
|
33 |
-
from
|
34 |
-
from
|
35 |
-
from
|
36 |
-
from
|
37 |
-
from
|
38 |
|
39 |
|
40 |
class SimpleChatIO(ChatIO):
|
|
|
28 |
from rich.markdown import Markdown
|
29 |
import torch
|
30 |
|
31 |
+
from src.model.model_adapter import add_model_args
|
32 |
+
from src.modules.awq import AWQConfig
|
33 |
+
from src.modules.exllama import ExllamaConfig
|
34 |
+
from src.modules.xfastertransformer import XftConfig
|
35 |
+
from src.modules.gptq import GptqConfig
|
36 |
+
from src.serve.inference import ChatIO, chat_loop
|
37 |
+
from src.utils import str_to_torch_dtype
|
38 |
|
39 |
|
40 |
class SimpleChatIO(ChatIO):
|
src/serve/controller.py
CHANGED
@@ -19,13 +19,13 @@ import numpy as np
|
|
19 |
import requests
|
20 |
import uvicorn
|
21 |
|
22 |
-
from
|
23 |
CONTROLLER_HEART_BEAT_EXPIRATION,
|
24 |
WORKER_API_TIMEOUT,
|
25 |
ErrorCode,
|
26 |
SERVER_ERROR_MSG,
|
27 |
)
|
28 |
-
from
|
29 |
|
30 |
|
31 |
logger = build_logger("controller", "controller.log")
|
|
|
19 |
import requests
|
20 |
import uvicorn
|
21 |
|
22 |
+
from src.constants import (
|
23 |
CONTROLLER_HEART_BEAT_EXPIRATION,
|
24 |
WORKER_API_TIMEOUT,
|
25 |
ErrorCode,
|
26 |
SERVER_ERROR_MSG,
|
27 |
)
|
28 |
+
from src.utils import build_logger
|
29 |
|
30 |
|
31 |
logger = build_logger("controller", "controller.log")
|
src/serve/gradio_block_arena_named.py
CHANGED
@@ -9,14 +9,14 @@ import time
|
|
9 |
import gradio as gr
|
10 |
import numpy as np
|
11 |
|
12 |
-
from
|
13 |
MODERATION_MSG,
|
14 |
CONVERSATION_LIMIT_MSG,
|
15 |
INPUT_CHAR_LEN_LIMIT,
|
16 |
CONVERSATION_TURN_LIMIT,
|
17 |
)
|
18 |
-
from
|
19 |
-
from
|
20 |
State,
|
21 |
bot_response,
|
22 |
get_conv_log_filename,
|
@@ -29,8 +29,8 @@ from fastchat.serve.gradio_web_server import (
|
|
29 |
_prepare_text_with_image,
|
30 |
get_model_description_md,
|
31 |
)
|
32 |
-
from
|
33 |
-
from
|
34 |
build_logger,
|
35 |
moderation_filter,
|
36 |
)
|
|
|
9 |
import gradio as gr
|
10 |
import numpy as np
|
11 |
|
12 |
+
from src.constants import (
|
13 |
MODERATION_MSG,
|
14 |
CONVERSATION_LIMIT_MSG,
|
15 |
INPUT_CHAR_LEN_LIMIT,
|
16 |
CONVERSATION_TURN_LIMIT,
|
17 |
)
|
18 |
+
from src.model.model_adapter import get_conversation_template
|
19 |
+
from src.serve.gradio_web_server import (
|
20 |
State,
|
21 |
bot_response,
|
22 |
get_conv_log_filename,
|
|
|
29 |
_prepare_text_with_image,
|
30 |
get_model_description_md,
|
31 |
)
|
32 |
+
from src.serve.remote_logger import get_remote_logger
|
33 |
+
from src.utils import (
|
34 |
build_logger,
|
35 |
moderation_filter,
|
36 |
)
|
src/serve/gradio_block_arena_vision_anony.py
CHANGED
@@ -9,7 +9,7 @@ import time
|
|
9 |
import gradio as gr
|
10 |
import numpy as np
|
11 |
|
12 |
-
from
|
13 |
TEXT_MODERATION_MSG,
|
14 |
IMAGE_MODERATION_MSG,
|
15 |
MODERATION_MSG,
|
@@ -18,9 +18,9 @@ from fastchat.constants import (
|
|
18 |
INPUT_CHAR_LEN_LIMIT,
|
19 |
CONVERSATION_TURN_LIMIT,
|
20 |
)
|
21 |
-
from
|
22 |
-
from
|
23 |
-
from
|
24 |
State,
|
25 |
bot_response,
|
26 |
get_conv_log_filename,
|
@@ -33,7 +33,7 @@ from fastchat.serve.gradio_web_server import (
|
|
33 |
get_model_description_md,
|
34 |
_prepare_text_with_image,
|
35 |
)
|
36 |
-
from
|
37 |
flash_buttons,
|
38 |
vote_last_response,
|
39 |
leftvote_last_response,
|
@@ -50,15 +50,15 @@ from fastchat.serve.gradio_block_arena_anony import (
|
|
50 |
get_sample_weight,
|
51 |
get_battle_pair,
|
52 |
)
|
53 |
-
from
|
54 |
get_vqa_sample,
|
55 |
set_invisible_image,
|
56 |
set_visible_image,
|
57 |
add_image,
|
58 |
moderate_input,
|
59 |
)
|
60 |
-
from
|
61 |
-
from
|
62 |
build_logger,
|
63 |
moderation_filter,
|
64 |
image_moderation_filter,
|
|
|
9 |
import gradio as gr
|
10 |
import numpy as np
|
11 |
|
12 |
+
from src.constants import (
|
13 |
TEXT_MODERATION_MSG,
|
14 |
IMAGE_MODERATION_MSG,
|
15 |
MODERATION_MSG,
|
|
|
18 |
INPUT_CHAR_LEN_LIMIT,
|
19 |
CONVERSATION_TURN_LIMIT,
|
20 |
)
|
21 |
+
from src.model.model_adapter import get_conversation_template
|
22 |
+
from src.serve.gradio_block_arena_named import flash_buttons
|
23 |
+
from src.serve.gradio_web_server import (
|
24 |
State,
|
25 |
bot_response,
|
26 |
get_conv_log_filename,
|
|
|
33 |
get_model_description_md,
|
34 |
_prepare_text_with_image,
|
35 |
)
|
36 |
+
from src.serve.gradio_block_arena_anony import (
|
37 |
flash_buttons,
|
38 |
vote_last_response,
|
39 |
leftvote_last_response,
|
|
|
50 |
get_sample_weight,
|
51 |
get_battle_pair,
|
52 |
)
|
53 |
+
from src.serve.gradio_block_arena_vision import (
|
54 |
get_vqa_sample,
|
55 |
set_invisible_image,
|
56 |
set_visible_image,
|
57 |
add_image,
|
58 |
moderate_input,
|
59 |
)
|
60 |
+
from src.serve.remote_logger import get_remote_logger
|
61 |
+
from src.utils import (
|
62 |
build_logger,
|
63 |
moderation_filter,
|
64 |
image_moderation_filter,
|
src/serve/huggingface_api.py
CHANGED
@@ -9,7 +9,7 @@ import argparse
|
|
9 |
|
10 |
import torch
|
11 |
|
12 |
-
from
|
13 |
|
14 |
|
15 |
@torch.inference_mode()
|
|
|
9 |
|
10 |
import torch
|
11 |
|
12 |
+
from src.model import load_model, get_conversation_template, add_model_args
|
13 |
|
14 |
|
15 |
@torch.inference_mode()
|
src/serve/huggingface_api_worker.py
CHANGED
@@ -34,9 +34,9 @@ from fastapi import BackgroundTasks, FastAPI, Request
|
|
34 |
from fastapi.responses import JSONResponse, StreamingResponse
|
35 |
from huggingface_hub import InferenceClient
|
36 |
|
37 |
-
from
|
38 |
-
from
|
39 |
-
from
|
40 |
|
41 |
worker_id = str(uuid.uuid4())[:8]
|
42 |
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
|
|
34 |
from fastapi.responses import JSONResponse, StreamingResponse
|
35 |
from huggingface_hub import InferenceClient
|
36 |
|
37 |
+
from src.constants import SERVER_ERROR_MSG, ErrorCode
|
38 |
+
from src.serve.base_model_worker import BaseModelWorker
|
39 |
+
from src.utils import build_logger
|
40 |
|
41 |
worker_id = str(uuid.uuid4())[:8]
|
42 |
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
src/serve/inference.py
CHANGED
@@ -29,17 +29,17 @@ from transformers.generation.logits_process import (
|
|
29 |
TopPLogitsWarper,
|
30 |
)
|
31 |
|
32 |
-
from
|
33 |
-
from
|
34 |
load_model,
|
35 |
get_conversation_template,
|
36 |
get_generate_stream_function,
|
37 |
)
|
38 |
-
from
|
39 |
-
from
|
40 |
-
from
|
41 |
-
from
|
42 |
-
from
|
43 |
|
44 |
|
45 |
def prepare_logits_processor(
|
|
|
29 |
TopPLogitsWarper,
|
30 |
)
|
31 |
|
32 |
+
from src.conversation import get_conv_template, SeparatorStyle
|
33 |
+
from src.model.model_adapter import (
|
34 |
load_model,
|
35 |
get_conversation_template,
|
36 |
get_generate_stream_function,
|
37 |
)
|
38 |
+
from src.modules.awq import AWQConfig
|
39 |
+
from src.modules.gptq import GptqConfig
|
40 |
+
from src.modules.exllama import ExllamaConfig
|
41 |
+
from src.modules.xfastertransformer import XftConfig
|
42 |
+
from src.utils import is_partial_stop, is_sentence_complete, get_context_length
|
43 |
|
44 |
|
45 |
def prepare_logits_processor(
|
src/serve/lightllm_worker.py
CHANGED
@@ -18,8 +18,8 @@ from typing import List
|
|
18 |
from fastapi import FastAPI, Request, BackgroundTasks
|
19 |
from fastapi.responses import StreamingResponse, JSONResponse
|
20 |
|
21 |
-
from
|
22 |
-
from
|
23 |
logger,
|
24 |
worker_id,
|
25 |
)
|
|
|
18 |
from fastapi import FastAPI, Request, BackgroundTasks
|
19 |
from fastapi.responses import StreamingResponse, JSONResponse
|
20 |
|
21 |
+
from src.serve.base_model_worker import BaseModelWorker
|
22 |
+
from src.serve.model_worker import (
|
23 |
logger,
|
24 |
worker_id,
|
25 |
)
|
src/serve/mlx_worker.py
CHANGED
@@ -22,12 +22,12 @@ from fastapi.concurrency import run_in_threadpool
|
|
22 |
from fastapi.responses import StreamingResponse, JSONResponse
|
23 |
import uvicorn
|
24 |
|
25 |
-
from
|
26 |
-
from
|
27 |
logger,
|
28 |
worker_id,
|
29 |
)
|
30 |
-
from
|
31 |
|
32 |
import mlx.core as mx
|
33 |
from mlx_lm import load, generate
|
|
|
22 |
from fastapi.responses import StreamingResponse, JSONResponse
|
23 |
import uvicorn
|
24 |
|
25 |
+
from src.serve.base_model_worker import BaseModelWorker
|
26 |
+
from src.serve.model_worker import (
|
27 |
logger,
|
28 |
worker_id,
|
29 |
)
|
30 |
+
from src.utils import get_context_length, is_partial_stop
|
31 |
|
32 |
import mlx.core as mx
|
33 |
from mlx_lm import load, generate
|
src/serve/model_worker.py
CHANGED
@@ -14,18 +14,18 @@ import torch.nn.functional as F
|
|
14 |
from transformers import set_seed
|
15 |
import uvicorn
|
16 |
|
17 |
-
from
|
18 |
-
from
|
19 |
load_model,
|
20 |
add_model_args,
|
21 |
get_generate_stream_function,
|
22 |
)
|
23 |
-
from
|
24 |
-
from
|
25 |
-
from
|
26 |
-
from
|
27 |
-
from
|
28 |
-
from
|
29 |
build_logger,
|
30 |
get_context_length,
|
31 |
str_to_torch_dtype,
|
|
|
14 |
from transformers import set_seed
|
15 |
import uvicorn
|
16 |
|
17 |
+
from src.constants import ErrorCode, SERVER_ERROR_MSG
|
18 |
+
from src.model.model_adapter import (
|
19 |
load_model,
|
20 |
add_model_args,
|
21 |
get_generate_stream_function,
|
22 |
)
|
23 |
+
from src.modules.awq import AWQConfig
|
24 |
+
from src.modules.exllama import ExllamaConfig
|
25 |
+
from src.modules.xfastertransformer import XftConfig
|
26 |
+
from src.modules.gptq import GptqConfig
|
27 |
+
from src.serve.base_model_worker import BaseModelWorker, app
|
28 |
+
from src.utils import (
|
29 |
build_logger,
|
30 |
get_context_length,
|
31 |
str_to_torch_dtype,
|
src/serve/multi_model_worker.py
CHANGED
@@ -44,21 +44,21 @@ import torch
|
|
44 |
import torch.nn.functional as F
|
45 |
import uvicorn
|
46 |
|
47 |
-
from
|
48 |
-
from
|
49 |
load_model,
|
50 |
add_model_args,
|
51 |
get_conversation_template,
|
52 |
)
|
53 |
-
from
|
54 |
-
from
|
55 |
-
from
|
56 |
-
from
|
57 |
-
from
|
58 |
-
from
|
59 |
-
from
|
60 |
-
from
|
61 |
-
from
|
62 |
|
63 |
|
64 |
# We store both the underlying workers and a mapping from their model names to
|
|
|
44 |
import torch.nn.functional as F
|
45 |
import uvicorn
|
46 |
|
47 |
+
from src.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG
|
48 |
+
from src.model.model_adapter import (
|
49 |
load_model,
|
50 |
add_model_args,
|
51 |
get_conversation_template,
|
52 |
)
|
53 |
+
from src.model.model_chatglm import generate_stream_chatglm
|
54 |
+
from src.model.model_falcon import generate_stream_falcon
|
55 |
+
from src.model.model_codet5p import generate_stream_codet5p
|
56 |
+
from src.modules.gptq import GptqConfig
|
57 |
+
from src.modules.exllama import ExllamaConfig
|
58 |
+
from src.modules.xfastertransformer import XftConfig
|
59 |
+
from src.serve.inference import generate_stream
|
60 |
+
from src.serve.model_worker import ModelWorker, worker_id, logger
|
61 |
+
from src.utils import build_logger, pretty_print_semaphore, get_context_length
|
62 |
|
63 |
|
64 |
# We store both the underlying workers and a mapping from their model names to
|
src/serve/openai_api_server.py
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings)
|
6 |
|
7 |
Usage:
|
8 |
-
python3 -m
|
9 |
"""
|
10 |
import asyncio
|
11 |
import argparse
|
@@ -27,13 +27,13 @@ import shortuuid
|
|
27 |
import tiktoken
|
28 |
import uvicorn
|
29 |
|
30 |
-
from
|
31 |
WORKER_API_TIMEOUT,
|
32 |
WORKER_API_EMBEDDING_BATCH_SIZE,
|
33 |
ErrorCode,
|
34 |
)
|
35 |
-
from
|
36 |
-
from
|
37 |
ChatCompletionRequest,
|
38 |
ChatCompletionResponse,
|
39 |
ChatCompletionResponseStreamChoice,
|
@@ -55,13 +55,13 @@ from fastchat.protocol.openai_api_protocol import (
|
|
55 |
ModelPermission,
|
56 |
UsageInfo,
|
57 |
)
|
58 |
-
from
|
59 |
APIChatCompletionRequest,
|
60 |
APITokenCheckRequest,
|
61 |
APITokenCheckResponse,
|
62 |
APITokenCheckResponseItem,
|
63 |
)
|
64 |
-
from
|
65 |
|
66 |
logger = build_logger("openai_api_server", "openai_api_server.log")
|
67 |
|
|
|
5 |
- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings)
|
6 |
|
7 |
Usage:
|
8 |
+
python3 -m src.serve.openai_api_server
|
9 |
"""
|
10 |
import asyncio
|
11 |
import argparse
|
|
|
27 |
import tiktoken
|
28 |
import uvicorn
|
29 |
|
30 |
+
from src.constants import (
|
31 |
WORKER_API_TIMEOUT,
|
32 |
WORKER_API_EMBEDDING_BATCH_SIZE,
|
33 |
ErrorCode,
|
34 |
)
|
35 |
+
from src.conversation import Conversation, SeparatorStyle
|
36 |
+
from src.protocol.openai_api_protocol import (
|
37 |
ChatCompletionRequest,
|
38 |
ChatCompletionResponse,
|
39 |
ChatCompletionResponseStreamChoice,
|
|
|
55 |
ModelPermission,
|
56 |
UsageInfo,
|
57 |
)
|
58 |
+
from src.protocol.api_protocol import (
|
59 |
APIChatCompletionRequest,
|
60 |
APITokenCheckRequest,
|
61 |
APITokenCheckResponse,
|
62 |
APITokenCheckResponseItem,
|
63 |
)
|
64 |
+
from src.utils import build_logger
|
65 |
|
66 |
logger = build_logger("openai_api_server", "openai_api_server.log")
|
67 |
|