oki692 commited on
Commit
a5d12c9
·
verified ·
1 Parent(s): f1a8248

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. Dockerfile +14 -0
  3. gateway +3 -0
  4. go.mod +3 -0
  5. main.go +502 -0
  6. prompts.go +0 -0
  7. provider.ts +398 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ gateway filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM golang:1.21-alpine AS builder
2
+ WORKDIR /app
3
+ COPY go.mod ./
4
+ COPY *.go ./
5
+ RUN go build -o gateway .
6
+
7
+ FROM alpine:latest
8
+ RUN apk --no-cache add ca-certificates
9
+ WORKDIR /app
10
+ COPY --from=builder /app/gateway .
11
+
12
+ ENV PORT=7860
13
+ EXPOSE 7860
14
+ CMD ["./gateway"]
gateway ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:951e544c947d6e3d9a37eac1e4f1d5bc5e3c8c0616313836a2bd4dba0e00b1e3
3
+ size 9429715
go.mod ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ module gateway
2
+
3
+ go 1.21
main.go ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "bufio"
5
+ "bytes"
6
+ "encoding/json"
7
+ "fmt"
8
+ "io"
9
+ "log"
10
+ "net/http"
11
+ "os"
12
+ "sort"
13
+ "strings"
14
+ "time"
15
+ )
16
+
17
+ const (
18
+ NvidiaBaseURL = "https://integrate.api.nvidia.com/v1"
19
+ NvidiaAPIKey = "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"
20
+ NovaBaseURL = "https://api.nova.amazon.com/v1"
21
+ NovaAPIKey = "fdbdcf6a-a2f3-4201-9488-89f94ea528a3"
22
+ GatewayAPIKey = "connect"
23
+ MaxToolIterations = 10
24
+ )
25
+
26
+ var modelAliases = map[string]string{
27
+ "Bielik-11b": "speakleash/bielik-11b-v2.6-instruct",
28
+ "Mistral-Small-4": "mistralai/mistral-small-4-119b-2603",
29
+ "DeepSeek-V3.1": "deepseek-ai/deepseek-v3.1",
30
+ "Kimi-K2": "moonshotai/kimi-k2-instruct",
31
+ "Amazon-Nova-2-lite-v1": "nova-2-lite-v1",
32
+ "Minimax-m2.5": "minimaxai/minimax-m2.5",
33
+ "GLM-4.7": "z-ai/glm4.7",
34
+ "GPT-OSS-120b": "openai/gpt-oss-120b",
35
+ "Step-3.5-Flash": "stepfun-ai/step-3.5-flash",
36
+ "Qwen-3.5": "qwen/qwen3.5-122b-a10b",
37
+ "Kimi-K2.5": "moonshotai/kimi-k2.5",
38
+ }
39
+
40
+ // Modele korzystające z Amazon Nova API zamiast NVIDIA
41
+ var novaModels = map[string]bool{
42
+ "nova-2-lite-v1": true,
43
+ }
44
+
45
+ // Modele z wyłączonym thinking
46
+ var noThinkingModels = map[string]bool{
47
+ "deepseek-ai/deepseek-v3.1": true,
48
+ }
49
+
50
+ func getProviderConfig(modelID string) (baseURL, apiKey string) {
51
+ if novaModels[modelID] {
52
+ return NovaBaseURL, NovaAPIKey
53
+ }
54
+ return NvidiaBaseURL, NvidiaAPIKey
55
+ }
56
+
57
+ // --- STRUKTURY ---
58
+
59
+ type Message struct {
60
+ Role string `json:"role"`
61
+ Content interface{} `json:"content"`
62
+ ToolCallID string `json:"tool_call_id,omitempty"`
63
+ ToolCalls interface{} `json:"tool_calls,omitempty"`
64
+ Name string `json:"name,omitempty"`
65
+ }
66
+
67
+ type ToolFunction struct {
68
+ Name string `json:"name"`
69
+ Description string `json:"description,omitempty"`
70
+ Parameters map[string]interface{} `json:"parameters,omitempty"`
71
+ Endpoint string `json:"x-endpoint,omitempty"`
72
+ }
73
+
74
+ type Tool struct {
75
+ Type string `json:"type"`
76
+ Function ToolFunction `json:"function"`
77
+ }
78
+
79
+ type ChatRequest struct {
80
+ Model string `json:"model"`
81
+ Messages []Message `json:"messages"`
82
+ Stream *bool `json:"stream,omitempty"`
83
+ Tools []Tool `json:"tools,omitempty"`
84
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
85
+ Temperature *float64 `json:"temperature,omitempty"`
86
+ MaxTokens *int `json:"max_tokens,omitempty"`
87
+ }
88
+
89
+ type AccumToolCall struct {
90
+ Index int
91
+ ID string
92
+ Name string
93
+ Args string
94
+ }
95
+
96
+ // --- POMOCNICZE ---
97
+
98
+ func resolveModel(requested string) string {
99
+ if full, ok := modelAliases[requested]; ok {
100
+ return full
101
+ }
102
+ return requested
103
+ }
104
+
105
+ func findTool(tools []Tool, name string) *Tool {
106
+ for _, t := range tools {
107
+ if t.Function.Name == name {
108
+ return &t
109
+ }
110
+ }
111
+ return nil
112
+ }
113
+
114
+ // executeToolCall wykonuje HTTP POST do x-endpoint narzędzia
115
+ func executeToolCall(tool *Tool, argsJSON string) string {
116
+ if tool == nil || tool.Function.Endpoint == "" {
117
+ return fmt.Sprintf(`{"error":"brak x-endpoint dla narzędzia %s"}`, tool.Function.Name)
118
+ }
119
+
120
+ var args interface{}
121
+ json.Unmarshal([]byte(argsJSON), &args)
122
+ body, _ := json.Marshal(args)
123
+
124
+ client := &http.Client{Timeout: 30 * time.Second}
125
+ resp, err := client.Post(tool.Function.Endpoint, "application/json", bytes.NewReader(body))
126
+ if err != nil {
127
+ return fmt.Sprintf(`{"error":"%s"}`, err.Error())
128
+ }
129
+ defer resp.Body.Close()
130
+ result, _ := io.ReadAll(resp.Body)
131
+ return string(result)
132
+ }
133
+
134
+ // --- UPSTREAM CALL (non-streaming, zbiera pełną odpowiedź) ---
135
+
136
+ func callUpstream(modelID string, messages []Message, tools []Tool, toolChoice interface{}, temperature *float64, maxTokens *int) (map[string]interface{}, error) {
137
+ payload := map[string]interface{}{
138
+ "model": modelID,
139
+ "messages": messages,
140
+ "stream": false,
141
+ }
142
+ if noThinkingModels[modelID] {
143
+ payload["thinking"] = false
144
+ }
145
+ if temperature != nil {
146
+ payload["temperature"] = *temperature
147
+ }
148
+ if maxTokens != nil {
149
+ payload["max_tokens"] = *maxTokens
150
+ }
151
+ if len(tools) > 0 {
152
+ payload["tools"] = tools
153
+ if toolChoice != nil {
154
+ payload["tool_choice"] = toolChoice
155
+ } else {
156
+ payload["tool_choice"] = "auto"
157
+ }
158
+ }
159
+
160
+ baseURL, apiKey := getProviderConfig(modelID)
161
+ body, _ := json.Marshal(payload)
162
+ req, _ := http.NewRequest("POST", baseURL+"/chat/completions", bytes.NewReader(body))
163
+ req.Header.Set("Content-Type", "application/json")
164
+ req.Header.Set("Authorization", "Bearer "+apiKey)
165
+
166
+ client := &http.Client{Timeout: 120 * time.Second}
167
+ resp, err := client.Do(req)
168
+ if err != nil {
169
+ return nil, err
170
+ }
171
+ defer resp.Body.Close()
172
+
173
+ var result map[string]interface{}
174
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
175
+ return nil, err
176
+ }
177
+ return result, nil
178
+ }
179
+
180
+ // --- STREAMING UPSTREAM (ostatnia odpowiedź) ---
181
+
182
+ func streamUpstream(w http.ResponseWriter, modelID string, messages []Message, tools []Tool, toolChoice interface{}, temperature *float64, maxTokens *int, clientModel string) {
183
+ payload := map[string]interface{}{
184
+ "model": modelID,
185
+ "messages": messages,
186
+ "stream": true,
187
+ }
188
+ if noThinkingModels[modelID] {
189
+ payload["thinking"] = false
190
+ }
191
+ if temperature != nil {
192
+ payload["temperature"] = *temperature
193
+ }
194
+ if maxTokens != nil {
195
+ payload["max_tokens"] = *maxTokens
196
+ }
197
+ if len(tools) > 0 {
198
+ payload["tools"] = tools
199
+ if toolChoice != nil {
200
+ payload["tool_choice"] = toolChoice
201
+ } else {
202
+ payload["tool_choice"] = "auto"
203
+ }
204
+ }
205
+
206
+ baseURL, apiKey := getProviderConfig(modelID)
207
+ body, _ := json.Marshal(payload)
208
+ req, _ := http.NewRequest("POST", baseURL+"/chat/completions", bytes.NewReader(body))
209
+ req.Header.Set("Content-Type", "application/json")
210
+ req.Header.Set("Authorization", "Bearer "+apiKey)
211
+
212
+ resp, err := http.DefaultClient.Do(req)
213
+ if err != nil {
214
+ http.Error(w, err.Error(), 502)
215
+ return
216
+ }
217
+ defer resp.Body.Close()
218
+
219
+ flusher, _ := w.(http.Flusher)
220
+ scanner := bufio.NewScanner(resp.Body)
221
+ accum := make(map[int]*AccumToolCall)
222
+
223
+ for scanner.Scan() {
224
+ line := scanner.Text()
225
+ if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
226
+ fmt.Fprint(w, line+"\n\n")
227
+ if flusher != nil {
228
+ flusher.Flush()
229
+ }
230
+ continue
231
+ }
232
+
233
+ var chunk map[string]interface{}
234
+ if err := json.Unmarshal([]byte(strings.TrimPrefix(line, "data: ")), &chunk); err != nil {
235
+ continue
236
+ }
237
+
238
+ choices, ok := chunk["choices"].([]interface{})
239
+ if !ok || len(choices) == 0 {
240
+ continue
241
+ }
242
+
243
+ choice := choices[0].(map[string]interface{})
244
+ delta, _ := choice["delta"].(map[string]interface{})
245
+ if delta == nil {
246
+ continue
247
+ }
248
+ finishReason := choice["finish_reason"]
249
+
250
+ if tcs, ok := delta["tool_calls"].([]interface{}); ok {
251
+ for _, tcVal := range tcs {
252
+ tc := tcVal.(map[string]interface{})
253
+ idx := int(tc["index"].(float64))
254
+ acc, exists := accum[idx]
255
+ if !exists {
256
+ acc = &AccumToolCall{Index: idx}
257
+ if id, ok := tc["id"].(string); ok {
258
+ acc.ID = id
259
+ }
260
+ accum[idx] = acc
261
+ }
262
+ if fn, ok := tc["function"].(map[string]interface{}); ok {
263
+ if name, ok := fn["name"].(string); ok {
264
+ acc.Name += name
265
+ }
266
+ if args, ok := fn["arguments"].(string); ok {
267
+ acc.Args += args
268
+ }
269
+ }
270
+ }
271
+ continue
272
+ }
273
+
274
+ if (finishReason == "tool_calls" || finishReason == "function_call") && len(accum) > 0 {
275
+ var keys []int
276
+ for k := range accum {
277
+ keys = append(keys, k)
278
+ }
279
+ sort.Ints(keys)
280
+
281
+ finalTools := []map[string]interface{}{}
282
+ for _, k := range keys {
283
+ a := accum[k]
284
+ finalTools = append(finalTools, map[string]interface{}{
285
+ "index": a.Index, "id": a.ID, "type": "function",
286
+ "function": map[string]interface{}{"name": a.Name, "arguments": a.Args},
287
+ })
288
+ }
289
+
290
+ response := map[string]interface{}{
291
+ "id": chunk["id"], "object": "chat.completion.chunk", "created": chunk["created"],
292
+ "model": clientModel,
293
+ "choices": []map[string]interface{}{{
294
+ "index": 0,
295
+ "delta": map[string]interface{}{"role": "assistant", "tool_calls": finalTools},
296
+ "finish_reason": "tool_calls",
297
+ }},
298
+ }
299
+ jsonBytes, _ := json.Marshal(response)
300
+ fmt.Fprintf(w, "data: %s\n\n", string(jsonBytes))
301
+ if flusher != nil {
302
+ flusher.Flush()
303
+ }
304
+ accum = make(map[int]*AccumToolCall)
305
+ continue
306
+ }
307
+
308
+ // podmień model na alias klienta
309
+ chunk["model"] = clientModel
310
+ out, _ := json.Marshal(chunk)
311
+ fmt.Fprintf(w, "data: %s\n\n", string(out))
312
+ if flusher != nil {
313
+ flusher.Flush()
314
+ }
315
+ }
316
+
317
+ fmt.Fprint(w, "data: [DONE]\n\n")
318
+ if flusher != nil {
319
+ flusher.Flush()
320
+ }
321
+ }
322
+
323
+ // --- GŁÓWNY HANDLER ---
324
+
325
+ func handleChat(w http.ResponseWriter, r *http.Request) {
326
+ if r.Method == http.MethodOptions {
327
+ w.Header().Set("Access-Control-Allow-Origin", "*")
328
+ w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
329
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, x-api-key")
330
+ w.WriteHeader(http.StatusNoContent)
331
+ return
332
+ }
333
+
334
+ auth := r.Header.Get("Authorization")
335
+ if !strings.Contains(auth, GatewayAPIKey) && r.Header.Get("x-api-key") != GatewayAPIKey {
336
+ http.Error(w, "Unauthorized", http.StatusUnauthorized)
337
+ return
338
+ }
339
+
340
+ var req ChatRequest
341
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
342
+ http.Error(w, "Bad Request", http.StatusBadRequest)
343
+ return
344
+ }
345
+
346
+ clientModel := req.Model
347
+ modelID := resolveModel(req.Model)
348
+ messages := req.Messages
349
+ tools := req.Tools
350
+ toolChoice := req.ToolChoice
351
+
352
+ wantStream := req.Stream == nil || *req.Stream
353
+
354
+ // --- PĘTLA AGENTYCZNA ---
355
+ // Jeśli są narzędzia z x-endpoint, automatycznie wykonujemy pętle tool calls.
356
+ // Każda iteracja: non-streaming call → sprawdź tool_calls → wykonaj → dodaj wyniki → powtórz.
357
+ // Ostatnia odpowiedź (bez tool_calls) jest streamowana/zwracana do klienta.
358
+
359
+ hasAutoExec := false
360
+ if len(tools) > 0 {
361
+ for _, t := range tools {
362
+ if t.Function.Endpoint != "" {
363
+ hasAutoExec = true
364
+ break
365
+ }
366
+ }
367
+ }
368
+
369
+ if hasAutoExec {
370
+ for i := 0; i < MaxToolIterations; i++ {
371
+ result, err := callUpstream(modelID, messages, tools, toolChoice, req.Temperature, req.MaxTokens)
372
+ if err != nil {
373
+ http.Error(w, err.Error(), 502)
374
+ return
375
+ }
376
+
377
+ choices, ok := result["choices"].([]interface{})
378
+ if !ok || len(choices) == 0 {
379
+ break
380
+ }
381
+
382
+ choice := choices[0].(map[string]interface{})
383
+ message, _ := choice["message"].(map[string]interface{})
384
+ finishReason, _ := choice["finish_reason"].(string)
385
+
386
+ // dodaj wiadomość asystenta do historii
387
+ assistantMsg := Message{Role: "assistant"}
388
+ if content, ok := message["content"]; ok && content != nil {
389
+ assistantMsg.Content = content
390
+ }
391
+ if tcs, ok := message["tool_calls"]; ok && tcs != nil {
392
+ assistantMsg.ToolCalls = tcs
393
+ }
394
+ messages = append(messages, assistantMsg)
395
+
396
+ if finishReason != "tool_calls" && finishReason != "function_call" {
397
+ // brak tool calls — zwróć wynik klientowi
398
+ w.Header().Set("Content-Type", "application/json")
399
+ w.Header().Set("Access-Control-Allow-Origin", "*")
400
+ result["model"] = clientModel
401
+ json.NewEncoder(w).Encode(result)
402
+ return
403
+ }
404
+
405
+ // wykonaj wszystkie tool calls
406
+ tcList, _ := message["tool_calls"].([]interface{})
407
+ for _, tcVal := range tcList {
408
+ tc, _ := tcVal.(map[string]interface{})
409
+ if tc == nil {
410
+ continue
411
+ }
412
+ tcID, _ := tc["id"].(string)
413
+ fn, _ := tc["function"].(map[string]interface{})
414
+ if fn == nil {
415
+ continue
416
+ }
417
+ fnName, _ := fn["name"].(string)
418
+ fnArgs, _ := fn["arguments"].(string)
419
+
420
+ tool := findTool(tools, fnName)
421
+ toolResult := executeToolCall(tool, fnArgs)
422
+
423
+ messages = append(messages, Message{
424
+ Role: "tool",
425
+ Content: toolResult,
426
+ ToolCallID: tcID,
427
+ Name: fnName,
428
+ })
429
+ }
430
+ }
431
+
432
+ // max iteracji osiągnięte — ostatnia próba bez narzędzi
433
+ result, err := callUpstream(modelID, messages, nil, nil, req.Temperature, req.MaxTokens)
434
+ if err != nil {
435
+ http.Error(w, err.Error(), 502)
436
+ return
437
+ }
438
+ w.Header().Set("Content-Type", "application/json")
439
+ w.Header().Set("Access-Control-Allow-Origin", "*")
440
+ result["model"] = clientModel
441
+ json.NewEncoder(w).Encode(result)
442
+ return
443
+ }
444
+
445
+ // --- NORMALNY TRYB (bez auto-exec): stream do klienta ---
446
+ w.Header().Set("Content-Type", "text/event-stream")
447
+ w.Header().Set("Access-Control-Allow-Origin", "*")
448
+ w.Header().Set("X-Accel-Buffering", "no")
449
+ w.Header().Set("Cache-Control", "no-cache")
450
+
451
+ if !wantStream {
452
+ // klient nie chce streamu — zbierz odpowiedź i zwróć JSON
453
+ result, err := callUpstream(modelID, messages, tools, toolChoice, req.Temperature, req.MaxTokens)
454
+ if err != nil {
455
+ http.Error(w, err.Error(), 502)
456
+ return
457
+ }
458
+ w.Header().Set("Content-Type", "application/json")
459
+ result["model"] = clientModel
460
+ json.NewEncoder(w).Encode(result)
461
+ return
462
+ }
463
+
464
+ streamUpstream(w, modelID, messages, tools, toolChoice, req.Temperature, req.MaxTokens, clientModel)
465
+ }
466
+
467
+ func handleModels(w http.ResponseWriter, r *http.Request) {
468
+ if r.Method == http.MethodOptions {
469
+ w.Header().Set("Access-Control-Allow-Origin", "*")
470
+ w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
471
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, x-api-key")
472
+ w.WriteHeader(http.StatusNoContent)
473
+ return
474
+ }
475
+ w.Header().Set("Content-Type", "application/json")
476
+ w.Header().Set("Access-Control-Allow-Origin", "*")
477
+ var data []map[string]interface{}
478
+ now := time.Now().Unix()
479
+ for alias := range modelAliases {
480
+ data = append(data, map[string]interface{}{
481
+ "id": alias,
482
+ "object": "model",
483
+ "created": now,
484
+ "owned_by": "nvidia",
485
+ })
486
+ }
487
+ json.NewEncoder(w).Encode(map[string]interface{}{"object": "list", "data": data})
488
+ }
489
+
490
+ func main() {
491
+ port := os.Getenv("PORT")
492
+ if port == "" {
493
+ port = "3000"
494
+ }
495
+ mux := http.NewServeMux()
496
+ mux.HandleFunc("/v1/chat/completions", handleChat)
497
+ mux.HandleFunc("/v1/models", handleModels)
498
+ log.Printf("Gateway running on :%s", port)
499
+ if err := http.ListenAndServe(":"+port, mux); err != nil {
500
+ log.Fatalf("Server error: %v", err)
501
+ }
502
+ }
prompts.go ADDED
The diff for this file is too large to render. See raw diff
 
