Vitrous commited on
Commit
67c4e45
·
verified ·
1 Parent(s): cac49fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -8
app.py CHANGED
@@ -39,8 +39,7 @@ model, tokenizer = load_model_norm()
39
  #Now we can init the FlaskApi
40
  app = FastAPI(root_path="/api/v1")
41
 
42
- # Function to generate a response using the model
43
-
44
  def generate_response(prompt: str) -> str:
45
  # Define the user prompt
46
  user_prompt = f'USER: {prompt}'
@@ -68,14 +67,49 @@ def generate_response(prompt: str) -> str:
68
 
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
 
 
73
  @app.get("/", tags=["Home"])
74
  async def api_home():
75
  return {'detail': 'Welcome to Eren Bot!'}
76
 
77
 
78
  # Endpoint to start a new conversation thread
 
 
79
  @app.post('/start_conversation/')
80
  async def start_conversation(request: Request):
81
  try:
@@ -86,23 +120,54 @@ async def start_conversation(request: Request):
86
  if not prompt:
87
  raise HTTPException(status_code=400, detail="No prompt provided")
88
 
89
- # Check if conversations dictionary is empty
90
- # if not conversations:
91
- # raise HTTPException(status_code=404, detail="No chat history available")
92
-
93
  # Generate a response for the initial prompt
94
  response = generate_response(prompt)
95
 
 
 
 
96
  # Create a new conversation thread and store the prompt and response
97
- ##conversations[thread_id] = {'prompt': prompt, 'responses': [response]}
98
- #return {'thread_id': thread_id, 'response': response}
99
  return {'response': response}
100
  except HTTPException:
101
  raise # Re-raise HTTPException to return it directly
102
  except Exception as e:
103
  raise HTTPException(status_code=500, detail=str(e))
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
 
106
  @app.get('/get_response/{thread_id}')
107
  async def get_response(thread_id: int):
108
  if thread_id not in conversations:
@@ -119,6 +184,7 @@ async def get_response(thread_id: int):
119
 
120
 
121
 
 
122
  @app.post('/chat/')
123
  async def chat(request: Request):
124
  data = await request.json()
 
39
  #Now we can init the FlaskApi
40
  app = FastAPI(root_path="/api/v1")
41
 
42
+ #Generates a response from the model
 
43
  def generate_response(prompt: str) -> str:
44
  # Define the user prompt
45
  user_prompt = f'USER: {prompt}'
 
67
 
68
 
69
 
70
+ def generate_response(persona_prompt: str, prompt: str) -> dict:
71
+ try:
72
+ # Validate inputs
73
+ if not persona_prompt or not prompt:
74
+ raise ValueError("Contextual prompt template and prompt cannot be empty.")
75
+
76
+ # Define the user prompt
77
+ user_prompt = f'USER: {prompt}'
78
+
79
+ # Generate the response
80
+ pipe = pipeline(
81
+ "text-generation",
82
+ model=model,
83
+ tokenizer=tokenizer,
84
+ max_new_tokens=512,
85
+ do_sample=True,
86
+ temperature=0.7,
87
+ top_p=0.95,
88
+ top_k=40,
89
+ repetition_penalty=1.1
90
+ )
91
+ generated_text = pipe(persona_prompt + user_prompt)[0]['generated_text']
92
+
93
+ # Extract only the assistant's response from the generated text
94
+ assistant_response = generated_text.split(user_prompt)[-1].strip()
95
+
96
+ return {"user": prompt,"assistant": assistant_response}
97
+
98
+ except Exception as e:
99
+ # Handle any exceptions and return an error message
100
+ return {"error": str(e)}
101
+
102
 
103
 
104
+ #This is the Root directory of the FastApi application
105
  @app.get("/", tags=["Home"])
106
  async def api_home():
107
  return {'detail': 'Welcome to Eren Bot!'}
108
 
109
 
110
  # Endpoint to start a new conversation thread
111
+
112
+ # Waits for the User to start a conversation and replies based on persona of the model
113
  @app.post('/start_conversation/')
114
  async def start_conversation(request: Request):
115
  try:
 
120
  if not prompt:
121
  raise HTTPException(status_code=400, detail="No prompt provided")
122
 
 
 
 
 
123
  # Generate a response for the initial prompt
124
  response = generate_response(prompt)
125
 
126
+ # Generate a unique thread ID
127
+ thread_id = len(conversations) + 1
128
+
129
  # Create a new conversation thread and store the prompt and response
130
+ conversations[thread_id] = {'prompt': prompt, 'responses': [response]}
131
+
132
  return {'response': response}
133
  except HTTPException:
134
  raise # Re-raise HTTPException to return it directly
135
  except Exception as e:
136
  raise HTTPException(status_code=500, detail=str(e))
137
 
138
+ # Endpoint to start a new chat thread
139
+
140
+ # Starts a new chat thread and expects the prompt and the persona_prompt from the user
141
+ @app.post('/start_chat/')
142
+ async def start_chat(request: Request):
143
+ try:
144
+ # Read JSON data from request body
145
+ data = await request.json()
146
+ prompt = data.get('prompt')
147
+ persona_prompt = data.get('persona_prompt')
148
+
149
+ if not prompt or not persona_prompt:
150
+ raise HTTPException(status_code=400, detail="Both prompt and contextual_prompt are required")
151
+
152
+ # Generate a response for the initial prompt
153
+ response = generate_response(persona_prompt, prompt)
154
+
155
+ # Generate a unique thread ID
156
+ thread_id = len(conversations) + 1
157
+
158
+ # Create a new conversation thread and store the prompt and response
159
+ conversations[thread_id] = {'prompt': prompt, 'responses': [response]}
160
+
161
+ # Return the thread ID and response
162
+ return {'thread_id': thread_id, 'response': response}
163
+ except HTTPException:
164
+ raise # Re-raise HTTPException to return it directly
165
+ except Exception as e:
166
+ raise HTTPException(status_code=500, detail=str(e))
167
+
168
+
169
 
170
+ # Gets the response from the model and user given a specific thread id of the conversation
171
  @app.get('/get_response/{thread_id}')
172
  async def get_response(thread_id: int):
173
  if thread_id not in conversations:
 
184
 
185
 
186
 
187
+
188
  @app.post('/chat/')
189
  async def chat(request: Request):
190
  data = await request.json()