aledraa commited on
Commit
f7cc5b0
·
verified ·
1 Parent(s): 211f8ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -166
app.py CHANGED
@@ -1,185 +1,77 @@
1
- from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
- import json
6
- import random
7
- import os
8
- from typing import List, Optional
9
 
10
- app = FastAPI(title="Qwen Data Generator API")
 
11
 
12
- # Global variables for model and tokenizer
13
- model = None
14
- tokenizer = None
15
  model_name = "Qwen/Qwen2.5-3B-Instruct"
 
 
 
 
 
 
 
 
 
16
 
17
- def load_model():
18
- """Load model and tokenizer with proper error handling"""
19
- global model, tokenizer
20
-
21
- try:
22
- print("Loading model...")
23
- print(f"Cache directory: {os.environ.get('HF_HOME', 'Not set')}")
24
-
25
- # Load tokenizer first (smaller download)
26
- tokenizer = AutoTokenizer.from_pretrained(
27
- model_name,
28
- trust_remote_code=True
29
- )
30
- print("Tokenizer loaded successfully!")
31
-
32
- # Load model with specific configurations for better compatibility
33
- model = AutoModelForCausalLM.from_pretrained(
34
- model_name,
35
- torch_dtype=torch.float16, # Use float16 to save memory
36
- device_map="auto",
37
- trust_remote_code=True,
38
- low_cpu_mem_usage=True
39
- )
40
- print("Model loaded successfully!")
41
-
42
- except Exception as e:
43
- print(f"Error loading model: {str(e)}")
44
- raise e
45
-
46
- # Load model on startup
47
- load_model()
48
-
49
  class GenerationRequest(BaseModel):
50
- llm_commands: List[str]
51
  batch_size: int = 50
52
- seed: Optional[int] = None
53
 
54
  class GenerationResponse(BaseModel):
55
- success: bool
56
- data: List[List[str]]
57
- error: Optional[str] = None
58
 
59
- def generate_data_prompt(llm_commands: List[str], batch_size: int) -> str:
60
- columns_description = "\n".join([
61
- f"Column {i+1}: {cmd}" for i, cmd in enumerate(llm_commands)
62
- ])
63
-
64
- return f"""Generate {batch_size} unique random rows of data based on these specifications:
65
- {columns_description}
 
 
 
 
 
66
 
67
- Requirements:
68
- - Each row must be different and realistic
69
- - Return ONLY a JSON array format: [["value1","value2"],["value1","value2"],...]
70
- - No additional text, explanations, or formatting
71
- - Values should be diverse and not repetitive
72
 
73
- JSON Array:"""
 
 
 
 
74
 
75
- @app.post("/generate", response_model=GenerationResponse)
76
- async def generate_data(request: GenerationRequest):
77
- global model, tokenizer
 
 
 
78
 
79
- if model is None or tokenizer is None:
80
- raise HTTPException(status_code=503, detail="Model not loaded")
 
81
 
 
 
82
  try:
