Really-amin commited on
Commit
56a6340
1 Parent(s): ce45c06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -300
app.py CHANGED
@@ -1,312 +1,197 @@
1
- import os
2
- import sys
3
- import gc
4
- import json
 
5
  import logging
6
- import traceback
7
- from datetime import datetime
8
- from pathlib import Path
9
- from typing import Dict, Optional, Set, List, Any
10
- from collections import deque
11
- from contextlib import contextmanager
12
- import psutil
13
- import tempfile
14
-
15
- # Set base directory to current file's directory
16
- BASE_DIR = Path(__file__).parent
17
-
18
- # Set environment variables before any other imports
19
- temp_cache_dir = tempfile.gettempdir()
20
- os.environ["HF_HOME"] = str(temp_cache_dir)
21
- if "TRANSFORMERS_CACHE" in os.environ:
22
- del os.environ["TRANSFORMERS_CACHE"]
23
-
24
- os.environ["TRANSFORMERS_PARALLELISM"] = "false"
25
- os.environ["TORCH_HOME"] = str(Path(temp_cache_dir) / "torch")
26
-
27
- # Critical ML imports
28
  import torch
29
- import transformers
30
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
31
 
32
- # Additional imports
33
- import aiofiles
34
- import aiodns
35
- import httpx
36
- import uvicorn
37
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException, Response, Query
38
- from fastapi.responses import HTMLResponse
39
- from fastapi.staticfiles import StaticFiles
40
- from fastapi.templating import Jinja2Templates
41
- from prometheus_client import Counter, Gauge, Histogram, CollectorRegistry
42
- from prometheus_fastapi_instrumentator import Instrumentator
43
- import hazm
44
- import backoff
45
- from logging.handlers import RotatingFileHandler
46
- from cachetools import TTLCache
47
-
48
- class AppConfig:
49
- BASE_DIR = Path("/app")
50
- STATIC_DIR = BASE_DIR / "static"
51
- TEMPLATES_DIR = BASE_DIR / "templates"
52
- CACHE_DIR = Path(temp_cache_dir) / "cache"
53
- LOG_DIR = Path(temp_cache_dir) / "logs"
54
-
55
- # Performance settings
56
- MAX_RETRIES = 3
57
- TIMEOUT = 60
58
- MAX_CONNECTIONS = 5
59
- CACHE_TTL = 300
60
- MEMORY_THRESHOLD = 85
61
- CLEANUP_THRESHOLD = 80
62
- MAX_MESSAGE_LENGTH = 512
63
- MAX_THREADS = 1
64
- MODEL_MAX_LENGTH = 128 # Increased from 16 to allow longer responses
65
-
66
- # Model settings
67
- MODEL_NAME = "bigscience/bloom-560m" # Upgraded from tiny-gpt2
68
- MODEL_BATCH_SIZE = 1
69
- MODEL_MAX_LENGTH = 128
70
-
71
- # Network settings
72
- PROXY_URL = "http://your-proxy-url:port" # Replace with your proxy if needed
73
-
74
- @classmethod
75
- def setup_directories(cls) -> None:
76
- try:
77
- for path in [cls.CACHE_DIR, cls.LOG_DIR, cls.STATIC_DIR, cls.TEMPLATES_DIR]:
78
- path.mkdir(exist_ok=True, parents=True)
79
- logging.info("Directory setup completed successfully")
80
- except Exception as e:
81
- logging.error(f"Failed to setup directories: {e}")
82
- raise
83
 
84
- class BloomAI:
 
 
 
 
 
 
 
 
 
 
 
 
85
  def __init__(self):
