Ashok75 commited on
Commit
1728e7f
·
verified ·
1 Parent(s): f160ca6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +128 -540
  2. requirements.txt +6 -19
app.py CHANGED
@@ -1,333 +1,176 @@
1
  """
2
- Nanbeige4.1-3B Inference Server for Hugging Face Space
3
- Lightweight API server exposing /chat endpoint for remote agent communication
4
  """
5
 
6
  import os
7
  import json
8
- import torch
9
- from typing import AsyncGenerator, Dict, List, Optional
10
  from contextlib import asynccontextmanager
11
- from fastapi import FastAPI, Request, HTTPException
12
- from fastapi.responses import StreamingResponse, HTMLResponse
 
13
  from fastapi.middleware.cors import CORSMiddleware
14
- from pydantic import BaseModel, Field
15
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
16
  from threading import Thread
17
- import asyncio
 
 
 
 
18
 
19
- # Global model instances
20
  model = None
21
  tokenizer = None
22
 
23
- # Model configuration
24
- MODEL_ID = "Nanbeige/Nanbeige4.1-3B"
25
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
- DEFAULT_MAX_TOKENS = 2048
27
- DEFAULT_TEMPERATURE = 0.6
28
- DEFAULT_TOP_P = 0.95
29
 
30
-
31
- class ChatMessage(BaseModel):
32
- role: str = Field(..., description="Message role: system, user, assistant, or tool")
33
- content: str = Field(..., description="Message content")
34
- tool_calls: Optional[List[Dict]] = Field(None, description="Tool calls from assistant")
35
- tool_call_id: Optional[str] = Field(None, description="Tool call ID for tool responses")
36
 
37
 
38
  class ChatRequest(BaseModel):
39
- messages: List[ChatMessage] = Field(..., description="Conversation history")
40
- tools: Optional[List[Dict]] = Field(None, description="Available tools for function calling")
41
- stream: bool = Field(default=False, description="Enable streaming response")
42
- max_tokens: int = Field(default=DEFAULT_MAX_TOKENS, ge=1, le=8192)
43
- temperature: float = Field(default=DEFAULT_TEMPERATURE, ge=0.0, le=2.0)
44
- top_p: float = Field(default=DEFAULT_TOP_P, ge=0.0, le=1.0)
45
- stop: Optional[List[str]] = Field(default=None, description="Stop sequences")
46
-
47
 
48
- class ChatResponse(BaseModel):
49
- id: str
50
- object: str = "chat.completion"
51
- created: int
52
- model: str
53
- choices: List[Dict]
54
- usage: Optional[Dict] = None
55
 
56
-
57
- def load_model():
58
- """Load Nanbeige4.1-3B model and tokenizer."""
59
  global model, tokenizer
60
-
61
- print(f"Loading {MODEL_ID} on {DEVICE}...")
62
-
63
- tokenizer = AutoTokenizer.from_pretrained(
64
- MODEL_ID,
65
- trust_remote_code=True,
66
- padding_side="left"
67
- )
68
-
69
- # Set pad token if not present
70
- if tokenizer.pad_token is None:
71
- tokenizer.pad_token = tokenizer.eos_token
72
-
73
  model = AutoModelForCausalLM.from_pretrained(
74
- MODEL_ID,
75
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
76
- device_map="auto" if DEVICE == "cuda" else None,
77
  trust_remote_code=True,
78
- low_cpu_mem_usage=True
 
79
  )
80
-
81
- if DEVICE == "cpu":
82
- model = model.to(DEVICE)
83
-
84
- model.eval()
85
- print(f"Model loaded successfully on {DEVICE}")
86
-
87
-
88
- @asynccontextmanager
89
- async def lifespan(app: FastAPI):
90
- """Application lifespan manager."""
91
- # Startup
92
- load_model()
93
  yield
94
- # Shutdown - cleanup happens automatically
 
 
 
 
95
 
96
 
97
  app = FastAPI(
98
  title="Nanbeige4.1-3B Inference API",
99
- description="Remote LLM inference service for Enterprise ReAct Agent",
100
  version="1.0.0",
101
  lifespan=lifespan
102
  )