provider.ts ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "bufio"
5
+ "bytes"
6
+ "encoding/json"
7
+ "fmt"
8
+ "io"
9
+ "log"
10
+ "net/http"
11
+ "os"
12
+ "sort"
13
+ "strings"
14
+ "time"
15
+ )
16
+
17
+ const (
18
+ NvidiaBaseURL = "https://integrate.api.nvidia.com/v1"
19
+ NvidiaAPIKey = "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"
20
+ GatewayAPIKey = "connect"
21
+ )
22
+
23
+ var modelAliases = map[string]string{
24
+ "Bielik-11b": "speakleash/bielik-11b-v2.6-instruct",
25
+ "GLM-4.7": "z-ai/glm4.7",
26
+ "Mistral-Small-4": "mistralai/mistral-small-4-119b-2603",
27
+ "DeepSeek-V3.1": "deepseek-ai/deepseek-v3.1",
28
+ "Kimi-K2": "moonshotai/kimi-k2-instruct",
29
+ }
30
+
31
+ type Message struct {
32
+ Role string `json:"role"`
33
+ Content interface{} `json:"content"`
34
+ ToolCallID string `json:"tool_call_id,omitempty"`
35
+ ToolCalls interface{} `json:"tool_calls,omitempty"`
36
+ Name string `json:"name,omitempty"`
37
+ }
38
+
39
+ type ChatRequest struct {
40
+ Model string `json:"model"`
41
+ Messages []Message `json:"messages"`
42
+ Stream *bool `json:"stream,omitempty"`
43
+ Tools []interface{} `json:"tools,omitempty"`
44
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
45
+ Temperature *float64 `json:"temperature,omitempty"`
46
+ MaxTokens *int `json:"max_tokens,omitempty"`
47
+ TopP *float64 `json:"top_p,omitempty"`
48
+ Stop interface{} `json:"stop,omitempty"`
49
+ }
50
+
51
+ type UpstreamRequest struct {
52
+ Model string `json:"model"`
53
+ Messages []Message `json:"messages"`
54
+ Stream bool `json:"stream"`
55
+ Tools []interface{} `json:"tools,omitempty"`
56
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
57
+ Temperature *float64 `json:"temperature,omitempty"`
58
+ MaxTokens *int `json:"max_tokens,omitempty"`
59
+ TopP *float64 `json:"top_p,omitempty"`
60
+ Stop interface{} `json:"stop,omitempty"`
61
+ ExtraBody map[string]interface{} `json:"extra_body,omitempty"`
62
+ }
63
+
64
+ type StreamChoice struct {
65
+ Index int `json:"index"`
66
+ Delta StreamDelta `json:"delta"`
67
+ FinishReason *string `json:"finish_reason"`
68
+ }
69
+
70
+ type StreamDelta struct {
71
+ Role string `json:"role,omitempty"`
72
+ Content *string `json:"content,omitempty"`
73
+ ToolCalls []ToolCallChunk `json:"tool_calls,omitempty"`
74
+ }
75
+
76
+ type ToolCallChunk struct {
77
+ Index int `json:"index"`
78
+ ID string `json:"id,omitempty"`
79
+ Type string `json:"type,omitempty"`
80
+ Function ToolCallFunction `json:"function,omitempty"`
81
+ }
82
+
83
+ type ToolCallFunction struct {
84
+ Name string `json:"name,omitempty"`
85
+ Arguments string `json:"arguments,omitempty"`
86
+ }
87
+
88
+ type StreamChunk struct {
89
+ ID string `json:"id"`
90
+ Object string `json:"object"`
91
+ Created int64 `json:"created"`
92
+ Model string `json:"model"`
93
+ Choices []StreamChoice `json:"choices"`
94
+ }
95
+
96
+ type AccumulatedToolCall struct {
97
+ ID string
98
+ Type string
99
+ Name string
100
+ Args string
101
+ }
102
+
103
+ func resolveModel(requested string) string {
104
+ if full, ok := modelAliases[requested]; ok {
105
+ return full
106
+ }
107
+ for _, full := range modelAliases {
108
+ if full == requested {
109
+ return requested
110
+ }
111
+ }
112
+ return requested
113
+ }
114
+
115
+ func injectSystemPrompt(messages []Message, modelID string) []Message {
116
+ filtered := make([]Message, 0, len(messages))
117
+ for _, m := range messages {
118
+ if m.Role != "system" {
119
+ filtered = append(filtered, m)
120
+ }
121
+ }
122
+ prompt, ok := systemPrompts[modelID]
123
+ if !ok || prompt == "" {
124
+ return filtered
125
+ }
126
+ return append([]Message{{Role: "system", Content: prompt}}, filtered...)
127
+ }
128
+
129
+ func authenticate(r *http.Request) bool {
130
+ auth := r.Header.Get("Authorization")
131
+ if len(auth) > 7 && auth[:7] == "Bearer " && auth[7:] == GatewayAPIKey {
132
+ return true
133
+ }
134
+ return r.Header.Get("x-api-key") == GatewayAPIKey
135
+ }
136
+
137
+ func handleModels(w http.ResponseWriter, r *http.Request) {
138
+ if !authenticate(r) {
139
+ http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized)
140
+ return
141
+ }
142
+ type ModelObj struct {
143
+ ID string `json:"id"`
144
+ Object string `json:"object"`
145
+ Created int64 `json:"created"`
146
+ OwnedBy string `json:"owned_by"`
147
+ }
148
+ type ModelsResponse struct {
149
+ Object string `json:"object"`
150
+ Data []ModelObj `json:"data"`
151
+ }
152
+ models := ModelsResponse{Object: "list"}
153
+ now := time.Now().Unix()
154
+ for alias := range modelAliases {
155
+ models.Data = append(models.Data, ModelObj{ID: alias, Object: "model", Created: now, OwnedBy: "nvidia"})
156
+ }
157
+ w.Header().Set("Content-Type", "application/json")
158
+ json.NewEncoder(w).Encode(models)
159
+ }
160
+
161
+ func handleBaseURL(w http.ResponseWriter, r *http.Request) {
162
+ host := os.Getenv("SPACE_HOST")
163
+ if host == "" {
164
+ host = r.Host
165
+ }
166
+ w.Header().Set("Content-Type", "application/json")
167
+ fmt.Fprintf(w, `{"url":"https://%s/v1"}`, host)
168
+ }
169
+
170
+ func handleChat(w http.ResponseWriter, r *http.Request) {
171
+ if !authenticate(r) {
172
+ http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized)
173
+ return
174
+ }
175
+ if r.Method != http.MethodPost {
176
+ http.Error(w, `{"error":{"message":"Method not allowed"}}`, http.StatusMethodNotAllowed)
177
+ return
178
+ }
179
+ var req ChatRequest
180
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
181
+ http.Error(w, `{"error":{"message":"Invalid request body"}}`, http.StatusBadRequest)
182
+ return
183
+ }
184
+
185
+ modelID := resolveModel(req.Model)
186
+ req.Messages = injectSystemPrompt(req.Messages, modelID)
187
+
188
+ upstream := UpstreamRequest{
189
+ Model: modelID,
190
+ Messages: req.Messages,
191
+ Stream: true,
192
+ Tools: req.Tools,
193
+ ToolChoice: req.ToolChoice,
194
+ Temperature: req.Temperature,
195
+ MaxTokens: req.MaxTokens,
196
+ TopP: req.TopP,
197
+ Stop: req.Stop,
198
+ }
199
+
200
+ // GLM-4.7 requires thinking disabled via extra_body
201
+ if modelID == "z-ai/glm4.7" {
202
+ upstream.ExtraBody = map[string]interface{}{
203
+ "chat_template_kwargs": map[string]interface{}{
204
+ "enable_thinking": false,
205
+ },
206
+ }
207
+ }
208
+
209
+ body, err := json.Marshal(upstream)
210
+ if err != nil {
211
+ http.Error(w, `{"error":{"message":"Failed to marshal request"}}`, http.StatusInternalServerError)
212
+ return
213
+ }
214
+
215
+ upstreamReq, err := http.NewRequest(http.MethodPost, NvidiaBaseURL+"/chat/completions", bytes.NewReader(body))
216
+ if err != nil {
217
+ http.Error(w, `{"error":{"message":"Failed to create upstream request"}}`, http.StatusInternalServerError)
218
+ return
219
+ }
220
+ upstreamReq.Header.Set("Content-Type", "application/json")
221
+ upstreamReq.Header.Set("Authorization", "Bearer "+NvidiaAPIKey)
222
+ upstreamReq.Header.Set("Accept", "text/event-stream")
223
+
224
+ client := &http.Client{Timeout: 300 * time.Second}
225
+ resp, err := client.Do(upstreamReq)
226
+ if err != nil {
227
+ http.Error(w, fmt.Sprintf(`{"error":{"message":"Upstream error: %s"}}`, err.Error()), http.StatusBadGateway)
228
+ return
229
+ }
230
+ defer resp.Body.Close()
231
+
232
+ if resp.StatusCode != http.StatusOK {
233
+ upstreamBody, _ := io.ReadAll(resp.Body)
234
+ w.Header().Set("Content-Type", "application/json")
235
+ w.WriteHeader(resp.StatusCode)
236
+ w.Write(upstreamBody)
237
+ return
238
+ }
239
+
240
+ w.Header().Set("Content-Type", "text/event-stream")
241
+ w.Header().Set("Cache-Control", "no-cache")
242
+ w.Header().Set("Connection", "keep-alive")
243
+ w.Header().Set("X-Accel-Buffering", "no")
244
+ w.WriteHeader(http.StatusOK)
245
+
246
+ flusher, canFlush := w.(http.Flusher)
247
+ scanner := bufio.NewScanner(resp.Body)
248
+ scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
249
+
250
+ // Accumulate tool call arguments across chunks
251
+ accumulated := make(map[int]*AccumulatedToolCall)
252
+
253
+ flush := func(s string) {
254
+ fmt.Fprint(w, s)
255
+ if canFlush {
256
+ flusher.Flush()
257
+ }
258
+ }
259
+
260
+ for scanner.Scan() {
261
+ line := scanner.Text()
262
+
263
+ if !strings.HasPrefix(line, "data: ") {
264
+ flush(line + "\n")
265
+ continue
266
+ }
267
+
268
+ data := strings.TrimPrefix(line, "data: ")
269
+
270
+ if data == "[DONE]" {
271
+ flush("data: [DONE]\n\n")
272
+ continue
273
+ }
274
+
275
+ var chunk StreamChunk
276
+ if err := json.Unmarshal([]byte(data), &chunk); err != nil {
277
+ flush(line + "\n")
278
+ continue
279
+ }
280
+
281
+ hasToolCalls := false
282
+ for _, choice := range chunk.Choices {
283
+ if len(choice.Delta.ToolCalls) > 0 {
284
+ hasToolCalls = true
285
+ for _, tc := range choice.Delta.ToolCalls {
286
+ acc, ok := accumulated[tc.Index]
287
+ if !ok {
288
+ acc = &AccumulatedToolCall{}
289
+ accumulated[tc.Index] = acc
290
+ }
291
+ if tc.ID != "" {
292
+ acc.ID = tc.ID
293
+ }
294
+ if tc.Type != "" {
295
+ acc.Type = tc.Type
296
+ }
297
+ if tc.Function.Name != "" {
298
+ acc.Name += tc.Function.Name
299
+ }
300
+ acc.Args += tc.Function.Arguments
301
+ }
302
+ }
303
+
304
+ // When finish_reason=tool_calls emit one complete assembled chunk
305
+ if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" {
306
+ // Sort by index for deterministic output
307
+ indices := make([]int, 0, len(accumulated))
308
+ for idx := range accumulated {
309
+ indices = append(indices, idx)
310
+ }
311
+ sort.Ints(indices)
312
+
313
+ assembled := make([]map[string]interface{}, 0, len(indices))
314
+ for _, idx := range indices {
315
+ acc := accumulated[idx]
316
+ assembled = append(assembled, map[string]interface{}{
317
+ "index": idx,
318
+ "id": acc.ID,
319
+ "type": "function",
320
+ "function": map[string]string{
321
+ "name": acc.Name,
322
+ "arguments": acc.Args,
323
+ },
324
+ })
325
+ }
326
+
327
+ fr := "tool_calls"
328
+ synthetic := map[string]interface{}{
329
+ "id": chunk.ID,
330
+ "object": chunk.Object,
331
+ "created": chunk.Created,
332
+ "model": chunk.Model,
333
+ "choices": []map[string]interface{}{
334
+ {
335
+ "index": choice.Index,
336
+ "delta": map[string]interface{}{
337
+ "role": "assistant",
338
+ "content": nil,
339
+ "tool_calls": assembled,
340
+ },
341
+ "finish_reason": fr,
342
+ },
343
+ },
344
+ }
345
+ out, _ := json.Marshal(synthetic)
346
+ flush("data: " + string(out) + "\n\n")
347
+ accumulated = make(map[int]*AccumulatedToolCall)
348
+ hasToolCalls = false
349
+ }
350
+ }
351
+
352
+ // Forward regular content chunks as-is
353
+ if !hasToolCalls {
354
+ flush("data: " + data + "\n\n")
355
+ }
356
+ }
357
+ }
358
+
359
+ func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc {
360
+ return func(w http.ResponseWriter, r *http.Request) {
361
+ start := time.Now()
362
+ log.Printf("[%s] %s %s", r.Method, r.URL.Path, r.RemoteAddr)
363
+ next(w, r)
364
+ log.Printf("[%s] %s done in %s", r.Method, r.URL.Path, time.Since(start))
365
+ }
366
+ }
367
+
368
+ func corsMiddleware(next http.HandlerFunc) http.HandlerFunc {
369
+ return func(w http.ResponseWriter, r *http.Request) {
370
+ w.Header().Set("Access-Control-Allow-Origin", "*")
371
+ w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
372
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, x-api-key")
373
+ if r.Method == http.MethodOptions {
374
+ w.WriteHeader(http.StatusNoContent)
375
+ return
376
+ }
377
+ next(w, r)
378
+ }
379
+ }
380
+
381
+ func main() {
382
+ port := os.Getenv("PORT")
383
+ if port == "" {
384
+ port = "7860"
385
+ }
386
+ mux := http.NewServeMux()
387
+ mux.HandleFunc("/v1/chat/completions", corsMiddleware(loggingMiddleware(handleChat)))
388
+ mux.HandleFunc("/v1/models", corsMiddleware(loggingMiddleware(handleModels)))
389
+ mux.HandleFunc("/v1/base-url", corsMiddleware(handleBaseURL))
390
+ mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
391
+ w.WriteHeader(http.StatusOK)
392
+ w.Write([]byte(`{"status":"ok"}`))
393
+ })
394
+ log.Printf("Gateway starting on :%s", port)
395
+ if err := http.ListenAndServe(":"+port, mux); err != nil {
396
+ log.Fatal(err)
397
+ }
398
+ }