Hadiil commited on
Commit
3bd7faf
·
verified ·
1 Parent(s): 97e0748

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -101
app.py CHANGED
@@ -1,7 +1,8 @@
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
  from fastapi.staticfiles import StaticFiles
3
- from fastapi.responses import RedirectResponse
4
  from transformers import pipeline, MarianMTModel, MarianTokenizer
 
5
  from typing import Optional
6
  import logging
7
  from PIL import Image
@@ -10,18 +11,18 @@ from docx import Document
10
  import fitz # PyMuPDF
11
  import pandas as pd
12
  from functools import lru_cache
13
- import os
14
 
15
  # Configure logging
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
- app = FastAPI(title="Vion AI Chatbot")
20
 
21
  # Mount static files
22
  app.mount("/static", StaticFiles(directory="static"), name="static")
23
 
24
- # ---- Model Initialization ----
25
  MODELS = {
26
  "summarization": "t5-small",
27
  "translation": {
@@ -30,12 +31,11 @@ MODELS = {
30
  "de": "Helsinki-NLP/opus-mt-en-de"
31
  },
32
  "image_captioning": "Salesforce/blip-image-captioning-base",
33
- "qa": "deepset/roberta-base-squad2" # Better for QA than t5-small
34
  }
35
 
36
- @lru_cache(maxsize=1)
37
- def get_pipeline(task: str, model_name: str = None):
38
- """Cached model loader with error handling"""
39
  try:
40
  if task == "translation" and model_name:
41
  tokenizer = MarianTokenizer.from_pretrained(model_name)
@@ -43,108 +43,130 @@ def get_pipeline(task: str, model_name: str = None):
43
  return pipeline("translation", model=model, tokenizer=tokenizer)
44
  return pipeline(task, model=model_name or MODELS.get(task))
45
  except Exception as e:
46
- logger.error(f"Failed to load {task} model: {str(e)}")
47
- raise HTTPException(status_code=500, detail=f"Model loading failed: {task}")
48
 
49
- # ---- Core Endpoints ----
50
- @app.post("/summarize")
51
- async def summarize_text(file: UploadFile = File(None), text: str = Form(None)):
52
- """Improved summarization endpoint"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  try:
54
- if file:
55
- text = await extract_text_from_file(file)
56
- elif not text:
57
- raise HTTPException(status_code=400, detail="No content provided")
58
-
59
- summarizer = get_pipeline("summarization")
60
- summary = summarizer(
61
- f"summarize: {text[:2000]}", # Truncate long texts
62
- max_length=150,
63
- min_length=30,
64
- do_sample=False
65
- )
66
- return {"summary": summary[0]['summary_text']}
67
- except Exception as e:
68
- logger.error(f"Summarization error: {str(e)}")
69
- raise HTTPException(status_code=500, detail="Summarization failed")
70
-
71
- @app.post("/answer")
72
- async def answer_question(
73
- question: str = Form(...),
74
- context: str = Form(None),
75
- file: UploadFile = File(None)
76
- ):
77
- """Fixed QA endpoint with proper answer extraction"""
78
- try:
79
- if file:
80
- context = await extract_text_from_file(file)
81
- elif not context:
82
- raise HTTPException(status_code=400, detail="Missing context")
83
-
84
- qa_pipeline = get_pipeline("qa")
85
- result = qa_pipeline(question=question, context=context[:2000]) # Truncate long contexts
86
- return {"answer": result["answer"]}
87
- except Exception as e:
88
- logger.error(f"QA error: {str(e)}")
89
- raise HTTPException(status_code=500, detail="Question answering failed")
90
-
91
- @app.post("/caption")
92
- async def caption_image(file: UploadFile = File(...)):
93
- """Image captioning endpoint"""
94
- try:
95
- if file.size > 5 * 1024 * 1024: # 5MB limit
96
- raise HTTPException(status_code=413, detail="File too large (max 5MB)")
97
-
98
- image = Image.open(io.BytesIO(await file.read()))
99
- if image.format not in ["JPEG", "PNG"]:
100
- raise HTTPException(status_code=400, detail="Only JPEG/PNG supported")
101
-
102
- captioner = get_pipeline("image_captioning")
103
- result = captioner(image)
104
- return {"caption": result[0]['generated_text']}
105
- except Exception as e:
106
- logger.error(f"Captioning error: {str(e)}")
107
- raise HTTPException(status_code=500, detail="Image processing failed")
108
 
109
- @app.post("/translate")
110
- async def translate_text(
111
- text: str = Form(...),
112
- target_lang: str = Form(...),
113
  file: UploadFile = File(None)
114
  ):
115
- """Translation endpoint"""
 
 
 
116
  try:
117
- if file:
118
- text = await extract_text_from_file(file)
119
-
120
- if target_lang not in MODELS["translation"]:
121
- raise HTTPException(status_code=400, detail="Unsupported language")
122
-
123
- translator = get_pipeline("translation", MODELS["translation"][target_lang])
124
- translated = translator(text[:1000]) # Limit translation length
125
- return {"translation": translated[0]['translation_text']}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  except Exception as e:
127
- logger.error(f"Translation error: {str(e)}")
128
- raise HTTPException(status_code=500, detail="Translation failed")
129
 
130
- # ---- Helper Functions ----
131
  async def extract_text_from_file(file: UploadFile) -> str:
132
- """Extracts text from PDF/DOCX/TXT files"""
133
- try:
134
- content = await file.read()
135
- if file.filename.endswith(".pdf"):
136
- doc = fitz.open(stream=content, filetype="pdf")
137
- return " ".join([page.get_text() for page in doc])
138
- elif file.filename.endswith(".docx"):
139
- doc = Document(io.BytesIO(content))
140
- return "\n".join([para.text for para in doc.paragraphs])
141
- elif file.filename.endswith(".txt"):
142
- return content.decode("utf-8")
143
- else:
144
- raise HTTPException(status_code=400, detail="Unsupported file type")
145
- except Exception as e:
146
- logger.error(f"File extraction error: {str(e)}")
147
- raise HTTPException(status_code=500, detail="File processing failed")
 
 
 
 
 
 
 
 
148
 
149
  @app.get("/", include_in_schema=False)
150
  async def home():
 
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
  from fastapi.staticfiles import StaticFiles
3
+ from fastapi.responses import RedirectResponse, JSONResponse
4
  from transformers import pipeline, MarianMTModel, MarianTokenizer
5
+ from langdetect import detect, LangDetectException
6
  from typing import Optional
7
  import logging
8
  from PIL import Image
 
11
  import fitz # PyMuPDF
12
  import pandas as pd
13
  from functools import lru_cache
14
+ import re
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
+ app = FastAPI(title="Auto-Detect AI Chatbot")
21
 
22
  # Mount static files
23
  app.mount("/static", StaticFiles(directory="static"), name="static")
24
 
25
+ # Model configurations
26
  MODELS = {
27
  "summarization": "t5-small",
28
  "translation": {
 
31
  "de": "Helsinki-NLP/opus-mt-en-de"
32
  },
33
  "image_captioning": "Salesforce/blip-image-captioning-base",
34
+ "qa": "deepset/roberta-base-squad2"
35
  }
36
 
37
+ @lru_cache(maxsize=4)
38
+ def load_model(task: str, model_name: str = None):
 
39
  try:
40
  if task == "translation" and model_name:
41
  tokenizer = MarianTokenizer.from_pretrained(model_name)
 
43
  return pipeline("translation", model=model, tokenizer=tokenizer)
44
  return pipeline(task, model=model_name or MODELS.get(task))
45
  except Exception as e:
46
+ logger.error(f"Model load failed: {str(e)}")
47
+ raise HTTPException(status_code=500, detail="Model loading error")
48
 
49
+ def detect_intent(text: str = None, file: UploadFile = None) -> str:
50
+ """Auto-detects user intent from input"""
51
+ # File-based detection
52
+ if file:
53
+ if file.content_type.startswith('image/'):
54
+ return "image_caption"
55
+ elif file.filename.endswith(('.xlsx', '.xls')):
56
+ return "visualize"
57
+ elif file.filename.endswith(('.pdf', '.docx', '.txt')):
58
+ return "summarize"
59
+
60
+ # Text analysis
61
+ if not text:
62
+ return "unknown"
63
+
64
+ text_lower = text.lower()
65
+
66
+ # Translation detection
67
+ lang_codes = ['fr', 'es', 'de', 'translate', 'traduire']
68
+ if any(re.search(rf'\b{lang}\b', text_lower) for lang in lang_codes):
69
+ return "translate"
70
+
71
+ # Question detection
72
+ question_words = ['what', 'when', 'why', 'how', '?', 'explain']
73
+ if any(word in text_lower for word in question_words):
74
+ return "qa"
75
+
76
+ # Language detection for non-English text
77
  try:
78
+ if detect(text) != 'en' and len(text.split()) > 3:
79
+ return "translate"
80
+ except LangDetectException:
81
+ pass
82
+
83
+ # Default to summarization for long text
84
+ if len(text) > 100:
85
+ return "summarize"
86
+
87
+ return "unknown"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ @app.post("/process")
90
+ async def process_input(
91
+ text: str = Form(None),
 
92
  file: UploadFile = File(None)
93
  ):
94
+ """Unified endpoint for all processing"""
95
+ intent = detect_intent(text, file)
96
+ logger.info(f"Detected intent: {intent}")
97
+
98
  try:
99
+ if intent == "summarize":
100
+ content = await extract_text_from_file(file) if file else text
101
+ summarizer = load_model("summarization")
102
+ summary = summarizer(
103
+ f"summarize: {content[:2000]}",
104
+ max_length=150,
105
+ min_length=30
106
+ )
107
+ return {"response": summary[0]['summary_text'], "type": "summary"}
108
+
109
+ elif intent == "translate":
110
+ content = await extract_text_from_file(file) if file else text
111
+ # Extract target language
112
+ target_lang = "fr" # Default
113
+ if text:
114
+ match = re.search(r'\b(fr|es|de)\b', text.lower())
115
+ if match:
116
+ target_lang = match.group(1)
117
+ translator = load_model("translation", MODELS["translation"][target_lang])
118
+ translated = translator(content[:1000])
119
+ return {"response": translated[0]['translation_text'], "type": "translation"}
120
+
121
+ elif intent == "qa":
122
+ context = await extract_text_from_file(file) if file else None
123
+ qa_pipeline = load_model("qa")
124
+ result = qa_pipeline(question=text, context=context[:2000] if context else "")
125
+ return {"response": result["answer"], "type": "answer"}
126
+
127
+ elif intent == "image_caption":
128
+ image = Image.open(io.BytesIO(await file.read()))
129
+ captioner = load_model("image_captioning")
130
+ caption = captioner(image)
131
+ return {"response": caption[0]['generated_text'], "type": "caption"}
132
+
133
+ elif intent == "visualize":
134
+ df = pd.read_excel(io.BytesIO(await file.read()))
135
+ code = generate_visualization_code(df, text)
136
+ return {"response": code, "type": "visualization_code"}
137
+
138
+ else:
139
+ return {"response": "Please clarify your request", "type": "clarification"}
140
+
141
  except Exception as e:
142
+ logger.error(f"Processing error: {str(e)}")
143
+ raise HTTPException(status_code=500, detail=str(e))
144
 
 
145
  async def extract_text_from_file(file: UploadFile) -> str:
146
+ """Extracts text from supported files"""
147
+ content = await file.read()
148
+ if file.filename.endswith('.pdf'):
149
+ doc = fitz.open(stream=content, filetype="pdf")
150
+ return " ".join(page.get_text() for page in doc)
151
+ elif file.filename.endswith('.docx'):
152
+ doc = Document(io.BytesIO(content))
153
+ return "\n".join(para.text for para in doc.paragraphs)
154
+ elif file.filename.endswith('.txt'):
155
+ return content.decode('utf-8')
156
+ raise HTTPException(status_code=400, detail="Unsupported file type")
157
+
158
+ def generate_visualization_code(df: pd.DataFrame, request: str) -> str:
159
+ """Generates Python visualization code"""
160
+ if "bar" in request.lower():
161
+ return f"""import matplotlib.pyplot as plt
162
+ plt.bar(df['{df.columns[0]}'], df['{df.columns[1]}'])
163
+ plt.title('Bar Chart')
164
+ plt.show()"""
165
+ else:
166
+ return f"""import seaborn as sns
167
+ sns.pairplot(df)
168
+ plt.title('Data Visualization')
169
+ plt.show()"""
170
 
171
  @app.get("/", include_in_schema=False)
172
  async def home():