AurelioAguirre commited on
Commit
8083005
·
1 Parent(s): c066c5b

Added double init, for embedding and chat models at the same time.

Browse files
Files changed (3) hide show
  1. main/api.py +61 -24
  2. main/app.py +0 -1
  3. main/routes.py +64 -35
main/api.py CHANGED
@@ -17,9 +17,13 @@ class LLMApi:
17
  self.models_path = self.base_path / config["folders"]["models"]
18
  self.cache_path = self.base_path / config["folders"]["cache"]
19
 
20
- self.model = None
21
- self.model_name = None
 
 
 
22
  self.tokenizer = None
 
23
 
24
  # Generation parameters from config
25
  gen_config = config["model"]["generation"]
@@ -64,7 +68,7 @@ class LLMApi:
64
 
65
  # Download and save tokenizer
66
  tokenizer = AutoTokenizer.from_pretrained(model_name)
67
- self.logger.info(f"Disnabling stdout logging")
68
  self.logger.disable_stream_to_logger()
69
 
70
  self.logger.info(f"Saving model to {model_path}")
@@ -78,14 +82,14 @@ class LLMApi:
78
 
79
  def initialize_model(self, model_name: str) -> None:
80
  """
81
- Initialize a model and tokenizer, either from local storage or by downloading.
82
 
83
  Args:
84
  model_name: The name of the model to initialize
85
  """
86
- self.logger.info(f"Initializing model: {model_name}")
87
  try:
88
- self.model_name = model_name
89
  local_model_path = self.models_path / model_name.split('/')[-1]
90
 
91
  # Check if model exists locally
@@ -96,7 +100,7 @@ class LLMApi:
96
  self.logger.info(f"Loading model from source: {model_name}")
97
  model_path = model_name
98
 
