Spaces:
Running
Running
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
from loguru import logger | |
from api.config import SETTINGS | |
from api.utils.compat import model_dump | |
def create_app() -> FastAPI: | |
""" create fastapi app server """ | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
return app | |
def create_embedding_model(): | |
""" get embedding model from sentence-transformers. """ | |
if SETTINGS.tei_endpoint is not None: | |
from openai import AsyncOpenAI | |
client = AsyncOpenAI(base_url=SETTINGS.tei_endpoint, api_key="none") | |
else: | |
from sentence_transformers import SentenceTransformer | |
client = SentenceTransformer(SETTINGS.embedding_name, device=SETTINGS.embedding_device) | |
return client | |
def create_generate_model(): | |
""" get generate model for chat or completion. """ | |
from api.core.default import DefaultEngine | |
from api.adapter.model import load_model | |
if SETTINGS.patch_type == "attention": | |
from api.utils.patches import apply_attention_patch | |
apply_attention_patch(use_memory_efficient_attention=True) | |
if SETTINGS.patch_type == "ntk": | |
from api.utils.patches import apply_ntk_scaling_patch | |
apply_ntk_scaling_patch(SETTINGS.alpha) | |
include = { | |
"model_name", "quantize", "device", "device_map", "num_gpus", "pre_seq_len", | |
"load_in_8bit", "load_in_4bit", "using_ptuning_v2", "dtype", "resize_embeddings" | |
} | |
kwargs = model_dump(SETTINGS, include=include) | |
model, tokenizer = load_model( | |
model_name_or_path=SETTINGS.model_path, | |
adapter_model=SETTINGS.adapter_model_path, | |
**kwargs, | |
) | |
logger.info("Using default engine") | |
return DefaultEngine( | |
model, | |
tokenizer, | |
SETTINGS.device, | |
model_name=SETTINGS.model_name, | |
context_len=SETTINGS.context_length if SETTINGS.context_length > 0 else None, | |
prompt_name=SETTINGS.chat_template, | |
use_streamer_v2=SETTINGS.use_streamer_v2, | |
) | |
def create_vllm_engine(): | |
""" get vllm generate engine for chat or completion. """ | |
try: | |
from vllm.engine.arg_utils import AsyncEngineArgs | |
from vllm.engine.async_llm_engine import AsyncLLMEngine | |
from vllm.transformers_utils.tokenizer import get_tokenizer | |
from api.core.vllm_engine import VllmEngine | |
except ImportError: | |
return None | |
include = { | |
"tokenizer_mode", "trust_remote_code", "tensor_parallel_size", | |
"dtype", "gpu_memory_utilization", "max_num_seqs", | |
} | |
kwargs = model_dump(SETTINGS, include=include) | |
engine_args = AsyncEngineArgs( | |
model=SETTINGS.model_path, | |
max_num_batched_tokens=SETTINGS.max_num_batched_tokens if SETTINGS.max_num_batched_tokens > 0 else None, | |
max_model_len=SETTINGS.context_length if SETTINGS.context_length > 0 else None, | |
quantization=SETTINGS.quantization_method, | |
**kwargs, | |
) | |
engine = AsyncLLMEngine.from_engine_args(engine_args) | |
# A separate tokenizer to map token IDs to strings. | |
tokenizer = get_tokenizer( | |
engine_args.tokenizer, | |
tokenizer_mode=engine_args.tokenizer_mode, | |
trust_remote_code=True, | |
) | |
logger.info("Using vllm engine") | |
return VllmEngine( | |
engine, | |
tokenizer, | |
SETTINGS.model_name, | |
SETTINGS.chat_template, | |
SETTINGS.context_length, | |
) | |
def create_llama_cpp_engine(): | |
""" get llama.cpp generate engine for chat or completion. """ | |
try: | |
from llama_cpp import Llama | |
from api.core.llama_cpp_engine import LlamaCppEngine | |
except ImportError: | |
return None | |
include = { | |
"n_gpu_layers", "main_gpu", "tensor_split", "n_batch", "n_threads", | |
"n_threads_batch", "rope_scaling_type", "rope_freq_base", "rope_freq_scale" | |
} | |
kwargs = model_dump(SETTINGS, include=include) | |
engine = Llama( | |
model_path=SETTINGS.model_path, | |
n_ctx=SETTINGS.context_length if SETTINGS.context_length > 0 else 2048, | |
**kwargs, | |
) | |
logger.info("Using llama.cpp engine") | |
return LlamaCppEngine(engine, SETTINGS.model_name, SETTINGS.chat_template) | |
def create_tgi_engine(): | |
""" get llama.cpp generate engine for chat or completion. """ | |
try: | |
from text_generation import AsyncClient | |
from api.core.tgi import TGIEngine | |
except ImportError: | |
return None | |
client = AsyncClient(SETTINGS.tgi_endpoint) | |
logger.info("Using TGI engine") | |
return TGIEngine(client, SETTINGS.model_name, SETTINGS.chat_template) | |
# fastapi app | |
app = create_app() | |
# model for embedding | |
EMBEDDED_MODEL = create_embedding_model() if (SETTINGS.embedding_name and SETTINGS.activate_inference) else None | |
# model for transformers generate | |
if (not SETTINGS.only_embedding) and SETTINGS.activate_inference: | |
if SETTINGS.engine == "default": | |
GENERATE_ENGINE = create_generate_model() | |
elif SETTINGS.engine == "vllm": | |
GENERATE_ENGINE = create_vllm_engine() | |
elif SETTINGS.engine == "llama.cpp": | |
GENERATE_ENGINE = create_llama_cpp_engine() | |
elif SETTINGS.engine == "tgi": | |
GENERATE_ENGINE = create_tgi_engine() | |
else: | |
GENERATE_ENGINE = None | |
# model names for special processing | |
EXCLUDE_MODELS = ["baichuan-13b", "baichuan2-13b", "qwen", "chatglm3"] | |