103
 
104
- # CORS for local agent communication
105
  app.add_middleware(
106
  CORSMiddleware,
107
- allow_origins=["*"], # Configure for production
108
  allow_credentials=True,
109
  allow_methods=["*"],
110
  allow_headers=["*"],
111
  )
112
 
113
 
114
- def format_messages_for_model(messages: List[ChatMessage], tools: Optional[List[Dict]] = None) -> str:
115
- """Format messages using Nanbeige chat template."""
116
- formatted_messages = []
117
-
118
  for msg in messages:
119
  if msg.role == "system":
120
- formatted_messages.append({"role": "system", "content": msg.content})
121
  elif msg.role == "user":
122
- formatted_messages.append({"role": "user", "content": msg.content})
123
  elif msg.role == "assistant":
124
- content = msg.content
125
- if msg.tool_calls:
126
- # Append tool calls to content
127
- tool_calls_str = json.dumps(msg.tool_calls)
128
- content = f"{content}\n<tool_calls>{tool_calls_str}</tool_calls>"
129
- formatted_messages.append({"role": "assistant", "content": content})
130
- elif msg.role == "tool":
131
- formatted_messages.append({
132
- "role": "tool",
133
- "content": msg.content,
134
- "tool_call_id": msg.tool_call_id
135
- })
136
-
137
- # Add tools to system message if provided
138
- if tools:
139
- tools_description = "\n\nAvailable tools:\n" + json.dumps(tools, indent=2)
140
- if formatted_messages and formatted_messages[0]["role"] == "system":
141
- formatted_messages[0]["content"] += tools_description
142
- else:
143
- formatted_messages.insert(0, {"role": "system", "content": tools_description})
144
-
145
- # Apply chat template
146
- prompt = tokenizer.apply_chat_template(
147
- formatted_messages,
148
- tokenize=False,
149
- add_generation_prompt=True
150
- )
151
-
152
- return prompt
153
 
154
 
155
- def parse_tool_calls(response_text: str) -> tuple[str, Optional[List[Dict]]]:
156
- """Parse tool calls from model response."""
157
- tool_calls = None
158
- content = response_text
159
-
160
- # Look for tool_calls in the response
161
- if "<tool_calls>" in response_text and "</tool_calls>" in response_text:
162
- try:
163
- start = response_text.find("<tool_calls>") + len("<tool_calls>")
164
- end = response_text.find("</tool_calls>")
165
- tool_calls_json = response_text[start:end]
166
- tool_calls = json.loads(tool_calls_json)
167
- content = response_text[:response_text.find("<tool_calls>")].strip()
168
- except (json.JSONDecodeError, ValueError):
169
- pass
170
-
171
- return content, tool_calls
172
-
173
-
174
- def generate_stream(
175
- prompt: str,
176
- max_tokens: int,
177
- temperature: float,
178
- top_p: float,
179
- stop: Optional[List[str]]
180
- ) -> AsyncGenerator[str, None]:
181
- """Generate streaming response."""
182
- inputs = tokenizer(prompt, return_tensors="pt", padding=True)
183
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
184
-
185
  streamer = TextIteratorStreamer(
186
  tokenizer,
187
  skip_prompt=True,
188
  skip_special_tokens=True
189
  )
190
-
191
- generation_kwargs = {
192
- "input_ids": inputs["input_ids"],
193
- "attention_mask": inputs["attention_mask"],
194
- "max_new_tokens": max_tokens,
195
- "temperature": temperature,
196
- "top_p": top_p,
197
- "do_sample": temperature > 0,
198
- "streamer": streamer,
199
- "pad_token_id": tokenizer.pad_token_id,
200
- "eos_token_id": tokenizer.eos_token_id,
201
- }
202
-
203
- if stop:
204
- generation_kwargs["stopping_criteria"] = create_stopping_criteria(stop)
205
-
206
  # Run generation in separate thread
207
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
208
  thread.start()
209
-
210
  generated_text = ""
211
  for new_text in streamer:
212
  generated_text += new_text
