|
|
"""Example usage of MbazaAI""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) |
|
|
|
|
|
from mbaza_ai import LegalAI, CountryDatasetLoader, Config |
|
|
from mbaza_ai.utils import setup_logging |
|
|
|
|
|
|
|
|
setup_logging("INFO") |
|
|
|
|
|
def main(): |
|
|
|
|
|
print("ποΈ Initializing MbazaAI...") |
|
|
config = Config() |
|
|
dataset_loader = CountryDatasetLoader(config) |
|
|
legal_ai = LegalAI(config, dataset_loader) |
|
|
|
|
|
print(f"π Current country: {dataset_loader.current_country}") |
|
|
print(f"π Available countries: {', '.join(config.get_available_countries())}") |
|
|
|
|
|
|
|
|
questions = [ |
|
|
"What are the requirements for Rwandan citizenship?", |
|
|
"What is the penalty for theft in Rwanda?", |
|
|
"What are the basic constitutional rights?", |
|
|
] |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("π€ MbazaAI Legal Question Answering Demo") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
for question in questions: |
|
|
print(f"\nβ Question: {question}") |
|
|
result = legal_ai.predict(question) |
|
|
print(f"β
Answer: {result['answer']}") |
|
|
print(f"π― Confidence: {result['confidence']:.2f}") |
|
|
print(f"π Category: {result.get('category', 'N/A')}") |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
print("\nπ Adding a new country (Kenya)...") |
|
|
config.add_country( |
|
|
country_code="kenya", |
|
|
name="Kenya", |
|
|
dataset_path="datasets/kenya/legal_data.json", |
|
|
language="en", |
|
|
legal_system="common_law" |
|
|
) |
|
|
|
|
|
print(f"β
Updated available countries: {', '.join(config.get_available_countries())}") |
|
|
|
|
|
|
|
|
print("\nπ Switching to Kenya legal dataset...") |
|
|
if legal_ai.switch_country("kenya"): |
|
|
print("β
Successfully switched to Kenya!") |
|
|
|
|
|
|
|
|
result = legal_ai.predict("What are the basic rights in Kenya?") |
|
|
print(f"\nβ Question: What are the basic rights in Kenya?") |
|
|
print(f"β
Answer: {result['answer']}") |
|
|
print(f"π Country: {result['country']}") |
|
|
|
|
|
|
|
|
print("\n" + "="*40) |
|
|
print("π Model Information") |
|
|
print("="*40) |
|
|
info = legal_ai.get_model_info() |
|
|
print(f"Model: {info['model_name']}") |
|
|
print(f"Current Country: {info['country']}") |
|
|
print(f"Legal Categories: {', '.join(info['legal_categories'])}") |
|
|
|
|
|
print("\nπ Demo completed successfully!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |