gordonchan's picture
Upload 41 files
ca56e6a verified
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from api.config import SETTINGS
from api.utils.compat import model_dump
def create_app() -> FastAPI:
""" create fastapi app server """
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return app
def create_embedding_model():
""" get embedding model from sentence-transformers. """
if SETTINGS.tei_endpoint is not None:
from openai import AsyncOpenAI
client = AsyncOpenAI(base_url=SETTINGS.tei_endpoint, api_key="none")
else:
from sentence_transformers import SentenceTransformer
client = SentenceTransformer(SETTINGS.embedding_name, device=SETTINGS.embedding_device)
return client
def create_generate_model():
""" get generate model for chat or completion. """
from api.core.default import DefaultEngine
from api.adapter.model import load_model
if SETTINGS.patch_type == "attention":
from api.utils.patches import apply_attention_patch
apply_attention_patch(use_memory_efficient_attention=True)
if SETTINGS.patch_type == "ntk":
from api.utils.patches import apply_ntk_scaling_patch
apply_ntk_scaling_patch(SETTINGS.alpha)
include = {
"model_name", "quantize", "device", "device_map", "num_gpus", "pre_seq_len",
"load_in_8bit", "load_in_4bit", "using_ptuning_v2", "dtype", "resize_embeddings"
}
kwargs = model_dump(SETTINGS, include=include)
model, tokenizer = load_model(
model_name_or_path=SETTINGS.model_path,
adapter_model=SETTINGS.adapter_model_path,
**kwargs,
)
logger.info("Using default engine")
return DefaultEngine(
model,
tokenizer,
SETTINGS.device,
model_name=SETTINGS.model_name,
context_len=SETTINGS.context_length if SETTINGS.context_length > 0 else None,
prompt_name=SETTINGS.chat_template,
use_streamer_v2=SETTINGS.use_streamer_v2,
)
def create_vllm_engine():
""" get vllm generate engine for chat or completion. """
try:
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.transformers_utils.tokenizer import get_tokenizer
from api.core.vllm_engine import VllmEngine
except ImportError:
return None
include = {
"tokenizer_mode", "trust_remote_code", "tensor_parallel_size",
"dtype", "gpu_memory_utilization", "max_num_seqs",
}
kwargs = model_dump(SETTINGS, include=include)
engine_args = AsyncEngineArgs(
model=SETTINGS.model_path,
max_num_batched_tokens=SETTINGS.max_num_batched_tokens if SETTINGS.max_num_batched_tokens > 0 else None,
max_model_len=SETTINGS.context_length if SETTINGS.context_length > 0 else None,
quantization=SETTINGS.quantization_method,
**kwargs,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(
engine_args.tokenizer,
tokenizer_mode=engine_args.tokenizer_mode,
trust_remote_code=True,
)
logger.info("Using vllm engine")
return VllmEngine(
engine,
tokenizer,
SETTINGS.model_name,
SETTINGS.chat_template,
SETTINGS.context_length,
)
def create_llama_cpp_engine():
""" get llama.cpp generate engine for chat or completion. """
try:
from llama_cpp import Llama
from api.core.llama_cpp_engine import LlamaCppEngine
except ImportError:
return None
include = {
"n_gpu_layers", "main_gpu", "tensor_split", "n_batch", "n_threads",
"n_threads_batch", "rope_scaling_type", "rope_freq_base", "rope_freq_scale"
}
kwargs = model_dump(SETTINGS, include=include)
engine = Llama(
model_path=SETTINGS.model_path,
n_ctx=SETTINGS.context_length if SETTINGS.context_length > 0 else 2048,
**kwargs,
)
logger.info("Using llama.cpp engine")
return LlamaCppEngine(engine, SETTINGS.model_name, SETTINGS.chat_template)
def create_tgi_engine():
""" get llama.cpp generate engine for chat or completion. """
try:
from text_generation import AsyncClient
from api.core.tgi import TGIEngine
except ImportError:
return None
client = AsyncClient(SETTINGS.tgi_endpoint)
logger.info("Using TGI engine")
return TGIEngine(client, SETTINGS.model_name, SETTINGS.chat_template)
# fastapi app
app = create_app()
# model for embedding
EMBEDDED_MODEL = create_embedding_model() if (SETTINGS.embedding_name and SETTINGS.activate_inference) else None
# model for transformers generate
if (not SETTINGS.only_embedding) and SETTINGS.activate_inference:
if SETTINGS.engine == "default":
GENERATE_ENGINE = create_generate_model()
elif SETTINGS.engine == "vllm":
GENERATE_ENGINE = create_vllm_engine()
elif SETTINGS.engine == "llama.cpp":
GENERATE_ENGINE = create_llama_cpp_engine()
elif SETTINGS.engine == "tgi":
GENERATE_ENGINE = create_tgi_engine()
else:
GENERATE_ENGINE = None
# model names for special processing
EXCLUDE_MODELS = ["baichuan-13b", "baichuan2-13b", "qwen", "chatglm3"]