Soumik555 commited on
Commit
5344861
·
1 Parent(s): 9ef9a4e
Files changed (7) hide show
  1. Dockerfile +1 -1
  2. chat_routes.py +45 -0
  3. config.py +11 -0
  4. logger.py +7 -0
  5. main.py +15 -278
  6. model_service.py +84 -0
  7. models.py +25 -0
Dockerfile CHANGED
@@ -53,7 +53,7 @@ EXPOSE 7860
53
 
54
  # Health check
55
  HEALTHCHECK --interval=30s --timeout=30s --start-period=300s --retries=3 \
56
- CMD curl -f http://localhost:7860/health || exit 1
57
 
58
  # Run FastAPI application
59
  CMD ["python", "main.py"]
 
53
 
54
  # Health check
55
  HEALTHCHECK --interval=30s --timeout=30s --start-period=300s --retries=3 \
56
+ CMD curl -f https://cronjob-python-chatbot.hf.space/health || exit 1
57
 
58
  # Run FastAPI application
59
  CMD ["python", "main.py"]
chat_routes.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from models import ChatRequest, ChatResponse, HealthResponse
3
+ from model_service import (
4
+ generate_response, model_loaded, MODEL_NAME, CACHE_DIR, startup_time, is_model_cached
5
+ )
6
+ import torch
7
+
8
+ router = APIRouter()
9
+
10
+ @router.get("/")
11
+ def root():
12
+ return {"message": "FastAPI Chatbot API", "status": "running"}
13
+
14
+ @router.get("/health", response_model=HealthResponse)
15
+ def health():
16
+ return HealthResponse(
17
+ status="healthy" if model_loaded else "initializing",
18
+ is_model_loaded=model_loaded,
19
+ model_name=MODEL_NAME,
20
+ cache_directory=CACHE_DIR,
21
+ startup_time=startup_time
22
+ )
23
+
24
+ @router.post("/chat", response_model=ChatResponse)
25
+ def chat(request: ChatRequest):
26
+ if not model_loaded:
27
+ raise HTTPException(status_code=503, detail="Model not loaded yet.")
28
+ if not request.message.strip():
29
+ raise HTTPException(status_code=400, detail="Message cannot be empty")
30
+
31
+ resp, time_taken = generate_response(
32
+ request.message, request.max_length, request.temperature, request.top_p
33
+ )
34
+ return ChatResponse(response=resp, model_name=MODEL_NAME, response_time=time_taken)
35
+
36
+ @router.get("/model-info")
37
+ def model_info():
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ return {
40
+ "model_name": MODEL_NAME,
41
+ "model_loaded": model_loaded,
42
+ "device": device,
43
+ "cache_directory": CACHE_DIR,
44
+ "model_cached": is_model_cached(MODEL_NAME),
45
+ }
config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ # Model configuration
5
+ MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium")
6
+ CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/app/model_cache")
7
+ MAX_LENGTH = int(os.getenv("MAX_LENGTH", "100"))
8
+ DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7"))
9
+
10
+ # Ensure cache directory exists
11
+ Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
logger.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ logging.basicConfig(
4
+ level=logging.INFO,
5
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
6
+ )
7
+ logger = logging.getLogger("chatbot")
main.py CHANGED
@@ -1,290 +1,27 @@
1
- import os
2
- from fastapi import FastAPI, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi.responses import JSONResponse
5
- from pydantic import BaseModel
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
- import torch
8
- import logging
9
  import threading
10
  import uvicorn
11
- from pathlib import Path
12
- import time
13
 
14
- # Configure logging
15
- logging.basicConfig(
16
- level=logging.INFO,
17
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
18
- )
19
- logger = logging.getLogger(__name__)
20
-
21
- # FastAPI app
22
- app = FastAPI(
23
- title="FastAPI Chatbot",
24
- description="Chatbot with FastAPI backend",
25
- version="1.0.0"
26
- )
27
-
28
- # Add CORS middleware
29
  app.add_middleware(
30
  CORSMiddleware,
31
- allow_origins=["*"],
32
- allow_credentials=True,
33
- allow_methods=["*"],
34
- allow_headers=["*"],
35
  )
36
 
