arthuroe commited on
Commit
3bbc58b
·
verified ·
1 Parent(s): e85f548

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +454 -0
app.py CHANGED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openrouter_llm import OpenRouterFreeAdapter, OpenRouterFreeChain
2
+ from langchain.schema import Document as LangchainDocument
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ import os
7
+ import uuid
8
+ import shutil
9
+ import logging
10
+ from typing import List, Optional, Dict, Any
11
+ from pathlib import Path
12
+
13
+ import fitz # PyMuPDF
14
+ import markdown
15
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Depends, BackgroundTasks
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from fastapi.responses import JSONResponse
18
+ from pydantic import BaseModel
19
+ from dotenv import load_dotenv
20
+
21
+ # Load environment variables
22
+ load_dotenv()
23
+
24
+ # Import LangChain components for embedding
25
+
26
+ # Import our free-only OpenRouter adapter
27
+
28
+ # Configure logging
29
+ logging.basicConfig(level=logging.INFO)
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Initialize FastAPI app
33
+ app = FastAPI(title="AskMyDocs API - Free LLM Edition")
34
+
35
+ # Add CORS middleware for frontend integration
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"], # Set to specific domain in production
39
+ allow_credentials=True,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+
44
+ # Configuration
45
+ OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "")
46
+ HF_MODEL_NAME = os.getenv(
47
+ "HF_MODEL_NAME", "sentence-transformers/all-mpnet-base-v2")
48
+ UPLOAD_DIR = os.getenv("UPLOAD_DIR", "./uploads")
49
+ DB_DIR = os.getenv("DB_DIR", "./vectordb")
50
+
51
+ print(HF_MODEL_NAME)
52
+ # Ensure directories exist
53
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
54
+ os.makedirs(DB_DIR, exist_ok=True)
55
+
56
+ # Initialize OpenRouter adapter (singleton)
57
+ openrouter_adapter = None
58
+
59
+ # Pydantic models
60
+
61
+
62
+ class QueryRequest(BaseModel):
63
+ query: str
64
+ collection_id: str
65
+
66
+
67
+ class QueryResponse(BaseModel):
68
+ answer: str
69
+ sources: List[str]
70
+
71
+
72
+ class Document(BaseModel):
73
+ id: str
74
+ filename: str
75
+ content_type: str
76
+
77
+
78
+ class DocumentList(BaseModel):
79
+ documents: List[Document]
80
+
81
+
82
+ class LLMInfo(BaseModel):
83
+ model: str
84
+ is_free: bool = True
85
+ provider: str = "openrouter"
86
+
87
+
88
+ class LLMModelsList(BaseModel):
89
+ current_model: str
90
+ free_models: List[Dict[str, Any]]
91
+
92
+
93
+ # Global variable to store vector databases (in memory for simplicity)
94
+ # In production, you would use persistent storage
95
+ vector_dbs = {}
96
+
97
+ # Helper functions
98
+
99
+
100
+ def get_embeddings():
101
+ """Get HuggingFace embedding model."""
102
+ return HuggingFaceEmbeddings(model_name=HF_MODEL_NAME)
103
+
104
+
105
+ def get_openrouter_adapter():
106
+ """Get or initialize the OpenRouter adapter for free models."""
107
+ global openrouter_adapter
108
+
109
+ if openrouter_adapter is None:
110
+ openrouter_adapter = OpenRouterFreeAdapter(api_key=OPENROUTER_API_KEY)
111
+
112
+ return openrouter_adapter
113
+
114
+
115
+ def extract_text_from_pdf(file_path):
116
+ """Extract text content from PDF files."""
117
+ text = ""
118
+ try:
119
+ doc = fitz.open(file_path)
120
+ for page in doc:
121
+ text += page.get_text()
122
+ return text
123
+ except Exception as e:
124
+ logger.error(f"Error extracting text from PDF: {e}")
125
+ raise HTTPException(
126
+ status_code=500, detail=f"Error processing PDF: {str(e)}")
127
+
128
+
129
+ def extract_text_from_markdown(file_path):
130
+ """Convert Markdown to plain text."""
131
+ try:
132
+ with open(file_path, 'r', encoding='utf-8') as f:
133
+ md_content = f.read()
134
+ html = markdown.markdown(md_content)
135
+ # Simple HTML to text conversion - in production use a more robust method
136
+ text = html.replace('<p>', '\n\n').replace(
137
+ '</p>', '').replace('<br>', '\n')
138
+ text = text.replace('<h1>', '\n\n# ').replace('</h1>', '\n')
139
+ text = text.replace('<h2>', '\n\n## ').replace('</h2>', '\n')
140
+ text = text.replace('<h3>', '\n\n### ').replace('</h3>', '\n')
141
+ # Remove other HTML tags
142
+ import re
143
+ text = re.sub('<[^<]+?>', '', text)
144
+ return text
145
+ except Exception as e:
146
+ logger.error(f"Error processing Markdown: {e}")
147
+ raise HTTPException(
148
+ status_code=500, detail=f"Error processing Markdown: {str(e)}")
149
+
150
+
151
+ def extract_text_from_file(file_path, content_type):
152
+ """Extract text based on file type."""
153
+ if content_type == "application/pdf":
154
+ return extract_text_from_pdf(file_path)
155
+ elif content_type == "text/markdown":
156
+ return extract_text_from_markdown(file_path)
157
+ elif content_type == "text/plain":
158
+ with open(file_path, 'r', encoding='utf-8') as f:
159
+ return f.read()
160
+ else:
161
+ raise HTTPException(
162
+ status_code=400, detail=f"Unsupported file type: {content_type}")
163
+
164
+
165
+ def process_documents(collection_id: str, file_paths: List[tuple]):
166
+ """Process documents and create vector store."""
167
+ try:
168
+ # Create text splitter
169
+ text_splitter = RecursiveCharacterTextSplitter(
170
+ chunk_size=1000,
171
+ chunk_overlap=100,
172
+ length_function=len,
173
+ )
174
+
175
+ all_docs = []
176
+ for file_path, content_type, filename in file_paths:
177
+ text_content = extract_text_from_file(file_path, content_type)
178
+ chunks = text_splitter.split_text(text_content)
179
+
180
+ # Create Document objects with metadata
181
+ docs = [
182
+ LangchainDocument(
183
+ page_content=chunk,
184
+ metadata={"source": filename, "chunk": i}
185
+ )
186
+ for i, chunk in enumerate(chunks)
187
+ ]
188
+ all_docs.extend(docs)
189
+
190
+ # Create vector store
191
+ embeddings = get_embeddings()
192
+ vector_db = FAISS.from_documents(all_docs, embeddings)
193
+
194
+ # Save vector store
195
+ collection_path = os.path.join(DB_DIR, collection_id)
196
+ os.makedirs(collection_path, exist_ok=True)
197
+ vector_db.save_local(collection_path)
198
+
199
+ # Store in memory (would be replaced by database lookup in production)
200
+ vector_dbs[collection_id] = vector_db
201
+
202
+ logger.info(
203
+ f"Successfully processed {len(all_docs)} chunks from {len(file_paths)} documents")
204
+ except Exception as e:
205
+ logger.error(f"Error processing documents: {e}")
206
+ raise HTTPException(
207
+ status_code=500, detail=f"Error processing documents: {str(e)}")
208
+
209
+
210
+ @app.get("/")
211
+ async def index():
212
+ return {"message": "Welcome to ask my doc"}
213
+
214
+
215
+ @app.get("/health")
216
+ async def health_check():
217
+ return {"status": "healthy"}
218
+
219
+
220
+ @app.post("/upload", response_model=Document)
221
+ async def upload_file(
222
+ background_tasks: BackgroundTasks,
223
+ collection_id: str = Form(...),
224
+ file: UploadFile = File(...),
225
+ ):
226
+ """Upload a document and process it for querying."""
227
+ try:
228
+ # Generate a unique ID for the document
229
+ doc_id = str(uuid.uuid4())
230
+
231
+ # Create collection directory if it doesn't exist
232
+ collection_dir = os.path.join(UPLOAD_DIR, collection_id)
233
+ os.makedirs(collection_dir, exist_ok=True)
234
+
235
+ # Define the file path
236
+ file_path = os.path.join(collection_dir, file.filename)
237
+
238
+ # Determine content type
239
+ content_type = file.content_type
240
+ if not content_type:
241
+ if file.filename.endswith('.pdf'):
242
+ content_type = "application/pdf"
243
+ elif file.filename.endswith('.md'):
244
+ content_type = "text/markdown"
245
+ elif file.filename.endswith('.txt'):
246
+ content_type = "text/plain"
247
+ else:
248
+ raise HTTPException(
249
+ status_code=400, detail="Unsupported file type")
250
+
251
+ # Save the file
252
+ with open(file_path, "wb") as f:
253
+ shutil.copyfileobj(file.file, f)
254
+
255
+ # Process the document in the background
256
+ background_tasks.add_task(
257
+ process_documents,
258
+ collection_id,
259
+ [(file_path, content_type, file.filename)]
260
+ )
261
+
262
+ return Document(
263
+ id=doc_id,
264
+ filename=file.filename,
265
+ content_type=content_type
266
+ )
267
+ except Exception as e:
268
+ logger.error(f"Error uploading file: {e}")
269
+ raise HTTPException(
270
+ status_code=500, detail=f"Error uploading file: {str(e)}")
271
+
272
+
273
+ @app.get("/collections/{collection_id}/documents", response_model=DocumentList)
274
+ async def list_documents(collection_id: str):
275
+ """List all documents in a collection."""
276
+ try:
277
+ collection_dir = os.path.join(UPLOAD_DIR, collection_id)
278
+ if not os.path.exists(collection_dir):
279
+ return DocumentList(documents=[])
280
+
281
+ documents = []
282
+ for filename in os.listdir(collection_dir):
283
+ file_path = os.path.join(collection_dir, filename)
284
+ if os.path.isfile(file_path):
285
+ content_type = "application/octet-stream"
286
+ if filename.endswith('.pdf'):
287
+ content_type = "application/pdf"
288
+ elif filename.endswith('.md'):
289
+ content_type = "text/markdown"
290
+ elif filename.endswith('.txt'):
291
+ content_type = "text/plain"
292
+
293
+ documents.append(Document(
294
+ # In production, store and retrieve actual IDs
295
+ id=str(uuid.uuid4()),
296
+ filename=filename,
297
+ content_type=content_type
298
+ ))
299
+
300
+ return DocumentList(documents=documents)
301
+ except Exception as e:
302
+ logger.error(f"Error listing documents: {e}")
303
+ raise HTTPException(
304
+ status_code=500, detail=f"Error listing documents: {str(e)}")
305
+
306
+
307
+ @app.post("/query", response_model=QueryResponse)
308
+ async def query_documents(request: QueryRequest):
309
+ """Query documents using natural language."""
310
+ try:
311
+ collection_id = request.collection_id
312
+
313
+ # Check if vector DB exists in memory
314
+ if collection_id in vector_dbs:
315
+ vector_db = vector_dbs[collection_id]
316
+ else:
317
+ # Load from disk
318
+ collection_path = os.path.join(DB_DIR, collection_id)
319
+ if not os.path.exists(collection_path):
320
+ raise HTTPException(
321
+ status_code=404, detail=f"Collection {collection_id} not found")
322
+
323
+ embeddings = get_embeddings()
324
+ vector_db = FAISS.load_local(collection_path, embeddings)
325
+ vector_dbs[collection_id] = vector_db
326
+
327
+ # Get the retriever
328
+ retriever = vector_db.as_retriever(search_kwargs={"k": 3})
329
+
330
+ # Get relevant documents
331
+ docs = retriever.get_relevant_documents(request.query)
332
+
333
+ # Extract sources
334
+ sources = []
335
+ for doc in docs:
336
+ if doc.metadata.get("source") not in sources:
337
+ sources.append(doc.metadata.get("source"))
338
+
339
+ # Get context from documents
340
+ context = [doc.page_content for doc in docs]
341
+
342
+ # Get OpenRouter adapter for free LLMs
343
+ adapter = get_openrouter_adapter()
344
+ chain = OpenRouterFreeChain(adapter)
345
+
346
+ # Generate answer
347
+ answer = chain.run(request.query, context)
348
+
349
+ return QueryResponse(
350
+ answer=answer,
351
+ sources=sources
352
+ )
353
+ except Exception as e:
354
+ logger.error(f"Error querying documents: {e}")
355
+ raise HTTPException(
356
+ status_code=500, detail=f"Error querying documents: {str(e)}")
357
+
358
+
359
+ @app.delete("/collections/{collection_id}/documents/{filename}")
360
+ async def delete_document(collection_id: str, filename: str):
361
+ """Delete a document from a collection."""
362
+ try:
363
+ file_path = os.path.join(UPLOAD_DIR, collection_id, filename)
364
+ if not os.path.exists(file_path):
365
+ raise HTTPException(
366
+ status_code=404, detail=f"Document {filename} not found")
367
+
368
+ os.remove(file_path)
369
+
370
+ # Rebuild vector store if needed
371
+ collection_path = os.path.join(DB_DIR, collection_id)
372
+ if os.path.exists(collection_path):
373
+ # In production, you would selectively remove documents rather than rebuilding
374
+ shutil.rmtree(collection_path)
375
+
376
+ # If there are still documents, rebuild the vector store
377
+ collection_dir = os.path.join(UPLOAD_DIR, collection_id)
378
+ if os.path.exists(collection_dir) and os.listdir(collection_dir):
379
+ file_paths = []
380
+ for fname in os.listdir(collection_dir):
381
+ fpath = os.path.join(collection_dir, fname)
382
+ if os.path.isfile(fpath):
383
+ content_type = "application/octet-stream"
384
+ if fname.endswith('.pdf'):
385
+ content_type = "application/pdf"
386
+ elif fname.endswith('.md'):
387
+ content_type = "text/markdown"
388
+ elif fname.endswith('.txt'):
389
+ content_type = "text/plain"
390
+ file_paths.append((fpath, content_type, fname))
391
+
392
+ if file_paths:
393
+ process_documents(collection_id, file_paths)
394
+
395
+ # Remove from in-memory cache
396
+ if collection_id in vector_dbs:
397
+ del vector_dbs[collection_id]
398
+
399
+ return JSONResponse(content={"message": f"Document {filename} deleted"})
400
+ except Exception as e:
401
+ logger.error(f"Error deleting document: {e}")
402
+ raise HTTPException(
403
+ status_code=500, detail=f"Error deleting document: {str(e)}")
404
+
405
+
406
+ @app.get("/llm/info", response_model=LLMInfo)
407
+ async def get_llm_info():
408
+ """Get the current LLM information."""
409
+ adapter = get_openrouter_adapter()
410
+
411
+ return LLMInfo(
412
+ model=adapter.model,
413
+ is_free=True,
414
+ provider="openrouter"
415
+ )
416
+
417
+
418
+ @app.get("/llm/models", response_model=LLMModelsList)
419
+ async def list_free_models():
420
+ """List all available free models."""
421
+ adapter = get_openrouter_adapter()
422
+ free_models = adapter.list_free_models()
423
+
424
+ # Create a simplified list for the frontend
425
+ model_list = []
426
+ for model in free_models:
427
+ model_info = {
428
+ "id": model.get("id"),
429
+ "name": model.get("name", model.get("id")),
430
+ "context_length": model.get("context_length", 4096),
431
+ "provider": model.get("id").split("/")[0] if "/" in model.get("id") else "unknown"
432
+ }
433
+ model_list.append(model_info)
434
+
435
+ return LLMModelsList(
436
+ current_model=adapter.model,
437
+ free_models=model_list
438
+ )
439
+
440
+
441
+ @app.post("/llm/change-model")
442
+ async def change_model(model_info: LLMInfo):
443
+ """Change the LLM model (only to another free model)."""
444
+ adapter = get_openrouter_adapter()
445
+
446
+ # Make sure the model has the :free suffix if it doesn't already
447
+ model_id = model_info.model
448
+ if not model_id.endswith(":free") and ":free" not in model_id:
449
+ model_id = f"{model_id}:free"
450
+
451
+ # Set the new model
452
+ adapter.model = model_id
453
+
454
+ return JSONResponse(content={"message": f"Model changed to {model_id}"})