zakihassan04 commited on
Commit
ee5f022
·
verified ·
1 Parent(s): 3828e70

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
3
+ os.environ["HF_HOME"] = "/tmp"
4
+
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel
7
+ import json
8
+ import torch
9
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer
10
+ from sentence_transformers import SentenceTransformer, util
11
+
12
+ # Load dataset
13
+ with open("data/gpt2_ready_filtered.jsonl", "r", encoding="utf-8") as f:
14
+ data = [json.loads(line) for line in f]
15
+ texts = [item["text"] for item in data]
16
+
17
+ # Load model
18
+ model_name = "nurfarah57/SQ-MT5"
19
+ tokenizer = MT5Tokenizer.from_pretrained(model_name)
20
+ model = MT5ForConditionalGeneration.from_pretrained(model_name)
21
+ model.eval()
22
+
23
+ # Load sentence embedder
24
+ embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
25
+ embeddings = embedder.encode(texts, convert_to_tensor=True)
26
+
27
+ # FastAPI app
28
+ app = FastAPI(
29
+ title="Somali QA API",
30
+ description="Su’aal weydii oo hel jawaab laga raadshay dataset-ka ama laga sameeyay model.",
31
+ version="1.0"
32
+ )
33
+
34
+ # Input schema
35
+ class QuestionRequest(BaseModel):
36
+ question: str
37
+
38
+ # Extract question/answer from dataset line
39
+ def extract_qa(text):
40
+ parts = text.split("\nJawaab:")
41
+ if len(parts) == 2:
42
+ return parts[0].replace("Su'aal:", "").strip(), parts[1].strip()
43
+ return None, None
44
+
45
+ # Match dataset semantically
46
+ def find_semantic_match(question, threshold=0.90):
47
+ user_emb = embedder.encode(question, convert_to_tensor=True)
48
+ hits = util.semantic_search(user_emb, embeddings, top_k=1)
49
+ if hits and hits[0][0]["score"] >= threshold:
50
+ idx = hits[0][0]["corpus_id"]
51
+ _, jawaab = extract_qa(texts[idx])
52
+ return jawaab
53
+ return None
54
+
55
+ # Fallback generation
56
+ def generate_with_mt5(question):
57
+ prompt = f"su'aal: {question}"
58
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
59
+ with torch.no_grad():
60
+ outputs = model.generate(inputs["input_ids"], max_length=128)
61
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
62
+
63
+ # API endpoint
64
+ @app.post("/qa")
65
+ def answer_question(req: QuestionRequest):
66
+ if not req.question.strip():
67
+ raise HTTPException(status_code=400, detail="Su’aal lama helin.")
68
+
69
+ match = find_semantic_match(req.question)
70
+ if match:
71
+ return {"answer": match, "source": "dataset"}
72
+
73
+ generated = generate_with_mt5(req.question)
74
+ return {"answer": generated, "source": "model"}
75
+
76
+ # Root
77
+ @app.get("/")
78
+ def root():
79
+ return {"message": "✅ Somali QA API is running!"}