Juna190825 commited on
Commit
4931ec7
·
verified ·
1 Parent(s): 1d6b569

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -39
app.py CHANGED
@@ -1,13 +1,12 @@
 
 
1
  from fastapi import FastAPI, Request
2
  from fastapi.responses import JSONResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.middleware.cors import CORSMiddleware
5
  import gradio as gr
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
- from autoawq import AutoAWQForCausalLM # Add this import
8
  from transformers import pipeline
9
- import torch
10
- import os
11
  import time
12
 
13
  # Ensure cache directories exist
@@ -17,26 +16,28 @@ os.makedirs(os.getenv('MPLCONFIGDIR', '/app/cache/matplotlib'), exist_ok=True)
17
  # Initialize FastAPI app
18
  app = FastAPI()
19
 
20
- # Mount Gradio app
21
- gradio_app = gr.Blocks()
22
-
23
- # Model loading function
24
  def load_model():
25
  model_name = "trillionlabs/Trillion-7B-preview-AWQ"
26
 
27
- # Load tokenizer with special handling
28
  try:
29
- tokenizer = AutoTokenizer.from_pretrained(model_name)
30
- except:
31
- # Fallback to using the model's tokenizer.json directly
32
- from transformers import PreTrainedTokenizerFast
33
- tokenizer = PreTrainedTokenizerFast(tokenizer_file=f"{model_name}/tokenizer.json")
 
 
 
 
34
 
35
- # Load model with CPU support
36
  model = AutoModelForCausalLM.from_pretrained(
37
  model_name,
38
  device_map="cpu",
39
- torch_dtype=torch.float32
 
40
  )
41
 
42
  # Create text generation pipeline
@@ -49,8 +50,13 @@ def load_model():
49
 
50
  return text_generator
51
 
52
- # Load model (this will happen when the server starts)
53
- text_generator = load_model()
 
 
 
 
 
54
 
55
  # API endpoint for text generation
56
  @app.post("/api/generate")
@@ -58,9 +64,9 @@ async def generate_text(request: Request):
58
  try:
59
  data = await request.json()
60
  prompt = data.get("prompt", "")
61
- max_length = data.get("max_length", 100)
62
 
63
- # Generate text
64
  start_time = time.time()
65
  outputs = text_generator(
66
  prompt,
@@ -68,14 +74,15 @@ async def generate_text(request: Request):
68
  do_sample=True,
69
  temperature=0.7,
70
  top_k=50,
71
- top_p=0.95
 
72
  )
73
  generation_time = time.time() - start_time
74
 
75
  return JSONResponse({
76
  "generated_text": outputs[0]["generated_text"],
77
- "generation_time": generation_time,
78
- "model": "trillionlabs/Trillion-7B-preview-AWQ",
79
  "device": "cpu"
80
  })
81
  except Exception as e:
@@ -83,26 +90,58 @@ async def generate_text(request: Request):
83
 
84
  # Gradio interface
85
  def gradio_generate(prompt, max_length=100):
86
- outputs = text_generator(
87
- prompt,
88
- max_length=max_length,
89
- do_sample=True,
90
- temperature=0.7,
91
- top_k=50,
92
- top_p=0.95
93
- )
94
- return outputs[0]["generated_text"]
 
 
 
 
 
95
 
