Loren commited on
Commit
17eb81a
·
verified ·
1 Parent(s): fe4e560

Upload main.py

Browse files
Files changed (1) hide show
  1. app/main.py +188 -173
app/main.py CHANGED
@@ -1,173 +1,188 @@
1
- from fastapi import FastAPI, Query
2
- from typing import List, Optional, Dict, Any
3
- from app import database
4
- from fastapi.middleware.cors import CORSMiddleware
5
-
6
- from pydantic import BaseModel
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
- import torch
9
- from app.templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
10
-
11
- app = FastAPI(
12
- title="Articles API",
13
- description="API pour récupérer articles et tags depuis SQLite",
14
- version="1.0"
15
- )
16
-
17
- # Chargement du modèle génératif
18
- MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
21
- torch_dtype=torch.float16,
22
- device_map="auto"
23
- )
24
-
25
- # CORS pour permettre l'accès depuis le navigateur
26
- app.add_middleware(
27
- CORSMiddleware,
28
- allow_origins=["*"], # autorise toutes les origines
29
- allow_credentials=True,
30
- allow_methods=["*"],
31
- allow_headers=["*"],
32
- )
33
-
34
- @app.get("/get_tags")
35
- def get_tags():
36
- """
37
- Récupère la liste de tous les tags disponibles via l'API.
38
-
39
- Returns:
40
- Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur.
41
- - Si succès :
42
- {
43
- "status": "ok",
44
- "tags": List[str] # Liste des noms de tags triés par ordre alphabétique
45
- }
46
- - En cas d'erreur :
47
- {
48
- "status": "error",
49
- "code": str, # Nom de l'exception
50
- "message": str # Message de l'exception
51
- }
52
-
53
- Notes:
54
- - L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`.
55
- - En cas de problème avec la base de données, un message d'erreur détaillé est retourné.
56
- """
57
- try:
58
- dict_result = database.fetch_tags()
59
- if dict_result["status"] == "ok":
60
- return {"status": "ok", "tags": dict_result["result"]}
61
- else:
62
- return dict_result
63
- except Exception as e:
64
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
65
-
66
- @app.get("/get_articles_with_tags")
67
- def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")):
68
- """
69
- Récupère les articles associés à une ou plusieurs tags spécifiés.
70
-
71
- Args:
72
- tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles.
73
- Doit contenir au moins un tag.
74
-
75
- Returns:
76
- Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur.
77
- - Si succès :
78
- {
79
- "status": "ok",
80
- "tags": List[str], # Tags utilisés pour filtrer
81
- "articles": List[Dict] # Liste des articles correspondants
82
- }
83
- Chaque article est un dictionnaire contenant :
84
- - 'article_id': int, ID de l'article
85
- - 'article_title': str, Titre de l'article
86
- - 'article_url': str, URL de l'article
87
- - En cas d'erreur :
88
- {
89
- "status": "error",
90
- "code": str, # Code d'erreur ou nom de l'exception
91
- "message": str # Message d'erreur
92
- }
93
-
94
- Notes:
95
- - Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags".
96
- - L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles.
97
- """
98
- try:
99
- dict_result = database.fetch_articles_by_tags(tags)
100
- if dict_result["status"] == "ok":
101
- return {"status": "ok",
102
- "tags": tags,
103
- "articles": dict_result["result"]}
104
- else:
105
- return dict_result
106
- except Exception as e:
107
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
108
-
109
-
110
- @app.get("/get_query_results")
111
- def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"),
112
- k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"),
113
- k_cross: int = Query(5, description="Nombre de résultats conservés après reranking")
114
- ) -> Dict[str, Any]:
115
- """
116
- Récupère les résultats d'une requête en utilisant deux modèles de recherche.
117
-
118
- Args:
119
- query (str): La requête utilisateur pour laquelle récupérer les résultats.
120
- k_model (int, optional): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10.
121
- k_cross (int, optional): Nombre de résultats à retourner pour le modèle croisé. Par défaut à 5.
122
-
123
- Returns:
124
- Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur.
125
-
126
- Notes:
127
- - L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats.
128
- - En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné.
129
- """
130
- try:
131
- dict_result = database.fetch_query_results(query, k_model, k_cross)
132
- if dict_result["status"] == "ok":
133
- return {"status": "ok",
134
- "results": dict_result["result"]}
135
- else:
136
- return dict_result
137
- except Exception as e:
138
- return {"status": "error", "code": type(e).__name__, "message": str(e)}
139
-
140
- # 🔹 Exemple de modèle d'entrée utilisateur
141
- class QueryRequest(BaseModel):
142
- question: str
143
-
144
- @app.post("/ask")
145
- async def ask_question(request: QueryRequest):
146
- try:
147
- user_query = request.question.strip()
148
- dict_result = database.fetch_query_results(user_query, k_model=10, k_cross=5)
149
- if dict_result["status"] == "ok":
150
- list_chunks = [resp['chunk_text'] for resp in dict_result['result']]
151
- if not list_chunks:
152
- answer = ("Je ne dispose pas d’informations sur ce sujet. "
153
- "Je peux uniquement répondre à des questions sur les articles " \
154
- "du jeu de données.")
155
- else:
156
- # Construction du prompt
157
- prompt = RAG_PROMPT_TEMPLATE.format(
158
- context="\n".join(list_chunks),
159
- question=user_query
160
- )
161
- # Génération de la réponse
162
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
163
- outputs = model.generate(**inputs, max_new_tokens=500)
164
- generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
165
- answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
166
- else:
167
- answer = f"Une erreur est survenue lors de la récupération des informations : \
168
- {dict_result['code']} - {dict_result['message']}."
169
- return {"answer": answer}
170
- except Exception as e:
171
- answer = f"Une erreur est survenue lors de la récupération des informations : \
172
- {type(e).__name__} - {str(e)}."
173
- return {"answer": answer}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Query
2
+ from typing import List, Dict, Any
3
+ from app import database
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from fastapi.responses import HTMLResponse
6
+
7
+ from pydantic import BaseModel
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ import torch
10
+ from app.templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
11
+
12
+ app = FastAPI(
13
+ title="Articles API",
14
+ description="API pour récupérer articles et tags depuis SQLite",
15
+ version="1.0"
16
+ )
17
+
18
+ # Chargement du modèle génératif
19
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
22
+ torch_dtype=torch.float16,
23
+ device_map="auto"
24
+ )
25
+
26
+ # CORS pour permettre l'accès depuis le navigateur
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"], # autorise toutes les origines
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ @app.get("/", response_class=HTMLResponse)
36
+ def home():
37
+ return """
38
+ <html>
39
+ <head><title>Page d'accueil</title></head>
40
+ <body>
41
+ <h1>Welcome on the API search articles !</h1>
42
+ </body>
43
+ </html>
44
+ """
45
+
46
+ @app.get("/get_tags")
47
+ def get_tags():
48
+ """
49
+ Récupère la liste de tous les tags disponibles via l'API.
50
+
51
+ Returns:
52
+ Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur.
53
+ - Si succès :
54
+ {
55
+ "status": "ok",
56
+ "tags": List[str] # Liste des noms de tags triés par ordre alphabétique
57
+ }
58
+ - En cas d'erreur :
59
+ {
60
+ "status": "error",
61
+ "code": str, # Nom de l'exception
62
+ "message": str # Message de l'exception
63
+ }
64
+
65
+ Notes:
66
+ - L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`.
67
+ - En cas de problème avec la base de données, un message d'erreur détaillé est retourné.
68
+ """
69
+ try:
70
+ dict_result = database.fetch_tags()
71
+ if dict_result["status"] == "ok":
72
+ return {"status": "ok", "tags": dict_result["result"]}
73
+ else:
74
+ return dict_result
75
+ except Exception as e:
76
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
77
+
78
+ @app.get("/get_articles_with_tags")
79
+ def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")):
80
+ """
81
+ Récupère les articles associés à une ou plusieurs tags spécifiés.
82
+
83
+ Args:
84
+ tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles.
85
+ Doit contenir au moins un tag.
86
+
87
+ Returns:
88
+ Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur.
89
+ - Si succès :
90
+ {
91
+ "status": "ok",
92
+ "tags": List[str], # Tags utilisés pour filtrer
93
+ "articles": List[Dict] # Liste des articles correspondants
94
+ }
95
+ Chaque article est un dictionnaire contenant :
96
+ - 'article_id': int, ID de l'article
97
+ - 'article_title': str, Titre de l'article
98
+ - 'article_url': str, URL de l'article
99
+ - En cas d'erreur :
100
+ {
101
+ "status": "error",
102
+ "code": str, # Code d'erreur ou nom de l'exception
103
+ "message": str # Message d'erreur
104
+ }
105
+
106
+ Notes:
107
+ - Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags".
108
+ - L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles.
109
+ """
110
+ try:
111
+ dict_result = database.fetch_articles_by_tags(tags)
112
+ if dict_result["status"] == "ok":
113
+ return {"status": "ok",
114
+ "tags": tags,
115
+ "articles": dict_result["result"]}
116
+ else:
117
+ return dict_result
118
+ except Exception as e:
119
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
120
+
121
+
122
+ @app.get("/get_query_results")
123
+ def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"),
124
+ k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"),
125
+ k_cross: int = Query(5, description="Nombre de résultats conservés après reranking")
126
+ ) -> Dict[str, Any]:
127
+ """
128
+ Récupère les résultats d'une requête en utilisant deux modèles de recherche.
129
+
130
+ Args:
131
+ query (str): La requête utilisateur pour laquelle récupérer les résultats.
132
+ k_model (int, optional): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10.
133
+ k_cross (int, optional): Nombre de résultats à retourner pour le modèle croisé. Par défaut à 5.
134
+
135
+ Returns:
136
+ Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur.
137
+
138
+ Notes:
139
+ - L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats.
140
+ - En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné.
141
+ """
142
+ try:
143
+ dict_result = database.fetch_query_results(query, k_model, k_cross)
144
+ if dict_result["status"] == "ok":
145
+ return {"status": "ok",
146
+ "results": dict_result["result"]}
147
+ else:
148
+ return dict_result
149
+ except Exception as e:
150
+ return {"status": "error", "code": type(e).__name__, "message": str(e)}
151
+
152
+ # 🔹 Exemple de modèle d'entrée utilisateur
153
+ class QueryRequest(BaseModel):
154
+ question: str
155
+
156
+ @app.post("/ask")
157
+ async def ask_question(request: QueryRequest):
158
+ try:
159
+ user_query = request.question.strip()
160
+ dict_result = database.fetch_query_results(user_query, k_model=10, k_cross=5)
161
+ if dict_result["status"] == "ok":
162
+ list_chunks = [resp['chunk_text'] for resp in dict_result['result']]
163
+ if not list_chunks:
164
+ answer = ("Je ne dispose pas d’informations sur ce sujet. "
165
+ "Je peux uniquement répondre à des questions sur les articles " \
166
+ "du jeu de données.")
167
+ else:
168
+ # Construction du prompt
169
+ prompt = RAG_PROMPT_TEMPLATE.format(
170
+ context="\n".join(list_chunks),
171
+ question=user_query
172
+ )
173
+ # Génération de la réponse
174
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
175
+ outputs = model.generate(**inputs, max_new_tokens=500)
176
+ generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
177
+ answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
178
+ return {"status": "ok",
179
+ "results": dict_result["result"],
180
+ "answer": answer}
181
+ else:
182
+ answer = f"Une erreur est survenue lors de la récupération des informations : \
183
+ {dict_result['code']} - {dict_result['message']}."
184
+ return {"status": "error", "answer": answer}
185
+ except Exception as e:
186
+ answer = f"Une erreur est survenue lors de la récupération des informations : \
187
+ {type(e).__name__} - {str(e)}."
188
+ return {"status": "error", "answer": answer}