86
- try:
87
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
- torch.set_num_threads(AppConfig.MAX_THREADS)
89
- if self.device.type == 'cpu':
90
- torch.set_num_interop_threads(1)
91
- self.pipeline = None
92
- self.model_loaded = False
93
- self.load_attempts = 0
94
- self.max_attempts = AppConfig.MAX_RETRIES
95
- self.model_name = AppConfig.MODEL_NAME
96
- self.tokenizer = None
97
- logging.info(f"BloomAI initialized with device: {self.device}")
98
- except Exception as e:
99
- logging.error(f"Failed to initialize BloomAI: {e}")
100
- raise
101
-
102
- def initialize(self) -> bool:
103
- try:
104
- logging.info(f"Loading AI model {self.model_name} on {self.device}")
105
-
106
- # Load tokenizer first
107
- self.tokenizer = AutoTokenizer.from_pretrained(
108
- self.model_name,
109
- cache_dir=AppConfig.CACHE_DIR,
110
- use_fast=True
111
- )
112
-
113
- # Configure model loading with optimizations
114
- model = AutoModelForCausalLM.from_pretrained(
115
- self.model_name,
116
- cache_dir=AppConfig.CACHE_DIR,
117
- torch_dtype=torch.float32,
118
- low_cpu_mem_usage=True,
119
- device_map="auto" if torch.cuda.is_available() else None
120
- )
121
-
122
- # Create pipeline with optimized configuration
123
- self.pipeline = pipeline(
124
- "text-generation",
125
- model=model,
126
- tokenizer=self.tokenizer,
127
- device=self.device,
128
- framework="pt",
129
- model_kwargs={
130
- "pad_token_id": self.tokenizer.eos_token_id
131
- }
132
- )
133
-
134
- # Clean up memory
135
- gc.collect()
136
- if torch.cuda.is_available():
137
- torch.cuda.empty_cache()
138
-
139
- self.model_loaded = True
140
- return True
141
-
142
- except Exception as e:
143
- logging.error(f"Model load error: {e}")
144
- traceback.print_exc()
145
- return False
146
-
147
- def generate_response(self, text: str) -> str:
148
- try:
149
- if not self.pipeline:
150
- return "مدل در دسترس نیست."
151
-
152
- with RESPONSE_TIME.labels(endpoint="/generate_response").time():
153
- outputs = self.pipeline(
154
- text,
155
- max_length=AppConfig.MODEL_MAX_LENGTH,
156
- min_length=20,
157
- do_sample=True,
158
- top_k=50,
159
- top_p=0.92,
160
- temperature=0.7,
161
- num_return_sequences=1,
162
- no_repeat_ngram_size=3,
163
- pad_token_id=self.tokenizer.eos_token_id,
164
- attention_mask=None,
165
- early_stopping=True,
166
- repetition_penalty=1.2
167
- )
168
-
169
- # Post-process the response
170
- response = outputs[0]['generated_text']
171
- response = response.strip()
172
-
173
- # Remove the input prompt if included
174
- if response.startswith(text):
175
- response = response[len(text):].strip()
176
-
177
- REQUESTS.labels(endpoint="/generate_response").inc()
178
- return response
179
-
180
- except Exception as e:
181
- logging.error(f"Error generating response: {e}", exc_info=True)
182
- return "خطا در تولید پاسخ."
183
-
184
- # FastAPI Application Setup
185
- app = FastAPI(title="BLOOM AI Assistant")
186
-
187
- # Mount static files and templates
188
- app.mount("/static", StaticFiles(directory=AppConfig.STATIC_DIR), name="static")
189
- templates = Jinja2Templates(directory=AppConfig.TEMPLATES_DIR)
190
-
191
- # Setup Prometheus metrics
192
- custom_registry = CollectorRegistry()
193
-
194
- REQUESTS = Counter(
195
- "bloom_http_requests_total",
196
- "Total number of HTTP requests made",
197
- labelnames=["endpoint"],
198
- registry=custom_registry
199
- )
200
-
201
- RESPONSE_TIME = Histogram(
202
- "bloom_http_response_seconds",
203
- "HTTP response time in seconds",
204
- labelnames=["endpoint"],
205
- registry=custom_registry
206
- )
207
-
208
- MEMORY_USAGE = Gauge(
209
- "bloom_system_memory_bytes",
210
- "System memory usage in bytes",
211
- registry=custom_registry
212
- )
213
-
214
- # Setup Prometheus instrumentation
215
- instrumentator = Instrumentator(
216
- registry=custom_registry,
217
- should_group_status_codes=False,
218
- should_ignore_untemplated=True,
219
- should_respect_env_var=True,
220
- should_instrument_requests_inprogress=True,
221
- excluded_handlers=["/metrics"]
222
- )
223
-
224
- @app.middleware("http")
225
- def add_custom_metrics(request: Request, call_next):
226
- response = call_next(request)
227
- REQUESTS.labels(endpoint=request.url.path).inc()
228
- return response
229
-
230
- instrumentator.instrument(app).expose(app)
231
-
232
- # Default HTML template
233
- DEFAULT_HTML = """
234
- <!DOCTYPE html>
235
- <html dir="rtl" lang="fa">
236
- <head>
237
- <meta charset="UTF-8">
238
- <title>هوش مصنوعی BLOOM</title>
239
- <style>
240
- body {
241
- font-family: Tahoma, Arial;
242
- text-align: center;
243
- margin-top: 50px;
244
- background-color: #f0f2f5;
245
- }
246
- .container {
247
- max-width: 600px;
248
- margin: 0 auto;
249
- padding: 20px;
250
- background-color: white;
251
- border-radius: 10px;
252
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
253
- }
254
- h1 { color: #1a73e8; }
255
- .status { margin: 20px 0; }
256
- </style>
257
- </head>
258
- <body>
259
- <div class="container">
260
- <h1>سیستم هوش مصنوعی BLOOM</h1>
261
- <div class="status">
262
- <p>وضعیت سیستم: <span id="status">در حال بارگذاری...</span></p>
263
- </div>
264
- </div>
265
- </body>
266
- </html>
267
- """
268
-
269
- # Route handlers
270
- @app.get("/", response_class=HTMLResponse)
271
- def home(request: Request):
272
- try:
273
- ai_ready = request.app.state.assistant.is_ready.is_set()
274
- return templates.TemplateResponse(
275
- "index.html",
276
- {
277
- "request": request,
278
- "ai_status": "آماده" if ai_ready else "در حال آماده‌سازی"
279
- }
280
- )
281
- except Exception as e:
282
- logging.error(f"Template error: {e}")
283
- return HTMLResponse(content=DEFAULT_HTML)
284
 