96
- with gradio_app:
97
- gr.Markdown("# Trillion-7B-preview-AWQ Demo (CPU)")
98
- gr.Markdown("This is a CPU-only demo of the Trillion-7B-preview-AWQ model running with 16GB RAM.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  with gr.Row():
101
- input_prompt = gr.Textbox(label="Input Prompt", lines=5)
102
- output_text = gr.Textbox(label="Generated Text", lines=5)
 
 
 
 
 
 
103
 
104
- length_slider = gr.Slider(50, 500, value=100, label="Max Length")
105
- generate_btn = gr.Button("Generate")
 
 
 
 
 
 
 
106
 
107
  generate_btn.click(
108
  fn=gradio_generate,
@@ -120,4 +159,9 @@ app.add_middleware(
120
  allow_credentials=True,
121
  allow_methods=["*"],
122
  allow_headers=["*"],
123
- )
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
  from fastapi import FastAPI, Request
4
  from fastapi.responses import JSONResponse
5
  from fastapi.staticfiles import StaticFiles
6
  from fastapi.middleware.cors import CORSMiddleware
7
  import gradio as gr
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
9
  from transformers import pipeline
 
 
10
  import time
11
 
12
  # Ensure cache directories exist
 
16
  # Initialize FastAPI app
17
  app = FastAPI()
18
 
19
+ # Model loading function without autoawq
 
 
 
20
  def load_model():
21
  model_name = "trillionlabs/Trillion-7B-preview-AWQ"
22
 
23
+ # Load tokenizer with error handling
24
  try:
25
+ tokenizer = AutoTokenizer.from_pretrained(
26
+ model_name,
27
+ trust_remote_code=True
28
+ )
29
+ except Exception as e:
30
+ print(f"Error loading tokenizer: {e}")
31
+ # Fallback to a more basic tokenizer if needed
32
+ from transformers import LlamaTokenizer
33
+ tokenizer = LlamaTokenizer.from_pretrained(model_name)
34
 
35
+ # Load model with CPU configuration
36
  model = AutoModelForCausalLM.from_pretrained(
37
  model_name,
38
  device_map="cpu",
39
+ torch_dtype=torch.float32,
40
+ trust_remote_code=True
41
  )
42
 
43
  # Create text generation pipeline
 
50
 
51
  return text_generator
52
 
53
+ # Load model
54
+ try:
55
+ text_generator = load_model()
56
+ except Exception as e:
57
+ print(f"Failed to load model: {e}")
58
+ # You might want to exit here or load a smaller model instead
59
+ raise
60
 
61
  # API endpoint for text generation
62
  @app.post("/api/generate")
 
64
  try:
65
  data = await request.json()
66
  prompt = data.get("prompt", "")
67
+ max_length = min(int(data.get("max_length", 100)), 500) # Limit to 500 tokens
68
 
69
+ # Generate text with timing
70
  start_time = time.time()
71
  outputs = text_generator(
72
  prompt,
 
74
  do_sample=True,
75
  temperature=0.7,
76
  top_k=50,
77
+ top_p=0.95,
78
+ pad_token_id=0 # Might be needed for some models
79
  )
80
  generation_time = time.time() - start_time
81
 
82
  return JSONResponse({
83
  "generated_text": outputs[0]["generated_text"],
84
+ "generation_time": round(generation_time, 2),
85
+ "model": "Trillion-7B-preview-AWQ",
86
  "device": "cpu"
87
  })
88
  except Exception as e:
 
90
 
91
  # Gradio interface
92
  def gradio_generate(prompt, max_length=100):
93
+ try:
94
+ max_length = min(int(max_length), 500) # Limit to 500 tokens
95
+ outputs = text_generator(
96
+ prompt,
97
+ max_length=max_length,
98
+ do_sample=True,
99
+ temperature=0.7,
100
+ top_k=50,
101
+ top_p=0.95,
102
+ pad_token_id=0
103
+ )
104
+ return outputs[0]["generated_text"]
105
+ except Exception as e:
106
+ return f"Error generating text: {str(e)}"
107
 
108
+ with gr.Blocks() as gradio_app:
109
+ gr.Markdown("""
110
+ # Trillion-7B-preview-AWQ Demo (CPU)
111
+ *Running on CPU with 16GB RAM - responses may be slow*
112
+ """)
113
+
114
+ with gr.Row():
115
+ input_prompt = gr.Textbox(
116
+ label="Input Prompt",
117
+ lines=5,
118
+ placeholder="Enter your prompt here..."
119
+ )
120
+ output_text = gr.Textbox(
121
+ label="Generated Text",
122
+ lines=5,
123
+ interactive=False
124
+ )
125
 
126
  with gr.Row():
127
+ length_slider = gr.Slider(
128
+ minimum=50,
129
+ maximum=500,
130
+ value=100,
131
+ step=10,
132
+ label="Max Length"
133
+ )
134
+ generate_btn = gr.Button("Generate", variant="primary")
135
 
136
+ # Additional examples
137
+ examples = gr.Examples(
138
+ examples=[
139
+ ["Explain quantum computing in simple terms."],
140
+ ["Write a short poem about artificial intelligence."],
141
+ ["How do I make a good cup of coffee?"]
142
+ ],
143
+ inputs=input_prompt
144
+ )
145
 
146
  generate_btn.click(
147
  fn=gradio_generate,
 
159
  allow_credentials=True,
160
  allow_methods=["*"],
161
  allow_headers=["*"],
162
+ )
163
+
164
+ # Health check endpoint
165
+ @app.get("/health")
166
+ async def health_check():
167
+ return {"status": "healthy", "model_loaded": text_generator is not None}