Juna190825 commited on
Commit
b0e1169
·
verified ·
1 Parent(s): 65487b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -21
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import os
2
  import torch
 
 
3
  from fastapi import FastAPI, Request
4
  from fastapi.responses import JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import gradio as gr
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
- import time
9
- import warnings
10
 
11
  # Suppress specific warnings
12
  warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.hub")
@@ -44,24 +44,12 @@ def load_model():
44
 
45
  log_message("Loading model...")
46
  try:
47
- # Try loading with IPEX optimization if available
48
- try:
49
- import intel_extension_for_pytorch as ipex
50
- use_ipex = True
51
- except ImportError:
52
- use_ipex = False
53
- log_message("IPEX not available, using standard CPU version")
54
-
55
  model = AutoModelForCausalLM.from_pretrained(
56
  model_name,
57
  torch_dtype=torch.float32,
58
  trust_remote_code=True
59
  )
60
 
61
- if use_ipex:
62
- log_message("Applying IPEX optimization...")
63
- model = ipex.optimize(model)
64
-
65
  # Explicitly move to CPU
66
  model = model.to("cpu")
67
  model.eval()
@@ -78,12 +66,12 @@ def load_model():
78
  device="cpu"
79
  )
80
 
81
- return text_generator
82
 
83
  # Load model
84
  try:
85
  log_message("Starting model loading process...")
86
- text_generator = load_model()
87
  log_message("Model loaded successfully")
88
  except Exception as e:
89
  log_message(f"Critical error loading model: {e}")
@@ -109,17 +97,18 @@ async def api_generate(request: Request):
109
  temperature=0.7,
110
  top_k=50,
111
  top_p=0.95,
112
- pad_token_id=tokenizer.eos_token_id if hasattr(text_generator, 'tokenizer') else 0
113
  )
114
  generation_time = time.time() - start_time
115
 
116
- return JSONResponse({
117
  "generated_text": outputs[0]["generated_text"],
118
  "time_seconds": round(generation_time, 2),
119
- "tokens_generated": len(text_generator.tokenizer.tokenize(outputs[0]["generated_text"]) if hasattr(text_generator, 'tokenizer') else None,
120
  "model": "Trillion-7B-preview-AWQ",
121
  "device": "cpu"
122
- })
 
123
  except Exception as e:
124
  log_message(f"API Error: {e}")
125
  return JSONResponse({"error": str(e)}, status_code=500)
@@ -149,7 +138,7 @@ def gradio_generate(prompt, max_length=100):
149
  temperature=0.7,
150
  top_k=50,
151
  top_p=0.95,
152
- pad_token_id=tokenizer.eos_token_id if hasattr(text_generator, 'tokenizer') else 0
153
  )
154
  return outputs[0]["generated_text"]
155
  except Exception as e:
 
1
  import os
2
  import torch
3
+ import time
4
+ import warnings
5
  from fastapi import FastAPI, Request
6
  from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  import gradio as gr
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
 
10
 
11
  # Suppress specific warnings
12
  warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.hub")
 
44
 
45
  log_message("Loading model...")
46
  try:
 
 
 
 
 
 
 
 
47
  model = AutoModelForCausalLM.from_pretrained(
48
  model_name,
49
  torch_dtype=torch.float32,
50
  trust_remote_code=True
51
  )
52
 
 
 
 
 
53
  # Explicitly move to CPU
54
  model = model.to("cpu")
55
  model.eval()
 
66
  device="cpu"
67
  )
68
 
69
+ return text_generator, tokenizer
70
 
71
  # Load model
72
  try:
73
  log_message("Starting model loading process...")
74
+ text_generator, tokenizer = load_model()
75
  log_message("Model loaded successfully")
76
  except Exception as e:
77
  log_message(f"Critical error loading model: {e}")
 
97
  temperature=0.7,
98
  top_k=50,
99
  top_p=0.95,
100
+ pad_token_id=tokenizer.eos_token_id
101
  )
102
  generation_time = time.time() - start_time
103
 
104
+ response_data = {
105
  "generated_text": outputs[0]["generated_text"],
106
  "time_seconds": round(generation_time, 2),
107
+ "tokens_generated": len(tokenizer.tokenize(outputs[0]["generated_text"])),
108
  "model": "Trillion-7B-preview-AWQ",
109
  "device": "cpu"
110
+ }
111
+ return JSONResponse(response_data)
112
  except Exception as e:
113
  log_message(f"API Error: {e}")
114
  return JSONResponse({"error": str(e)}, status_code=500)
 
138
  temperature=0.7,
139
  top_k=50,
140
  top_p=0.95,
141
+ pad_token_id=tokenizer.eos_token_id
142
  )
143
  return outputs[0]["generated_text"]
144
  except Exception as e: