Spaces:
Paused
Paused
feat(sailor-chat): add sail/Sailor-4B-Chat with the same context length
Browse files
main.py
CHANGED
@@ -29,6 +29,18 @@ engine_llama_3_2: LLM = LLM(
|
|
29 |
)
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
@app.get("/")
|
33 |
def greet_json():
|
34 |
cuda_info: dict[str, Any] = {}
|
@@ -49,7 +61,13 @@ def greet_json():
|
|
49 |
{
|
50 |
"name": "meta-llama/Llama-3.2-3B-Instruct",
|
51 |
"revision": "0cb88a4f764b7a12671c53f0838cd831a0843b95",
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
]
|
54 |
}
|
55 |
|
@@ -85,3 +103,25 @@ def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str,
|
|
85 |
return {
|
86 |
"error": str(e)
|
87 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
)
|
30 |
|
31 |
|
32 |
+
engine_sailor_chat: LLM = LLM(
|
33 |
+
model='sail/Sailor-4B-Chat',
|
34 |
+
revision="89a866a7041e6ec023dd462adeca8e28dd53c83e",
|
35 |
+
max_num_batched_tokens=512, # Reduced for T4
|
36 |
+
max_num_seqs=16, # Reduced for T4
|
37 |
+
gpu_memory_utilization=0.85, # Slightly increased, adjust if needed
|
38 |
+
max_model_len=32768,
|
39 |
+
enforce_eager=True, # Disable CUDA graph
|
40 |
+
dtype='auto', # Use 'half' if you want half precision
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
@app.get("/")
|
45 |
def greet_json():
|
46 |
cuda_info: dict[str, Any] = {}
|
|
|
61 |
{
|
62 |
"name": "meta-llama/Llama-3.2-3B-Instruct",
|
63 |
"revision": "0cb88a4f764b7a12671c53f0838cd831a0843b95",
|
64 |
+
"max_model_len": engine_llama_3_2.llm_engine.model_config.max_model_len,
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"name": "sail/Sailor-4B-Chat",
|
68 |
+
"revision": "89a866a7041e6ec023dd462adeca8e28dd53c83e",
|
69 |
+
"max_model_len": engine_sailor_chat.llm_engine.model_config.max_model_len,
|
70 |
+
},
|
71 |
]
|
72 |
}
|
73 |
|
|
|
103 |
return {
|
104 |
"error": str(e)
|
105 |
}
|
106 |
+
|
107 |
+
|
108 |
+
@app.post("/generate-sailor-chat")
|
109 |
+
def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str, str]:
|
110 |
+
try:
|
111 |
+
sampling_params: SamplingParams = SamplingParams(
|
112 |
+
temperature=request.temperature,
|
113 |
+
max_tokens=request.max_tokens,
|
114 |
+
logit_bias=request.logit_bias,
|
115 |
+
)
|
116 |
+
|
117 |
+
# Generate text
|
118 |
+
return engine_sailor_chat.generate(
|
119 |
+
prompts=request.prompt,
|
120 |
+
sampling_params=sampling_params
|
121 |
+
)
|
122 |
+
|
123 |
+
except Exception as e:
|
124 |
+
return {
|
125 |
+
"error": str(e)
|
126 |
+
}
|
127 |
+
|