213
- # Check for stop sequences
214
- if stop:
215
- for s in stop:
216
- if s in generated_text:
217
- generated_text = generated_text[:generated_text.find(s)]
218
- break
219
-
220
- yield new_text
221
-
222
  thread.join()
223
 
224
 
225
- def create_stopping_criteria(stop_sequences: List[str]):
226
- """Create stopping criteria for generation."""
227
- from transformers import StoppingCriteria, StoppingCriteriaList
228
-
229
- class StopSequenceCriteria(StoppingCriteria):
230
- def __init__(self, stops, tokenizer):
231
- self.stops = stops
232
- self.tokenizer = tokenizer
233
-
234
- def __call__(self, input_ids, scores, **kwargs):
235
- generated = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
236
- for stop in self.stops:
237
- if stop in generated:
238
- return True
239
- return False
240
-
241
- return StoppingCriteriaList([StopSequenceCriteria(stop_sequences, tokenizer)])
242
-
243
-
244
- def generate_non_stream(
245
- prompt: str,
246
- max_tokens: int,
247
- temperature: float,
248
- top_p: float,
249
- stop: Optional[List[str]]
250
- ) -> str:
251
- """Generate non-streaming response."""
252
- inputs = tokenizer(prompt, return_tensors="pt", padding=True)
253
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
254
-
255
- with torch.no_grad():
256
- outputs = model.generate(
257
- input_ids=inputs["input_ids"],
258
- attention_mask=inputs["attention_mask"],
259
- max_new_tokens=max_tokens,
260
- temperature=temperature,
261
- top_p=top_p,
262
- do_sample=temperature > 0,
263
- pad_token_id=tokenizer.pad_token_id,
264
- eos_token_id=tokenizer.eos_token_id,
265
- )
266
-
267
- generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
268
 
269
- # Apply stop sequences
270
- if stop:
271
- for s in stop:
272
- if s in generated:
273
- generated = generated[:generated.find(s)]
274
- break
275
 
276
- return generated
 
 
 
 
 
 
277
 
278
 
279
- @app.post("/chat", response_model=ChatResponse)
280
- async def chat_completion(request: ChatRequest):
281
  """
282
- Main chat completion endpoint.
283
- Compatible with OpenAI-style API for easy integration.
284
  """
285
- import time
286
-
287
- prompt = format_messages_for_model(request.messages, request.tools)
288
-
 
 
289
  if request.stream:
290
- async def stream_response():
291
- generated = ""
292
- async for chunk in generate_stream(
293
- prompt,
294
- request.max_tokens,
295
- request.temperature,
296
- request.top_p,
297
- request.stop
298
- ):
299
- generated += chunk
300
- data = {
301
- "id": f"chatcmpl-{int(time.time())}",
302
- "object": "chat.completion.chunk",
303
- "created": int(time.time()),
304
- "model": MODEL_ID,
305
- "choices": [{
306
- "index": 0,
307
- "delta": {"content": chunk},
308
- "finish_reason": None
309
- }]
310
- }
311
- yield f"data: {json.dumps(data)}\n\n"
312
-
313
- # Final chunk
314
- content, tool_calls = parse_tool_calls(generated)
315
- final_data = {
316
- "id": f"chatcmpl-{int(time.time())}",
317
- "object": "chat.completion.chunk",
318
- "created": int(time.time()),
319
- "model": MODEL_ID,
320
- "choices": [{
321
- "index": 0,
322
- "delta": {},
323
- "finish_reason": "stop"
324
- }]
325
- }
326
- yield f"data: {json.dumps(final_data)}\n\n"
327
- yield "data: [DONE]\n\n"
328
-
329
  return StreamingResponse(
330
- stream_response(),
331
  media_type="text/event-stream",
332
  headers={
333
  "Cache-Control": "no-cache",
@@ -335,288 +178,33 @@ async def chat_completion(request: ChatRequest):
335
  "X-Accel-Buffering": "no"
336
  }
337
  )
338
-
339
  else:
