aledraa commited on
Commit
a7a61ee
·
verified ·
1 Parent(s): 098863d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +186 -99
main.py CHANGED
@@ -1,13 +1,86 @@
1
- from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import json
7
  import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # --- App and Model Loading ---
10
- app = FastAPI()
11
  app.add_middleware(
12
  CORSMiddleware,
13
  allow_origins=['*'],
@@ -16,18 +89,6 @@ app.add_middleware(
16
  allow_headers=['*'],
17
  )
18
 
19
- model_name = "Qwen/Qwen2.5-0.5B-Instruct"
20
- print("Loading model...")
21
-
22
- # Load model directly - simple approach for HF Spaces
23
- tokenizer = AutoTokenizer.from_pretrained(model_name)
24
- model = AutoModelForCausalLM.from_pretrained(model_name)
25
-
26
- # Set pad token if not exists
27
- if tokenizer.pad_token is None:
28
- tokenizer.pad_token = tokenizer.eos_token
29
-
30
- print("Model loaded successfully.")
31
 
32
  # --- API Request and Response Models ---
33
  class GenerationRequest(BaseModel):
@@ -36,129 +97,155 @@ class GenerationRequest(BaseModel):
36
 
37
  class GenerationResponse(BaseModel):
38
  data: list
 
 
 
39
 
40
  # --- Helper Functions ---
41
  def extract_json_from_text(text: str):
42
- """Extract JSON array from model output, handling extra text."""
43
- # Look for JSON array pattern
44
- json_pattern = r'\[\s*\[.*?\]\s*\]'
45
- matches = re.findall(json_pattern, text, re.DOTALL)
46
-
47
- if matches:
48
- try:
49
- return json.loads(matches[0])
50
- except:
51
- pass
 
 
52
 
53
- # Fallback: try to find anything that looks like nested arrays
54
  try:
55
- # Find content between first [ and last ]
56
- start = text.find('[')
57
- end = text.rfind(']') + 1
58
- if start != -1 and end != 0:
59
- json_candidate = text[start:end]
60
- return json.loads(json_candidate)
61
- except:
62
- pass
63
-
64
- return None
 
 
 
 
 
 
 
65
 
66
- def create_optimized_prompt(commands: list[str], batch_size: int) -> str:
67
- """Create a more structured prompt to reduce hallucination."""
68
- return f"""Generate exactly {batch_size} rows of data. Each row has {len(commands)} columns:
69
- {chr(10).join([f'Column {i+1}: {cmd}' for i, cmd in enumerate(commands)])}
 
 
 
 
70
 
71
- Output format: JSON array only, no explanations.
72
- Example: [[value1, value2], [value3, value4]]
73
 
74
- Generate {batch_size} rows:"""
 
 
 
75
 
76
- # --- API Endpoint ---
77
  @app.post("/generate", response_model=GenerationResponse)
78
  async def generate_data(request: GenerationRequest):
 
 
 
 
79
  try:
80
- # Create optimized prompt
81
- prompt = create_optimized_prompt(request.llm_commands, request.batch_size)
82
 
83
  messages = [
84
- {"role": "system", "content": "You are a precise data generator. Output only valid JSON arrays with no extra text."},
85
  {"role": "user", "content": prompt}
86
  ]
87
 
88
- # Apply chat template
89
- text = tokenizer.apply_chat_template(
90
  messages,
91
  tokenize=False,
92
  add_generation_prompt=True
93
  )
94
 
95
- # Tokenize with optimized settings
96
- model_inputs = tokenizer(
97
- text,
98
- return_tensors="pt",
99
- truncation=True,
100
- max_length=2048, # Limit input length
101
- padding=False
102
- ).to(model.device)
103
 
104
- # Generate with optimized parameters
105
- with torch.no_grad(): # Disable gradients for inference
 
 
106
  generated_ids = model.generate(
107
  **model_inputs,
108
- max_new_tokens=min(1024, request.batch_size * 20), # Dynamic max tokens
109
- min_new_tokens=10,
110
  do_sample=True,
111
- temperature=0.7, # Balanced creativity/consistency
112
- top_p=0.9,
113
- top_k=50,
114
- repetition_penalty=1.1,
115
  pad_token_id=tokenizer.pad_token_id,
116
- eos_token_id=tokenizer.eos_token_id,
117
- use_cache=True,
118
- num_beams=1, # Faster than beam search
119
- early_stopping=True
120
  )
121
 
122
- # Extract generated text
123
- generated_ids = [
124
- output_ids[len(input_ids):]
125
- for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
126
- ]
127
-
128
- response_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
129
- print(f"Raw model output: {response_text[:200]}...") # Debug print
130
 
131
- # Extract JSON data
132
  json_data = extract_json_from_text(response_text)
133
 
 
134
  if json_data and isinstance(json_data, list):
135
- # Validate data structure
136
- if len(json_data) > 0 and isinstance(json_data[0], list):
137
- # Ensure we have the right number of columns
138
- expected_cols = len(request.llm_commands)
139
- filtered_data = [
140
- row for row in json_data
141
- if isinstance(row, list) and len(row) == expected_cols
142
- ]
143
-
144
- if filtered_data:
145
- return {"data": filtered_data[:request.batch_size]}
146
-
147
- print(f"Failed to parse JSON. Raw output: {response_text}")
148
- return {"data": []}
149
-
 
150
  except Exception as e:
151
- print(f"Error in generation: {e}")
152
- return {"data": []}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- @app.get("/")
 
155
  def read_root():
156
- return {"status": "ok", "model": model_name}
157
 
158
- @app.get("/health")
159
  def health_check():
160
  return {
161
  "status": "healthy",
162
  "model_loaded": model is not None,
163
- "device": str(model.device) if model else "unknown"
 
164
  }
 
1
+ from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import json
7
  import re
8
+ import time
9
+ from contextlib import asynccontextmanager
10
+
11
+ # --- Performance Optimizations & Model Loading ---
12
+
13
+ # 1. Device Selection: Use CUDA GPU if available for a massive speed boost.
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ # 2. Data Type: Use float16 on GPU for faster computation and less memory usage.
16
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
17
+
18
+ print(f"--- System Info ---")
19
+ print(f"Using device: {device}")
20
+ print(f"Using dtype: {torch_dtype}")
21
+ print("--------------------")
22
+
23
+ # --- App State and Model Placeholders ---
24
+ model_name = "Qwen/Qwen2.5-0.5B-Instruct"
25
+ tokenizer = None
26
+ model = None
27
+
28
+ # --- Lifespan Event Handler ---
29
+ @asynccontextmanager
30
+ async def lifespan(app: FastAPI):
31
+ """
32
+ Handles startup and shutdown events.
33
+ Loads the ML model and tokenizer on startup.
34
+ """
35
+ global tokenizer, model
36
+
37
+ print("Loading model and tokenizer...")
38
+ start_time = time.time()
39
+
40
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
41
+
42
+ # Set pad token if it's not already set
43
+ if tokenizer.pad_token is None:
44
+ tokenizer.pad_token = tokenizer.eos_token
45
+
46
+ try:
47
+ # 3. Attention Mechanism: Use Flash Attention 2 for a ~2x speedup on compatible GPUs.
48
+ print("Attempting to load model with Flash Attention 2...")
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_name,
51
+ torch_dtype=torch_dtype,
52
+ attn_implementation="flash_attention_2"
53
+ ).to(device)
54
+ print("Successfully loaded model with Flash Attention 2.")
55
+ except (ImportError, RuntimeError) as e:
56
+ print(f"Flash Attention 2 not available ({e}), falling back to default attention.")
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ model_name,
59
+ torch_dtype=torch_dtype,
60
+ ).to(device)
61
+
62
+ # 4. Model Compilation (PyTorch 2.0+): JIT-compiles the model for faster execution.
63
+ print("Compiling model with torch.compile()...")
64
+ try:
65
+ model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
66
+ print("Model compiled successfully.")
67
+ except Exception as e:
68
+ print(f"torch.compile() failed: {e}. Running with uncompiled model.")
69
+
70
+ end_time = time.time()
71
+ print(f"Model loading and compilation finished in {end_time - start_time:.2f} seconds.")
72
+
73
+ yield
74
+
75
+ # Clean up resources on shutdown (optional)
76
+ print("Cleaning up and shutting down.")
77
+ model = None
78
+ tokenizer = None
79
+
80
+
81
+ # --- FastAPI App Initialization ---
82
+ app = FastAPI(lifespan=lifespan)
83
 
 
 
