sarmadsiddiqui29 commited on
Commit
c99b386
·
verified ·
1 Parent(s): 98a1f99

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +379 -0
main.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ from pymongo import MongoClient
3
+ from pydantic import BaseModel
4
+ from passlib.context import CryptContext
5
+ from bson import ObjectId
6
+ from datetime import datetime, timedelta
7
+ import jwt
8
+ from collections import Counter
9
+ from fastapi.responses import JSONResponse
10
+ app = FastAPI()
11
+
12
+ # MongoDB connection
13
+ client = MongoClient(
14
+ "mongodb+srv://sarmadsiddiqui29:Rollno169@cluster0.uchmc.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0",
15
+ tls=True,
16
+ tlsAllowInvalidCertificates=True # For testing only, disable for production
17
+ )
18
+ db = client["annotations_db"]
19
+
20
+ # Password hashing context
21
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
22
+
23
+ # Secret key for JWT
24
+ SECRET_KEY = "your_secret_key" # Replace with a secure secret key
25
+ ALGORITHM = "HS256"
26
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30 # Token expiration time
27
+
28
+ # In-memory variable to store the token
29
+ current_token = None
30
+
31
+ # MongoDB Collections
32
+ users_collection = db["users"]
33
+ stories_collection = db["stories"]
34
+ prompts_collection = db["prompts"]
35
+ summaries_collection = db["summaries"]
36
+
37
+
38
+ # Models
39
+ class User(BaseModel):
40
+ email: str
41
+ password: str
42
+
43
+
44
+ class Story(BaseModel):
45
+ story_id: str
46
+ story: str
47
+ # annotator_id is removed from the Story model
48
+
49
+
50
+ class Prompt(BaseModel):
51
+ story_id: str
52
+ prompt: str
53
+ annotator_id: int = None # Will be set automatically
54
+
55
+
56
+ class Summary(BaseModel):
57
+ story_id: str
58
+ summary: str
59
+ annotator_id: int =None # Add annotator_id to Summary model
60
+
61
+
62
+ # Serialize document function
63
+ def serialize_document(doc):
64
+ """Convert a MongoDB document into a serializable dictionary."""
65
+ if isinstance(doc, ObjectId):
66
+ return str(doc)
67
+ if isinstance(doc, dict):
68
+ return {k: serialize_document(v) for k, v in doc.items()}
69
+ if isinstance(doc, list):
70
+ return [serialize_document(i) for i in doc]
71
+ return doc
72
+
73
+
74
+ # Helper Functions
75
+ def hash_password(password: str) -> str:
76
+ return pwd_context.hash(password)
77
+
78
+
79
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
80
+ return pwd_context.verify(plain_password, hashed_password)
81
+
82
+
83
+ def create_access_token(data: dict, expires_delta: timedelta = None):
84
+ to_encode = data.copy()
85
+ if expires_delta:
86
+ expire = datetime.utcnow() + expires_delta
87
+ else:
88
+ expire = datetime.utcnow() + timedelta(minutes=15)
89
+ to_encode.update({"exp": expire})
90
+ return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
91
+
92
+
93
+ def get_annotator_id() -> int:
94
+ if current_token is None:
95
+ raise HTTPException(status_code=401, detail="User not logged in")
96
+
97
+ try:
98
+ payload = jwt.decode(current_token, SECRET_KEY, algorithms=[ALGORITHM])
99
+ return payload["annotator_id"]
100
+ except jwt.PyJWTError:
101
+ raise HTTPException(status_code=401, detail="Invalid token")
102
+
103
+
104
+ # Endpoints for user, story, prompt, and summary operations
105
+
106
+ # Register User
107
+ @app.post("/register")
108
+ async def register_user(user: User):
109
+ if db.users.find_one({"email": user.email}):
110
+ raise HTTPException(status_code=400, detail="Email already registered")
111
+
112
+ user_data = {
113
+ "email": user.email,
114
+ "password": hash_password(user.password),
115
+ "annotator_id": db.users.count_documents({}) + 1
116
+ }
117
+ db.users.insert_one(user_data)
118
+
119
+ return {"message": "User registered successfully", "annotator_id": user_data["annotator_id"]}
120
+
121
+
122
+ # Login User
123
+ @app.post("/login")
124
+ async def login_user(user: User):
125
+ found_user = db.users.find_one({"email": user.email})
126
+ if not found_user or not verify_password(user.password, found_user["password"]):
127
+ raise HTTPException(status_code=400, detail="Invalid email or password")
128
+
129
+ # Create access token and store it
130
+ global current_token
131
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
132
+ current_token = create_access_token(data={"email": found_user["email"], "annotator_id": found_user["annotator_id"]},
133
+ expires_delta=access_token_expires)
134
+
135
+ return {"access_token": current_token, "token_type": "bearer"}
136
+
137
+
138
+ # Add Story
139
+ @app.post("/story")
140
+ async def add_story(story: Story):
141
+ # annotator_id is not needed when adding a story
142
+ if db.stories.find_one({"story_id": story.story_id}):
143
+ raise HTTPException(status_code=400, detail="Story already exists")
144
+
145
+ db.stories.insert_one(story.dict())
146
+ return {"message": "Story added successfully"}
147
+
148
+
149
+ # Add Prompt
150
+ @app.post("/prompt")
151
+ async def add_prompt(prompt: Prompt):
152
+ annotator_id = get_annotator_id() # Automatically get the annotator ID
153
+ prompt.annotator_id = annotator_id # Assign annotator ID to the prompt
154
+
155
+ db.prompts.insert_one(prompt.dict())
156
+ return {"message": "Prompt added successfully"}
157
+
158
+
159
+ # Add Summary
160
+ @app.post("/summary")
161
+ async def add_summary(summary: Summary):
162
+ annotator_id = get_annotator_id() # Automatically get the annotator ID
163
+ summary.annotator_id = annotator_id # Assign annotator ID to the summary
164
+
165
+ db.summaries.insert_one(summary.dict())
166
+ return {"message": "Summary added successfully"}
167
+
168
+
169
+ # Delete All Users
170
+ @app.delete("/users")
171
+ async def delete_all_users():
172
+ result = db.users.delete_many({})
173
+ return {"message": f"{result.deleted_count} users deleted"}
174
+
175
+
176
+ # Delete All Stories
177
+ @app.delete("/stories")
178
+ async def delete_all_stories():
179
+ result = db.stories.delete_many({})
180
+ return {"message": f"{result.deleted_count} stories deleted"}
181
+
182
+
183
+ # Delete All Prompts
184
+ @app.delete("/prompts")
185
+ async def delete_all_prompts():
186
+ result = db.prompts.delete_many({})
187
+ return {"message": f"{result.deleted_count} prompts deleted"}
188
+
189
+
190
+ # Delete All Summaries
191
+ @app.delete("/summaries")
192
+ async def delete_all_summaries():
193
+ result = db.summaries.delete_many({})
194
+ return {"message": f"{result.deleted_count} summaries deleted"}
195
+
196
+
197
+ # Test MongoDB Connection
198
+ @app.get("/test")
199
+ async def test_connection():
200
+ try:
201
+ db.list_collection_names()
202
+ return {"message": "Connected to MongoDB successfully"}
203
+ except Exception as e:
204
+ raise HTTPException(status_code=500, detail=str(e))
205
+
206
+
207
+ # Display Story by ID
208
+ @app.get("/story/{story_id}")
209
+ async def display_story(story_id: str):
210
+ story = db.stories.find_one({"story_id": story_id})
211
+ if story:
212
+ return serialize_document(story) # Serialize the story document
213
+ raise HTTPException(status_code=404, detail="Story not found")
214
+
215
+
216
+ # Display All for a Given Annotator ID
217
+ from fastapi import Query
218
+
219
+ from fastapi import Query, HTTPException
220
+
221
+ @app.get("/display_all")
222
+ async def display_all(story_id: str = Query(...)):
223
+ annotator_id = get_annotator_id() # Automatically get the annotator ID from the token
224
+
225
+ # Fetch the specific prompt associated with the provided story_id for the current annotator
226
+ prompt = db.prompts.find_one({"story_id": story_id, "annotator_id": annotator_id})
227
+ if not prompt:
228
+ raise HTTPException(status_code=404, detail="Prompt not found for this annotator and story ID")
229
+
230
+ # Fetch the corresponding story
231
+ story = db.stories.find_one({"story_id": story_id}) or {"story": ""}
232
+
233
+ # Fetch the summary for the specific annotator
234
+ summary = db.summaries.find_one({"story_id": story_id, "annotator_id": annotator_id}) or {"summary": ""}
235
+
236
+ # Prepare the result
237
+ result = {
238
+ "story_id": story_id,
239
+ "story": story["story"], # Get the story text
240
+ "annotator_id": prompt["annotator_id"],
241
+ "summary": summary.get("summary", ""), # Use empty string if summary not found
242
+ "prompt": prompt.get("prompt", "") # Use empty string if prompt not found
243
+ }
244
+
245
+ return serialize_document(result) # Serialize the story document
246
+
247
+
248
+
249
+
250
+ @app.delete("/prompt/{story_id}")
251
+ async def delete_prompt(story_id: str):
252
+ annotator_id = get_annotator_id() # Automatically get the annotator ID from the token
253
+
254
+ # Find and delete all prompts associated with the provided story_id for the current annotator
255
+ result = db.prompts.delete_many({"story_id": story_id, "annotator_id": annotator_id})
256
+
257
+ if result.deleted_count > 0:
258
+ return {"message": f"{result.deleted_count} prompt(s) deleted successfully"}
259
+ else:
260
+ raise HTTPException(status_code=404, detail="No prompts found for this annotator and story ID")
261
+ @app.delete("/summary/{story_id}")
262
+ async def delete_summary(story_id: str):
263
+ annotator_id = get_annotator_id() # Automatically get the annotator ID from the token
264
+
265
+ # Find and delete all summaries associated with the provided story_id for the current annotator
266
+ result = db.summaries.delete_many({"story_id": story_id, "annotator_id": annotator_id})
267
+
268
+ if result.deleted_count > 0:
269
+ return {"message": f"{result.deleted_count} summary(ies) deleted successfully"}
270
+ else:
271
+ raise HTTPException(status_code=404, detail="No summaries found for this annotator and story ID")
272
+ @app.delete("/story/{story_id}")
273
+ async def delete_story(story_id: str):
274
+ annotator_id = get_annotator_id() # Automatically get the annotator ID from the token
275
+
276
+ # Find and delete the story associated with the provided story_id for the current annotator
277
+ story_result = db.stories.delete_one({"story_id": story_id})
278
+
279
+ # Delete all prompts associated with the provided story_id for the current annotator
280
+ prompts_result = db.prompts.delete_many({"story_id": story_id, "annotator_id": annotator_id})
281
+
282
+ # Delete all summaries associated with the provided story_id for the current annotator
283
+ summaries_result = db.summaries.delete_many({"story_id": story_id, "annotator_id": annotator_id})
284
+
285
+ if story_result.deleted_count > 0:
286
+ return {
287
+ "message": f"Story deleted successfully",
288
+ "deleted_prompts": prompts_result.deleted_count,
289
+ "deleted_summaries": summaries_result.deleted_count,
290
+ }
291
+ else:
292
+ raise HTTPException(status_code=404, detail="Story not found for this annotator")
293
+ @app.put("/story/{story_id}")
294
+ async def update_story(story_id: str, updated_story: Story):
295
+ annotator_id = get_annotator_id() # Automatically get the annotator ID from the token
296
+
297
+ # Check if the story exists and belongs to the current annotator
298
+ existing_story = db.stories.find_one({"story_id": story_id, "annotator_id": annotator_id})
299
+ if not existing_story:
300
+ raise HTTPException(status_code=404, detail="Story not found or does not belong to this annotator")
301
+
302
+ # Update the story
303
+ db.stories.update_one({"story_id": story_id}, {"$set": {"story": updated_story.story}})
304
+ return {"message": "Story updated successfully"}
305
+ @app.put("/prompt/{story_id}")
306
+ async def update_prompt(story_id: str, updated_prompt: Prompt):
307
+ annotator_id = get_annotator_id() # Automatically get the annotator ID from the token
308
+
309
+ # Check if the prompt exists and belongs to the current annotator
310
+ existing_prompt = db.prompts.find_one({"story_id": story_id, "annotator_id": annotator_id})
311
+ if not existing_prompt:
312
+ raise HTTPException(status_code=404, detail="Prompt not found or does not belong to this annotator")
313
+
314
+ # Update the prompt
315
+ db.prompts.update_one({"story_id": story_id, "annotator_id": annotator_id}, {"$set": {"prompt": updated_prompt.prompt}})
316
+ return {"message": "Prompt updated successfully"}
317
+ @app.put("/summary/{story_id}")
318
+ async def update_summary(story_id: str, updated_summary: Summary):
319
+ annotator_id = get_annotator_id() # Automatically get the annotator ID from the token
320
+
321
+ # Check if the summary exists and belongs to the current annotator
322
+ existing_summary = db.summaries.find_one({"story_id": story_id, "annotator_id": annotator_id})
323
+ if not existing_summary:
324
+ raise HTTPException(status_code=404, detail="Summary not found or does not belong to this annotator")
325
+
326
+ # Update the summary
327
+ db.summaries.update_one({"story_id": story_id, "annotator_id": annotator_id}, {"$set": {"summary": updated_summary.summary}})
328
+ return {"message": "Summary updated successfully"}
329
+
330
+
331
+ @app.get("/prompt/{story_id}")
332
+ async def get_prompt(story_id: str):
333
+ annotator_id = get_annotator_id() # Automatically get the annotator ID from the token
334
+
335
+ # Retrieve the prompt associated with the story_id for the current annotator
336
+ prompt = db.prompts.find_one({"story_id": story_id, "annotator_id": annotator_id})
337
+
338
+ if prompt:
339
+ return {"story_id": story_id, "prompt": prompt.get("prompt", "")} # Return prompt or empty string
340
+ else:
341
+ return {"story_id": story_id, "prompt": ""} # Return empty if no prompt found
342
+
343
+
344
+ @app.get("/summary/{story_id}")
345
+ async def get_summary(story_id: str):
346
+ annotator_id = get_annotator_id() # Automatically get the annotator ID from the token
347
+
348
+ # Retrieve the summary associated with the story_id for the current annotator
349
+ summary = db.summaries.find_one({"story_id": story_id, "annotator_id": annotator_id})
350
+
351
+ if summary:
352
+ return {"story_id": story_id, "summary": summary.get("summary", "")} # Return summary or empty string
353
+ else:
354
+ return {"story_id": story_id, "summary": ""} # Return empty if no summary found
355
+
356
+
357
+ @app.get("/story/{story_id}")
358
+ async def get_story(story_id: str):
359
+ # Retrieve the story associated with the story_id
360
+ story = db.stories.find_one({"story_id": story_id})
361
+
362
+ if story:
363
+ return {"story_id": story_id, "story": story.get("story", "")} # Return story text or empty string
364
+ else:
365
+ return {"story_id": story_id, "story": ""} # Return empty if no story found
366
+
367
+
368
+ @app.get("/annotators")
369
+ async def get_annotators():
370
+ # Fetch all prompts synchronously
371
+ prompts = prompts_collection.find() # Get cursor
372
+
373
+ # Count prompts by annotator_id
374
+ annotator_counts = Counter(prompt['annotator_id'] for prompt in prompts if 'annotator_id' in prompt)
375
+
376
+ # Convert the Counter to a list of dictionaries
377
+ annotators = [{"annotator_id": annotator_id, "prompt_count": count} for annotator_id, count in annotator_counts.items()]
378
+
379
+ return JSONResponse(content=annotators)