Juna190825 commited on
Commit
65487b9
·
verified ·
1 Parent(s): 4931ec7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -65
app.py CHANGED
@@ -2,45 +2,75 @@ import os
2
  import torch
3
  from fastapi import FastAPI, Request
4
  from fastapi.responses import JSONResponse
5
- from fastapi.staticfiles import StaticFiles
6
  from fastapi.middleware.cors import CORSMiddleware
7
  import gradio as gr
8
- from transformers import AutoModelForCausalLM, AutoTokenizer
9
- from transformers import pipeline
10
  import time
 
 
 
 
 
 
 
 
11
 
12
  # Ensure cache directories exist
13
- os.makedirs(os.getenv('HUGGINGFACE_HUB_CACHE', '/app/cache/huggingface'), exist_ok=True)
14
- os.makedirs(os.getenv('MPLCONFIGDIR', '/app/cache/matplotlib'), exist_ok=True)
15
 
16
  # Initialize FastAPI app
17
  app = FastAPI()
18
 
19
- # Model loading function without autoawq
 
 
 
20
  def load_model():
 
21
  model_name = "trillionlabs/Trillion-7B-preview-AWQ"
22
 
23
- # Load tokenizer with error handling
24
  try:
25
  tokenizer = AutoTokenizer.from_pretrained(
26
  model_name,
27
  trust_remote_code=True
28
  )
29
  except Exception as e:
30
- print(f"Error loading tokenizer: {e}")
31
- # Fallback to a more basic tokenizer if needed
32
  from transformers import LlamaTokenizer
33
  tokenizer = LlamaTokenizer.from_pretrained(model_name)
34
 