84
  app.add_middleware(
85
  CORSMiddleware,
86
  allow_origins=['*'],
 
89
  allow_headers=['*'],
90
  )
91
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # --- API Request and Response Models ---
94
  class GenerationRequest(BaseModel):
 
97
 
98
  class GenerationResponse(BaseModel):
99
  data: list
100
+ raw_output: str # Added for debugging
101
+ duration_s: float # Added for performance tracking
102
+
103
 
104
  # --- Helper Functions ---
105
  def extract_json_from_text(text: str):
106
+ """
107
+ Extracts a JSON array from the model's raw text output.
108
+ This version is more robust and handles incomplete JSON at the end.
109
+ """
110
+ # Find the first '[' and the last ']' to bound the JSON content
111
+ start_bracket = text.find('[')
112
+ end_bracket = text.rfind(']')
113
+
114
+ if start_bracket == -1 or end_bracket == -1:
115
+ return None # No JSON array found
116
+
117
+ json_str = text[start_bracket : end_bracket + 1]
118
 
 
119
  try:
120
+ # Attempt to parse the primary JSON string
121
+ return json.loads(json_str)
122
+ except json.JSONDecodeError:
123
+ # Fallback for malformed JSON: try to parse line by line
124
+ print("Warning: Initial JSON parsing failed. Attempting to recover partial data.")
125
+ potential_rows = json_str.strip()[1:-1].split('],[')
126
+ valid_rows = []
127
+ for row_str in potential_rows:
128
+ try:
129
+ # Reconstruct and parse each potential row
130
+ clean_row_str = row_str.replace('[', '').replace(']', '').strip()
131
+ if clean_row_str:
132
+ valid_rows.append(json.loads(f'[{clean_row_str}]'))
133
+ except json.JSONDecodeError:
134
+ continue # Skip malformed rows
135
+ return valid_rows if valid_rows else None
136
+
137
 