340
- generated = generate_non_stream(
341
- prompt,
342
- request.max_tokens,
343
- request.temperature,
344
- request.top_p,
345
- request.stop
346
- )
347
-
348
- content, tool_calls = parse_tool_calls(generated)
349
-
350
- # Calculate token usage
351
- input_tokens = len(tokenizer.encode(prompt))
352
- output_tokens = len(tokenizer.encode(generated))
353
-
354
- response = ChatResponse(
355
- id=f"chatcmpl-{int(time.time())}",
356
- object="chat.completion",
357
- created=int(time.time()),
358
- model=MODEL_ID,
359
- choices=[{
360
- "index": 0,
361
- "message": {
362
- "role": "assistant",
363
- "content": content,
364
- "tool_calls": tool_calls
365
- },
366
- "finish_reason": "stop"
367
- }],
368
- usage={
369
- "prompt_tokens": input_tokens,
370
- "completion_tokens": output_tokens,
371
- "total_tokens": input_tokens + output_tokens
372
- }
373
  )
374
-
375
- return response
376
-
377
-
378
- @app.get("/chat", response_class=HTMLResponse)
379
- async def chat_interface():
380
- """Simple web interface for testing."""
381
- return """
382
- <!DOCTYPE html>
383
- <html lang="en">
384
- <head>
385
- <meta charset="UTF-8">
386
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
387
- <title>Nanbeige4.1-3B Chat</title>
388
- <style>
389
- * { margin: 0; padding: 0; box-sizing: border-box; }
390
- body {
391
- font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
392
- background: #1a1a2e;
393
- color: #eee;
394
- min-height: 100vh;
395
- display: flex;
396
- flex-direction: column;
397
- }
398
- header {
399
- background: #16213e;
400
- padding: 1rem 2rem;
401
- border-bottom: 1px solid #0f3460;
402
- }
403
- header h1 { font-size: 1.25rem; color: #e94560; }
404
- header p { font-size: 0.875rem; color: #888; margin-top: 0.25rem; }
405
- .chat-container {
406
- flex: 1;
407
- display: flex;
408
- flex-direction: column;
409
- max-width: 900px;
410
- width: 100%;
411
- margin: 0 auto;
412
- padding: 1rem;
413
- }
414
- .messages {
415
- flex: 1;
416
- overflow-y: auto;
417
- padding: 1rem;
418
- display: flex;
419
- flex-direction: column;
420
- gap: 1rem;
421
- }
422
- .message {
423
- max-width: 80%;
424
- padding: 1rem;
425
- border-radius: 12px;
426
- line-height: 1.6;
427
- }
428
- .message.user {
429
- align-self: flex-end;
430
- background: #e94560;
431
- color: white;
432
- }
433
- .message.assistant {
434
- align-self: flex-start;
435
- background: #16213e;
436
- border: 1px solid #0f3460;
437
- }
438
- .message.system {
439
- align-self: center;
440
- background: #0f3460;
441
- font-size: 0.875rem;
442
- color: #888;
443
- }
444
- .input-area {
445
- display: flex;
446
- gap: 0.5rem;
447
- padding: 1rem;
448
- background: #16213e;
449
- border-top: 1px solid #0f3460;
450
- }
451
- textarea {
452
- flex: 1;
453
- padding: 0.75rem 1rem;
454
- border: 1px solid #0f3460;
455
- border-radius: 8px;
456
- background: #1a1a2e;
457
- color: #eee;
458
- font-size: 1rem;
459
- resize: none;
460
- min-height: 50px;
461
- max-height: 150px;
462
- }
463
- textarea:focus {
464
- outline: none;
465
- border-color: #e94560;
466
- }
467
- button {
468
- padding: 0.75rem 1.5rem;
469
- background: #e94560;
470
- color: white;
471
- border: none;
472
- border-radius: 8px;
473
- cursor: pointer;
474
- font-size: 1rem;
475
- transition: background 0.2s;
476
- }
477
- button:hover { background: #d63d56; }
478
- button:disabled { background: #666; cursor: not-allowed; }
479
- .loading {
480
- display: inline-block;
481
- width: 20px;
482
- height: 20px;
483
- border: 2px solid #0f3460;
484
- border-top-color: #e94560;
485
- border-radius: 50%;
486
- animation: spin 1s linear infinite;
487
- }
488
- @keyframes spin { to { transform: rotate(360deg); } }
489
- .tool-calls {
490
- margin-top: 0.5rem;
491
- padding: 0.5rem;
492
- background: #0f3460;
493
- border-radius: 6px;
494
- font-size: 0.8rem;
495
- font-family: monospace;
496
- }
497
- </style>
498
- </head>
499
- <body>
500
- <header>
501
- <h1>Nanbeige4.1-3B Inference Server</h1>
502
- <p>Remote LLM service for Enterprise ReAct Agent</p>
503
- </header>
504
- <div class="chat-container">
505
- <div class="messages" id="messages"></div>
506
- <div class="input-area">
507
- <textarea id="input" placeholder="Type your message..." rows="1"></textarea>
508
- <button id="send" onclick="sendMessage()">Send</button>
509
- </div>
510
- </div>
511
-
512
- <script>
513
- const messages = document.getElementById('messages');
514
- const input = document.getElementById('input');
515
- const sendBtn = document.getElementById('send');
516
- let conversation = [];
517
-
518
- // Auto-resize textarea
519
- input.addEventListener('input', () => {
520
- input.style.height = 'auto';
521
- input.style.height = Math.min(input.scrollHeight, 150) + 'px';
522
- });
523
-
524
- // Enter to send, Shift+Enter for new line
525
- input.addEventListener('keydown', (e) => {
526
- if (e.key === 'Enter' && !e.shiftKey) {
527
- e.preventDefault();
528
- sendMessage();
529
- }
530
- });
531
-
532
- function addMessage(role, content, toolCalls = null) {
533
- const div = document.createElement('div');
534
- div.className = `message ${role}`;
535
- div.textContent = content;
536
- if (toolCalls) {
537
- const toolDiv = document.createElement('div');
538
- toolDiv.className = 'tool-calls';
539
- toolDiv.textContent = 'Tool calls: ' + JSON.stringify(toolCalls, null, 2);
540
- div.appendChild(toolDiv);
541
  }
542
- messages.appendChild(div);
543
- messages.scrollTop = messages.scrollHeight;
544
  }
545
 
546
- async function sendMessage() {
547
- const text = input.value.trim();
548
- if (!text) return;
549
-
550
- addMessage('user', text);
551
- conversation.push({ role: 'user', content: text });
552
- input.value = '';
553
- input.style.height = 'auto';
554
- sendBtn.disabled = true;
555
- sendBtn.innerHTML = '<span class="loading"></span>';
556
-
557
- try {
558
- const response = await fetch('/chat', {
559
- method: 'POST',
560
- headers: { 'Content-Type': 'application/json' },
561
- body: JSON.stringify({
562
- messages: conversation,
563
- stream: false,
564
- max_tokens: 2048,
565
- temperature: 0.6
566
- })
567
- });
568
-
569
- const data = await response.json();
570
- const assistantMsg = data.choices[0].message;
571
-
572
- addMessage('assistant', assistantMsg.content, assistantMsg.tool_calls);
573
- conversation.push({
574
- role: 'assistant',
575
- content: assistantMsg.content,
576
- tool_calls: assistantMsg.tool_calls
577
- });
578
- } catch (error) {
579
- addMessage('system', 'Error: ' + error.message);
580
- } finally {
581
- sendBtn.disabled = false;
582
- sendBtn.textContent = 'Send';
583
- }
584
- }
585
-
586
- // Initial system message
587
- addMessage('system', 'Welcome! The model is ready for inference.');
588
- </script>
589
- </body>
590
- </html>
591
- """
592
-
593
-
594
- @app.get("/health")
595
- async def health_check():
596
- """Health check endpoint."""
597
- return {
598
- "status": "healthy",
599
- "model": MODEL_ID,
600
- "device": DEVICE,
601
- "model_loaded": model is not None and tokenizer is not None
602
- }
603
-
604
-
605
- @app.get("/")
606
- async def root():
607
- """Root endpoint - redirect to chat interface."""
608
- return {
609
- "message": "Nanbeige4.1-3B Inference Server",
610
- "endpoints": {
611
- "chat": "/chat (POST for API, GET for web interface)",
612
- "health": "/health"
613
- },
614
- "model": MODEL_ID,
615
- "device": DEVICE
616
- }
617
-
618
 