99
- self.model = AutoModelForCausalLM.from_pretrained(
100
  model_path,
101
  device_map="auto",
102
  load_in_8bit=True,
@@ -108,9 +112,42 @@ class LLMApi:
108
  self.generation_config["eos_token_id"] = self.tokenizer.eos_token_id
109
  self.generation_config["pad_token_id"] = self.tokenizer.eos_token_id
110
 
111
- self.logger.info(f"Successfully initialized model: {model_name}")
112
  except Exception as e:
113
- self.logger.error(f"Failed to initialize model {model_name}: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  raise
115
 
116
  def has_chat_template(self) -> bool:
@@ -158,22 +195,22 @@ class LLMApi:
158
  """
159
  self.logger.debug(f"Generating response for prompt: {prompt[:50]}...")
160
 
161
- if self.model is None:
162
- raise RuntimeError("Model not initialized. Call initialize_model first.")
163
 
164
  try:
165
  text = self._prepare_prompt(prompt, system_message)
166
  inputs = self.tokenizer([text], return_tensors="pt")
167
 
168
  # Remove token_type_ids if present
169
- model_inputs = {k: v.to(self.model.device) for k, v in inputs.items()
170
  if k != 'token_type_ids'}
171
 
172
  generation_config = self.generation_config.copy()
173
  if max_new_tokens:
174
  generation_config["max_new_tokens"] = max_new_tokens
175
 
176
- generated_ids = self.model.generate(
177
  **model_inputs,
178
  **generation_config
179
  )
@@ -202,15 +239,15 @@ class LLMApi:
202
  """
203
  self.logger.debug(f"Starting streaming generation for prompt: {prompt[:50]}...")
204
 
205
- if self.model is None:
206
- raise RuntimeError("Model not initialized. Call initialize_model first.")
207
 
208
  try:
209
  text = self._prepare_prompt(prompt, system_message)
210
  inputs = self.tokenizer([text], return_tensors="pt")
211
 
212
  # Remove token_type_ids if present
213
- model_inputs = {k: v.to(self.model.device) for k, v in inputs.items()
214
  if k != 'token_type_ids'}
215
 
216
  # Configure generation
@@ -227,7 +264,7 @@ class LLMApi:
227
  )
228
 
229
  # Create a thread to run the generation
230
- thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
231
  thread.start()
232
 
233
  # Yield the generated text in chunks
@@ -241,21 +278,21 @@ class LLMApi:
241
 
242
  def generate_embedding(self, text: str) -> List[float]:
243
  """
244
- Generate a single embedding vector for a chunk of text.
245
  Returns a list of floats representing the text embedding.
246
  """
247
  self.logger.debug(f"Generating embedding for text: {text[:50]}...")
248
 
249
- if self.model is None or self.tokenizer is None:
250
- raise RuntimeError("Model not initialized. Call initialize_model first.")
251
 
252
  try:
253
  # Tokenize the input text and ensure input_ids are Long type
254
- inputs = self.tokenizer(text, return_tensors='pt')
255
- input_ids = inputs.input_ids.to(dtype=torch.long, device=self.model.device)
256
 
257
  # Get the model's dtype from its parameters for the attention mask
258
- model_dtype = next(self.model.parameters()).dtype
259
 
260
  # Create an attention mask with matching dtype
261
  attention_mask = torch.zeros(
@@ -269,7 +306,7 @@ class LLMApi:
269
 
270
  # Get model outputs
271
  with torch.no_grad():
272
- outputs = self.model(
273
  input_ids=input_ids,
274
  attention_mask=attention_mask,
275
  output_hidden_states=True,
 
17
  self.models_path = self.base_path / config["folders"]["models"]
18
  self.cache_path = self.base_path / config["folders"]["cache"]
19
 
20
+ # Initialize model variables for both generation and embedding
21
+ self.generation_model = None
22
+ self.generation_model_name = None
23
+ self.embedding_model = None
24
+ self.embedding_model_name = None
25
  self.tokenizer = None
26
+ self.embedding_tokenizer = None
27
 
28
  # Generation parameters from config
29
  gen_config = config["model"]["generation"]
 
68
 
69
  # Download and save tokenizer
70
  tokenizer = AutoTokenizer.from_pretrained(model_name)
71
+ self.logger.info(f"Disabling stdout logging")
72
  self.logger.disable_stream_to_logger()
73
 
74
  self.logger.info(f"Saving model to {model_path}")
 
82
 
83
  def initialize_model(self, model_name: str) -> None:
84
  """
85
+ Initialize a model and tokenizer for text generation.
86
 
87
  Args:
88
  model_name: The name of the model to initialize
89
  """
90
+ self.logger.info(f"Initializing generation model: {model_name}")
91
  try:
92
+ self.generation_model_name = model_name
93
  local_model_path = self.models_path / model_name.split('/')[-1]
94
 
95
  # Check if model exists locally
 
100
  self.logger.info(f"Loading model from source: {model_name}")
101
  model_path = model_name
102
 
103
+ self.generation_model = AutoModelForCausalLM.from_pretrained(
104
  model_path,
105
  device_map="auto",
106
  load_in_8bit=True,
 
112
  self.generation_config["eos_token_id"] = self.tokenizer.eos_token_id
113
  self.generation_config["pad_token_id"] = self.tokenizer.eos_token_id
114
 
115
+ self.logger.info(f"Successfully initialized generation model: {model_name}")
116
  except Exception as e:
117
+ self.logger.error(f"Failed to initialize generation model {model_name}: {str(e)}")
118
+ raise
119
+
120
+ def initialize_embedding_model(self, model_name: str) -> None:
121
+ """
122
+ Initialize a model and tokenizer specifically for embeddings.
123
+
124
+ Args:
125
+ model_name: The name of the model to initialize for embeddings
126
+ """
127
+ self.logger.info(f"Initializing embedding model: {model_name}")
128
+ try:
129
+ self.embedding_model_name = model_name
130
+ local_model_path = self.models_path / model_name.split('/')[-1]
131
+
132
+ # Check if model exists locally
133
+ if local_model_path.exists():
134
+ self.logger.info(f"Loading embedding model from local path: {local_model_path}")
135
+ model_path = local_model_path
136
+ else:
137
+ self.logger.info(f"Loading embedding model from source: {model_name}")
138
+ model_path = model_name
139
+
140
+ self.embedding_model = AutoModelForCausalLM.from_pretrained(
141
+ model_path,
142
+ device_map="auto",
143
+ load_in_8bit=True,
144
+ torch_dtype=torch.float16
145
+ )
146
+ self.embedding_tokenizer = AutoTokenizer.from_pretrained(model_path)
147
+
148
+ self.logger.info(f"Successfully initialized embedding model: {model_name}")
149
+ except Exception as e:
150
+ self.logger.error(f"Failed to initialize embedding model {model_name}: {str(e)}")
151
  raise
152
 
153
  def has_chat_template(self) -> bool:
 
195
  """
196
  self.logger.debug(f"Generating response for prompt: {prompt[:50]}...")
197
 
198
+ if self.generation_model is None:
199
+ raise RuntimeError("Generation model not initialized. Call initialize_model first.")
200
 
201
  try:
202
  text = self._prepare_prompt(prompt, system_message)
203
  inputs = self.tokenizer([text], return_tensors="pt")
204
 
205
  # Remove token_type_ids if present
206
+ model_inputs = {k: v.to(self.generation_model.device) for k, v in inputs.items()
207
  if k != 'token_type_ids'}
208
 
209
  generation_config = self.generation_config.copy()
210
  if max_new_tokens:
211
  generation_config["max_new_tokens"] = max_new_tokens
212
 
213
+ generated_ids = self.generation_model.generate(
214
  **model_inputs,
215
  **generation_config
216
  )
 
239
  """
240
  self.logger.debug(f"Starting streaming generation for prompt: {prompt[:50]}...")
241
 
242
+ if self.generation_model is None:
243
+ raise RuntimeError("Generation model not initialized. Call initialize_model first.")
244
 
245
  try:
246
  text = self._prepare_prompt(prompt, system_message)
247
  inputs = self.tokenizer([text], return_tensors="pt")
248
 
249
  # Remove token_type_ids if present
250
+ model_inputs = {k: v.to(self.generation_model.device) for k, v in inputs.items()
251
  if k != 'token_type_ids'}
252
 
253
  # Configure generation
 
264
  )
265
 
266
  # Create a thread to run the generation
267
+ thread = Thread(target=self.generation_model.generate, kwargs=generation_kwargs)
268
  thread.start()
269
 
270
  # Yield the generated text in chunks
 
278
 
279
  def generate_embedding(self, text: str) -> List[float]:
280
  """
281
+ Generate a single embedding vector for a chunk of text using the dedicated embedding model.
282
  Returns a list of floats representing the text embedding.
283
  """
284
  self.logger.debug(f"Generating embedding for text: {text[:50]}...")
285
 
286
+ if self.embedding_model is None or self.embedding_tokenizer is None:
287
+ raise RuntimeError("Embedding model not initialized. Call initialize_embedding_model first.")
288
 
289
  try:
290
  # Tokenize the input text and ensure input_ids are Long type
291
+ inputs = self.embedding_tokenizer(text, return_tensors='pt')
292
+ input_ids = inputs.input_ids.to(dtype=torch.long, device=self.embedding_model.device)
293
 
294
  # Get the model's dtype from its parameters for the attention mask
295
+ model_dtype = next(self.embedding_model.parameters()).dtype
296
 
297
  # Create an attention mask with matching dtype
298
  attention_mask = torch.zeros(
 
306
 
307
  # Get model outputs
308
  with torch.no_grad():
309
+ outputs = self.embedding_model(
310
  input_ids=input_ids,
311
  attention_mask=attention_mask,
312
  output_hidden_states=True,
main/app.py CHANGED
@@ -1,5 +1,4 @@
1
  import yaml
2
- import sys
3
  from fastapi import FastAPI
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from .routes import router, init_router
 
1
  import yaml
 
2
  from fastapi import FastAPI
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from .routes import router, init_router
main/routes.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  from fastapi import APIRouter, HTTPException
2
  from pydantic import BaseModel
3
  from typing import Optional, List, Dict, Union
@@ -51,12 +55,12 @@ class ValidationResponse(BaseModel):
51
  @router.get("/system/validate",
52
  response_model=ValidationResponse,
53
  summary="Validate System Configuration",
54
- description="Validates system configuration, folders, and model setup")
55
  async def validate_system():
56
  """
57
  Validates:
58
  - Configuration parameters
59
- - Model setup
60
  - Folder structure
61
  - Required permissions
62
  """
@@ -88,20 +92,31 @@ async def validate_system():
88
  # Validate model setup
89
  try:
90
  model_status = {
91
- "model_files_exist": False,
92
- "model_loadable": False,
 
 
93
  "tokenizer_valid": False
94
  }
95
 
96
- if api.model_name:
97
- model_path = api.models_path / api.model_name.split('/')[-1]
98
- model_status["model_files_exist"] = validate_model_path(model_path)
 
 
 
 
 
 
99
 
100
- if not model_status["model_files_exist"]:
101
- issues.append("Model files are missing or incomplete")
 
102
 
103
- model_status["model_loadable"] = api.model is not None
104
- model_status["tokenizer_valid"] = api.tokenizer is not None
 
 
105
 
106
  except Exception as e:
107
  logger.error(f"Model validation failed: {str(e)}")
@@ -110,9 +125,12 @@ async def validate_system():
110
 
111
  # Validate folder structure and permissions
112
  try:
113
- folder_status = {"models_folder": api.models_path.exists(), "cache_folder": api.cache_path.exists(),
114
- "logs_folder": Path(api.base_path / "logs").exists(), "write_permissions": False}
115
-
 
 
 
116
 
117
  # Test write permissions by attempting to create a test file
118
  test_file = api.models_path / ".test_write"
@@ -148,7 +166,6 @@ async def validate_system():
148
  logger.info(f"System validation completed with status: {overall_status}")
149
  return validation_response
150
 
151
-
152
  @router.get("/system/status",
153
  response_model=SystemStatusResponse,
154
  summary="Check System Status",
@@ -224,12 +241,16 @@ async def check_system():
224
 
225
  # Check Model Status
226
  try:
227
- current_model_path = api.models_path / api.model_name.split('/')[-1] if api.model_name else None
228
  status.model = {
229
- "is_loaded": api.model is not None,
230
- "current_model": api.model_name,
231
- "is_valid": validate_model_path(current_model_path) if current_model_path else False,
232
- "has_chat_template": api.has_chat_template() if api.model else False
 
 
 
 
 
233
  }
234
  logger.debug(f"Model status retrieved: {status.model}")
235
  except Exception as e:
@@ -239,7 +260,6 @@ async def check_system():
239
  logger.info("System status check completed")
240
  return status
241
 
242
-
243
  @router.post("/generate")
244
  async def generate_text(request: GenerateRequest):
245
  """Generate text response from prompt"""
@@ -256,7 +276,6 @@ async def generate_text(request: GenerateRequest):
256
  logger.error(f"Error in generate_text endpoint: {str(e)}")
257
  raise HTTPException(status_code=500, detail=str(e))
258
 
259
-
260
  @router.post("/generate/stream")
261
  async def generate_stream(request: GenerateRequest):
262
  """Generate streaming text response from prompt"""
@@ -271,7 +290,6 @@ async def generate_stream(request: GenerateRequest):
271
  logger.error(f"Error in generate_stream endpoint: {str(e)}")
272
  raise HTTPException(status_code=500, detail=str(e))
273
 
274
-
275
  @router.post("/embedding", response_model=EmbeddingResponse)
276
  async def generate_embedding(request: EmbeddingRequest):
277
  """Generate embedding vector from text"""
@@ -287,7 +305,6 @@ async def generate_embedding(request: EmbeddingRequest):
287
  logger.error(f"Error in generate_embedding endpoint: {str(e)}")
288
  raise HTTPException(status_code=500, detail=str(e))
289
 
290
-
291
  @router.post("/model/download",
292
  summary="Download default or specified model",
293
  description="Downloads model files. Uses default model from config if none specified.")
@@ -332,18 +349,30 @@ async def initialize_model(model_name: Optional[str] = None):
332
  logger.error(f"Error initializing model: {str(e)}")
333
  raise HTTPException(status_code=500, detail=str(e))
334
 
335
-
336
- @router.get("/models/status")
337
- async def get_model_status():
338
- """Get current model status"""
 
339
  try:
340
- status = {
341
- "model_loaded": api.model is not None,
342
- "current_model": api.model_name if api.model_name else None,
343
- "has_chat_template": api.has_chat_template() if api.model else False
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  }
345
- logger.info(f"Retrieved model status: {status}")
346
- return status
347
  except Exception as e:
348
- logger.error(f"Error getting model status: {str(e)}")
349
  raise HTTPException(status_code=500, detail=str(e))
 
1
+ # routes.py for the LLM Engine.
2
+ # This file contains the FastAPI routes for the LLM Engine API.
3
+ # It includes routes for generating text, generating embeddings, checking system status, and validating system configuration.
4
+
5
  from fastapi import APIRouter, HTTPException
6
  from pydantic import BaseModel
7
  from typing import Optional, List, Dict, Union
 
55
  @router.get("/system/validate",
56
  response_model=ValidationResponse,
57
  summary="Validate System Configuration",
58
+ description="Validates system configuration, folders, and model setup for both generation and embedding models")
59
  async def validate_system():
60
  """
61
  Validates:
62
  - Configuration parameters
63
+ - Model setup for both generation and embedding models
64
  - Folder structure
65
  - Required permissions
66
  """
 
92
  # Validate model setup
93
  try:
94
  model_status = {
95
+ "generation_model_files_exist": False,
96
+ "generation_model_loadable": False,
97
+ "embedding_model_files_exist": False,
98
+ "embedding_model_loadable": False,
99
  "tokenizer_valid": False
100
  }
101
 
102
+ if api.generation_model_name:
103
+ gen_model_path = api.models_path / api.generation_model_name.split('/')[-1]
104
+ model_status["generation_model_files_exist"] = validate_model_path(gen_model_path)
105
+ model_status["generation_model_loadable"] = api.generation_model is not None
106
+
107
+ if api.embedding_model_name:
108
+ emb_model_path = api.models_path / api.embedding_model_name.split('/')[-1]
109
+ model_status["embedding_model_files_exist"] = validate_model_path(emb_model_path)
110
+ model_status["embedding_model_loadable"] = api.embedding_model is not None
111
 
112
+ model_status["tokenizer_valid"] = (
113
+ api.tokenizer is not None and api.embedding_tokenizer is not None
114
+ )
115
 
116
+ if not model_status["generation_model_files_exist"]:
117
+ issues.append("Generation model files are missing or incomplete")
118
+ if not model_status["embedding_model_files_exist"]:
119
+ issues.append("Embedding model files are missing or incomplete")
120
 
121
  except Exception as e:
122
  logger.error(f"Model validation failed: {str(e)}")
 
125
 
126
  # Validate folder structure and permissions
127
  try:
128
+ folder_status = {
129
+ "models_folder": api.models_path.exists(),
130
+ "cache_folder": api.cache_path.exists(),
131
+ "logs_folder": Path(api.base_path / "logs").exists(),
132
+ "write_permissions": False
133
+ }
134
 
135
  # Test write permissions by attempting to create a test file
136
  test_file = api.models_path / ".test_write"
 
166
  logger.info(f"System validation completed with status: {overall_status}")
167
  return validation_response
168
 
 
169
  @router.get("/system/status",
170
  response_model=SystemStatusResponse,
171
  summary="Check System Status",
 
241
 
242
  # Check Model Status
243
  try:
 
244
  status.model = {
245
+ "generation_model": {
246
+ "is_loaded": api.generation_model is not None,
247
+ "current_model": api.generation_model_name,
248
+ "has_chat_template": api.has_chat_template() if api.generation_model else False
249
+ },
250
+ "embedding_model": {
251
+ "is_loaded": api.embedding_model is not None,
252
+ "current_model": api.embedding_model_name
253
+ }
254
  }
255
  logger.debug(f"Model status retrieved: {status.model}")
256
  except Exception as e:
 
260
  logger.info("System status check completed")
261
  return status
262
 
 
263
  @router.post("/generate")
264
  async def generate_text(request: GenerateRequest):
265
  """Generate text response from prompt"""
 
276
  logger.error(f"Error in generate_text endpoint: {str(e)}")
277
  raise HTTPException(status_code=500, detail=str(e))
278
 
 
279
  @router.post("/generate/stream")
280
  async def generate_stream(request: GenerateRequest):
281
  """Generate streaming text response from prompt"""
 
290
  logger.error(f"Error in generate_stream endpoint: {str(e)}")
291
  raise HTTPException(status_code=500, detail=str(e))
292
 
 
293
  @router.post("/embedding", response_model=EmbeddingResponse)
294
  async def generate_embedding(request: EmbeddingRequest):
295
  """Generate embedding vector from text"""
 
305
  logger.error(f"Error in generate_embedding endpoint: {str(e)}")
306
  raise HTTPException(status_code=500, detail=str(e))
307
 
 
308
  @router.post("/model/download",
309
  summary="Download default or specified model",
310
  description="Downloads model files. Uses default model from config if none specified.")
 
349
  logger.error(f"Error initializing model: {str(e)}")
350
  raise HTTPException(status_code=500, detail=str(e))
351
 
352
+ @router.post("/model/initialize/embedding",
353
+ summary="Initialize embedding model",
354
+ description="Initialize a separate model specifically for generating embeddings")
355
+ async def initialize_embedding_model(model_name: Optional[str] = None):
356
+ """Initialize a model specifically for embeddings"""
357
  try:
358
+ # Use model name from config if none provided
359
+ embedding_model = model_name or config["model"]["defaults"].get("embedding_model_name")
360
+ if not embedding_model:
361
+ raise HTTPException(
362
+ status_code=400,
363
+ detail="No embedding model specified and no default found in config"
364
+ )
365
+
366
+ logger.info(f"Received request to initialize embedding model: {embedding_model}")
367
+
368
+ api.initialize_embedding_model(embedding_model)
369
+ logger.info(f"Successfully initialized embedding model: {embedding_model}")
370
+
371
+ return {
372
+ "status": "success",
373
+ "message": f"Embedding model {embedding_model} initialized",
374
+ "model_name": embedding_model
375
  }
 
 
376
  except Exception as e:
377
+ logger.error(f"Error initializing embedding model: {str(e)}")
378
  raise HTTPException(status_code=500, detail=str(e))