Vitrous commited on
Commit
7a208d9
·
verified ·
1 Parent(s): ec357c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -36,6 +36,7 @@ def load_quantized_model(model_id, model_basename):
36
  model, tokenizer = load_quantized_model(model_name_or_path, "model.safetensors")
37
 
38
 
 
39
  def load_model_norm():
40
  if torch.cuda.is_available():
41
  print("CUDA is available. GPU will be used.")
@@ -51,6 +52,7 @@ def load_model_norm():
51
 
52
  return model, tokenizer
53
 
 
54
  # Function to generate a response using the model
55
  def generate_response(prompt: str) -> str:
56
  PERSONA_NAME = "Ivana"
@@ -94,6 +96,7 @@ async def api_home():
94
  return {'detail': 'Welcome to Eren Bot!'}
95
 
96
 
 
97
  # Endpoint to start a new conversation thread
98
  @app.post('/api/start_conversation')
99
  async def start_conversation(request: Request):
@@ -110,6 +113,7 @@ async def start_conversation(request: Request):
110
  return {'thread_id': thread_id, 'response': response}
111
 
112
 
 
113
  # Endpoint to get the response of a conversation thread
114
  @app.get('/api/get_response/{thread_id}')
115
  async def get_response(thread_id: int):
@@ -124,9 +128,11 @@ async def get_response(thread_id: int):
124
 
125
  return {'response': response}
126
 
 
127
  @app.post('/api/chat')
128
  async def chat(request: Request, chat_input: ChatInput):
129
- prompt = chat_input.prompt
 
130
 
131
  # Generate a response based on the prompt
132
  response = generate_response(prompt)
 
36
  model, tokenizer = load_quantized_model(model_name_or_path, "model.safetensors")
37
 
38
 
39
+
40
  def load_model_norm():
41
  if torch.cuda.is_available():
42
  print("CUDA is available. GPU will be used.")
 
52
 
53
  return model, tokenizer
54
 
55
+
56
  # Function to generate a response using the model
57
  def generate_response(prompt: str) -> str:
58
  PERSONA_NAME = "Ivana"
 
96
  return {'detail': 'Welcome to Eren Bot!'}
97
 
98
 
99
+
100
  # Endpoint to start a new conversation thread
101
  @app.post('/api/start_conversation')
102
  async def start_conversation(request: Request):
 
113
  return {'thread_id': thread_id, 'response': response}
114
 
115
 
116
+
117
  # Endpoint to get the response of a conversation thread
118
  @app.get('/api/get_response/{thread_id}')
119
  async def get_response(thread_id: int):
 
128
 
129
  return {'response': response}
130
 
131
+
132
  @app.post('/api/chat')
133
  async def chat(request: Request, chat_input: ChatInput):
134
+ data = await request.json()
135
+ prompt = data.get('prompt')
136
 
137
  # Generate a response based on the prompt
138
  response = generate_response(prompt)