yusufs commited on
Commit
586265c
·
1 Parent(s): 2425953

feat(sailor-chat): add sail/Sailor-4B-Chat with the same context length

Browse files
Files changed (1) hide show
  1. main.py +41 -1
main.py CHANGED
@@ -29,6 +29,18 @@ engine_llama_3_2: LLM = LLM(
29
  )
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  @app.get("/")
33
  def greet_json():
34
  cuda_info: dict[str, Any] = {}
@@ -49,7 +61,13 @@ def greet_json():
49
  {
50
  "name": "meta-llama/Llama-3.2-3B-Instruct",
51
  "revision": "0cb88a4f764b7a12671c53f0838cd831a0843b95",
52
- }
 
 
 
 
 
 
53
  ]
54
  }
55
 
@@ -85,3 +103,25 @@ def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str,
85
  return {
86
  "error": str(e)
87
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
 
31
 
32
+ engine_sailor_chat: LLM = LLM(
33
+ model='sail/Sailor-4B-Chat',
34
+ revision="89a866a7041e6ec023dd462adeca8e28dd53c83e",
35
+ max_num_batched_tokens=512, # Reduced for T4
36
+ max_num_seqs=16, # Reduced for T4
37
+ gpu_memory_utilization=0.85, # Slightly increased, adjust if needed
38
+ max_model_len=32768,
39
+ enforce_eager=True, # Disable CUDA graph
40
+ dtype='auto', # Use 'half' if you want half precision
41
+ )
42
+
43
+
44
  @app.get("/")
45
  def greet_json():
46
  cuda_info: dict[str, Any] = {}
 
61
  {
62
  "name": "meta-llama/Llama-3.2-3B-Instruct",
63
  "revision": "0cb88a4f764b7a12671c53f0838cd831a0843b95",
64
+ "max_model_len": engine_llama_3_2.llm_engine.model_config.max_model_len,
65
+ },
66
+ {
67
+ "name": "sail/Sailor-4B-Chat",
68
+ "revision": "89a866a7041e6ec023dd462adeca8e28dd53c83e",
69
+ "max_model_len": engine_sailor_chat.llm_engine.model_config.max_model_len,
70
+ },
71
  ]
72
  }
73
 
 
103
  return {
104
  "error": str(e)
105
  }
106
+
107
+
108
+ @app.post("/generate-sailor-chat")
109
+ def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str, str]:
110
+ try:
111
+ sampling_params: SamplingParams = SamplingParams(
112
+ temperature=request.temperature,
113
+ max_tokens=request.max_tokens,
114
+ logit_bias=request.logit_bias,
115
+ )
116
+
117
+ # Generate text
118
+ return engine_sailor_chat.generate(
119
+ prompts=request.prompt,
120
+ sampling_params=sampling_params
121
+ )
122
+
123
+ except Exception as e:
124
+ return {
125
+ "error": str(e)
126
+ }
127
+