285
- @app.exception_handler(404)
286
- def not_found_handler(request: Request, exc: Exception):
287
- return HTMLResponse(content=DEFAULT_HTML, status_code=404)
288
 
289
- @app.exception_handler(500)
290
- def server_error_handler(request: Request, exc: Exception):
291
- return HTMLResponse(content=DEFAULT_HTML, status_code=500)
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  @app.websocket("/ws")
294
- def websocket_endpoint(websocket: WebSocket, client_id: str = Query(...)):
295
- websocket.accept()
 
296
  try:
297
  while True:
298
- data = websocket.receive_text()
299
- response = app.state.assistant.process_message(data, client_id)
300
- websocket.send_text(response)
 
 
 
 
 
 
301
  except WebSocketDisconnect:
302
- app.state.assistant.websocket_manager.disconnect(client_id)
303
-
304
- if __name__ == "__main__":
305
- uvicorn.run(
306
- app,
307
- host="0.0.0.0",
308
- port=7860,
309
- reload=False,
310
- workers=1,
311
- log_level="info"
312
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
2
+ from fastapi.responses import HTMLResponse
3
+ from fastapi.templating import Jinja2Templates
4
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
5
+ import asyncio
6
  import logging
7
+ import httpx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import torch
 
 
9
 
10
+ # Initialize FastAPI
11
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Logging setup
14
+ logging.basicConfig(level=logging.INFO)
15
+
16
+ # Telegram Token and Chat ID (Replace with your actual values)
17
+ TELEGRAM_TOKEN = "7437859619:AAGeGG3ZkLM0OVaw-Exx1uMRE55JtBCZZCY"
18
+ CHAT_ID = "-1002228627548"
19
+
20
+ # Templating setup
21
+ templates = Jinja2Templates(directory="templates")
22
+
23
+
24
+ # WebSocket Manager
25
+ class WebSocketManager:
26
  def __init__(self):
27
+ self.active_connection: WebSocket = None
28
+
29
+ async def connect(self, websocket: WebSocket):
30
+ """Connects the WebSocket"""
31
+ await websocket.accept()
32
+ self.active_connection = websocket
33
+ logging.info("WebSocket connected.")
34
+
35
+ async def disconnect(self):
36
+ """Disconnects the WebSocket"""
37
+ if self.active_connection:
38
+ await self.active_connection.close()
39
+ self.active_connection = None
40
+ logging.info("WebSocket disconnected.")
41
+
42
+ async def send_message(self, message: str):
43
+ """Sends a message through WebSocket"""
44
+ if self.active_connection:
45
+ await self.active_connection.send_text(message)
46
+ logging.info(f"Sent via WebSocket: {message}")
47
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ websocket_manager = WebSocketManager()
 
 
50
 
 
 
 
51
 
52
+ # BLOOM Model Manager
53
+ class BloomAI:
54
+ def __init__(self):
55
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ self.pipeline = None
57
+
58
+ def load_model(self):
59
+ """Loads BLOOM AI Model"""
60
+ logging.info("Loading BLOOM model...")
61
+ tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
62
+ model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
63
+ self.pipeline = pipeline(
64
+ "text-generation",
65
+ model=model,
66
+ tokenizer=tokenizer,
67
+ device=0 if torch.cuda.is_available() else -1
68
+ )
69
+ logging.info("BLOOM model loaded successfully.")
70
+
71
+ async def generate_response(self, prompt: str) -> str:
72
+ """Generates a response using BLOOM"""
73
+ if not prompt.strip():
74
+ return "⚠️ Please send a valid message."
75
+ logging.info(f"Generating response for prompt: {prompt}")
76
+ outputs = self.pipeline(
77
+ prompt,
78
+ max_length=100,
79
+ do_sample=True,
80
+ temperature=0.7,
81
+ top_k=50,
82
+ top_p=0.9,
83
+ num_return_sequences=1,
84
+ no_repeat_ngram_size=2
85
+ )
86
+ response = outputs[0]["generated_text"]
87
+ return response.strip()
88
+
89
+
90
+ # Initialize BLOOM
91
+ bloom_ai = BloomAI()
92
+ bloom_ai.load_model()
93
+
94
+
95
+ # Telegram Message Handling
96
+ async def send_telegram_message(text: str):
97
+ """Sends a message to Telegram"""
98
+ async with httpx.AsyncClient() as client:
99
+ url = f"https://api.telegram.org/bot{TELEGRAM_TOKEN}/sendMessage"
100
+ payload = {"chat_id": CHAT_ID, "text": text}
101
+ response = await client.post(url, json=payload)
102
+ if response.status_code == 200:
103
+ logging.info(f"Sent to Telegram: {text}")
104
+ else:
105
+ logging.error(f"Failed to send message to Telegram: {response.text}")
106
+
107
+
108
+ @app.post("/telegram")
109
+ async def telegram_webhook(update: dict):
110
+ """Handles Telegram Webhook messages"""
111
+ if "message" in update:
112
+ chat_id = str(update["message"]["chat"]["id"])
113
+ if chat_id != CHAT_ID:
114
+ return {"status": "Unauthorized"}
115
+
116
+ user_message = update["message"]["text"]
117
+ logging.info(f"Received from Telegram: {user_message}")
118
+
119
+ # Process the message
120
+ response = await bloom_ai.generate_response(user_message)
121
+ await send_telegram_message(response)
122
+ return {"status": "ok"}
123
+
124
+
125
+ # WebSocket Endpoint
126
  @app.websocket("/ws")
127
+ async def websocket_endpoint(websocket: WebSocket):
128
+ """WebSocket communication for real-time interaction"""
129
+ await websocket_manager.connect(websocket)
130
  try:
131
  while True:
132
+ # Receive message from WebSocket
133
+ data = await websocket.receive_text()
134
+ logging.info(f"Received from WebSocket: {data}")
135
+
136
+ # Process the message
137
+ response = await bloom_ai.generate_response(data)
138
+
139
+ # Send response back through WebSocket
140
+ await websocket_manager.send_message(response)
141
  except WebSocketDisconnect:
142
+ # Handle WebSocket disconnection
143
+ await websocket_manager.disconnect()
144
+
145
+
146
+ # HTML Test UI
147
+ @app.get("/")
148
+ async def get_ui(request: Request):
149
+ """Displays the WebSocket HTML UI"""
150
+ return templates.TemplateResponse("index.html", {"request": request})
151
+
152
+
153
+ # Simple UI (fallback in case templates folder is not available)
154
+ @app.get("/simple-ui")
155
+ async def simple_ui():
156
+ """Fallback HTML for WebSocket Test"""
157
+ return HTMLResponse(content="""
158
+ <!DOCTYPE html>
159
+ <html>
160
+ <head>
161
+ <title>WebSocket Test</title>
162
+ <script>
163
+ let ws = new WebSocket("ws://localhost:8000/ws");
164
+
165
+ ws.onopen = () => {
166
+ console.log("WebSocket connection opened.");
167
+ };
168
+
169
+ ws.onmessage = (event) => {
170
+ console.log("Message from server:", event.data);
171
+ const msgContainer = document.getElementById("messages");
172
+ const msg = document.createElement("div");
173
+ msg.innerText = event.data;
174
+ msgContainer.appendChild(msg);
175
+ };
176
+
177
+ ws.onclose = () => {
178
+ console.log("WebSocket connection closed.");
179
+ };
180
+
181
+ function sendMessage() {
182
+ const input = document.getElementById("messageInput");
183
+ const message = input.value;
184
+ ws.send(message);
185
+ input.value = "";
186
+ }
187
+ </script>
188
+ </head>
189
+ <body>
190
+ <h1>WebSocket Test</h1>
191
+ <div id="messages" style="border: 1px solid black; height: 200px; overflow-y: scroll;"></div>
192
+ <input id="messageInput" type="text" />
193
+ <button onclick="sendMessage()">Send</button>
194
+ </body>
195
+ </html>
196
+ """)
197
+