unknown commited on
Commit
8ecbd6b
·
1 Parent(s): e06bc75

Fixed the model optimzation speed

Browse files
Files changed (2) hide show
  1. app.py +22 -25
  2. requirements.txt +2 -1
app.py CHANGED
@@ -6,14 +6,13 @@ from typing import List, Optional, Dict, Any
6
  from fastapi import FastAPI, HTTPException, Request, status
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel, Field
9
- from llama_cpp import Llama
10
  from concurrent.futures import ThreadPoolExecutor
11
 
12
  # -------------------------
13
  # Configuration (via env)
14
  # -------------------------
15
  REPO_ID = os.getenv("REPO_ID", "unsloth/gemma-3-270m-it-GGUF")
16
- MODEL_FILENAME = os.getenv("MODEL_FILENAME", "gemma-3-270m-it-F16.gguf")
17
  MAX_WORKERS = int(os.getenv("MAX_WORKERS", "2")) # ThreadPool workers (reduced for speed)
18
  MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", "1")) # Reduced for speed
19
  RATE_LIMIT_PER_MIN = int(os.getenv("RATE_LIMIT_PER_MIN", "60"))
@@ -70,7 +69,7 @@ class GenerationResponse(BaseModel):
70
  # -------------------------
71
  # Global objects
72
  # -------------------------
73
- LLM_MODEL: Optional[Llama] = None
74
  executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
75
  model_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
76
 
@@ -108,20 +107,21 @@ rate_limiter = RateLimiter(RATE_LIMIT_PER_MIN)
108
  # build_prompt_from_messages function removed - using chat completion format directly
109
 
110
  def generate_sync(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, num_beams: int = 1, early_stopping: bool = True, use_cache: bool = True) -> str:
111
- # llama-cpp-python generation parameters
112
  generation_kwargs = {
113
- "max_tokens": max_new_tokens,
114
  "temperature": temperature,
115
  "top_p": top_p,
 
 
 
 
116
  }
117
 
118
- # Create chat completion using llama-cpp-python
119
- response = LLM_MODEL.create_chat_completion(
120
- messages=messages,
121
- **generation_kwargs
122
- )
123
 
124
- return response["choices"][0]["message"]["content"]
125
 
126
  async def generate_async(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, num_beams: int = 1, early_stopping: bool = True, use_cache: bool = True) -> str:
127
  loop = asyncio.get_event_loop()
@@ -138,28 +138,25 @@ async def on_startup():
138
  global LLM_MODEL
139
 
140
  try:
141
- logger.info(f"Loading model from {REPO_ID}/{MODEL_FILENAME}...")
142
- LLM_MODEL = Llama.from_pretrained(
143
- repo_id=REPO_ID,
144
- filename=MODEL_FILENAME,
145
- n_ctx=N_CTX,
146
- n_threads=N_THREADS,
147
- n_gpu_layers=N_GPU_LAYERS,
148
- verbose=False
149
  )
150
  logger.info("Model loaded successfully.")
151
 
152
  # Warm up the model with a dummy request for faster first inference
153
  logger.info("Warming up model...")
154
  dummy_messages = [{"role": "user", "content": "Hello"}]
155
- _ = LLM_MODEL.create_chat_completion(
156
- messages=dummy_messages,
157
- max_tokens=5,
158
  temperature=0.1
159
  )
160
  logger.info("Model warmed up successfully.")
161
  except Exception as e:
162
- logger.error(f"Failed to load model {REPO_ID}/{MODEL_FILENAME}: {e}")
163
  raise RuntimeError(f"Model loading failed: {e}") from e
164
 
165
  # -------------------------
@@ -167,7 +164,7 @@ async def on_startup():
167
  # -------------------------
168
  @app.get("/")
169
  async def root():
170
- return {"status": "Gemma 3 API is running 🎉", "model": f"{REPO_ID}/{MODEL_FILENAME}"}
171
 
172
  @app.get("/health")
173
  async def health():