138
+ def create_structured_prompt(commands: list[str], batch_size: int) -> str:
139
+ """
140
+ Creates a more structured and forceful prompt to ensure the model returns clean JSON.
141
+ """
142
+ cols_description = '\n'.join([f'- Column {i+1}: {cmd}' for i, cmd in enumerate(commands)])
143
+ return f"""
144
+ Generate exactly {batch_size} rows of data.
145
+ Each inner array must have exactly {len(commands)} columns.
146
 
147
+ The columns are defined as follows:
148
+ {cols_description}
149
 
150
+ Your entire response must be ONLY the JSON array of arrays, with no additional text, explanations, or markdown.
151
+ Example of a valid response:
152
+ [["value1", "value2"], ["value3", "value4"]]
153
+ """
154
 
155
+ # --- API Endpoints ---
156
  @app.post("/generate", response_model=GenerationResponse)
157
  async def generate_data(request: GenerationRequest):
158
+ if not model or not tokenizer:
159
+ raise HTTPException(status_code=503, detail="Model is not ready. Please try again in a moment.")
160
+
161
+ start_time = time.time()
162
  try:
163
+ # Create a more reliable prompt
164
+ prompt = create_structured_prompt(request.llm_commands, request.batch_size)
165
 
166
  messages = [
167
+ {"role": "system", "content": "You are a precise data generation machine. Your sole purpose is to return a valid JSON array of arrays. You will not deviate from this role."},
168
  {"role": "user", "content": prompt}
169
  ]
170
 
171
+ # Apply the chat template
172
+ text_input = tokenizer.apply_chat_template(
173
  messages,
174
  tokenize=False,
175
  add_generation_prompt=True
176
  )
177
 
178
+ model_inputs = tokenizer([text_input], return_tensors="pt").to(device)
 
 
 
 
 
 
 
179
 
180
+ # Generate with no_grad context for better performance
181
+ with torch.no_grad():
182
+ # Dynamically set max_new_tokens based on expected output size with a buffer
183
+ max_new_tokens = int(request.batch_size * len(request.llm_commands) * 10 + 50)
184
  generated_ids = model.generate(
185
  **model_inputs,
186
+ max_new_tokens=min(4096, max_new_tokens),
 
187
  do_sample=True,
188
+ temperature=0.7,
189
+ top_p=0.95,
 
 
190
  pad_token_id=tokenizer.pad_token_id,
 
 
 
 
191
  )
192
 
193
+ # Decode the output
194
+ response_text = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
 
 
 
 
 
 
195
 
196
+ # Extract and validate JSON data
197
  json_data = extract_json_from_text(response_text)
198
 
199
+ final_data = []
200
  if json_data and isinstance(json_data, list):
201
+ expected_cols = len(request.llm_commands)
202
+ # Filter for valid rows and cap at the requested batch size
203
+ final_data = [
204
+ row for row in json_data
205
+ if isinstance(row, list) and len(row) == expected_cols
206
+ ][:request.batch_size]
207
+ else:
208
+ print(f"Failed to parse JSON. Raw output: {response_text}")
209
+
210
+ end_time = time.time()
211
+ return {
212
+ "data": final_data,
213
+ "raw_output": response_text,
214
+ "duration_s": round(end_time - start_time, 2)
215
+ }
216
+
217
  except Exception as e:
218
+ print(f"An error occurred during generation: {e}")
219
+ raise HTTPException(status_code=500, detail=str(e))
220
+
221
+ # --- New Test Route ---
222
+ @app.get("/test", response_model=GenerationResponse, summary="Run a predefined test generation")
223
+ async def test_generation():
224
+ """
225
+ A simple test endpoint that generates 10 rows of sample data with fixed commands.
226
+ This allows for easy performance testing and validation.
227
+ """
228
+ test_request = GenerationRequest(
229
+ llm_commands=[
230
+ "a common first name starting with the letter A",
231
+ "an age as an integer between 20 and 30"
232
+ ],
233
+ batch_size=10
234
+ )
235
+ print("--- Running /test endpoint ---")
236
+ return await generate_data(test_request)
237
+
238
 
239
+ # --- Health and Status Routes ---
240
+ @app.get("/", summary="Root status check")
241
  def read_root():
242
+ return {"status": "ok", "model_name": model_name, "device": device}
243
 
244
+ @app.get("/health", summary="Health check for the service")
245
  def health_check():
246
  return {
247
  "status": "healthy",
248
  "model_loaded": model is not None,
249
+ "tokenizer_loaded": tokenizer is not None,
250
+ "device": device
251
  }