83
- # Set seed for reproducibility if provided
84
- if request.seed:
85
- torch.manual_seed(request.seed)
86
- random.seed(request.seed)
87
-
88
- # Build prompt
89
- prompt = generate_data_prompt(request.llm_commands, request.batch_size)
90
-
91
- # Prepare messages for chat template
92
- messages = [
93
- {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant that generates structured data."},
94
- {"role": "user", "content": prompt}
95
- ]
96
-
97
- # Apply chat template
98
- text = tokenizer.apply_chat_template(
99
- messages,
100
- tokenize=False,
101
- add_generation_prompt=True
102
- )
103
-
104
- # Tokenize and generate
105
- model_inputs = tokenizer([text], return_tensors="pt")
106
-
107
- # Move inputs to same device as model
108
- if torch.cuda.is_available():
109
- model_inputs = model_inputs.to('cuda')
110
-
111
- with torch.no_grad():
112
- generated_ids = model.generate(
113
- **model_inputs,
114
- max_new_tokens=2048,
115
- temperature=0.8,
116
- do_sample=True,
117
- pad_token_id=tokenizer.eos_token_id,
118
- eos_token_id=tokenizer.eos_token_id
119
- )
120
-
121
- # Decode response
122
- generated_ids = [
123
- output_ids[len(input_ids):]
124
- for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
125
- ]
126
-
127
- response_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
128
-
129
- # Parse JSON from response
130
- try:
131
- # Find JSON array in the response
132
- start_idx = response_text.find('[')
133
- end_idx = response_text.rfind(']') + 1
134
-
135
- if start_idx == -1 or end_idx == 0:
136
- raise ValueError("No JSON array found in response")
137
-
138
- json_str = response_text[start_idx:end_idx]
139
- parsed_data = json.loads(json_str)
140
-
141
- # Validate data structure
142
- if not isinstance(parsed_data, list):
143
- raise ValueError("Response is not a list")
144
-
145
- # Filter and validate rows
146
- valid_rows = []
147
- expected_columns = len(request.llm_commands)
148
-
149
- for row in parsed_data:
150
- if isinstance(row, list) and len(row) == expected_columns:
151
- # Convert all values to strings
152
- valid_rows.append([str(cell) for cell in row])
153
-
154
- return GenerationResponse(
155
- success=True,
156
- data=valid_rows
157
- )
158
-
159
- except json.JSONDecodeError as e:
160
- return GenerationResponse(
161
- success=False,
162
- data=[],
163
- error=f"Failed to parse JSON: {str(e)}"
164
- )
165
- except Exception as e:
166
- return GenerationResponse(
167
- success=False,
168
- data=[],
169
- error=f"Data processing error: {str(e)}"
170
- )
171
-
172
  except Exception as e:
173
- return GenerationResponse(
174
- success=False,
175
- data=[],
176
- error=f"Generation error: {str(e)}"
177
- )
178
-
179
- @app.get("/health")
180
- async def health_check():
181
- return {"status": "healthy", "model": model_name}
182
 
183
- if __name__ == "__main__":
184
- import uvicorn
185
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
  import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
5
 
6
+ # --- App and Model Loading ---
7
+ app = FastAPI()
8
 
 
 
 
9
  model_name = "Qwen/Qwen2.5-3B-Instruct"
10
+ print("Loading model...")
11
+ # To leverage a GPU on Hugging Face Spaces, device_map="auto" is key
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_name,
14
+ torch_dtype="auto",
15
+ device_map="auto"
16
+ )
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ print("Model loaded successfully.")
19
 
20
+ # --- API Request and Response Models ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class GenerationRequest(BaseModel):
22
+ llm_commands: list[str]
23
  batch_size: int = 50
 
24
 
25
  class GenerationResponse(BaseModel):
26
+ data: list
 
 
27
 
28
+ # --- API Endpoint ---
29
+ @app.post("/generate", response_model=GenerationResponse)
30
+ async def generate_data(request: GenerationRequest):
31
+ prompt = f"""
32
+ You are a data generator. Your task is to generate {request.batch_size} random, non-similar rows of data based on the following commands.
33
+ Each command corresponds to a column.
34
+ Commands: {request.llm_commands}
35
+ Return the data as a valid JSON array of arrays, where each inner array represents a row.
36
+ For example, for the commands ["an age between 20 and 30", "a random city in California"], the output should look like:
37
+ [[25, "Los Angeles"], [22, "San Francisco"]]
38
+ Do not include any extra text, explanations, or markdown formatting in your response. Only output the raw JSON array.
39
+ """
40
 
41
+ messages = [
42
+ {"role": "system", "content": "You are a helpful assistant that generates structured data."},
43
+ {"role": "user", "content": prompt}
44
+ ]
 
45
 
46
+ text = tokenizer.apply_chat_template(
47
+ messages,
48
+ tokenize=False,
49
+ add_generation_prompt=True
50
+ )
51
 
52
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
53
+
54
+ generated_ids = model.generate(
55
+ **model_inputs,
56
+ max_new_tokens=2048 # Increased to handle larger batches
57
+ )
58
 
59
+ generated_ids = [
60
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
61
+ ]
62
 
63
+ response_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
64
+
65
  try:
66
+ # The model might still add extra text, so we clean it
67
+ json_response = torch.tensor(eval(response_text.strip()))
68
+ return {"data": json_response.tolist()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  except Exception as e:
70
+ print(f"Error parsing model output: {e}")
71
+ print(f"Raw output was: {response_text}")
72
+ # Return empty on failure to prevent crashing the Inngest job
73
+ return {"data": []}
 
 
 
 
 
74
 
75
+ @app.get("/")
76
+ def read_root():
77
+ return {"status": "ok"}