619
  if __name__ == "__main__":
620
  import uvicorn
621
- port = int(os.environ.get("PORT", 7860))
622
- uvicorn.run(app, host="0.0.0.0", port=port)
 
1
  """
2
+ HuggingFace Space application for Nanbeige4.1-3B model inference.
3
+ Provides streaming chat completion API.
4
  """
5
 
6
  import os
7
  import json
8
+ import asyncio
9
+ from typing import AsyncGenerator, List, Dict, Any, Optional
10
  from contextlib import asynccontextmanager
11
+
12
+ from fastapi import FastAPI, HTTPException
13
+ from fastapi.responses import StreamingResponse
14
  from fastapi.middleware.cors import CORSMiddleware
15
+ from pydantic import BaseModel
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
17
  from threading import Thread
18
+ import torch
19
+
20
+ # Model configuration
21
+ MODEL_NAME = "Nanbeige/Nanbeige4.1-3B"
22
+ MAX_LENGTH = 8192
23
 
24
+ # Global model and tokenizer
25
  model = None
26
  tokenizer = None
27
 
 
 
 
 
 
 
28
 
29
+ class Message(BaseModel):
30
+ role: str
31
+ content: str
 
 
 
32
 
33
 
34
  class ChatRequest(BaseModel):
35
+ messages: List[Message]
36
+ stream: bool = True
37
+ max_tokens: int = 2048
38
+ temperature: float = 0.6
39
+ tools: Optional[List[Dict]] = None
 
 
 
