File size: 9,677 Bytes
a5cd74c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from __future__ import annotations
import os, re, json, uuid, random, string, logging, asyncio
from datetime import datetime, timedelta
from typing import List, Callable, Any, Optional

import httpx
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel, Field

# ────────────────────────── logging ──────────────────────────────────────
logging.basicConfig(
    level=os.getenv("LOG_LEVEL", "INFO"),
    format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
)
log = logging.getLogger("snapzion-service")
log.info("snapzion service starting …")

# ────────────────────────── ENV & constants ─────────────────────────────
SYSTEM_PROMPT = os.getenv(
    "SYSTEM_PROMPT",
    "You are a prompt-safety model. Decide if the prompt is safe. "
    "Respond with 'safe' or 'not safe'.",
)
SAFETY_API_KEY = os.getenv("SAFETY_API_KEY", "sk-F8l9ALDrJSpVCWJ3G1XbqP09oE3UD09Jf0t4WSlnrSJFdTtX")
SAFETY_MODEL_URL = os.getenv(
    "SAFETY_MODEL_URL",
    "https://api.typegpt.net/v1/chat/completions",
)

MAX_RETRIES = int(os.getenv("MAX_RETRIES", "5"))
INITIAL_DELAY = float(os.getenv("INITIAL_DELAY", "0.5"))
MAX_DELAY = float(os.getenv("MAX_DELAY", "2.5"))

# ────────────────────────── FastAPI / HTTPX ────────────────────────────
app = FastAPI(title="Snapzion Image-Gen API | NAI", version="2.4.1")
_http: Optional[httpx.AsyncClient] = None

@app.on_event("startup")
async def _startup():
    global _http
    _http = httpx.AsyncClient(
        timeout=30,
        limits=httpx.Limits(max_connections=100, max_keepalive_connections=40),
    )
    log.info("HTTPX pool ready βœ“")

# ────────────────────────── Pydantic models ────────────────────────────
class ChatMessage(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    stream: bool = Field(default=False)

# ────────────────────────── Helpers ────────────────────────────────────
def _fake_user() -> tuple[str, str, str]:
    first = random.choice("Alice Bob Carol David Evelyn Frank Grace Hector Ivy Jackie".split())
    last = random.choice("Smith Johnson Davis Miller Thompson Garcia Brown Wilson Martin Clark".split())
    email = ''.join(random.choices(string.ascii_lowercase + string.digits, k=8)) + "@example.com"
    cust = "cus_" + ''.join(random.choices(string.ascii_letters + string.digits, k=14))
    return f"{first} {last}", email, cust

async def _retry(fn: Callable, *a, **kw) -> Any:
    max_tries = kw.pop("max_retries", MAX_RETRIES)
    delay = INITIAL_DELAY
    for n in range(1, max_tries + 1):
        try:
            return await fn(*a, **kw)
        except httpx.HTTPStatusError as exc:
            if exc.response.status_code == 400:
                log.warning("%s try %d/%d: HTTP 400 error: %s", fn.__name__, n, max_tries, exc)
                if n == max_tries:
                    log.error("%s failed after %d tries: HTTP 400 error: %s", fn.__name__, n, exc)
                    raise
            else:
                log.error("%s failed with status %d: %s", fn.__name__, exc.response.status_code, exc)
                raise
        except Exception as exc:
            if n == max_tries:
                log.error("%s failed after %d tries: %s", fn.__name__, n, exc)
                raise
            log.warning("%s try %d/%d: %s", fn.__name__, n, max_tries, exc)
            await asyncio.sleep(delay + random.uniform(0, 0.4))
            delay = min(delay * 2, MAX_DELAY)

# ────────────────────────── Safety check ───────────────────────────────
async def _raw_safety(prompt: str) -> bool:
    assert _http
    hdrs = {"Authorization": f"Bearer {SAFETY_API_KEY}", "Content-Type": "application/json"}
    payload = {
        "model": "meta-llama/Meta-Llama-3-8B-Instruct-Lite",
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt},
        ],
    }
    r = await _http.post(SAFETY_MODEL_URL, json=payload, headers=hdrs)
    r.raise_for_status()

    raw = r.json()["choices"][0]["message"]["content"].strip().lower()
    log.debug("Safety raw reply: %r", raw)

    if re.search(r"\b(not\s+safe|unsafe)\b", raw):
        log.warning("Prompt-safety verdict: NOT SAFE")
        return False
    if re.search(r"\bsafe\b", raw):
        log.info("Prompt-safety verdict: SAFE")
        return True

    log.warning("Prompt-safety unknown reply %r β†’ NOT SAFE", raw)
    return False