35
- # Load model with CPU configuration
36
- model = AutoModelForCausalLM.from_pretrained(
37
- model_name,
38
- device_map="cpu",
39
- torch_dtype=torch.float32,
40
- trust_remote_code=True
41
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Create text generation pipeline
44
  text_generator = pipeline(
45
  "text-generation",
46
  model=model,
@@ -52,21 +82,25 @@ def load_model():
52
 
53
  # Load model
54
  try:
 
55
  text_generator = load_model()
 
56
  except Exception as e:
57
- print(f"Failed to load model: {e}")
58
- # You might want to exit here or load a smaller model instead
59
  raise
60
 
61
- # API endpoint for text generation
62
  @app.post("/api/generate")
63
- async def generate_text(request: Request):
 
64
  try:
65
  data = await request.json()
66
- prompt = data.get("prompt", "")
67
- max_length = min(int(data.get("max_length", 100)), 500) # Limit to 500 tokens
 
 
 
68
 
69
- # Generate text with timing
70
  start_time = time.time()
71
  outputs = text_generator(
72
  prompt,
@@ -75,23 +109,39 @@ async def generate_text(request: Request):
75
  temperature=0.7,
76
  top_k=50,
77
  top_p=0.95,
78
- pad_token_id=0 # Might be needed for some models
79
  )
80
  generation_time = time.time() - start_time
81
 
82
  return JSONResponse({
83
  "generated_text": outputs[0]["generated_text"],
84
- "generation_time": round(generation_time, 2),
 
85
  "model": "Trillion-7B-preview-AWQ",
86
  "device": "cpu"
87
  })
88
  except Exception as e:
 
89
  return JSONResponse({"error": str(e)}, status_code=500)
90
 
91
- # Gradio interface
 
 
 
 
 
 
 
 
 
 
92
  def gradio_generate(prompt, max_length=100):
 
93
  try:
94
- max_length = min(int(max_length), 500) # Limit to 500 tokens
 
 
 
95
  outputs = text_generator(
96
  prompt,
97
  max_length=max_length,
@@ -99,69 +149,72 @@ def gradio_generate(prompt, max_length=100):
99
  temperature=0.7,
100
  top_k=50,
101
  top_p=0.95,
102
- pad_token_id=0
103
  )
104
  return outputs[0]["generated_text"]
105
  except Exception as e:
 
106
  return f"Error generating text: {str(e)}"
107
 
108
- with gr.Blocks() as gradio_app:
109
  gr.Markdown("""
110
- # Trillion-7B-preview-AWQ Demo (CPU)
111
- *Running on CPU with 16GB RAM - responses may be slow*
112
  """)
113
 
114
  with gr.Row():
115
- input_prompt = gr.Textbox(
116
- label="Input Prompt",
117
- lines=5,
118
- placeholder="Enter your prompt here..."
119
- )
120
- output_text = gr.Textbox(
121
- label="Generated Text",
122
- lines=5,
123
- interactive=False
124
- )
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- with gr.Row():
127
- length_slider = gr.Slider(
128
- minimum=50,
129
- maximum=500,
130
- value=100,
131
- step=10,
132
- label="Max Length"
133
- )
134
- generate_btn = gr.Button("Generate", variant="primary")
135
-
136
- # Additional examples
137
- examples = gr.Examples(
138
  examples=[
139
- ["Explain quantum computing in simple terms."],
140
- ["Write a short poem about artificial intelligence."],
141
- ["How do I make a good cup of coffee?"]
 
142
  ],
143
- inputs=input_prompt
 
144
  )
145
 
146
  generate_btn.click(
147
  fn=gradio_generate,
148
- inputs=[input_prompt, length_slider],
149
  outputs=output_text
150
  )
151
 
152
  # Mount Gradio app
153
  app = gr.mount_gradio_app(app, gradio_app, path="/")
154
 
155
- # CORS middleware
156
  app.add_middleware(
157
  CORSMiddleware,
158
  allow_origins=["*"],
159
- allow_credentials=True,
160
  allow_methods=["*"],
161
  allow_headers=["*"],
162
  )
163
 
164
- # Health check endpoint
165
- @app.get("/health")
166
- async def health_check():
167
- return {"status": "healthy", "model_loaded": text_generator is not None}
 
2
  import torch
3
  from fastapi import FastAPI, Request
4
  from fastapi.responses import JSONResponse
 
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import gradio as gr
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
8
  import time
9
+ import warnings
10
+
11
+ # Suppress specific warnings
12
+ warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.hub")
13
+
14
+ # Configure environment variables for cache
15
+ os.environ["HF_HOME"] = os.getenv("HF_HOME", "/app/cache/huggingface")
16
+ os.environ["MPLCONFIGDIR"] = os.getenv("MPLCONFIGDIR", "/app/cache/matplotlib")
17
 
18
  # Ensure cache directories exist
19
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
20
+ os.makedirs(os.environ["MPLCONFIGDIR"], exist_ok=True)
21
 
22
  # Initialize FastAPI app
23
  app = FastAPI()
24
 
25
+ def log_message(message: str):
26
+ """Helper function for logging"""
27
+ print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {message}")
28
+
29
  def load_model():
30
+ """Load the model with CPU optimization"""
31
  model_name = "trillionlabs/Trillion-7B-preview-AWQ"
32
 
33
+ log_message("Loading tokenizer...")
34
  try:
35
  tokenizer = AutoTokenizer.from_pretrained(
36
  model_name,
37
  trust_remote_code=True
38
  )
39
  except Exception as e:
40
+ log_message(f"Tokenizer loading failed: {e}")
41
+ # Fallback to LlamaTokenizer if available
42
  from transformers import LlamaTokenizer
43
  tokenizer = LlamaTokenizer.from_pretrained(model_name)
44
 
45
+ log_message("Loading model...")
46
+ try:
47
+ # Try loading with IPEX optimization if available
48
+ try:
49
+ import intel_extension_for_pytorch as ipex
50
+ use_ipex = True
51
+ except ImportError:
52
+ use_ipex = False
53
+ log_message("IPEX not available, using standard CPU version")
54
+
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ model_name,
57
+ torch_dtype=torch.float32,
58
+ trust_remote_code=True
59
+ )
60
+
61
+ if use_ipex:
62
+ log_message("Applying IPEX optimization...")
63
+ model = ipex.optimize(model)
64
+
65
+ # Explicitly move to CPU
66
+ model = model.to("cpu")
67
+ model.eval()
68
+
69
+ except Exception as e:
70
+ log_message(f"Model loading failed: {e}")
71
+ raise
72
 
73
+ log_message("Creating pipeline...")
74
  text_generator = pipeline(
75
  "text-generation",
76
  model=model,
 
82
 
83
  # Load model
84
  try:
85
+ log_message("Starting model loading process...")
86
  text_generator = load_model()
87
+ log_message("Model loaded successfully")
88
  except Exception as e:
89
+ log_message(f"Critical error loading model: {e}")
 
90
  raise
91
 
92
+ # API endpoints
93
  @app.post("/api/generate")
94
+ async def api_generate(request: Request):
95
+ """API endpoint for text generation"""
96
  try:
97
  data = await request.json()
98
+ prompt = data.get("prompt", "").strip()
99
+ if not prompt:
100
+ return JSONResponse({"error": "Prompt cannot be empty"}, status_code=400)
101
+
102
+ max_length = min(int(data.get("max_length", 100)), 300) # Conservative limit
103
 
 
104
  start_time = time.time()
105
  outputs = text_generator(
106
  prompt,
 
109
  temperature=0.7,
110
  top_k=50,
111
  top_p=0.95,
112
+ pad_token_id=tokenizer.eos_token_id if hasattr(text_generator, 'tokenizer') else 0
113
  )
114
  generation_time = time.time() - start_time
115
 
116
  return JSONResponse({
117
  "generated_text": outputs[0]["generated_text"],
118
+ "time_seconds": round(generation_time, 2),
119
+ "tokens_generated": len(text_generator.tokenizer.tokenize(outputs[0]["generated_text"]) if hasattr(text_generator, 'tokenizer') else None,
120
  "model": "Trillion-7B-preview-AWQ",
121
  "device": "cpu"
122
  })
123
  except Exception as e:
124
+ log_message(f"API Error: {e}")
125
  return JSONResponse({"error": str(e)}, status_code=500)
126
 
127
+ @app.get("/health")
128
+ async def health_check():
129
+ """Health check endpoint"""
130
+ return {
131
+ "status": "healthy",
132
+ "model_loaded": text_generator is not None,
133
+ "device": "cpu",
134
+ "cache_path": os.environ["HF_HOME"]
135
+ }
136
+
137
+ # Gradio Interface
138
  def gradio_generate(prompt, max_length=100):
139
+ """Function for Gradio interface generation"""
140
  try:
141
+ max_length = min(int(max_length), 300) # Same conservative limit as API
142
+ if not prompt.strip():
143
+ return "Please enter a prompt"
144
+
145
  outputs = text_generator(
146
  prompt,
147
  max_length=max_length,
 
149
  temperature=0.7,
150
  top_k=50,
151
  top_p=0.95,
152
+ pad_token_id=tokenizer.eos_token_id if hasattr(text_generator, 'tokenizer') else 0
153
  )
154
  return outputs[0]["generated_text"]
155
  except Exception as e:
156
+ log_message(f"Gradio Error: {e}")
157
  return f"Error generating text: {str(e)}"
158
 
159
+ with gr.Blocks(title="Trillion-7B CPU Demo", theme=gr.themes.Default()) as gradio_app:
160
  gr.Markdown("""
161
+ # 🚀 Trillion-7B-preview-AWQ (CPU Version)
162
+ *Running on CPU with optimized settings - responses may be slower than GPU versions*
163
  """)
164
 
165
  with gr.Row():
166
+ with gr.Column():
167
+ input_prompt = gr.Textbox(
168
+ label="Your Prompt",
169
+ placeholder="Enter text here...",
170
+ lines=5,
171
+ max_lines=10
172
+ )
173
+ with gr.Row():
174
+ max_length = gr.Slider(
175
+ label="Max Length",
176
+ minimum=20,
177
+ maximum=300,
178
+ value=100,
179
+ step=10
180
+ )
181
+ generate_btn = gr.Button("Generate", variant="primary")
182
+ with gr.Column():
183
+ output_text = gr.Textbox(
184
+ label="Generated Text",
185
+ lines=10,
186
+ interactive=False
187
+ )
188
 
189
+ # Examples
190
+ gr.Examples(
 
 
 
 
 
 
 
 
 
 
191
  examples=[
192
+ ["Explain quantum computing in simple terms"],
193
+ ["Write a haiku about artificial intelligence"],
194
+ ["What are the main benefits of renewable energy?"],
195
+ ["Suggest three ideas for a science fiction story"]
196
  ],
197
+ inputs=input_prompt,
198
+ label="Example Prompts"
199
  )
200
 
201
  generate_btn.click(
202
  fn=gradio_generate,
203
+ inputs=[input_prompt, max_length],
204
  outputs=output_text
205
  )
206
 
207
  # Mount Gradio app
208
  app = gr.mount_gradio_app(app, gradio_app, path="/")
209
 
210
+ # CORS configuration
211
  app.add_middleware(
212
  CORSMiddleware,
213
  allow_origins=["*"],
 
214
  allow_methods=["*"],
215
  allow_headers=["*"],
216
  )
217
 
218
+ if __name__ == "__main__":
219
+ import uvicorn
220
+ uvicorn.run(app, host="0.0.0.0", port=7860)