Spaces:
Sleeping
Sleeping
| """Integration tests for LLM providers.""" | |
| import pytest | |
| import os | |
| from pathlib import Path | |
| import tempfile | |
| from src.rag import ProjectRAG | |
| from src.agent import ProjectAgent | |
| # Sample meeting for testing | |
| SAMPLE_MEETING = """# Meeting: Test Sprint Planning | |
| Date: 2025-01-15 | |
| Participants: Alice, Bob | |
| ## Discussion | |
| Discussed the test implementation. | |
| ## Decisions | |
| - Use pytest for testing | |
| ## Action Items | |
| - [ ] Alice: Write unit tests by 2025-01-20 | |
| - [ ] Bob: Review code by 2025-01-18 | |
| ## Blockers | |
| - Waiting for API access | |
| """ | |
| def test_rag(): | |
| """Create a RAG system with test data.""" | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Create data structure | |
| data_dir = Path(temp_dir) / "data" | |
| project_dir = data_dir / "test_project" / "meetings" | |
| project_dir.mkdir(parents=True) | |
| # Write sample meeting | |
| (project_dir / "2025-01-15-sprint.md").write_text(SAMPLE_MEETING) | |
| # Create persistent dir for ChromaDB | |
| persist_dir = Path(temp_dir) / "chroma" | |
| # Initialize RAG | |
| rag = ProjectRAG(data_dir, persist_dir=persist_dir) | |
| rag.load_and_index() | |
| yield rag | |
| class TestHuggingFaceProvider: | |
| """Integration tests for HuggingFace provider.""" | |
| def test_hf_agent_creation(self, test_rag): | |
| """Test that HuggingFace agent can be created with valid token.""" | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| pytest.skip("HF_TOKEN not set") | |
| agent = ProjectAgent(test_rag, provider="huggingface") | |
| assert agent is not None | |
| assert agent.provider == "huggingface" | |
| assert agent.llm is not None | |
| def test_hf_simple_query(self, test_rag): | |
| """Test a simple query with HuggingFace.""" | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| pytest.skip("HF_TOKEN not set") | |
| agent = ProjectAgent(test_rag, provider="huggingface") | |
| response = agent.query("What are the action items?") | |
| assert response is not None | |
| assert len(response) > 0 | |
| # Should mention Alice or Bob from the test data | |
| assert "alice" in response.lower() or "bob" in response.lower() or "test" in response.lower() | |
| def test_hf_blockers_query(self, test_rag): | |
| """Test blockers query with HuggingFace.""" | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| pytest.skip("HF_TOKEN not set") | |
| agent = ProjectAgent(test_rag, provider="huggingface") | |
| response = agent.query("What blockers do we have?") | |
| assert response is not None | |
| assert len(response) > 0 | |
| def test_hf_invalid_token(self, test_rag): | |
| """Test that invalid token raises appropriate error.""" | |
| os.environ["HF_TOKEN"] = "invalid_token_12345" | |
| agent = ProjectAgent(test_rag, provider="huggingface") | |
| with pytest.raises(Exception) as exc_info: | |
| agent.query("What are the action items?") | |
| # Should get an authentication error | |
| error_msg = str(exc_info.value).lower() | |
| assert "401" in error_msg or "unauthorized" in error_msg or "invalid" in error_msg or "error" in error_msg | |
| class TestGoogleProvider: | |
| """Integration tests for Google provider.""" | |
| def test_google_agent_creation(self, test_rag): | |
| """Test that Google agent can be created with valid key.""" | |
| api_key = os.getenv("GOOGLE_API_KEY") | |
| if not api_key: | |
| pytest.skip("GOOGLE_API_KEY not set") | |
| agent = ProjectAgent(test_rag, provider="google") | |
| assert agent is not None | |
| assert agent.provider == "google" | |
| assert agent.llm is not None | |
| def test_google_simple_query(self, test_rag): | |
| """Test a simple query with Google.""" | |
| api_key = os.getenv("GOOGLE_API_KEY") | |
| if not api_key: | |
| pytest.skip("GOOGLE_API_KEY not set") | |
| agent = ProjectAgent(test_rag, provider="google") | |
| response = agent.query("What are the action items?") | |
| assert response is not None | |
| assert len(response) > 0 | |
| def test_google_blockers_query(self, test_rag): | |
| """Test blockers query with Google.""" | |
| api_key = os.getenv("GOOGLE_API_KEY") | |
| if not api_key: | |
| pytest.skip("GOOGLE_API_KEY not set") | |
| agent = ProjectAgent(test_rag, provider="google") | |
| response = agent.query("What blockers do we have?") | |
| assert response is not None | |
| assert len(response) > 0 | |
| def test_google_invalid_key(self, test_rag): | |
| """Test that invalid key raises appropriate error.""" | |
| os.environ["GOOGLE_API_KEY"] = "invalid_key_12345" | |
| agent = ProjectAgent(test_rag, provider="google") | |
| with pytest.raises(Exception): | |
| agent.query("What are the action items?") | |