40
 
 
 
 
 
 
 
 
41
 
42
+ @asynccontextmanager
43
+ async def lifespan(app: FastAPI):
44
+ """Application lifespan handler."""
45
  global model, tokenizer
46
+
47
+ print("Loading model...")
48
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
49
  model = AutoModelForCausalLM.from_pretrained(
50
+ MODEL_NAME,
 
 
51
  trust_remote_code=True,
52
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
53
+ device_map="auto" if torch.cuda.is_available() else None
54
  )
55
+
56
+ if not torch.cuda.is_available():
57
+ model = model.to("cpu")
58
+
59
+ print("Model loaded successfully!")
 
 
 
 
 
 
 
 
60
  yield
61
+
62
+ # Cleanup
63
+ del model
64
+ del tokenizer
65
+ torch.cuda.empty_cache()
66
 
67
 
68
  app = FastAPI(
69
  title="Nanbeige4.1-3B Inference API",
70
+ description="Streaming chat completion API for Nanbeige4.1-3B",
71
  version="1.0.0",
72
  lifespan=lifespan
73
  )
74
 
 
75
  app.add_middleware(
76
  CORSMiddleware,
77
+ allow_origins=["*"],
78
  allow_credentials=True,
79
  allow_methods=["*"],
80
  allow_headers=["*"],
81
  )
82
 
83
 
84
+ def format_messages(messages: List[Message]) -> str:
85
+ """Format messages into prompt string."""
86
+ formatted = []
 
87
  for msg in messages:
88
  if msg.role == "system":
89
+ formatted.append(f"System: {msg.content}")
90
  elif msg.role == "user":
91
+ formatted.append(f"User: {msg.content}")
92
  elif msg.role == "assistant":
93
+ formatted.append(f"Assistant: {msg.content}")
94
+
95
+ formatted.append("Assistant:")
96
+ return "\n\n".join(formatted)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
+ async def stream_tokens(prompt: str, max_tokens: int, temperature: float) -> AsyncGenerator[str, None]:
100
+ """Stream tokens from the model."""
101
+ global model, tokenizer
102
+
103
+ inputs = tokenizer(prompt, return_tensors="pt")
104
+ if torch.cuda.is_available():
105
+ inputs = inputs.to("cuda")
106
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  streamer = TextIteratorStreamer(
108
  tokenizer,
109
  skip_prompt=True,
110
  skip_special_tokens=True
111
  )
