Spaces:
Running
Running
gordonchan
commited on
Commit
•
ca56e6a
1
Parent(s):
61100a9
Upload 41 files
Browse files- api/adapter/__init__.py +1 -0
- api/adapter/model.py +582 -0
- api/adapter/schema.py +375 -0
- api/adapter/template.py +1304 -0
- api/config.py +270 -0
- api/core/__init__.py +0 -0
- api/core/default.py +570 -0
- api/core/llama_cpp_engine.py +175 -0
- api/core/tgi.py +257 -0
- api/core/vllm_engine.py +170 -0
- api/generation/__init__.py +5 -0
- api/generation/baichuan.py +69 -0
- api/generation/chatglm.py +300 -0
- api/generation/qwen.py +302 -0
- api/generation/stream.py +355 -0
- api/generation/utils.py +134 -0
- api/generation/xverse.py +75 -0
- api/llama_cpp_routes/__init__.py +2 -0
- api/llama_cpp_routes/chat.py +75 -0
- api/llama_cpp_routes/completion.py +72 -0
- api/llama_cpp_routes/utils.py +21 -0
- api/models.py +172 -0
- api/routes/__init__.py +1 -0
- api/routes/chat.py +67 -0
- api/routes/completion.py +69 -0
- api/routes/embedding.py +114 -0
- api/routes/model.py +38 -0
- api/server.py +40 -0
- api/tgi_routes/__init__.py +2 -0
- api/tgi_routes/chat.py +169 -0
- api/tgi_routes/completion.py +136 -0
- api/utils/__init__.py +0 -0
- api/utils/apply_lora.py +44 -0
- api/utils/compat.py +36 -0
- api/utils/constants.py +32 -0
- api/utils/patches.py +223 -0
- api/utils/protocol.py +446 -0
- api/utils/request.py +166 -0
- api/vllm_routes/__init__.py +2 -0
- api/vllm_routes/chat.py +206 -0
- api/vllm_routes/completion.py +226 -0
api/adapter/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from api.adapter.template import get_prompt_adapter
|
api/adapter/model.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from typing import List, Optional, Any, Dict, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from loguru import logger
|
7 |
+
from peft import PeftModel
|
8 |
+
from tqdm import tqdm
|
9 |
+
from transformers import (
|
10 |
+
AutoModel,
|
11 |
+
AutoConfig,
|
12 |
+
AutoTokenizer,
|
13 |
+
AutoModelForCausalLM,
|
14 |
+
BitsAndBytesConfig,
|
15 |
+
PreTrainedTokenizer,
|
16 |
+
PreTrainedModel,
|
17 |
+
)
|
18 |
+
from transformers.utils.versions import require_version
|
19 |
+
|
20 |
+
if sys.version_info >= (3, 9):
|
21 |
+
from functools import cache
|
22 |
+
else:
|
23 |
+
from functools import lru_cache as cache
|
24 |
+
|
25 |
+
|
26 |
+
class BaseModelAdapter:
|
27 |
+
""" The base and default model adapter. """
|
28 |
+
|
29 |
+
model_names = []
|
30 |
+
|
31 |
+
def match(self, model_name) -> bool:
|
32 |
+
"""
|
33 |
+
Check if the given model name matches any of the predefined model names.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
model_name (str): The model name to check.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
bool: True if the model name matches any of the predefined model names, False otherwise.
|
40 |
+
"""
|
41 |
+
|
42 |
+
return any(m in model_name for m in self.model_names) if self.model_names else True
|
43 |
+
|
44 |
+
def load_model(
|
45 |
+
self,
|
46 |
+
model_name_or_path: Optional[str] = None,
|
47 |
+
adapter_model: Optional[str] = None,
|
48 |
+
**kwargs: Any,
|
49 |
+
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
50 |
+
"""
|
51 |
+
Load a model and tokenizer based on the provided model name or path.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
model_name_or_path (str, optional): The name or path of the model. Defaults to None.
|
55 |
+
adapter_model (str, optional): The adapter model to load the tokenizer from. Defaults to None.
|
56 |
+
**kwargs: Additional keyword arguments.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
|
60 |
+
"""
|
61 |
+
|
62 |
+
model_name_or_path = model_name_or_path or self.default_model_name_or_path
|
63 |
+
tokenizer_kwargs = {"trust_remote_code": True, "use_fast": False}
|
64 |
+
tokenizer_kwargs.update(self.tokenizer_kwargs)
|
65 |
+
|
66 |
+
# load a tokenizer from adapter model if it exists.
|
67 |
+
if adapter_model is not None:
|
68 |
+
try:
|
69 |
+
tokenizer = self.tokenizer_class.from_pretrained(
|
70 |
+
adapter_model, **tokenizer_kwargs,
|
71 |
+
)
|
72 |
+
except OSError:
|
73 |
+
tokenizer = self.tokenizer_class.from_pretrained(
|
74 |
+
model_name_or_path, **tokenizer_kwargs,
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
tokenizer = self.tokenizer_class.from_pretrained(
|
78 |
+
model_name_or_path, **tokenizer_kwargs,
|
79 |
+
)
|
80 |
+
|
81 |
+
config_kwargs = self.model_kwargs
|
82 |
+
device = kwargs.get("device", "cuda")
|
83 |
+
num_gpus = kwargs.get("num_gpus", 1)
|
84 |
+
dtype = kwargs.get("dtype", "half")
|
85 |
+
if device == "cuda":
|
86 |
+
if "torch_dtype" not in config_kwargs:
|
87 |
+
if dtype == "half":
|
88 |
+
config_kwargs["torch_dtype"] = torch.float16
|
89 |
+
elif dtype == "bfloat16":
|
90 |
+
config_kwargs["torch_dtype"] = torch.bfloat16
|
91 |
+
elif dtype == "float32":
|
92 |
+
config_kwargs["torch_dtype"] = torch.float32
|
93 |
+
|
94 |
+
if num_gpus != 1:
|
95 |
+
config_kwargs["device_map"] = "auto"
|
96 |
+
# model_kwargs["device_map"] = "sequential" # This is important for not the same VRAM sizes
|
97 |
+
|
98 |
+
# Quantization configurations (using bitsandbytes library).
|
99 |
+
if kwargs.get("load_in_8bit", False):
|
100 |
+
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
101 |
+
|
102 |
+
config_kwargs["load_in_8bit"] = True
|
103 |
+
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
104 |
+
load_in_8bit=True,
|
105 |
+
llm_int8_threshold=6.0,
|
106 |
+
)
|
107 |
+
config_kwargs["device_map"] = "auto" if device == "cuda" else None
|
108 |
+
|
109 |
+
logger.info("Quantizing model to 8 bit.")
|
110 |
+
|
111 |
+
elif kwargs.get("load_in_4bit", False):
|
112 |
+
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
113 |
+
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
114 |
+
|
115 |
+
config_kwargs["load_in_4bit"] = True
|
116 |
+
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
117 |
+
load_in_4bit=True,
|
118 |
+
bnb_4bit_compute_dtype=torch.float16,
|
119 |
+
bnb_4bit_use_double_quant=True,
|
120 |
+
bnb_4bit_quant_type="nf4",
|
121 |
+
)
|
122 |
+
config_kwargs["device_map"] = "auto" if device == "cuda" else None
|
123 |
+
|
124 |
+
logger.info("Quantizing model to 4 bit.")
|
125 |
+
|
126 |
+
if kwargs.get("device_map", None) == "auto":
|
127 |
+
config_kwargs["device_map"] = "auto"
|
128 |
+
|
129 |
+
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
130 |
+
|
131 |
+
# Fix config (for Qwen)
|
132 |
+
if hasattr(config, "fp16") and hasattr(config, "bf16"):
|
133 |
+
setattr(config, "fp16", dtype == "half")
|
134 |
+
setattr(config, "bf16", dtype == "bfloat16")
|
135 |
+
config_kwargs.pop("torch_dtype", None)
|
136 |
+
|
137 |
+
if kwargs.get("using_ptuning_v2", False) and adapter_model:
|
138 |
+
config.pre_seq_len = kwargs.get("pre_seq_len", 128)
|
139 |
+
|
140 |
+
# Load and prepare pretrained models (without valuehead).
|
141 |
+
model = self.model_class.from_pretrained(
|
142 |
+
model_name_or_path,
|
143 |
+
config=config,
|
144 |
+
trust_remote_code=True,
|
145 |
+
**config_kwargs
|
146 |
+
)
|
147 |
+
|
148 |
+
if device == "cpu":
|
149 |
+
model = model.float()
|
150 |
+
|
151 |
+
# post process for special tokens
|
152 |
+
tokenizer = self.post_tokenizer(tokenizer)
|
153 |
+
is_chatglm = "chatglm" in str(type(model))
|
154 |
+
|
155 |
+
if adapter_model is not None:
|
156 |
+
model = self.load_adapter_model(model, tokenizer, adapter_model, is_chatglm, config_kwargs, **kwargs)
|
157 |
+
|
158 |
+
if is_chatglm or "baichuan" in str(type(model)) or "xverse" in str(type(model)):
|
159 |
+
quantize = kwargs.get("quantize", None)
|
160 |
+
if quantize and quantize != 16:
|
161 |
+
logger.info(f"Quantizing model to {quantize} bit.")
|
162 |
+
model = model.quantize(quantize)
|
163 |
+
|
164 |
+
if device == "cuda" and num_gpus == 1 and "device_map" not in config_kwargs:
|
165 |
+
model.to(device)
|
166 |
+
|
167 |
+
# inference mode
|
168 |
+
model.eval()
|
169 |
+
|
170 |
+
return model, tokenizer
|
171 |
+
|
172 |
+
def load_lora_model(
|
173 |
+
self, model: PreTrainedModel, adapter_model: str, model_kwargs: Dict,
|
174 |
+
) -> PeftModel:
|
175 |
+
"""
|
176 |
+
Load a LoRA model.
|
177 |
+
|
178 |
+
This function loads a LoRA model using the specified pretrained model and adapter model.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
model (PreTrainedModel): The base pretrained model.
|
182 |
+
adapter_model (str): The name or path of the adapter model.
|
183 |
+
model_kwargs (dict): Additional keyword arguments for the model.
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
PeftModel: The loaded LoRA model.
|
187 |
+
"""
|
188 |
+
return PeftModel.from_pretrained(
|
189 |
+
model,
|
190 |
+
adapter_model,
|
191 |
+
torch_dtype=model_kwargs.get("torch_dtype", torch.float16),
|
192 |
+
)
|
193 |
+
|
194 |
+
def load_adapter_model(
|
195 |
+
self,
|
196 |
+
model: PreTrainedModel,
|
197 |
+
tokenizer: PreTrainedTokenizer,
|
198 |
+
adapter_model: str,
|
199 |
+
is_chatglm: bool,
|
200 |
+
model_kwargs: Dict,
|
201 |
+
**kwargs: Any,
|
202 |
+
) -> PreTrainedModel:
|
203 |
+
using_ptuning_v2 = kwargs.get("using_ptuning_v2", False)
|
204 |
+
resize_embeddings = kwargs.get("resize_embeddings", False)
|
205 |
+
if adapter_model and resize_embeddings and not is_chatglm:
|
206 |
+
model_vocab_size = model.get_input_embeddings().weight.size(0)
|
207 |
+
tokenzier_vocab_size = len(tokenizer)
|
208 |
+
logger.info(f"Vocab of the base model: {model_vocab_size}")
|
209 |
+
logger.info(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
|
210 |
+
|
211 |
+
if model_vocab_size != tokenzier_vocab_size:
|
212 |
+
assert tokenzier_vocab_size > model_vocab_size
|
213 |
+
logger.info("Resize model embeddings to fit tokenizer")
|
214 |
+
model.resize_token_embeddings(tokenzier_vocab_size)
|
215 |
+
|
216 |
+
if using_ptuning_v2:
|
217 |
+
prefix_state_dict = torch.load(os.path.join(adapter_model, "pytorch_model.bin"))
|
218 |
+
new_prefix_state_dict = {
|
219 |
+
k[len("transformer.prefix_encoder."):]: v
|
220 |
+
for k, v in prefix_state_dict.items()
|
221 |
+
if k.startswith("transformer.prefix_encoder.")
|
222 |
+
}
|
223 |
+
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
224 |
+
model.transformer.prefix_encoder.float()
|
225 |
+
else:
|
226 |
+
model = self.load_lora_model(model, adapter_model, model_kwargs)
|
227 |
+
|
228 |
+
return model
|
229 |
+
|
230 |
+
def post_tokenizer(self, tokenizer) -> PreTrainedTokenizer:
|
231 |
+
return tokenizer
|
232 |
+
|
233 |
+
@property
|
234 |
+
def model_class(self):
|
235 |
+
return AutoModelForCausalLM
|
236 |
+
|
237 |
+
@property
|
238 |
+
def model_kwargs(self):
|
239 |
+
return {}
|
240 |
+
|
241 |
+
@property
|
242 |
+
def tokenizer_class(self):
|
243 |
+
return AutoTokenizer
|
244 |
+
|
245 |
+
@property
|
246 |
+
def tokenizer_kwargs(self):
|
247 |
+
return {}
|
248 |
+
|
249 |
+
@property
|
250 |
+
def default_model_name_or_path(self):
|
251 |
+
return "zpn/llama-7b"
|
252 |
+
|
253 |
+
|
254 |
+
# A global registry for all model adapters
|
255 |
+
model_adapters: List[BaseModelAdapter] = []
|
256 |
+
|
257 |
+
|
258 |
+
def register_model_adapter(cls):
|
259 |
+
""" Register a model adapter. """
|
260 |
+
model_adapters.append(cls())
|
261 |
+
|
262 |
+
|
263 |
+
@cache
|
264 |
+
def get_model_adapter(model_name: str) -> BaseModelAdapter:
|
265 |
+
"""
|
266 |
+
Get a model adapter for a given model name.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
model_name (str): The name of the model.
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
ModelAdapter: The model adapter that matches the given model name.
|
273 |
+
"""
|
274 |
+
for adapter in model_adapters:
|
275 |
+
if adapter.match(model_name):
|
276 |
+
return adapter
|
277 |
+
raise ValueError(f"No valid model adapter for {model_name}")
|
278 |
+
|
279 |
+
|
280 |
+
def load_model(
|
281 |
+
model_name: str,
|
282 |
+
model_name_or_path: Optional[str] = None,
|
283 |
+
adapter_model: Optional[str] = None,
|
284 |
+
quantize: Optional[int] = 16,
|
285 |
+
device: Optional[str] = "cuda",
|
286 |
+
load_in_8bit: Optional[bool] = False,
|
287 |
+
**kwargs: Any,
|
288 |
+
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
289 |
+
"""
|
290 |
+
Load a pre-trained model and tokenizer.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
model_name (str): The name of the model.
|
294 |
+
model_name_or_path (Optional[str], optional): The path or name of the pre-trained model. Defaults to None.
|
295 |
+
adapter_model (Optional[str], optional): The name of the adapter model. Defaults to None.
|
296 |
+
quantize (Optional[int], optional): The quantization level. Defaults to 16.
|
297 |
+
device (Optional[str], optional): The device to load the model on. Defaults to "cuda".
|
298 |
+
load_in_8bit (Optional[bool], optional): Whether to load the model in 8-bit mode. Defaults to False.
|
299 |
+
**kwargs (Any): Additional keyword arguments.
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
|
303 |
+
"""
|
304 |
+
model_name = model_name.lower()
|
305 |
+
|
306 |
+
if "tiger" in model_name:
|
307 |
+
def skip(*args, **kwargs):
|
308 |
+
pass
|
309 |
+
|
310 |
+
torch.nn.init.kaiming_uniform_ = skip
|
311 |
+
torch.nn.init.uniform_ = skip
|
312 |
+
torch.nn.init.normal_ = skip
|
313 |
+
|
314 |
+
# get model adapter
|
315 |
+
adapter = get_model_adapter(model_name)
|
316 |
+
model, tokenizer = adapter.load_model(
|
317 |
+
model_name_or_path,
|
318 |
+
adapter_model,
|
319 |
+
device=device,
|
320 |
+
quantize=quantize,
|
321 |
+
load_in_8bit=load_in_8bit,
|
322 |
+
**kwargs
|
323 |
+
)
|
324 |
+
return model, tokenizer
|
325 |
+
|
326 |
+
|
327 |
+
class ChatglmModelAdapter(BaseModelAdapter):
|
328 |
+
""" https://github.com/THUDM/ChatGLM-6B """
|
329 |
+
|
330 |
+
model_names = ["chatglm"]
|
331 |
+
|
332 |
+
@property
|
333 |
+
def model_class(self):
|
334 |
+
return AutoModel
|
335 |
+
|
336 |
+
@property
|
337 |
+
def default_model_name_or_path(self):
|
338 |
+
return "THUDM/chatglm2-6b"
|
339 |
+
|
340 |
+
|
341 |
+
class Chatglm3ModelAdapter(ChatglmModelAdapter):
|
342 |
+
""" https://github.com/THUDM/ChatGLM-6B """
|
343 |
+
|
344 |
+
model_names = ["chatglm3"]
|
345 |
+
|
346 |
+
@property
|
347 |
+
def tokenizer_kwargs(self):
|
348 |
+
return {"encode_special_tokens": True}
|
349 |
+
|
350 |
+
@property
|
351 |
+
def default_model_name_or_path(self):
|
352 |
+
return "THUDM/chatglm3-6b"
|
353 |
+
|
354 |
+
|
355 |
+
class LlamaModelAdapter(BaseModelAdapter):
|
356 |
+
""" https://github.com/project-baize/baize-chatbot """
|
357 |
+
|
358 |
+
model_names = ["alpaca", "baize", "openbuddy-llama", "ziya-llama", "guanaco", "llama2"]
|
359 |
+
|
360 |
+
def post_tokenizer(self, tokenizer):
|
361 |
+
tokenizer.bos_token = "<s>"
|
362 |
+
tokenizer.eos_token = "</s>"
|
363 |
+
tokenizer.unk_token = "<unk>"
|
364 |
+
return tokenizer
|
365 |
+
|
366 |
+
@property
|
367 |
+
def model_kwargs(self):
|
368 |
+
return {"low_cpu_mem_usage": True}
|
369 |
+
|
370 |
+
|
371 |
+
class MossModelAdapter(BaseModelAdapter):
|
372 |
+
""" https://github.com/OpenLMLab/MOSS """
|
373 |
+
|
374 |
+
model_names = ["moss"]
|
375 |
+
|
376 |
+
@property
|
377 |
+
def default_model_name_or_path(self):
|
378 |
+
return "fnlp/moss-moon-003-sft-int4"
|
379 |
+
|
380 |
+
|
381 |
+
class PhoenixModelAdapter(BaseModelAdapter):
|
382 |
+
""" https://github.com/FreedomIntelligence/LLMZoo """
|
383 |
+
|
384 |
+
model_names = ["phoenix"]
|
385 |
+
|
386 |
+
@property
|
387 |
+
def model_kwargs(self):
|
388 |
+
return {"low_cpu_mem_usage": True}
|
389 |
+
|
390 |
+
@property
|
391 |
+
def tokenizer_kwargs(self):
|
392 |
+
return {"use_fast": True}
|
393 |
+
|
394 |
+
@property
|
395 |
+
def default_model_name_or_path(self):
|
396 |
+
return "FreedomIntelligence/phoenix-inst-chat-7b"
|
397 |
+
|
398 |
+
|
399 |
+
class FireflyModelAdapter(BaseModelAdapter):
|
400 |
+
""" https://github.com/yangjianxin1/Firefly """
|
401 |
+
|
402 |
+
model_names = ["firefly"]
|
403 |
+
|
404 |
+
@property
|
405 |
+
def model_kwargs(self):
|
406 |
+
return {"torch_dtype": torch.float32}
|
407 |
+
|
408 |
+
@property
|
409 |
+
def tokenizer_kwargs(self):
|
410 |
+
return {"use_fast": True}
|
411 |
+
|
412 |
+
@property
|
413 |
+
def default_model_name_or_path(self):
|
414 |
+
return "YeungNLP/firefly-2b6"
|
415 |
+
|
416 |
+
|
417 |
+
class YuLanChatModelAdapter(BaseModelAdapter):
|
418 |
+
""" https://github.com/RUC-GSAI/YuLan-Chat """
|
419 |
+
|
420 |
+
model_names = ["yulan"]
|
421 |
+
|
422 |
+
def post_tokenizer(self, tokenizer):
|
423 |
+
tokenizer.bos_token = "<s>"
|
424 |
+
tokenizer.eos_token = "</s>"
|
425 |
+
tokenizer.unk_token = "<unk>"
|
426 |
+
return tokenizer
|
427 |
+
|
428 |
+
@property
|
429 |
+
def model_kwargs(self):
|
430 |
+
return {"low_cpu_mem_usage": True}
|
431 |
+
|
432 |
+
def load_adapter_model(self, model, tokenizer, adapter_model, is_chatglm, model_kwargs, **kwargs):
|
433 |
+
adapter_model = AutoModelForCausalLM.from_pretrained(
|
434 |
+
adapter_model, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
435 |
+
)
|
436 |
+
if model.model.embed_tokens.weight.size(0) + 1 == adapter_model.model.embed_tokens.weight.size(0):
|
437 |
+
model.resize_token_embeddings(len(tokenizer))
|
438 |
+
model.model.embed_tokens.weight.data[-1, :] = 0
|
439 |
+
|
440 |
+
logger.info("Applying the delta")
|
441 |
+
for name, param in tqdm(model.state_dict().items(), desc="Applying delta"):
|
442 |
+
assert name in model.state_dict()
|
443 |
+
param.data += model.state_dict()[name]
|
444 |
+
|
445 |
+
return model
|
446 |
+
|
447 |
+
|
448 |
+
class TigerBotModelAdapter(BaseModelAdapter):
|
449 |
+
""" https://github.com/TigerResearch/TigerBot """
|
450 |
+
|
451 |
+
model_names = ["tiger"]
|
452 |
+
|
453 |
+
@property
|
454 |
+
def tokenizer_kwargs(self):
|
455 |
+
return {"use_fast": True}
|
456 |
+
|
457 |
+
@property
|
458 |
+
def default_model_name_or_path(self):
|
459 |
+
return "TigerResearch/tigerbot-7b-sft"
|
460 |
+
|
461 |
+
|
462 |
+
class OpenBuddyFalconModelAdapter(BaseModelAdapter):
|
463 |
+
""" https://github.com/OpenBuddy/OpenBuddy """
|
464 |
+
|
465 |
+
model_names = ["openbuddy-falcon"]
|
466 |
+
|
467 |
+
@property
|
468 |
+
def default_model_name_or_path(self):
|
469 |
+
return "OpenBuddy/openbuddy-falcon-7b-v5-fp16"
|
470 |
+
|
471 |
+
|
472 |
+
class AnimaModelAdapter(LlamaModelAdapter):
|
473 |
+
|
474 |
+
model_names = ["anima"]
|
475 |
+
|
476 |
+
def load_lora_model(self, model, adapter_model, model_kwargs):
|
477 |
+
return PeftModel.from_pretrained(model, adapter_model)
|
478 |
+
|
479 |
+
|
480 |
+
class BaiChuanModelAdapter(BaseModelAdapter):
|
481 |
+
""" https://github.com/baichuan-inc/Baichuan-13B """
|
482 |
+
|
483 |
+
model_names = ["baichuan"]
|
484 |
+
|
485 |
+
def load_lora_model(self, model, adapter_model, model_kwargs):
|
486 |
+
return PeftModel.from_pretrained(model, adapter_model)
|
487 |
+
|
488 |
+
@property
|
489 |
+
def default_model_name_or_path(self):
|
490 |
+
return "baichuan-inc/Baichuan-13B-Chat"
|
491 |
+
|
492 |
+
|
493 |
+
class InternLMModelAdapter(BaseModelAdapter):
|
494 |
+
""" https://github.com/InternLM/InternLM """
|
495 |
+
|
496 |
+
model_names = ["internlm"]
|
497 |
+
|
498 |
+
@property
|
499 |
+
def default_model_name_or_path(self):
|
500 |
+
return "internlm/internlm-chat-7b"
|
501 |
+
|
502 |
+
|
503 |
+
class StarCodeModelAdapter(BaseModelAdapter):
|
504 |
+
""" https://github.com/bigcode-project/starcoder """
|
505 |
+
|
506 |
+
model_names = ["starcode", "starchat"]
|
507 |
+
|
508 |
+
@property
|
509 |
+
def tokenizer_kwargs(self):
|
510 |
+
return {}
|
511 |
+
|
512 |
+
@property
|
513 |
+
def default_model_name_or_path(self):
|
514 |
+
return "HuggingFaceH4/starchat-beta"
|
515 |
+
|
516 |
+
|
517 |
+
class AquilaModelAdapter(BaseModelAdapter):
|
518 |
+
""" https://github.com/FlagAI-Open/FlagAI """
|
519 |
+
|
520 |
+
model_names = ["aquila"]
|
521 |
+
|
522 |
+
@property
|
523 |
+
def default_model_name_or_path(self):
|
524 |
+
return "BAAI/AquilaChat-7B"
|
525 |
+
|
526 |
+
|
527 |
+
class QwenModelAdapter(BaseModelAdapter):
|
528 |
+
""" https://github.com/QwenLM/Qwen-7B """
|
529 |
+
|
530 |
+
model_names = ["qwen"]
|
531 |
+
|
532 |
+
@property
|
533 |
+
def default_model_name_or_path(self):
|
534 |
+
return "Qwen/Qwen-7B-Chat"
|
535 |
+
|
536 |
+
|
537 |
+
class XverseModelAdapter(BaseModelAdapter):
|
538 |
+
""" https://github.com/xverse-ai/XVERSE-13B """
|
539 |
+
|
540 |
+
model_names = ["xverse"]
|
541 |
+
|
542 |
+
@property
|
543 |
+
def default_model_name_or_path(self):
|
544 |
+
return "xverse/XVERSE-13B-Chat"
|
545 |
+
|
546 |
+
|
547 |
+
class CodeLlamaModelAdapter(LlamaModelAdapter):
|
548 |
+
""" https://github.com/project-baize/baize-chatbot """
|
549 |
+
|
550 |
+
model_names = ["code-llama"]
|
551 |
+
|
552 |
+
@property
|
553 |
+
def tokenizer_class(self):
|
554 |
+
require_version("transformers>=4.33.1", "To fix: pip install transformers>=4.33.1")
|
555 |
+
from transformers import CodeLlamaTokenizer
|
556 |
+
|
557 |
+
return CodeLlamaTokenizer
|
558 |
+
|
559 |
+
@property
|
560 |
+
def default_model_name_or_path(self):
|
561 |
+
return "codellama/CodeLlama-7b-Instruct-hf"
|
562 |
+
|
563 |
+
|
564 |
+
register_model_adapter(ChatglmModelAdapter)
|
565 |
+
register_model_adapter(Chatglm3ModelAdapter)
|
566 |
+
register_model_adapter(LlamaModelAdapter)
|
567 |
+
register_model_adapter(MossModelAdapter)
|
568 |
+
register_model_adapter(PhoenixModelAdapter)
|
569 |
+
register_model_adapter(FireflyModelAdapter)
|
570 |
+
register_model_adapter(YuLanChatModelAdapter)
|
571 |
+
register_model_adapter(TigerBotModelAdapter)
|
572 |
+
register_model_adapter(OpenBuddyFalconModelAdapter)
|
573 |
+
register_model_adapter(AnimaModelAdapter)
|
574 |
+
register_model_adapter(BaiChuanModelAdapter)
|
575 |
+
register_model_adapter(InternLMModelAdapter)
|
576 |
+
register_model_adapter(AquilaModelAdapter)
|
577 |
+
register_model_adapter(QwenModelAdapter)
|
578 |
+
register_model_adapter(XverseModelAdapter)
|
579 |
+
register_model_adapter(CodeLlamaModelAdapter)
|
580 |
+
|
581 |
+
# After all adapters, try the default base adapter.
|
582 |
+
register_model_adapter(BaseModelAdapter)
|
api/adapter/schema.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional
|
2 |
+
|
3 |
+
from openai.types.chat.completion_create_params import Function
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
from api.utils.compat import model_dump
|
7 |
+
|
8 |
+
|
9 |
+
def convert_data_type(param_type: str) -> str:
|
10 |
+
""" convert data_type to typescript data type """
|
11 |
+
return "number" if param_type in {"integer", "float"} else param_type
|
12 |
+
|
13 |
+
|
14 |
+
def get_param_type(param: Dict[str, Any]) -> str:
|
15 |
+
""" get param_type of parameter """
|
16 |
+
param_type = "any"
|
17 |
+
if "type" in param:
|
18 |
+
raw_param_type = param["type"]
|
19 |
+
param_type = (
|
20 |
+
" | ".join(raw_param_type)
|
21 |
+
if type(raw_param_type) is list
|
22 |
+
else raw_param_type
|
23 |
+
)
|
24 |
+
elif "oneOf" in param:
|
25 |
+
one_of_types = [
|
26 |
+
convert_data_type(item["type"])
|
27 |
+
for item in param["oneOf"]
|
28 |
+
if "type" in item
|
29 |
+
]
|
30 |
+
one_of_types = list(set(one_of_types))
|
31 |
+
param_type = " | ".join(one_of_types)
|
32 |
+
return convert_data_type(param_type)
|
33 |
+
|
34 |
+
|
35 |
+
def get_format_param(param: Dict[str, Any]) -> Optional[str]:
|
36 |
+
""" Get "format" from param. There are cases where format is not directly in param but in oneOf """
|
37 |
+
if "format" in param:
|
38 |
+
return param["format"]
|
39 |
+
if "oneOf" in param:
|
40 |
+
formats = [item["format"] for item in param["oneOf"] if "format" in item]
|
41 |
+
if formats:
|
42 |
+
return " or ".join(formats)
|
43 |
+
return None
|
44 |
+
|
45 |
+
|
46 |
+
def get_param_info(param: Dict[str, Any]) -> Optional[str]:
|
47 |
+
""" get additional information about parameter such as: format, default value, min, max, ... """
|
48 |
+
param_type = param.get("type", "any")
|
49 |
+
info_list = []
|
50 |
+
if "description" in param:
|
51 |
+
desc = param["description"]
|
52 |
+
if not desc.endswith("."):
|
53 |
+
desc += "."
|
54 |
+
info_list.append(desc)
|
55 |
+
|
56 |
+
if "default" in param:
|
57 |
+
default_value = param["default"]
|
58 |
+
if param_type == "string":
|
59 |
+
default_value = f'"{default_value}"' # if string --> add ""
|
60 |
+
info_list.append(f"Default={default_value}.")
|
61 |
+
|
62 |
+
format_param = get_format_param(param)
|
63 |
+
if format_param is not None:
|
64 |
+
info_list.append(f"Format={format_param}")
|
65 |
+
|
66 |
+
info_list.extend(
|
67 |
+
f"{field_name}={str(param[field])}"
|
68 |
+
for field, field_name in [
|
69 |
+
("maximum", "Maximum"),
|
70 |
+
("minimum", "Minimum"),
|
71 |
+
("maxLength", "Maximum length"),
|
72 |
+
("minLength", "Minimum length"),
|
73 |
+
]
|
74 |
+
if field in param
|
75 |
+
)
|
76 |
+
if info_list:
|
77 |
+
result = "// " + " ".join(info_list)
|
78 |
+
return result.replace("\n", " ")
|
79 |
+
return None
|
80 |
+
|
81 |
+
|
82 |
+
def append_new_param_info(info_list: List[str], param_declaration: str, comment_info: Optional[str], depth: int):
|
83 |
+
""" Append a new parameter with comment to the info_list """
|
84 |
+
offset = "".join([" " for _ in range(depth)]) if depth >= 1 else ""
|
85 |
+
if comment_info is not None:
|
86 |
+
# if depth == 0: # format: //comment\nparam: type
|
87 |
+
info_list.append(f"{offset}{comment_info}")
|
88 |
+
info_list.append(f"{offset}{param_declaration}")
|
89 |
+
|
90 |
+
|
91 |
+
def get_enum_option_str(enum_options: List) -> str:
|
92 |
+
"""get enum option separated by: "|"
|
93 |
+
|
94 |
+
Args:
|
95 |
+
enum_options (List): list of options
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
_type_: concatenation of options separated by "|"
|
99 |
+
"""
|
100 |
+
# if each option is string --> add quote
|
101 |
+
return " | ".join([f'"{v}"' if type(v) is str else str(v) for v in enum_options])
|
102 |
+
|
103 |
+
|
104 |
+
def get_array_typescript(param_name: Optional[str], param_dic: dict, depth: int = 0) -> str:
|
105 |
+
"""recursive implementation for generating type script of array
|
106 |
+
|
107 |
+
Args:
|
108 |
+
param_name (Optional[str]): name of param, optional
|
109 |
+
param_dic (dict): param_dic
|
110 |
+
depth (int, optional): nested level. Defaults to 0.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
_type_: typescript of array
|
114 |
+
"""
|
115 |
+
offset = "".join([" " for _ in range(depth)]) if depth >= 1 else ""
|
116 |
+
items_info = param_dic.get("items", {})
|
117 |
+
|
118 |
+
if len(items_info) == 0:
|
119 |
+
return f"{offset}{param_name}: []" if param_name is not None else "[]"
|
120 |
+
array_type = get_param_type(items_info)
|
121 |
+
if array_type == "object":
|
122 |
+
info_lines = []
|
123 |
+
child_lines = get_parameter_typescript(
|
124 |
+
items_info.get("properties", {}), items_info.get("required", []), depth + 1
|
125 |
+
)
|
126 |
+
# if comment_info is not None:
|
127 |
+
# info_lines.append(f"{offset}{comment_info}")
|
128 |
+
if param_name is not None:
|
129 |
+
info_lines.append(f"{offset}{param_name}" + ": {")
|
130 |
+
else:
|
131 |
+
info_lines.append(f"{offset}" + "{")
|
132 |
+
info_lines.extend(child_lines)
|
133 |
+
info_lines.append(f"{offset}" + "}[]")
|
134 |
+
return "\n".join(info_lines)
|
135 |
+
|
136 |
+
elif array_type == "array":
|
137 |
+
item_info = get_array_typescript(None, items_info, depth + 1)
|
138 |
+
if param_name is None:
|
139 |
+
return f"{item_info}[]"
|
140 |
+
return f"{offset}{param_name}: {item_info.strip()}[]"
|
141 |
+
|
142 |
+
else:
|
143 |
+
if "enum" not in items_info:
|
144 |
+
return (
|
145 |
+
f"{array_type}[]"
|
146 |
+
if param_name is None
|
147 |
+
else f"{offset}{param_name}: {array_type}[],"
|
148 |
+
)
|
149 |
+
item_type = get_enum_option_str(items_info["enum"])
|
150 |
+
if param_name is None:
|
151 |
+
return f"({item_type})[]"
|
152 |
+
else:
|
153 |
+
return f"{offset}{param_name}: ({item_type})[]"
|
154 |
+
|
155 |
+
|
156 |
+
def get_parameter_typescript(properties, required_params, depth=0) -> List[str]:
|
157 |
+
"""Recursion, returning the information about parameters including data type, description and other information
|
158 |
+
These kinds of information will be put into the prompt
|
159 |
+
|
160 |
+
Args:
|
161 |
+
properties (_type_): properties in parameters
|
162 |
+
required_params (_type_): List of required parameters
|
163 |
+
depth (int, optional): the depth of params (nested level). Defaults to 0.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
_type_: list of lines containing information about all parameters
|
167 |
+
"""
|
168 |
+
tp_lines = []
|
169 |
+
for param_name, param in properties.items():
|
170 |
+
# Sometimes properties have "required" field as a list of string.
|
171 |
+
# Even though it is supposed to be not under properties. So we skip it
|
172 |
+
if not isinstance(param, dict):
|
173 |
+
continue
|
174 |
+
# Param Description
|
175 |
+
comment_info = get_param_info(param)
|
176 |
+
# Param Name declaration
|
177 |
+
param_declaration = f"{param_name}"
|
178 |
+
if isinstance(required_params, list) and param_name not in required_params:
|
179 |
+
param_declaration += "?"
|
180 |
+
param_type = get_param_type(param)
|
181 |
+
|
182 |
+
offset = ""
|
183 |
+
if depth >= 1:
|
184 |
+
offset = "".join([" " for _ in range(depth)])
|
185 |
+
|
186 |
+
if param_type == "object": # param_type is object
|
187 |
+
child_lines = get_parameter_typescript(param.get("properties", {}), param.get("required", []), depth + 1)
|
188 |
+
if comment_info is not None:
|
189 |
+
tp_lines.append(f"{offset}{comment_info}")
|
190 |
+
|
191 |
+
param_declaration += ": {"
|
192 |
+
tp_lines.append(f"{offset}{param_declaration}")
|
193 |
+
tp_lines.extend(child_lines)
|
194 |
+
tp_lines.append(f"{offset}" + "},")
|
195 |
+
|
196 |
+
elif param_type == "array": # param_type is an array
|
197 |
+
item_info = param.get("items", {})
|
198 |
+
if "type" not in item_info: # don't know type of array
|
199 |
+
param_declaration += ": [],"
|
200 |
+
append_new_param_info(tp_lines, param_declaration, comment_info, depth)
|
201 |
+
else:
|
202 |
+
array_declaration = get_array_typescript(param_declaration, param, depth)
|
203 |
+
if not array_declaration.endswith(","):
|
204 |
+
array_declaration += ","
|
205 |
+
if comment_info is not None:
|
206 |
+
tp_lines.append(f"{offset}{comment_info}")
|
207 |
+
tp_lines.append(array_declaration)
|
208 |
+
else:
|
209 |
+
if "enum" in param:
|
210 |
+
param_type = " | ".join([f'"{v}"' for v in param["enum"]])
|
211 |
+
param_declaration += f": {param_type},"
|
212 |
+
append_new_param_info(tp_lines, param_declaration, comment_info, depth)
|
213 |
+
|
214 |
+
return tp_lines
|
215 |
+
|
216 |
+
|
217 |
+
def generate_schema_from_functions(functions: List[Function], namespace="functions") -> str:
|
218 |
+
"""
|
219 |
+
Convert functions schema to a schema that language models can understand.
|
220 |
+
"""
|
221 |
+
|
222 |
+
schema = "// Supported function definitions that should be called when necessary.\n"
|
223 |
+
schema += f"namespace {namespace} {{\n\n"
|
224 |
+
|
225 |
+
for function in functions:
|
226 |
+
# Convert a Function object to dict, if necessary
|
227 |
+
if isinstance(function, BaseModel):
|
228 |
+
function = model_dump(function)
|
229 |
+
function_name = function.get("name", None)
|
230 |
+
if function_name is None:
|
231 |
+
continue
|
232 |
+
|
233 |
+
description = function.get("description", "")
|
234 |
+
schema += f"// {description}\n"
|
235 |
+
schema += f"type {function_name}"
|
236 |
+
|
237 |
+
parameters = function.get("parameters", None)
|
238 |
+
if parameters is not None and parameters.get("properties") is not None:
|
239 |
+
schema += " = (_: {\n"
|
240 |
+
required_params = parameters.get("required", [])
|
241 |
+
tp_lines = get_parameter_typescript(parameters.get("properties"), required_params, 0)
|
242 |
+
schema += "\n".join(tp_lines)
|
243 |
+
schema += "\n}) => any;\n\n"
|
244 |
+
else:
|
245 |
+
# Doesn't have any parameters
|
246 |
+
schema += " = () => any;\n\n"
|
247 |
+
|
248 |
+
schema += f"}} // namespace {namespace}"
|
249 |
+
|
250 |
+
return schema
|
251 |
+
|
252 |
+
|
253 |
+
def generate_schema_from_openapi(specification: Dict[str, Any], description: str, namespace: str) -> str:
|
254 |
+
"""
|
255 |
+
Convert OpenAPI specification object to a schema that language models can understand.
|
256 |
+
|
257 |
+
Input:
|
258 |
+
specification: can be obtained by json. loads of any OpanAPI json spec, or yaml.safe_load for yaml OpenAPI specs
|
259 |
+
|
260 |
+
Example output:
|
261 |
+
|
262 |
+
// General Description
|
263 |
+
namespace functions {
|
264 |
+
|
265 |
+
// Simple GET endpoint
|
266 |
+
type getEndpoint = (_: {
|
267 |
+
// This is a string parameter
|
268 |
+
param_string: string,
|
269 |
+
param_integer: number,
|
270 |
+
param_boolean?: boolean,
|
271 |
+
param_enum: "value1" | "value2" | "value3",
|
272 |
+
}) => any;
|
273 |
+
|
274 |
+
} // namespace functions
|
275 |
+
"""
|
276 |
+
|
277 |
+
description_clean = description.replace("\n", "")
|
278 |
+
|
279 |
+
schema = f"// {description_clean}\n"
|
280 |
+
schema += f"namespace {namespace} {{\n\n"
|
281 |
+
|
282 |
+
for path_name, paths in specification.get("paths", {}).items():
|
283 |
+
for method_name, method_info in paths.items():
|
284 |
+
operationId = method_info.get("operationId", None)
|
285 |
+
if operationId is None:
|
286 |
+
continue
|
287 |
+
description = method_info.get("description", method_info.get("summary", ""))
|
288 |
+
schema += f"// {description}\n"
|
289 |
+
schema += f"type {operationId}"
|
290 |
+
|
291 |
+
if ("requestBody" in method_info) or (method_info.get("parameters") is not None):
|
292 |
+
schema += f" = (_: {{\n"
|
293 |
+
# Body
|
294 |
+
if "requestBody" in method_info:
|
295 |
+
try:
|
296 |
+
body_schema = (
|
297 |
+
method_info.get("requestBody", {})
|
298 |
+
.get("content", {})
|
299 |
+
.get("application/json", {})
|
300 |
+
.get("schema", {})
|
301 |
+
)
|
302 |
+
except AttributeError:
|
303 |
+
body_schema = {}
|
304 |
+
for param_name, param in body_schema.get("properties", {}).items():
|
305 |
+
# Param Description
|
306 |
+
description = param.get("description")
|
307 |
+
if description is not None:
|
308 |
+
schema += f"// {description}\n"
|
309 |
+
|
310 |
+
# Param Name
|
311 |
+
schema += f"{param_name}"
|
312 |
+
if (
|
313 |
+
(not param.get("required", False))
|
314 |
+
or (param.get("nullable", False))
|
315 |
+
or (param_name in body_schema.get("required", []))
|
316 |
+
):
|
317 |
+
schema += "?"
|
318 |
+
|
319 |
+
# Param Type
|
320 |
+
param_type = param.get("type", "any")
|
321 |
+
if param_type == "integer":
|
322 |
+
param_type = "number"
|
323 |
+
if "enum" in param:
|
324 |
+
param_type = " | ".join([f'"{v}"' for v in param["enum"]])
|
325 |
+
schema += f": {param_type},\n"
|
326 |
+
|
327 |
+
# URL
|
328 |
+
for param in method_info.get("parameters", []):
|
329 |
+
# Param Description
|
330 |
+
if description := param.get("description"):
|
331 |
+
schema += f"// {description}\n"
|
332 |
+
|
333 |
+
# Param Name
|
334 |
+
schema += f"{param['name']}"
|
335 |
+
if (not param.get("required", False)) or (param.get("nullable", False)):
|
336 |
+
schema += "?"
|
337 |
+
if param.get("schema") is None:
|
338 |
+
continue
|
339 |
+
# Param Type
|
340 |
+
param_type = param["schema"].get("type", "any")
|
341 |
+
if param_type == "integer":
|
342 |
+
param_type = "number"
|
343 |
+
if "enum" in param["schema"]:
|
344 |
+
param_type = " | ".join([f'"{v}"' for v in param["schema"]["enum"]])
|
345 |
+
schema += f": {param_type},\n"
|
346 |
+
|
347 |
+
schema += f"}}) => any;\n\n"
|
348 |
+
else:
|
349 |
+
# Doesn't have any parameters
|
350 |
+
schema += f" = () => any;\n\n"
|
351 |
+
|
352 |
+
schema += f"}} // namespace {namespace}"
|
353 |
+
|
354 |
+
return schema
|
355 |
+
|
356 |
+
|
357 |
+
if __name__ == "__main__":
|
358 |
+
functions = [
|
359 |
+
{
|
360 |
+
"name": "get_current_weather",
|
361 |
+
"description": "Get the current weather in a given location",
|
362 |
+
"parameters": {
|
363 |
+
"type": "object",
|
364 |
+
"properties": {
|
365 |
+
"location": {
|
366 |
+
"type": "string",
|
367 |
+
"description": "The city and state, e.g. San Francisco, CA",
|
368 |
+
},
|
369 |
+
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
370 |
+
},
|
371 |
+
"required": ["location"],
|
372 |
+
},
|
373 |
+
}
|
374 |
+
]
|
375 |
+
print(generate_schema_from_functions(functions))
|
api/adapter/template.py
ADDED
@@ -0,0 +1,1304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|