Spaces:
Configuration error
Configuration error
# farmlingua/app/agents/crew_pipeline.py333 | |
import os | |
import sys | |
import requests | |
import joblib | |
import faiss | |
import numpy as np | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
from app.utils import config | |
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
if BASE_DIR not in sys.path: | |
sys.path.insert(0, BASE_DIR) | |
DEVICE = 0 if os.environ.get("CUDA_VISIBLE_DEVICES") else -1 | |
try: | |
classifier = joblib.load(config.CLASSIFIER_PATH) | |
except Exception: | |
classifier = None | |
print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...") | |
expert_pipeline = pipeline( | |
"text-generation", | |
model=config.EXPERT_MODEL_NAME, | |
device=DEVICE, | |
max_new_tokens=700, | |
temperature=0.3, | |
repetition_penalty=1.1 | |
) | |
print(f"Loading formatter/weather model ({config.FORMATTER_MODEL_NAME})...") | |
formatter_pipeline = pipeline( | |
"text2text-generation", | |
model=config.FORMATTER_MODEL_NAME, | |
device=DEVICE | |
) | |
embedder = SentenceTransformer(config.EMBEDDING_MODEL) | |
def retrieve_docs(query, vs_path): | |
if not vs_path or not os.path.exists(vs_path): | |
return None | |
if os.path.isdir(vs_path): | |
try: | |
from langchain.vectorstores import FAISS as LCFAISS | |
from langchain.embeddings import SentenceTransformerEmbeddings | |
embed_model = SentenceTransformerEmbeddings(model_name=config.EMBEDDING_MODEL) | |
vs = LCFAISS.load_local(str(vs_path), embed_model, allow_dangerous_deserialization=True) | |
docs = vs.similarity_search(query, k=3) | |
return "\n\n".join(d.page_content for d in docs) if docs else None | |
except Exception: | |
return None | |
try: | |
index = faiss.read_index(str(vs_path)) | |
except Exception: | |
return None | |
query_vec = np.array([embedder.encode(query)], dtype=np.float32) | |
D, I = index.search(query_vec, k=3) | |
if D[0][0] == 0: | |
return None | |
meta_path = str(vs_path) + "_meta.npy" | |
if os.path.exists(meta_path): | |
metadata = np.load(meta_path, allow_pickle=True).item() | |
docs = [metadata.get(str(idx), "") for idx in I[0] if str(idx) in metadata] | |
docs = [doc for doc in docs if doc] | |
return "\n\n".join(docs) if docs else None | |
return None | |
def get_weather(state_name): | |
url = "http://api.weatherapi.com/v1/current.json" | |
params = { | |
"key": config.WEATHER_API_KEY, | |
"q": f"{state_name}, Nigeria", | |
"aqi": "no" | |
} | |
r = requests.get(url, params=params) | |
if r.status_code != 200: | |
return f"Unable to retrieve weather for {state_name}." | |
data = r.json() | |
return ( | |
f"Weather in {state_name}:\n" | |
f"- Condition: {data['current']['condition']['text']}\n" | |
f"- Temperature: {data['current']['temp_c']}°C\n" | |
f"- Humidity: {data['current']['humidity']}%\n" | |
f"- Wind: {data['current']['wind_kph']} kph" | |
) | |
def detect_intent(query): | |
q_lower = query.lower() | |
if any(word in q_lower for word in ["weather", "temperature", "rain", "forecast"]): | |
for state in config.STATES: | |
if state.lower() in q_lower: | |
return "weather", state | |
return "weather", None | |
if any(word in q_lower for word in ["latest", "update", "breaking", "news", "current", "predict"]): | |
return "live_update", None | |
if hasattr(classifier, "predict") and hasattr(classifier, "predict_proba"): | |
predicted_intent = classifier.predict([query])[0] | |
confidence = max(classifier.predict_proba([query])[0]) | |
if confidence < config.CLASSIFIER_CONFIDENCE_THRESHOLD: | |
return "low_confidence", None | |
return predicted_intent, None | |
return "normal", None | |
def run_pipeline(user_query: str): | |
intent, extra = detect_intent(user_query) | |
if intent == "weather" and extra: | |
weather_text = get_weather(extra) | |
return formatter_pipeline(weather_text, max_length=256, do_sample=False)[0]["generated_text"] | |
if intent == "live_update": | |
context = retrieve_docs(user_query, config.LIVE_VS_PATH) | |
if context: | |
user_query += f"\n\nLatest agricultural updates:\n{context}" | |
if intent == "low_confidence": | |
context = retrieve_docs(user_query, config.STATIC_VS_PATH) | |
if context: | |
user_query += f"\n\nReference information:\n{context}" | |
expert_response = expert_pipeline( | |
f"Provide a detailed agricultural answer for: {user_query}", | |
max_new_tokens=700, | |
temperature=0.3 | |
)[0]['generated_text'] | |
formatted_response = formatter_pipeline( | |
f"Format the following answer to be clear, structured, and easy to understand for Nigerian farmers:\n\n{expert_response}", | |
max_length=512, | |
do_sample=False | |
)[0]['generated_text'] | |
return formatted_response | |