@@ -220,6 +217,6 @@ async def generate(req: GenerationRequest, request: Request):
220
 
221
  return GenerationResponse(
222
  generated_text=generated_text,
223
- model=f"{REPO_ID}/{MODEL_FILENAME}",
224
  runtime_seconds=round(runtime, 3)
225
  )
 
6
  from fastapi import FastAPI, HTTPException, Request, status
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel, Field
9
+ from transformers import pipeline
10
  from concurrent.futures import ThreadPoolExecutor
11
 
12
  # -------------------------
13
  # Configuration (via env)
14
  # -------------------------
15
  REPO_ID = os.getenv("REPO_ID", "unsloth/gemma-3-270m-it-GGUF")
 
16
  MAX_WORKERS = int(os.getenv("MAX_WORKERS", "2")) # ThreadPool workers (reduced for speed)
17
  MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", "1")) # Reduced for speed
18
  RATE_LIMIT_PER_MIN = int(os.getenv("RATE_LIMIT_PER_MIN", "60"))
 
69
  # -------------------------
70
  # Global objects
71
  # -------------------------
72
+ LLM_MODEL: Optional[Any] = None
73
  executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
74
  model_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
75
 
 
107
  # build_prompt_from_messages function removed - using chat completion format directly
108
 
109
  def generate_sync(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, num_beams: int = 1, early_stopping: bool = True, use_cache: bool = True) -> str:
110
+ # transformers pipeline generation parameters
111
  generation_kwargs = {
112
+ "max_new_tokens": max_new_tokens,
113
  "temperature": temperature,
114
  "top_p": top_p,
115
+ "do_sample": do_sample,
116
+ "num_beams": num_beams,
117
+ "early_stopping": early_stopping,
118
+ "use_cache": use_cache,
119
  }
120
 
121
+ # Generate using transformers pipeline
122
+ response = LLM_MODEL(messages, **generation_kwargs)
 
 
 
123
 
124
+ return response[0]["generated_text"][-1]["content"] if isinstance(response[0]["generated_text"], list) else response[0]["generated_text"]
125
 
126
  async def generate_async(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, num_beams: int = 1, early_stopping: bool = True, use_cache: bool = True) -> str:
127
  loop = asyncio.get_event_loop()
 
138
  global LLM_MODEL
139
 
140
  try:
141
+ logger.info(f"Loading model from {REPO_ID}...")
142
+ LLM_MODEL = pipeline(
143
+ "text-generation",
144
+ model=REPO_ID,
145
+ device_map="auto" if N_GPU_LAYERS > 0 else "cpu"
 
 
 
146
  )
147
  logger.info("Model loaded successfully.")
148
 
149
  # Warm up the model with a dummy request for faster first inference
150
  logger.info("Warming up model...")
151
  dummy_messages = [{"role": "user", "content": "Hello"}]
152
+ _ = LLM_MODEL(
153
+ dummy_messages,
154
+ max_new_tokens=5,
155
  temperature=0.1
156
  )
157
  logger.info("Model warmed up successfully.")
158
  except Exception as e:
159
+ logger.error(f"Failed to load model {REPO_ID}: {e}")
160
  raise RuntimeError(f"Model loading failed: {e}") from e
161
 
162
  # -------------------------
 
164
  # -------------------------
165
  @app.get("/")
166
  async def root():
167
+ return {"status": "Gemma 3 API is running 🎉", "model": REPO_ID}
168
 
169
  @app.get("/health")
170
  async def health():
 
217
 
218
  return GenerationResponse(
219
  generated_text=generated_text,
220
+ model=REPO_ID,
221
  runtime_seconds=round(runtime, 3)
222
  )
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  fastapi
2
  uvicorn
3
- llama-cpp-python
 
4
  pydantic
5
  python-multipart
6
  starlette
 
1
  fastapi
2
  uvicorn
3
+ transformers
4
+ torch
5
  pydantic
6
  python-multipart
7
  starlette