Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -26,13 +26,25 @@ gcp_credentials = json.loads(GOOGLE_CLOUD_CREDENTIALS)
|
|
26 |
storage_client = storage.Client.from_service_account_info(gcp_credentials)
|
27 |
bucket = storage_client.bucket(GOOGLE_CLOUD_BUCKET)
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
class ModelManager:
|
30 |
def __init__(self):
|
31 |
self.params = {"n_ctx": 2048, "n_batch": 512, "n_predict": 512, "repeat_penalty": 1.1, "n_threads": 1, "seed": -1, "stop": ["</s>"], "tokens": []}
|
32 |
-
# self.tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") #Load
|
33 |
self.request_queue = Queue()
|
34 |
self.response_queue = Queue()
|
35 |
-
self.
|
|
|
36 |
self.start_processing_processes()
|
37 |
|
38 |
def load_model_from_bucket(self, bucket_path):
|
@@ -44,6 +56,12 @@ class ModelManager:
|
|
44 |
print(f"Error loading model: {e}")
|
45 |
return None
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def save_model_to_bucket(self, model, bucket_path):
|
48 |
blob = bucket.blob(bucket_path)
|
49 |
try:
|
@@ -72,14 +90,15 @@ class ModelManager:
|
|
72 |
print(f"Error during training: {e}")
|
73 |
|
74 |
|
75 |
-
def generate_text(self, prompt):
|
76 |
-
if self.
|
|
|
77 |
inputs = self.tokenizer(prompt, return_tensors="pt")
|
78 |
-
outputs =
|
79 |
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
80 |
return generated_text
|
81 |
else:
|
82 |
-
return "Error
|
83 |
|
84 |
def start_processing_processes(self):
|
85 |
p = Process(target=self.process_requests)
|
@@ -90,9 +109,9 @@ class ModelManager:
|
|
90 |
request_data = self.request_queue.get()
|
91 |
if request_data is None:
|
92 |
break
|
93 |
-
inputs, top_p, top_k, temperature, max_tokens = request_data
|
94 |
try:
|
95 |
-
response = self.generate_text(inputs)
|
96 |
self.response_queue.put(response)
|
97 |
except Exception as e:
|
98 |
print(f"Error during inference: {e}")
|
@@ -102,30 +121,34 @@ model_manager = ModelManager()
|
|
102 |
|
103 |
class ChatRequest(BaseModel):
|
104 |
message: str
|
|
|
105 |
|
106 |
@spaces.GPU()
|
107 |
-
async def generate_streaming_response(inputs):
|
108 |
top_p = 0.9
|
109 |
top_k = 50
|
110 |
temperature = 0.7
|
111 |
max_tokens = model_manager.params["n_ctx"] - len(model_manager.tokenizer.encode(inputs))
|
112 |
-
model_manager.request_queue.put((inputs, top_p, top_k, temperature, max_tokens))
|
113 |
full_text = model_manager.response_queue.get()
|
114 |
async def stream_response():
|
115 |
yield full_text
|
116 |
return StreamingResponse(stream_response())
|
117 |
|
118 |
-
async def process_message(message):
|
119 |
inputs = message.strip()
|
120 |
-
return await generate_streaming_response(inputs)
|
121 |
|
122 |
@app.post("/generate_multimodel")
|
123 |
async def api_generate_multimodel(request: Request):
|
124 |
data = await request.json()
|
125 |
message = data["message"]
|
126 |
-
|
|
|
|
|
|
|
127 |
|
128 |
-
iface = gr.Interface(fn=process_message, inputs=gr.Textbox(lines=2, placeholder="Enter your message here..."), outputs=gr.Markdown(stream=True), title="Unified Multi-Model API", description="Enter a message to get responses from a unified model.") #gradio is not suitable for production
|
129 |
|
130 |
if __name__ == "__main__":
|
131 |
iface.launch()
|
|
|
26 |
storage_client = storage.Client.from_service_account_info(gcp_credentials)
|
27 |
bucket = storage_client.bucket(GOOGLE_CLOUD_BUCKET)
|
28 |
|
29 |
+
MODEL_NAMES = {
|
30 |
+
"starcoder": "starcoder2-3b-q2_k.gguf",
|
31 |
+
"gemma_2b_it": "gemma-2-2b-it-q2_k.gguf",
|
32 |
+
"llama_3_2_1b": "Llama-3.2-1B.Q2_K.gguf",
|
33 |
+
"gemma_2b_imat": "gemma-2-2b-iq1_s-imat.gguf",
|
34 |
+
"phi_3_mini": "phi-3-mini-128k-instruct-iq2_xxs-imat.gguf",
|
35 |
+
"qwen2_0_5b": "qwen2-0.5b-iq1_s-imat.gguf",
|
36 |
+
"gemma_9b_it": "gemma-2-9b-it-q2_k.gguf",
|
37 |
+
"gpt2_xl": "gpt2-xl-q2_k.gguf",
|
38 |
+
}
|
39 |
+
|
40 |
class ModelManager:
|
41 |
def __init__(self):
|
42 |
self.params = {"n_ctx": 2048, "n_batch": 512, "n_predict": 512, "repeat_penalty": 1.1, "n_threads": 1, "seed": -1, "stop": ["</s>"], "tokens": []}
|
43 |
+
# self.tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") # Load from GCS for production
|
44 |
self.request_queue = Queue()
|
45 |
self.response_queue = Queue()
|
46 |
+
self.models = {} # Dictionary to hold multiple models
|
47 |
+
self.load_models()
|
48 |
self.start_processing_processes()
|
49 |
|
50 |
def load_model_from_bucket(self, bucket_path):
|
|
|
56 |
print(f"Error loading model: {e}")
|
57 |
return None
|
58 |
|
59 |
+
def load_models(self):
|
60 |
+
for name, path in MODEL_NAMES.items():
|
61 |
+
model = self.load_model_from_bucket(path)
|
62 |
+
if model:
|
63 |
+
self.models[name] = model
|
64 |
+
|
65 |
def save_model_to_bucket(self, model, bucket_path):
|
66 |
blob = bucket.blob(bucket_path)
|
67 |
try:
|
|
|
90 |
print(f"Error during training: {e}")
|
91 |
|
92 |
|
93 |
+
def generate_text(self, prompt, model_name):
|
94 |
+
if model_name in self.models:
|
95 |
+
model = self.models[model_name]
|
96 |
inputs = self.tokenizer(prompt, return_tensors="pt")
|
97 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
98 |
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
99 |
return generated_text
|
100 |
else:
|
101 |
+
return "Error: Model not found."
|
102 |
|
103 |
def start_processing_processes(self):
|
104 |
p = Process(target=self.process_requests)
|
|
|
109 |
request_data = self.request_queue.get()
|
110 |
if request_data is None:
|
111 |
break
|
112 |
+
inputs, model_name, top_p, top_k, temperature, max_tokens = request_data
|
113 |
try:
|
114 |
+
response = self.generate_text(inputs, model_name)
|
115 |
self.response_queue.put(response)
|
116 |
except Exception as e:
|
117 |
print(f"Error during inference: {e}")
|
|
|
121 |
|
122 |
class ChatRequest(BaseModel):
|
123 |
message: str
|
124 |
+
model_name: str
|
125 |
|
126 |
@spaces.GPU()
|
127 |
+
async def generate_streaming_response(inputs, model_name):
|
128 |
top_p = 0.9
|
129 |
top_k = 50
|
130 |
temperature = 0.7
|
131 |
max_tokens = model_manager.params["n_ctx"] - len(model_manager.tokenizer.encode(inputs))
|
132 |
+
model_manager.request_queue.put((inputs, model_name, top_p, top_k, temperature, max_tokens))
|
133 |
full_text = model_manager.response_queue.get()
|
134 |
async def stream_response():
|
135 |
yield full_text
|
136 |
return StreamingResponse(stream_response())
|
137 |
|
138 |
+
async def process_message(message, model_name):
|
139 |
inputs = message.strip()
|
140 |
+
return await generate_streaming_response(inputs, model_name)
|
141 |
|
142 |
@app.post("/generate_multimodel")
|
143 |
async def api_generate_multimodel(request: Request):
|
144 |
data = await request.json()
|
145 |
message = data["message"]
|
146 |
+
model_name = data.get("model_name", list(MODEL_NAMES.keys())[0])
|
147 |
+
if model_name not in MODEL_NAMES:
|
148 |
+
return {"error": "Invalid model name"}
|
149 |
+
return await process_message(message, model_name)
|
150 |
|
151 |
+
iface = gr.Interface(fn=process_message, inputs=[gr.Textbox(lines=2, placeholder="Enter your message here..."), gr.Dropdown(list(MODEL_NAMES.keys()), label="Select Model")], outputs=gr.Markdown(stream=True), title="Unified Multi-Model API", description="Enter a message to get responses from a unified model.") #gradio is not suitable for production
|
152 |
|
153 |
if __name__ == "__main__":
|
154 |
iface.launch()
|