37
- # Pydantic models with fixed namespace conflicts
38
- class ChatRequest(BaseModel):
39
- message: str
40
- max_length: int = 100
41
- temperature: float = 0.7
42
- top_p: float = 0.9
43
-
44
- class Config:
45
- protected_namespaces = ()
46
-
47
- class ChatResponse(BaseModel):
48
- response: str
49
- model_name: str
50
- response_time: float
51
-
52
- class Config:
53
- protected_namespaces = ()
54
-
55
- class HealthResponse(BaseModel):
56
- status: str
57
- is_model_loaded: bool
58
- model_name: str
59
- cache_directory: str
60
- startup_time: float
61
-
62
- class Config:
63
- protected_namespaces = ()
64
-
65
- # Global variables
66
- tokenizer = None
67
- model = None
68
- generator = None
69
- startup_time = time.time()
70
- model_loaded = False
71
-
72
- # Configuration
73
- MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium")
74
- CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/app/model_cache")
75
- MAX_LENGTH = int(os.getenv("MAX_LENGTH", "100"))
76
- DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7"))
77
-
78
- def ensure_cache_dir():
79
- """Ensure cache directory exists"""
80
- Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
81
- logger.info(f"Cache directory: {CACHE_DIR}")
82
-
83
- def is_model_cached(model_name: str) -> bool:
84
- """Check if model is already cached"""
85
- try:
86
- model_path = Path(CACHE_DIR) / f"models--{model_name.replace('/', '--')}"
87
- is_cached = model_path.exists() and any(model_path.iterdir())
88
- logger.info(f"Model cached: {is_cached}")
89
- return is_cached
90
- except Exception as e:
91
- logger.error(f"Error checking cache: {e}")
92
- return False
93
-
94
- def load_model():
95
- """Load the Hugging Face model with caching"""
96
- global tokenizer, model, generator, model_loaded
97
-
98
- try:
99
- ensure_cache_dir()
100
-
101
- logger.info(f"Loading model: {MODEL_NAME}")
102
- logger.info(f"Cache dir: {CACHE_DIR}")
103
- logger.info(f"CUDA available: {torch.cuda.is_available()}")
104
-
105
- start_time = time.time()
106
-
107
- # Load tokenizer first
108
- logger.info("Loading tokenizer...")
109
- tokenizer = AutoTokenizer.from_pretrained(
110
- MODEL_NAME,
111
- cache_dir=CACHE_DIR,
112
- local_files_only=False
113
- )
114
-
115
- # Add padding token if it doesn't exist
116
- if tokenizer.pad_token is None:
117
- tokenizer.pad_token = tokenizer.eos_token
118
-
119
- # Load model
120
- logger.info("Loading model...")
121
- model = AutoModelForCausalLM.from_pretrained(
122
- MODEL_NAME,
123
- cache_dir=CACHE_DIR,
124
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
125
- device_map="auto" if torch.cuda.is_available() else None,
126
- low_cpu_mem_usage=True,
127
- local_files_only=False
128
- )
129
-
130
- # Create text generation pipeline
131
- logger.info("Creating pipeline...")
132
- device = 0 if torch.cuda.is_available() else -1
133
- generator = pipeline(
134
- "text-generation",
135
- model=model,
136
- tokenizer=tokenizer,
137
- device=device,
138
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
139
- )
140
-
141
- load_time = time.time() - start_time
142
- model_loaded = True
143
- logger.info(f"✅ Model loaded successfully in {load_time:.2f} seconds!")
144
- logger.info(f"Model device: {model.device}")
145
-
146
- return True
147
-
148
- except Exception as e:
149
- logger.error(f"❌ Error loading model: {str(e)}", exc_info=True)
150
- return False
151
-
152
- def generate_response(message: str, max_length: int = 100, temperature: float = 0.7, top_p: float = 0.9) -> str:
153
- """Generate response using the loaded model"""
154
- if not generator:
155
- return "❌ Model not loaded. Please wait for initialization...", 0.0
156
-
157
- try:
158
- start_time = time.time()
159
-
160
- # Generate response with parameters
161
- response = generator(
162
- message,
163
- max_length=max_length,
164
- temperature=temperature,
165
- top_p=top_p,
166
- num_return_sequences=1,
167
- pad_token_id=tokenizer.eos_token_id,
168
- do_sample=True,
169
- truncation=True,
170
- repetition_penalty=1.1
171
- )
172
-
173
- # Extract generated text
174
- generated_text = response[0]['generated_text']
175
-
176
- # Clean up response
177
- if generated_text.startswith(message):
178
- bot_response = generated_text[len(message):].strip()
179
- else:
180
- bot_response = generated_text.strip()
181
-
182
- # Fallback if empty response
183
- if not bot_response:
184
- bot_response = "I'm not sure how to respond to that. Could you try rephrasing?"
185
-
186
- response_time = time.time() - start_time
187
- logger.info(f"Generated response in {response_time:.2f}s")
188
-
189
- return bot_response, response_time
190
-
191
- except Exception as e:
192
- logger.error(f"Error generating response: {str(e)}", exc_info=True)
193
- return f"❌ Error generating response: {str(e)}", 0.0
194
-
195
- # FastAPI endpoints
196
- @app.get("/")
197
- async def root():
198
- """Root endpoint"""
199
- return {"message": "FastAPI Chatbot API", "status": "running"}
200
-
201
- @app.get("/health", response_model=HealthResponse)
202
- async def health_check():
203
- """Health check endpoint with detailed information"""
204
- return HealthResponse(
205
- status="healthy" if model_loaded else "initializing",
206
- is_model_loaded=model_loaded,
207
- model_name=MODEL_NAME,
208
- cache_directory=CACHE_DIR,
209
- startup_time=time.time() - startup_time
210
- )
211
-
212
- @app.post("/chat", response_model=ChatResponse)
213
- async def chat_endpoint(request: ChatRequest):
214
- """Chat endpoint for API access"""
215
- if not model_loaded:
216
- raise HTTPException(
217
- status_code=503,
218
- detail="Model not loaded yet. Please wait for initialization."
219
- )
220
-
221
- # Validate input
222
- if not request.message.strip():
223
- raise HTTPException(status_code=400, detail="Message cannot be empty")
224
-
225
- if len(request.message) > 1000:
226
- raise HTTPException(status_code=400, detail="Message too long (max 1000 characters)")
227
-
228
- # Generate response
229
- response_text, response_time = generate_response(
230
- request.message.strip(),
231
- request.max_length,
232
- request.temperature,
233
- request.top_p
234
- )
235
-
236
- return ChatResponse(
237
- response=response_text,
238
- model_name=MODEL_NAME,
239
- response_time=response_time
240
- )
241
-
242
- @app.get("/model-info")
243
- async def get_model_info():
244
- """Get detailed model information"""
245
- device = "cuda" if torch.cuda.is_available() else "cpu"
246
- if model and hasattr(model, 'device'):
247
- device = str(model.device)
248
-
249
- return {
250
- "model_name": MODEL_NAME,
251
- "model_loaded": model_loaded,
252
- "device": device,
253
- "cache_directory": CACHE_DIR,
254
- "model_cached": is_model_cached(MODEL_NAME),
255
- "parameters": {
256
- "max_length": MAX_LENGTH,
257
- "default_temperature": DEFAULT_TEMPERATURE
258
- }
259
- }
260
 
261
  @app.on_event("startup")
262
- async def startup_event():
263
- """Load model on startup"""
264
- logger.info("🚀 Starting FastAPI Chatbot...")
265
- logger.info("📦 Loading model...")
266
-
267
- # Load model in background thread to not block startup
268
- def load_model_background():
269
- global model_loaded
270
- model_loaded = load_model()
271
- if model_loaded:
272
- logger.info("✅ Model loaded successfully!")
273
- else:
274
- logger.error("❌ Failed to load model.")
275
-
276
- # Start model loading in background
277
- threading.Thread(target=load_model_background, daemon=True).start()
278
-
279
- def run_fastapi():
280
- """Run FastAPI server"""
281
- uvicorn.run(
282
- app,
283
- host="0.0.0.0",
284
- port=7860, # Changed to 7860 for HuggingFace
285
- log_level="info",
286
- access_log=True
287
- )
288
 
289
  if __name__ == "__main__":
290
- run_fastapi()
 
1
+ from fastapi import FastAPI
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from chat_routes import router as chat_router
4
+ from model_service import load_model
 
 
 
5
  import threading
6
  import uvicorn
7
+ from logger import logger
 
8
 
9
+ app = FastAPI(title="FastAPI Chatbot", version="1.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  app.add_middleware(
11
  CORSMiddleware,
12
+ allow_origins=["*"], allow_credentials=True,
13
+ allow_methods=["*"], allow_headers=["*"],
 
 
14
  )
