gordonchan commited on
Commit
ca56e6a
1 Parent(s): 61100a9

Upload 41 files

Browse files
api/adapter/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from api.adapter.template import get_prompt_adapter
api/adapter/model.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import List, Optional, Any, Dict, Tuple
4
+
5
+ import torch
6
+ from loguru import logger
7
+ from peft import PeftModel
8
+ from tqdm import tqdm
9
+ from transformers import (
10
+ AutoModel,
11
+ AutoConfig,
12
+ AutoTokenizer,
13
+ AutoModelForCausalLM,
14
+ BitsAndBytesConfig,
15
+ PreTrainedTokenizer,
16
+ PreTrainedModel,
17
+ )
18
+ from transformers.utils.versions import require_version
19
+
20
+ if sys.version_info >= (3, 9):
21
+ from functools import cache
22
+ else:
23
+ from functools import lru_cache as cache
24
+
25
+
26
+ class BaseModelAdapter:
27
+ """ The base and default model adapter. """
28
+
29
+ model_names = []
30
+
31
+ def match(self, model_name) -> bool:
32
+ """
33
+ Check if the given model name matches any of the predefined model names.
34
+
35
+ Args:
36
+ model_name (str): The model name to check.
37
+
38
+ Returns:
39
+ bool: True if the model name matches any of the predefined model names, False otherwise.
40
+ """
41
+
42
+ return any(m in model_name for m in self.model_names) if self.model_names else True
43
+
44
+ def load_model(
45
+ self,
46
+ model_name_or_path: Optional[str] = None,
47
+ adapter_model: Optional[str] = None,
48
+ **kwargs: Any,
49
+ ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
50
+ """
51
+ Load a model and tokenizer based on the provided model name or path.
52
+
53
+ Args:
54
+ model_name_or_path (str, optional): The name or path of the model. Defaults to None.
55
+ adapter_model (str, optional): The adapter model to load the tokenizer from. Defaults to None.
56
+ **kwargs: Additional keyword arguments.
57
+
58
+ Returns:
59
+ Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
60
+ """
61
+
62
+ model_name_or_path = model_name_or_path or self.default_model_name_or_path
63
+ tokenizer_kwargs = {"trust_remote_code": True, "use_fast": False}
64
+ tokenizer_kwargs.update(self.tokenizer_kwargs)
65
+
66
+ # load a tokenizer from adapter model if it exists.
67
+ if adapter_model is not None:
68
+ try:
69
+ tokenizer = self.tokenizer_class.from_pretrained(
70
+ adapter_model, **tokenizer_kwargs,
71
+ )
72
+ except OSError:
73
+ tokenizer = self.tokenizer_class.from_pretrained(
74
+ model_name_or_path, **tokenizer_kwargs,
75
+ )
76
+ else:
77
+ tokenizer = self.tokenizer_class.from_pretrained(
78
+ model_name_or_path, **tokenizer_kwargs,
79
+ )
80
+
81
+ config_kwargs = self.model_kwargs
82
+ device = kwargs.get("device", "cuda")
83
+ num_gpus = kwargs.get("num_gpus", 1)
84
+ dtype = kwargs.get("dtype", "half")
85
+ if device == "cuda":
86
+ if "torch_dtype" not in config_kwargs:
87
+ if dtype == "half":
88
+ config_kwargs["torch_dtype"] = torch.float16
89
+ elif dtype == "bfloat16":
90
+ config_kwargs["torch_dtype"] = torch.bfloat16
91
+ elif dtype == "float32":
92
+ config_kwargs["torch_dtype"] = torch.float32
93
+
94
+ if num_gpus != 1:
95
+ config_kwargs["device_map"] = "auto"
96
+ # model_kwargs["device_map"] = "sequential" # This is important for not the same VRAM sizes
97
+
98
+ # Quantization configurations (using bitsandbytes library).
99
+ if kwargs.get("load_in_8bit", False):
100
+ require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
101
+
102
+ config_kwargs["load_in_8bit"] = True
103
+ config_kwargs["quantization_config"] = BitsAndBytesConfig(
104
+ load_in_8bit=True,
105
+ llm_int8_threshold=6.0,
106
+ )
107
+ config_kwargs["device_map"] = "auto" if device == "cuda" else None
108
+
109
+ logger.info("Quantizing model to 8 bit.")
110
+
111
+ elif kwargs.get("load_in_4bit", False):
112
+ require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
113
+ require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
114
+
115
+ config_kwargs["load_in_4bit"] = True
116
+ config_kwargs["quantization_config"] = BitsAndBytesConfig(
117
+ load_in_4bit=True,
118
+ bnb_4bit_compute_dtype=torch.float16,
119
+ bnb_4bit_use_double_quant=True,
120
+ bnb_4bit_quant_type="nf4",
121
+ )
122
+ config_kwargs["device_map"] = "auto" if device == "cuda" else None
123
+
124
+ logger.info("Quantizing model to 4 bit.")
125
+
126
+ if kwargs.get("device_map", None) == "auto":
127
+ config_kwargs["device_map"] = "auto"
128
+
129
+ config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
130
+
131
+ # Fix config (for Qwen)
132
+ if hasattr(config, "fp16") and hasattr(config, "bf16"):
133
+ setattr(config, "fp16", dtype == "half")
134
+ setattr(config, "bf16", dtype == "bfloat16")
135
+ config_kwargs.pop("torch_dtype", None)
136
+
137
+ if kwargs.get("using_ptuning_v2", False) and adapter_model:
138
+ config.pre_seq_len = kwargs.get("pre_seq_len", 128)
139
+
140
+ # Load and prepare pretrained models (without valuehead).
141
+ model = self.model_class.from_pretrained(
142
+ model_name_or_path,
143
+ config=config,
144
+ trust_remote_code=True,
145
+ **config_kwargs
146
+ )
147
+
148
+ if device == "cpu":
149
+ model = model.float()
150
+
151
+ # post process for special tokens
152
+ tokenizer = self.post_tokenizer(tokenizer)
153
+ is_chatglm = "chatglm" in str(type(model))
154
+
155
+ if adapter_model is not None:
156
+ model = self.load_adapter_model(model, tokenizer, adapter_model, is_chatglm, config_kwargs, **kwargs)
157
+
158
+ if is_chatglm or "baichuan" in str(type(model)) or "xverse" in str(type(model)):
159
+ quantize = kwargs.get("quantize", None)
160
+ if quantize and quantize != 16:
161
+ logger.info(f"Quantizing model to {quantize} bit.")
162
+ model = model.quantize(quantize)
163
+
164
+ if device == "cuda" and num_gpus == 1 and "device_map" not in config_kwargs:
165
+ model.to(device)
166
+
167
+ # inference mode
168
+ model.eval()
169
+
170
+ return model, tokenizer
171
+
172
+ def load_lora_model(
173
+ self, model: PreTrainedModel, adapter_model: str, model_kwargs: Dict,
174
+ ) -> PeftModel:
175
+ """
176
+ Load a LoRA model.
177
+
178
+ This function loads a LoRA model using the specified pretrained model and adapter model.
179
+
180
+ Args:
181
+ model (PreTrainedModel): The base pretrained model.
182
+ adapter_model (str): The name or path of the adapter model.
183
+ model_kwargs (dict): Additional keyword arguments for the model.
184
+
185
+ Returns:
186
+ PeftModel: The loaded LoRA model.
187
+ """
188
+ return PeftModel.from_pretrained(
189
+ model,
190
+ adapter_model,
191
+ torch_dtype=model_kwargs.get("torch_dtype", torch.float16),
192
+ )
193
+
194
+ def load_adapter_model(
195
+ self,
196
+ model: PreTrainedModel,
197
+ tokenizer: PreTrainedTokenizer,
198
+ adapter_model: str,
199
+ is_chatglm: bool,
200
+ model_kwargs: Dict,
201
+ **kwargs: Any,
202
+ ) -> PreTrainedModel:
203
+ using_ptuning_v2 = kwargs.get("using_ptuning_v2", False)
204
+ resize_embeddings = kwargs.get("resize_embeddings", False)
205
+ if adapter_model and resize_embeddings and not is_chatglm:
206
+ model_vocab_size = model.get_input_embeddings().weight.size(0)
207
+ tokenzier_vocab_size = len(tokenizer)
208
+ logger.info(f"Vocab of the base model: {model_vocab_size}")
209
+ logger.info(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
210
+
211
+ if model_vocab_size != tokenzier_vocab_size:
212
+ assert tokenzier_vocab_size > model_vocab_size
213
+ logger.info("Resize model embeddings to fit tokenizer")
214
+ model.resize_token_embeddings(tokenzier_vocab_size)
215
+
216
+ if using_ptuning_v2:
217
+ prefix_state_dict = torch.load(os.path.join(adapter_model, "pytorch_model.bin"))
218
+ new_prefix_state_dict = {
219
+ k[len("transformer.prefix_encoder."):]: v
220
+ for k, v in prefix_state_dict.items()
221
+ if k.startswith("transformer.prefix_encoder.")
222
+ }
223
+ model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
224
+ model.transformer.prefix_encoder.float()
225
+ else:
226
+ model = self.load_lora_model(model, adapter_model, model_kwargs)
227
+
228
+ return model
229
+
230
+ def post_tokenizer(self, tokenizer) -> PreTrainedTokenizer:
231
+ return tokenizer
232
+
233
+ @property
234
+ def model_class(self):
235
+ return AutoModelForCausalLM
236
+
237
+ @property
238
+ def model_kwargs(self):
239
+ return {}
240
+
241
+ @property
242
+ def tokenizer_class(self):
243
+ return AutoTokenizer
244
+
245
+ @property
246
+ def tokenizer_kwargs(self):
247
+ return {}
248
+
249
+ @property
250
+ def default_model_name_or_path(self):
251
+ return "zpn/llama-7b"
252
+
253
+
254
+ # A global registry for all model adapters
255
+ model_adapters: List[BaseModelAdapter] = []
256
+
257
+
258
+ def register_model_adapter(cls):
259
+ """ Register a model adapter. """
260
+ model_adapters.append(cls())
261
+
262
+
263
+ @cache
264
+ def get_model_adapter(model_name: str) -> BaseModelAdapter:
265
+ """
266
+ Get a model adapter for a given model name.
267
+
268
+ Args:
269
+ model_name (str): The name of the model.
270
+
271
+ Returns:
272
+ ModelAdapter: The model adapter that matches the given model name.
273
+ """
274
+ for adapter in model_adapters:
275
+ if adapter.match(model_name):
276
+ return adapter
277
+ raise ValueError(f"No valid model adapter for {model_name}")
278
+
279
+
280
+ def load_model(
281
+ model_name: str,
282
+ model_name_or_path: Optional[str] = None,
283
+ adapter_model: Optional[str] = None,
284
+ quantize: Optional[int] = 16,
285
+ device: Optional[str] = "cuda",
286
+ load_in_8bit: Optional[bool] = False,
287
+ **kwargs: Any,
288
+ ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
289
+ """
290
+ Load a pre-trained model and tokenizer.
291
+
292
+ Args:
293
+ model_name (str): The name of the model.
294
+ model_name_or_path (Optional[str], optional): The path or name of the pre-trained model. Defaults to None.
295
+ adapter_model (Optional[str], optional): The name of the adapter model. Defaults to None.
296
+ quantize (Optional[int], optional): The quantization level. Defaults to 16.
297
+ device (Optional[str], optional): The device to load the model on. Defaults to "cuda".
298
+ load_in_8bit (Optional[bool], optional): Whether to load the model in 8-bit mode. Defaults to False.
299
+ **kwargs (Any): Additional keyword arguments.
300
+
301
+ Returns:
302
+ Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
303
+ """
304
+ model_name = model_name.lower()
305
+
306
+ if "tiger" in model_name:
307
+ def skip(*args, **kwargs):
308
+ pass
309
+
310
+ torch.nn.init.kaiming_uniform_ = skip
311
+ torch.nn.init.uniform_ = skip
312
+ torch.nn.init.normal_ = skip
313
+
314
+ # get model adapter
315
+ adapter = get_model_adapter(model_name)
316
+ model, tokenizer = adapter.load_model(
317
+ model_name_or_path,
318
+ adapter_model,
319
+ device=device,
320
+ quantize=quantize,
321
+ load_in_8bit=load_in_8bit,
322
+ **kwargs
323
+ )
324
+ return model, tokenizer
325
+
326
+
327
+ class ChatglmModelAdapter(BaseModelAdapter):
328
+ """ https://github.com/THUDM/ChatGLM-6B """
329
+
330
+ model_names = ["chatglm"]
331
+
332
+ @property
333
+ def model_class(self):
334
+ return AutoModel
335
+
336
+ @property
337
+ def default_model_name_or_path(self):
338
+ return "THUDM/chatglm2-6b"
339
+
340
+
341
+ class Chatglm3ModelAdapter(ChatglmModelAdapter):
342
+ """ https://github.com/THUDM/ChatGLM-6B """
343
+
344
+ model_names = ["chatglm3"]
345
+
346
+ @property
347
+ def tokenizer_kwargs(self):
348
+ return {"encode_special_tokens": True}
349
+
350
+ @property
351
+ def default_model_name_or_path(self):
352
+ return "THUDM/chatglm3-6b"
353
+
354
+
355
+ class LlamaModelAdapter(BaseModelAdapter):
356
+ """ https://github.com/project-baize/baize-chatbot """
357
+
358
+ model_names = ["alpaca", "baize", "openbuddy-llama", "ziya-llama", "guanaco", "llama2"]
359
+
360
+ def post_tokenizer(self, tokenizer):
361
+ tokenizer.bos_token = "<s>"
362
+ tokenizer.eos_token = "</s>"
363
+ tokenizer.unk_token = "<unk>"
364
+ return tokenizer
365
+
366
+ @property
367
+ def model_kwargs(self):
368
+ return {"low_cpu_mem_usage": True}
369
+
370
+
371
+ class MossModelAdapter(BaseModelAdapter):
372
+ """ https://github.com/OpenLMLab/MOSS """
373
+
374
+ model_names = ["moss"]
375
+
376
+ @property
377
+ def default_model_name_or_path(self):
378
+ return "fnlp/moss-moon-003-sft-int4"
379
+
380
+
381
+ class PhoenixModelAdapter(BaseModelAdapter):
382
+ """ https://github.com/FreedomIntelligence/LLMZoo """
383
+
384
+ model_names = ["phoenix"]
385
+
386
+ @property
387
+ def model_kwargs(self):
388
+ return {"low_cpu_mem_usage": True}
389
+
390
+ @property
391
+ def tokenizer_kwargs(self):
392
+ return {"use_fast": True}
393
+
394
+ @property
395
+ def default_model_name_or_path(self):
396
+ return "FreedomIntelligence/phoenix-inst-chat-7b"
397
+
398
+
399
+ class FireflyModelAdapter(BaseModelAdapter):
400
+ """ https://github.com/yangjianxin1/Firefly """
401
+
402
+ model_names = ["firefly"]
403
+
404
+ @property
405
+ def model_kwargs(self):
406
+ return {"torch_dtype": torch.float32}
407
+
408
+ @property
409
+ def tokenizer_kwargs(self):
410
+ return {"use_fast": True}
411
+
412
+ @property
413
+ def default_model_name_or_path(self):
414
+ return "YeungNLP/firefly-2b6"
415
+
416
+
417
+ class YuLanChatModelAdapter(BaseModelAdapter):
418
+ """ https://github.com/RUC-GSAI/YuLan-Chat """
419
+
420
+ model_names = ["yulan"]
421
+
422
+ def post_tokenizer(self, tokenizer):
423
+ tokenizer.bos_token = "<s>"
424
+ tokenizer.eos_token = "</s>"
425
+ tokenizer.unk_token = "<unk>"
426
+ return tokenizer
427
+
428
+ @property
429
+ def model_kwargs(self):
430
+ return {"low_cpu_mem_usage": True}
431
+
432
+ def load_adapter_model(self, model, tokenizer, adapter_model, is_chatglm, model_kwargs, **kwargs):
433
+ adapter_model = AutoModelForCausalLM.from_pretrained(
434
+ adapter_model, torch_dtype=torch.float16, low_cpu_mem_usage=True
435
+ )
436
+ if model.model.embed_tokens.weight.size(0) + 1 == adapter_model.model.embed_tokens.weight.size(0):
437
+ model.resize_token_embeddings(len(tokenizer))
438
+ model.model.embed_tokens.weight.data[-1, :] = 0
439
+
440
+ logger.info("Applying the delta")
441
+ for name, param in tqdm(model.state_dict().items(), desc="Applying delta"):
442
+ assert name in model.state_dict()
443
+ param.data += model.state_dict()[name]
444
+
445
+ return model
446
+
447
+
448
+ class TigerBotModelAdapter(BaseModelAdapter):
449
+ """ https://github.com/TigerResearch/TigerBot """
450
+
451
+ model_names = ["tiger"]
452
+
453
+ @property
454
+ def tokenizer_kwargs(self):
455
+ return {"use_fast": True}
456
+
457
+ @property
458
+ def default_model_name_or_path(self):
459
+ return "TigerResearch/tigerbot-7b-sft"
460
+
461
+
462
+ class OpenBuddyFalconModelAdapter(BaseModelAdapter):
463
+ """ https://github.com/OpenBuddy/OpenBuddy """
464
+
465
+ model_names = ["openbuddy-falcon"]
466
+
467
+ @property
468
+ def default_model_name_or_path(self):
469
+ return "OpenBuddy/openbuddy-falcon-7b-v5-fp16"
470
+
471
+
472
+ class AnimaModelAdapter(LlamaModelAdapter):
473
+
474
+ model_names = ["anima"]
475
+
476
+ def load_lora_model(self, model, adapter_model, model_kwargs):
477
+ return PeftModel.from_pretrained(model, adapter_model)
478
+
479
+
480
+ class BaiChuanModelAdapter(BaseModelAdapter):
481
+ """ https://github.com/baichuan-inc/Baichuan-13B """
482
+
483
+ model_names = ["baichuan"]
484
+
485
+ def load_lora_model(self, model, adapter_model, model_kwargs):
486
+ return PeftModel.from_pretrained(model, adapter_model)
487
+
488
+ @property
489
+ def default_model_name_or_path(self):
490
+ return "baichuan-inc/Baichuan-13B-Chat"
491
+
492
+
493
+ class InternLMModelAdapter(BaseModelAdapter):
494
+ """ https://github.com/InternLM/InternLM """
495
+
496
+ model_names = ["internlm"]
497
+
498
+ @property
499
+ def default_model_name_or_path(self):
500
+ return "internlm/internlm-chat-7b"
501
+
502
+
503
+ class StarCodeModelAdapter(BaseModelAdapter):
504
+ """ https://github.com/bigcode-project/starcoder """
505
+
506
+ model_names = ["starcode", "starchat"]
507
+
508
+ @property
509
+ def tokenizer_kwargs(self):
510
+ return {}
511
+
512
+ @property
513
+ def default_model_name_or_path(self):
514
+ return "HuggingFaceH4/starchat-beta"
515
+
516
+
517
+ class AquilaModelAdapter(BaseModelAdapter):
518
+ """ https://github.com/FlagAI-Open/FlagAI """
519
+
520
+ model_names = ["aquila"]
521
+
522
+ @property
523
+ def default_model_name_or_path(self):
524
+ return "BAAI/AquilaChat-7B"
525
+
526
+
527
+ class QwenModelAdapter(BaseModelAdapter):
528
+ """ https://github.com/QwenLM/Qwen-7B """
529
+
530
+ model_names = ["qwen"]
531
+
532
+ @property
533
+ def default_model_name_or_path(self):
534
+ return "Qwen/Qwen-7B-Chat"
535
+
536
+
537
+ class XverseModelAdapter(BaseModelAdapter):
538
+ """ https://github.com/xverse-ai/XVERSE-13B """
539
+
540
+ model_names = ["xverse"]
541
+
542
+ @property
543
+ def default_model_name_or_path(self):
544
+ return "xverse/XVERSE-13B-Chat"
545
+
546
+
547
+ class CodeLlamaModelAdapter(LlamaModelAdapter):
548
+ """ https://github.com/project-baize/baize-chatbot """
549
+
550
+ model_names = ["code-llama"]
551
+
552
+ @property
553
+ def tokenizer_class(self):
554
+ require_version("transformers>=4.33.1", "To fix: pip install transformers>=4.33.1")
555
+ from transformers import CodeLlamaTokenizer
556
+
557
+ return CodeLlamaTokenizer
558
+
559
+ @property
560
+ def default_model_name_or_path(self):
561
+ return "codellama/CodeLlama-7b-Instruct-hf"
562
+
563
+
564
+ register_model_adapter(ChatglmModelAdapter)
565
+ register_model_adapter(Chatglm3ModelAdapter)
566
+ register_model_adapter(LlamaModelAdapter)
567
+ register_model_adapter(MossModelAdapter)
568
+ register_model_adapter(PhoenixModelAdapter)
569
+ register_model_adapter(FireflyModelAdapter)
570
+ register_model_adapter(YuLanChatModelAdapter)
571
+ register_model_adapter(TigerBotModelAdapter)
572
+ register_model_adapter(OpenBuddyFalconModelAdapter)
573
+ register_model_adapter(AnimaModelAdapter)
574
+ register_model_adapter(BaiChuanModelAdapter)
575
+ register_model_adapter(InternLMModelAdapter)
576
+ register_model_adapter(AquilaModelAdapter)
577
+ register_model_adapter(QwenModelAdapter)
578
+ register_model_adapter(XverseModelAdapter)
579
+ register_model_adapter(CodeLlamaModelAdapter)
580
+
581
+ # After all adapters, try the default base adapter.
582
+ register_model_adapter(BaseModelAdapter)
api/adapter/schema.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from openai.types.chat.completion_create_params import Function
4
+ from pydantic import BaseModel
5
+
6
+ from api.utils.compat import model_dump
7
+
8
+
9
+ def convert_data_type(param_type: str) -> str:
10
+ """ convert data_type to typescript data type """
11
+ return "number" if param_type in {"integer", "float"} else param_type
12
+
13
+
14
+ def get_param_type(param: Dict[str, Any]) -> str:
15
+ """ get param_type of parameter """
16
+ param_type = "any"
17
+ if "type" in param:
18
+ raw_param_type = param["type"]
19
+ param_type = (
20
+ " | ".join(raw_param_type)
21
+ if type(raw_param_type) is list
22
+ else raw_param_type
23
+ )
24
+ elif "oneOf" in param:
25
+ one_of_types = [
26
+ convert_data_type(item["type"])
27
+ for item in param["oneOf"]
28
+ if "type" in item
29
+ ]
30
+ one_of_types = list(set(one_of_types))
31
+ param_type = " | ".join(one_of_types)
32
+ return convert_data_type(param_type)
33
+
34
+
35
+ def get_format_param(param: Dict[str, Any]) -> Optional[str]:
36
+ """ Get "format" from param. There are cases where format is not directly in param but in oneOf """
37
+ if "format" in param:
38
+ return param["format"]
39
+ if "oneOf" in param:
40
+ formats = [item["format"] for item in param["oneOf"] if "format" in item]
41
+ if formats:
42
+ return " or ".join(formats)
43
+ return None
44
+
45
+
46
+ def get_param_info(param: Dict[str, Any]) -> Optional[str]:
47
+ """ get additional information about parameter such as: format, default value, min, max, ... """
48
+ param_type = param.get("type", "any")
49
+ info_list = []
50
+ if "description" in param:
51
+ desc = param["description"]
52
+ if not desc.endswith("."):
53
+ desc += "."
54
+ info_list.append(desc)
55
+
56
+ if "default" in param:
57
+ default_value = param["default"]
58
+ if param_type == "string":
59
+ default_value = f'"{default_value}"' # if string --> add ""
60
+ info_list.append(f"Default={default_value}.")
61
+
62
+ format_param = get_format_param(param)
63
+ if format_param is not None:
64
+ info_list.append(f"Format={format_param}")
65
+
66
+ info_list.extend(
67
+ f"{field_name}={str(param[field])}"
68
+ for field, field_name in [
69
+ ("maximum", "Maximum"),
70
+ ("minimum", "Minimum"),
71
+ ("maxLength", "Maximum length"),
72
+ ("minLength", "Minimum length"),
73
+ ]
74
+ if field in param
75
+ )
76
+ if info_list:
77
+ result = "// " + " ".join(info_list)
78
+ return result.replace("\n", " ")
79
+ return None
80
+
81
+
82
+ def append_new_param_info(info_list: List[str], param_declaration: str, comment_info: Optional[str], depth: int):
83
+ """ Append a new parameter with comment to the info_list """
84
+ offset = "".join([" " for _ in range(depth)]) if depth >= 1 else ""
85
+ if comment_info is not None:
86
+ # if depth == 0: # format: //comment\nparam: type
87
+ info_list.append(f"{offset}{comment_info}")
88
+ info_list.append(f"{offset}{param_declaration}")
89
+
90
+
91
+ def get_enum_option_str(enum_options: List) -> str:
92
+ """get enum option separated by: "|"
93
+
94
+ Args:
95
+ enum_options (List): list of options
96
+
97
+ Returns:
98
+ _type_: concatenation of options separated by "|"
99
+ """
100
+ # if each option is string --> add quote
101
+ return " | ".join([f'"{v}"' if type(v) is str else str(v) for v in enum_options])
102
+
103
+
104
+ def get_array_typescript(param_name: Optional[str], param_dic: dict, depth: int = 0) -> str:
105
+ """recursive implementation for generating type script of array
106
+
107
+ Args:
108
+ param_name (Optional[str]): name of param, optional
109
+ param_dic (dict): param_dic
110
+ depth (int, optional): nested level. Defaults to 0.
111
+
112
+ Returns:
113
+ _type_: typescript of array
114
+ """
115
+ offset = "".join([" " for _ in range(depth)]) if depth >= 1 else ""
116
+ items_info = param_dic.get("items", {})
117
+
118
+ if len(items_info) == 0:
119
+ return f"{offset}{param_name}: []" if param_name is not None else "[]"
120
+ array_type = get_param_type(items_info)
121
+ if array_type == "object":
122
+ info_lines = []
123
+ child_lines = get_parameter_typescript(
124
+ items_info.get("properties", {}), items_info.get("required", []), depth + 1
125
+ )
126
+ # if comment_info is not None:
127
+ # info_lines.append(f"{offset}{comment_info}")
128
+ if param_name is not None:
129
+ info_lines.append(f"{offset}{param_name}" + ": {")
130
+ else:
131
+ info_lines.append(f"{offset}" + "{")
132
+ info_lines.extend(child_lines)
133
+ info_lines.append(f"{offset}" + "}[]")
134
+ return "\n".join(info_lines)
135
+
136
+ elif array_type == "array":
137
+ item_info = get_array_typescript(None, items_info, depth + 1)
138
+ if param_name is None:
139
+ return f"{item_info}[]"
140
+ return f"{offset}{param_name}: {item_info.strip()}[]"
141
+
142
+ else:
143
+ if "enum" not in items_info:
144
+ return (
145
+ f"{array_type}[]"
146
+ if param_name is None
147
+ else f"{offset}{param_name}: {array_type}[],"
148
+ )
149
+ item_type = get_enum_option_str(items_info["enum"])
150
+ if param_name is None:
151
+ return f"({item_type})[]"
152
+ else:
153
+ return f"{offset}{param_name}: ({item_type})[]"
154
+
155
+
156
+ def get_parameter_typescript(properties, required_params, depth=0) -> List[str]:
157
+ """Recursion, returning the information about parameters including data type, description and other information
158
+ These kinds of information will be put into the prompt
159
+
160
+ Args:
161
+ properties (_type_): properties in parameters
162
+ required_params (_type_): List of required parameters
163
+ depth (int, optional): the depth of params (nested level). Defaults to 0.
164
+
165
+ Returns:
166
+ _type_: list of lines containing information about all parameters
167
+ """
168
+ tp_lines = []
169
+ for param_name, param in properties.items():
170
+ # Sometimes properties have "required" field as a list of string.
171
+ # Even though it is supposed to be not under properties. So we skip it
172
+ if not isinstance(param, dict):
173
+ continue
174
+ # Param Description
175
+ comment_info = get_param_info(param)
176
+ # Param Name declaration
177
+ param_declaration = f"{param_name}"
178
+ if isinstance(required_params, list) and param_name not in required_params:
179
+ param_declaration += "?"
180
+ param_type = get_param_type(param)
181
+
182
+ offset = ""
183
+ if depth >= 1:
184
+ offset = "".join([" " for _ in range(depth)])
185
+
186
+ if param_type == "object": # param_type is object
187
+ child_lines = get_parameter_typescript(param.get("properties", {}), param.get("required", []), depth + 1)
188
+ if comment_info is not None:
189
+ tp_lines.append(f"{offset}{comment_info}")
190
+
191
+ param_declaration += ": {"
192
+ tp_lines.append(f"{offset}{param_declaration}")
193
+ tp_lines.extend(child_lines)
194
+ tp_lines.append(f"{offset}" + "},")
195
+
196
+ elif param_type == "array": # param_type is an array
197
+ item_info = param.get("items", {})
198
+ if "type" not in item_info: # don't know type of array
199
+ param_declaration += ": [],"
200
+ append_new_param_info(tp_lines, param_declaration, comment_info, depth)
201
+ else:
202
+ array_declaration = get_array_typescript(param_declaration, param, depth)
203
+ if not array_declaration.endswith(","):
204
+ array_declaration += ","
205
+ if comment_info is not None:
206
+ tp_lines.append(f"{offset}{comment_info}")
207
+ tp_lines.append(array_declaration)
208
+ else:
209
+ if "enum" in param:
210
+ param_type = " | ".join([f'"{v}"' for v in param["enum"]])
211
+ param_declaration += f": {param_type},"
212
+ append_new_param_info(tp_lines, param_declaration, comment_info, depth)
213
+
214
+ return tp_lines
215
+
216
+
217
+ def generate_schema_from_functions(functions: List[Function], namespace="functions") -> str:
218
+ """
219
+ Convert functions schema to a schema that language models can understand.
220
+ """
221
+
222
+ schema = "// Supported function definitions that should be called when necessary.\n"
223
+ schema += f"namespace {namespace} {{\n\n"
224
+
225
+ for function in functions:
226
+ # Convert a Function object to dict, if necessary
227
+ if isinstance(function, BaseModel):
228
+ function = model_dump(function)
229
+ function_name = function.get("name", None)
230
+ if function_name is None:
231
+ continue
232
+
233
+ description = function.get("description", "")
234
+ schema += f"// {description}\n"
235
+ schema += f"type {function_name}"
236
+
237
+ parameters = function.get("parameters", None)
238
+ if parameters is not None and parameters.get("properties") is not None:
239
+ schema += " = (_: {\n"
240
+ required_params = parameters.get("required", [])
241
+ tp_lines = get_parameter_typescript(parameters.get("properties"), required_params, 0)
242
+ schema += "\n".join(tp_lines)
243
+ schema += "\n}) => any;\n\n"
244
+ else:
245
+ # Doesn't have any parameters
246
+ schema += " = () => any;\n\n"
247
+
248
+ schema += f"}} // namespace {namespace}"
249
+
250
+ return schema
251
+
252
+
253
+ def generate_schema_from_openapi(specification: Dict[str, Any], description: str, namespace: str) -> str:
254
+ """
255
+ Convert OpenAPI specification object to a schema that language models can understand.
256
+
257
+ Input:
258
+ specification: can be obtained by json. loads of any OpanAPI json spec, or yaml.safe_load for yaml OpenAPI specs
259
+
260
+ Example output:
261
+
262
+ // General Description
263
+ namespace functions {
264
+
265
+ // Simple GET endpoint
266
+ type getEndpoint = (_: {
267
+ // This is a string parameter
268
+ param_string: string,
269
+ param_integer: number,
270
+ param_boolean?: boolean,
271
+ param_enum: "value1" | "value2" | "value3",
272
+ }) => any;
273
+
274
+ } // namespace functions
275
+ """
276
+
277
+ description_clean = description.replace("\n", "")
278
+
279
+ schema = f"// {description_clean}\n"
280
+ schema += f"namespace {namespace} {{\n\n"
281
+
282
+ for path_name, paths in specification.get("paths", {}).items():
283
+ for method_name, method_info in paths.items():
284
+ operationId = method_info.get("operationId", None)
285
+ if operationId is None:
286
+ continue
287
+ description = method_info.get("description", method_info.get("summary", ""))
288
+ schema += f"// {description}\n"
289
+ schema += f"type {operationId}"
290
+
291
+ if ("requestBody" in method_info) or (method_info.get("parameters") is not None):
292
+ schema += f" = (_: {{\n"
293
+ # Body
294
+ if "requestBody" in method_info:
295
+ try:
296
+ body_schema = (
297
+ method_info.get("requestBody", {})
298
+ .get("content", {})
299
+ .get("application/json", {})
300
+ .get("schema", {})
301
+ )
302
+ except AttributeError:
303
+ body_schema = {}
304
+ for param_name, param in body_schema.get("properties", {}).items():
305
+ # Param Description
306
+ description = param.get("description")
307
+ if description is not None:
308
+ schema += f"// {description}\n"
309
+
310
+ # Param Name
311
+ schema += f"{param_name}"
312
+ if (
313
+ (not param.get("required", False))
314
+ or (param.get("nullable", False))
315
+ or (param_name in body_schema.get("required", []))
316
+ ):
317
+ schema += "?"
318
+
319
+ # Param Type
320
+ param_type = param.get("type", "any")
321
+ if param_type == "integer":
322
+ param_type = "number"
323
+ if "enum" in param:
324
+ param_type = " | ".join([f'"{v}"' for v in param["enum"]])
325
+ schema += f": {param_type},\n"
326
+
327
+ # URL
328
+ for param in method_info.get("parameters", []):
329
+ # Param Description
330
+ if description := param.get("description"):
331
+ schema += f"// {description}\n"
332
+
333
+ # Param Name
334
+ schema += f"{param['name']}"
335
+ if (not param.get("required", False)) or (param.get("nullable", False)):
336
+ schema += "?"
337
+ if param.get("schema") is None:
338
+ continue
339
+ # Param Type
340
+ param_type = param["schema"].get("type", "any")
341
+ if param_type == "integer":
342
+ param_type = "number"
343
+ if "enum" in param["schema"]:
344
+ param_type = " | ".join([f'"{v}"' for v in param["schema"]["enum"]])
345
+ schema += f": {param_type},\n"
346
+
347
+ schema += f"}}) => any;\n\n"
348
+ else:
349
+ # Doesn't have any parameters
350
+ schema += f" = () => any;\n\n"
351
+
352
+ schema += f"}} // namespace {namespace}"
353
+
354
+ return schema
355
+
356
+
357
+ if __name__ == "__main__":
358
+ functions = [
359
+ {
360
+ "name": "get_current_weather",
361
+ "description": "Get the current weather in a given location",
362
+ "parameters": {
363
+ "type": "object",
364
+ "properties": {
365
+ "location": {
366
+ "type": "string",
367
+ "description": "The city and state, e.g. San Francisco, CA",
368
+ },
369
+ "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
370
+ },
371
+ "required": ["location"],
372
+ },
373
+ }
374
+ ]
375
+ print(generate_schema_from_functions(functions))
api/adapter/template.py ADDED
@@ -0,0 +1,1304 @@