async def is_safe(prompt: str) -> bool:
    return await _retry(_raw_safety, prompt)

# ────────────────────────── Blackbox Image API ─────────────────────────
async def _raw_blackbox(prompt: str) -> str:
    assert _http
    name, email, _ = _fake_user()
    user_id = ''.join(random.choices(string.digits, k=21))
    expiry = (datetime.utcnow().replace(microsecond=0) + timedelta(days=30)).isoformat() + "Z"

    payload = {
        "query": prompt,
        "session": {
            "user": {
                "name": name,
                "email": email,
                "image": "https://lh3.googleusercontent.com/a/ACg8ocI-ze5Qe42S-j8xaCL6X7KSVwfiOae4fONqpTxzt0d2_a2FIld1=s96-c",
                "id": user_id
            },
            "expires": expiry
        }
    }

    headers = {
        "accept": "*/*",
        "accept-language": "en-US,en;q=0.9,ru;q=0.8",
        "content-type": "text/plain;charset=UTF-8",
        "origin": "https://www.blackbox.ai",
        "priority": "u=1, i",
        "referer": "https://www.blackbox.ai/",
        "sec-ch-ua": '"Google Chrome";v="135", "Not-A.Brand";v="8", "Chromium";v="135"',
        "sec-ch-ua-mobile": "?0",
        "sec-ch-ua-platform": '"Windows"',
        "sec-fetch-dest": "empty",
        "sec-fetch-mode": "cors",
        "sec-fetch-site": "same-origin",
        "user-agent": (
            "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
            "AppleWebKit/537.36 (KHTML, like Gecko) "
            "Chrome/135.0.0.0 Safari/537.36"
        ),
    }

    resp = await _http.post("https://www.blackbox.ai/api/image-generator", json=payload, headers=headers)
    resp.raise_for_status()

    try:
        return resp.json().get("markdown", "").strip()
    except json.JSONDecodeError:
        return resp.text.strip()

async def blackbox(prompt: str) -> str:
    return await _retry(_raw_blackbox, prompt)

# ────────────────────────── Main route ─────────────────────────────────
@app.post("/v1/chat/completions")
async def chat(req: ChatRequest):
    if _http is None:
        raise HTTPException(503, "HTTP client not ready")

    user_prompt = next((m.content for m in reversed(req.messages) if m.role == "user"), "")
    if not user_prompt:
        raise HTTPException(400, "User prompt missing")

    try:
        if not await is_safe(user_prompt):
            return JSONResponse({"error": "Your prompt is considered unsafe."}, status_code=400)
    except httpx.HTTPStatusError as exc:
        return JSONResponse({"error": f"Safety check failed: HTTP {exc.response.status_code}", "reason": str(exc)}, status_code=503)

    try:
        md = await blackbox(user_prompt)
    except httpx.HTTPStatusError as exc:
        return JSONResponse({"error": f"Image generation failed: HTTP {exc.response.status_code}", "reason": str(exc)}, status_code=503)
    except Exception as exc:
        return JSONResponse({"error": "Image generation failed after retries.", "reason": str(exc)}, status_code=503)

    md = re.sub(r"!\[[^\]]*\]\(https://storage\.googleapis\.com([^\)]*)\)",
                f"![{user_prompt}](https://cdn.snapzion.com\\1)", md)

    uid, ts = str(uuid.uuid4()), int(datetime.now().timestamp())

    if not req.stream:
        return {
            "id": uid,
            "object": "chat.completion",
            "created": ts,
            "model": "Image-Generator",
            "choices": [{
                "index": 0,
                "message": {"role": "assistant", "content": md},
                "finish_reason": "stop",
            }],
            "usage": None,
        }

    async def sse():
        chunk1 = {"id": uid, "object": "chat.completion.chunk", "created": ts, "model": "Image-Generator",
                  "choices": [{"index": 0, "delta": {"role": "assistant", "content": md}, "finish_reason": None}], "usage": None}
        yield f"data: {json.dumps(chunk1)}\n\n"
        chunk2 = {"id": uid, "object": "chat.completion.chunk", "created": ts, "model": "Image-Generator",
                  "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": "stop"}], "usage": None}
        yield f"data: {json.dumps(chunk2)}\n\n"
        yield "data: [DONE]\n\n"

    return StreamingResponse(sse(), media_type="text/event-stream")