15
 
16
+ app.include_router(chat_router)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @app.on_event("startup")
19
+ def startup():
20
+ def load_in_bg():
21
+ success = load_model()
22
+ if success: logger.info("Model loaded on startup.")
23
+ else: logger.error("Model failed to load.")
24
+ threading.Thread(target=load_in_bg, daemon=True).start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  if __name__ == "__main__":
27
+ uvicorn.run("app.main:app", host="0.0.0.0", port=7860, reload=False)
model_service.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from pathlib import Path
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+ import torch
5
+ from logger import logger
6
+ from config import CACHE_DIR, MODEL_NAME
7
+
8
+
9
+ tokenizer = None
10
+ model = None
11
+ generator = None
12
+ model_loaded = False
13
+ startup_time = time.time()
14
+
15
+ def is_model_cached(model_name: str) -> bool:
16
+ try:
17
+ model_path = Path(CACHE_DIR) / f"models--{model_name.replace('/', '--')}"
18
+ return model_path.exists() and any(model_path.iterdir())
19
+ except Exception as e:
20
+ logger.error(f"Error checking cache: {e}")
21
+ return False
22
+
23
+ def load_model():
24
+ global tokenizer, model, generator, model_loaded
25
+
26
+ try:
27
+ logger.info(f"Loading model: {MODEL_NAME}")
28
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
29
+ start = time.time()
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
32
+ if tokenizer.pad_token is None:
33
+ tokenizer.pad_token = tokenizer.eos_token
34
+
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ MODEL_NAME,
37
+ cache_dir=CACHE_DIR,
38
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
39
+ device_map="auto" if torch.cuda.is_available() else None,
40
+ low_cpu_mem_usage=True,
41
+ )
42
+
43
+ device = 0 if torch.cuda.is_available() else -1
44
+ generator = pipeline(
45
+ "text-generation",
46
+ model=model,
47
+ tokenizer=tokenizer,
48
+ device=device,
49
+ )
50
+
51
+ model_loaded = True
52
+ logger.info(f"✅ Model loaded in {time.time()-start:.2f}s on {model.device}")
53
+ return True
54
+
55
+ except Exception as e:
56
+ logger.error(f"❌ Error loading model: {e}", exc_info=True)
57
+ model_loaded = False
58
+ return False
59
+
60
+ def generate_response(message: str, max_length: int, temperature: float, top_p: float):
61
+ if not generator:
62
+ return "❌ Model not loaded yet", 0.0
63
+
64
+ start = time.time()
65
+ try:
66
+ result = generator(
67
+ message,
68
+ max_length=max_length,
69
+ temperature=temperature,
70
+ top_p=top_p,
71
+ num_return_sequences=1,
72
+ pad_token_id=tokenizer.eos_token_id,
73
+ do_sample=True,
74
+ repetition_penalty=1.1
75
+ )
76
+ text = result[0]['generated_text']
77
+ reply = text[len(message):].strip() if text.startswith(message) else text.strip()
78
+ if not reply:
79
+ reply = "I'm not sure how to respond to that."
80
+ return reply, time.time()-start
81
+
82
+ except Exception as e:
83
+ logger.error(f"Generation error: {e}", exc_info=True)
84
+ return f"❌ Error: {e}", 0.0
models.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ class ChatRequest(BaseModel):
4
+ message: str
5
+ max_length: int = 100
6
+ temperature: float = 0.7
7
+ top_p: float = 0.9
8
+ class Config:
9
+ protected_namespaces = ()
10
+
11
+ class ChatResponse(BaseModel):
12
+ response: str
13
+ model_name: str
14
+ response_time: float
15
+ class Config:
16
+ protected_namespaces = ()
17
+
18
+ class HealthResponse(BaseModel):
19
+ status: str
20
+ is_model_loaded: bool
21
+ model_name: str
22
+ cache_directory: str
23
+ startup_time: float
24
+ class Config:
25
+ protected_namespaces = ()