Update app.py
Browse files
app.py
CHANGED
@@ -36,6 +36,7 @@ def load_quantized_model(model_id, model_basename):
|
|
36 |
model, tokenizer = load_quantized_model(model_name_or_path, "model.safetensors")
|
37 |
|
38 |
|
|
|
39 |
def load_model_norm():
|
40 |
if torch.cuda.is_available():
|
41 |
print("CUDA is available. GPU will be used.")
|
@@ -51,6 +52,7 @@ def load_model_norm():
|
|
51 |
|
52 |
return model, tokenizer
|
53 |
|
|
|
54 |
# Function to generate a response using the model
|
55 |
def generate_response(prompt: str) -> str:
|
56 |
PERSONA_NAME = "Ivana"
|
@@ -94,6 +96,7 @@ async def api_home():
|
|
94 |
return {'detail': 'Welcome to Eren Bot!'}
|
95 |
|
96 |
|
|
|
97 |
# Endpoint to start a new conversation thread
|
98 |
@app.post('/api/start_conversation')
|
99 |
async def start_conversation(request: Request):
|
@@ -110,6 +113,7 @@ async def start_conversation(request: Request):
|
|
110 |
return {'thread_id': thread_id, 'response': response}
|
111 |
|
112 |
|
|
|
113 |
# Endpoint to get the response of a conversation thread
|
114 |
@app.get('/api/get_response/{thread_id}')
|
115 |
async def get_response(thread_id: int):
|
@@ -124,9 +128,11 @@ async def get_response(thread_id: int):
|
|
124 |
|
125 |
return {'response': response}
|
126 |
|
|
|
127 |
@app.post('/api/chat')
|
128 |
async def chat(request: Request, chat_input: ChatInput):
|
129 |
-
|
|
|
130 |
|
131 |
# Generate a response based on the prompt
|
132 |
response = generate_response(prompt)
|
|
|
36 |
model, tokenizer = load_quantized_model(model_name_or_path, "model.safetensors")
|
37 |
|
38 |
|
39 |
+
|
40 |
def load_model_norm():
|
41 |
if torch.cuda.is_available():
|
42 |
print("CUDA is available. GPU will be used.")
|
|
|
52 |
|
53 |
return model, tokenizer
|
54 |
|
55 |
+
|
56 |
# Function to generate a response using the model
|
57 |
def generate_response(prompt: str) -> str:
|
58 |
PERSONA_NAME = "Ivana"
|
|
|
96 |
return {'detail': 'Welcome to Eren Bot!'}
|
97 |
|
98 |
|
99 |
+
|
100 |
# Endpoint to start a new conversation thread
|
101 |
@app.post('/api/start_conversation')
|
102 |
async def start_conversation(request: Request):
|
|
|
113 |
return {'thread_id': thread_id, 'response': response}
|
114 |
|
115 |
|
116 |
+
|
117 |
# Endpoint to get the response of a conversation thread
|
118 |
@app.get('/api/get_response/{thread_id}')
|
119 |
async def get_response(thread_id: int):
|
|
|
128 |
|
129 |
return {'response': response}
|
130 |
|
131 |
+
|
132 |
@app.post('/api/chat')
|
133 |
async def chat(request: Request, chat_input: ChatInput):
|
134 |
+
data = await request.json()
|
135 |
+
prompt = data.get('prompt')
|
136 |
|
137 |
# Generate a response based on the prompt
|
138 |
response = generate_response(prompt)
|