112
+
113
+ generation_kwargs = dict(
114
+ inputs,
115
+ streamer=streamer,
116
+ max_new_tokens=max_tokens,
117
+ temperature=temperature,
118
+ do_sample=temperature > 0,
119
+ pad_token_id=tokenizer.eos_token_id
120
+ )
121
+
 
 
 
 
 
 
122
  # Run generation in separate thread
123
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
124
  thread.start()
125
+
126
  generated_text = ""
127
  for new_text in streamer:
128
  generated_text += new_text
129
+ # Yield each token
130
+ data = json.dumps({"type": "token", "content": new_text})
131
+ yield f"data: {data}\n\n"
132
+
133
+ # Signal completion
134
+ yield f"data: {json.dumps({'type': 'done', 'content': ''})}\n\n"
135
+
 
 
136
  thread.join()
137
 
138
 
139
+ @app.get("/")
140
+ async def root():
141
+ """Root endpoint."""
142
+ return {
143
+ "name": "Nanbeige4.1-3B Inference API",
144
+ "version": "1.0.0",
145
+ "model": MODEL_NAME,
146
+ "status": "running"
147
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
 
 
 
 
 
 
149
 
150
+ @app.get("/health")
151
+ async def health():
152
+ """Health check endpoint."""
153
+ return {
154
+ "status": "healthy",
155
+ "model_loaded": model is not None and tokenizer is not None
156
+ }
157
 
158
 
159
+ @app.post("/chat")
160
+ async def chat(request: ChatRequest):
161
  """
162
+ Chat completion endpoint with streaming support.
 
163
  """
164
+ if model is None or tokenizer is None:
165
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
166
+
167
+ # Format messages into prompt
168
+ prompt = format_messages(request.messages)
169
+
170
  if request.stream:
171
+ # Return streaming response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  return StreamingResponse(
173
+ stream_tokens(prompt, request.max_tokens, request.temperature),
174
  media_type="text/event-stream",
175
  headers={
176
  "Cache-Control": "no-cache",
 
178
  "X-Accel-Buffering": "no"
179
  }
180
  )
 
181
  else:
182
+ # Non-streaming response
183
+ inputs = tokenizer(prompt, return_tensors="pt")
184
+ if torch.cuda.is_available():
185
+ inputs = inputs.to("cuda")
186
+
187
+ outputs = model.generate(
188
+ **inputs,
189
+ max_new_tokens=request.max_tokens,
190
+ temperature=request.temperature,
191
+ do_sample=request.temperature > 0,
192
+ pad_token_id=tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  )
194
+
195
+ response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
196
+ # Extract only the assistant's response
197
+ response_text = response_text[len(prompt):].strip()
198
+
199
+ return {
200
+ "content": response_text,
201
+ "usage": {
202
+ "prompt_tokens": inputs.input_ids.shape[1],
203
+ "completion_tokens": outputs.shape[1] - inputs.input_ids.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  }
 
 
205
  }
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  if __name__ == "__main__":
209
  import uvicorn
210
+ uvicorn.run(app, host="0.0.0.0", port=7860)
 
requirements.txt CHANGED
@@ -1,19 +1,6 @@
1
- # Nanbeige4.1-3B Inference Server Dependencies
2
- # Optimized for Hugging Face Space deployment
3
-
4
- # Core framework
5
- fastapi>=0.115.0
6
- uvicorn[standard]>=0.32.0
7
- pydantic>=2.9.0
8
-
9
- # ML/Transformers
10
- torch>=2.1.0
11
- transformers>=4.40.0
12
- accelerate>=0.30.0
13
-
14
- # Utilities
15
- python-dotenv>=1.0.0
16
-
17
- # Note: This configuration uses the original Nanbeige4.1-3B model
18
- # from HuggingFace Hub (Nanbeige/Nanbeige4.1-3B)
19
- # The model will be downloaded on first startup
 
1
+ fastapi==0.109.0
2
+ uvicorn[standard]==0.27.0
3
+ transformers==4.37.0
4
+ torch==2.1.2
5
+ accelerate==0.26.0
6
+ pydantic==2.5.3