Spaces:
Build error
Build error
import os | |
import logging | |
import asyncio | |
import yaml | |
from typing import Dict, List, Any, Tuple, Optional | |
from abc import ABC, abstractmethod | |
import gradio as gr | |
from dotenv import load_dotenv | |
from langchain.llms import HuggingFaceHub | |
from langchain.agents import initialize_agent, AgentType | |
from langchain.chains import LLMChain | |
from langchain.prompts import PromptTemplate | |
# Load environment variables | |
load_dotenv() | |
# Custom Exceptions | |
class CodeFusionError(Exception): | |
"""Base exception class for CodeFusion.""" | |
pass | |
class AgentInitializationError(CodeFusionError): | |
"""Raised when there's an error initializing the agent.""" | |
pass | |
class ToolExecutionError(CodeFusionError): | |
"""Raised when there's an error executing a tool.""" | |
pass | |
# Utility Functions | |
def load_config() -> Dict: | |
"""Load configuration from config.yaml file or use default values.""" | |
config_path = 'config.yaml' | |
default_config = { | |
'model_name': "google/flan-t5-xl", | |
'api_key': "your_default_api_key_here", | |
'temperature': 0.5, | |
'verbose': True | |
} | |
try: | |
with open(config_path, 'r') as config_file: | |
config = yaml.safe_load(config_file) | |
except FileNotFoundError: | |
print(f"Config file not found at {config_path}. Using default configuration.") | |
config = default_config | |
# Override with environment variables if set | |
config['api_key'] = os.getenv('HUGGINGFACE_API_KEY', config['api_key']) | |
return config | |
def setup_logging() -> logging.Logger: | |
"""Set up logging configuration.""" | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
filename='codefusion.log' | |
) | |
return logging.getLogger(__name__) | |
# Load configuration and set up logging | |
config = load_config() | |
logger = setup_logging() | |
# Tool Classes | |
class Tool(ABC): | |
"""Abstract base class for all tools used by the agent.""" | |
def __init__(self, name: str, description: str): | |
self.name = name | |
self.description = description | |
self.llm = HuggingFaceHub( | |
repo_id=config['model_name'], | |
model_kwargs={"temperature": config['temperature']}, | |
huggingfacehub_api_token=config['api_key'] | |
) | |
async def run(self, arguments: Dict[str, Any]) -> Dict[str, str]: | |
"""Execute the tool's functionality.""" | |
pass | |
class CodeGenerationTool(Tool): | |
"""Tool for generating code snippets in various languages.""" | |
def __init__(self): | |
super().__init__("Code Generation", "Generates code snippets in various languages.") | |
self.prompt_template = PromptTemplate( | |
input_variables=["language", "code_description"], | |
template="Generate {language} code for: {code_description}" | |
) | |
self.chain = LLMChain(llm=self.llm, prompt=self.prompt_template) | |
async def run(self, arguments: Dict[str, str]) -> Dict[str, str]: | |
language = arguments.get("language", "python") | |
code_description = arguments.get("code_description", "print('Hello, World!')") | |
try: | |
code = await self.chain.arun(language=language, code_description=code_description) | |
return {"output": code} | |
except Exception as e: | |
logger.error(f"Error in CodeGenerationTool: {e}") | |
raise ToolExecutionError(f"Failed to generate code: {e}") | |
class CodeExplanationTool(Tool): | |
"""Tool for explaining code snippets.""" | |
def __init__(self): | |
super().__init__("Code Explanation", "Explains code snippets in simple terms.") | |
self.prompt_template = PromptTemplate( | |
input_variables=["code"], | |
template="Explain the following code in simple terms:\n\n{code}" | |
) | |
self.chain = LLMChain(llm=self.llm, prompt=self.prompt_template) | |
async def run(self, arguments: Dict[str, str]) -> Dict[str, str]: | |
code = arguments.get("code", "print('Hello, World!')") | |
try: | |
explanation = await self.chain.arun(code=code) | |
return {"output": explanation} | |
except Exception as e: | |
logger.error(f"Error in CodeExplanationTool: {e}") | |
raise ToolExecutionError(f"Failed to explain code: {e}") | |
class DebuggingTool(Tool): | |
"""Tool for debugging code snippets.""" | |
def __init__(self): | |
super().__init__("Debugging", "Helps identify and fix issues in code snippets.") | |
self.prompt_template = PromptTemplate( | |
input_variables=["code", "error_message"], | |
template="Debug the following code:\n\n{code}\n\nError message: {error_message}" | |
) | |
self.chain = LLMChain(llm=self.llm, prompt=self.prompt_template) | |
async def run(self, arguments: Dict[str, str]) -> Dict[str, str]: | |
code = arguments.get("code", "") | |
error_message = arguments.get("error_message", "") | |
try: | |
debug_result = await self.chain.arun(code=code, error_message=error_message) | |
return {"output": debug_result} | |
except Exception as e: | |
logger.error(f"Error in DebuggingTool: {e}") | |
raise ToolExecutionError(f"Failed to debug code: {e}") | |
# Agent Class | |
class Agent: | |
"""Represents an AI agent with specific tools and capabilities.""" | |
def __init__(self, name: str, role: str, tools: List[Tool]): | |
self.name = name | |
self.role = role | |
self.tools = tools | |
self.memory: List[tuple] = [] | |
try: | |
self.llm = HuggingFaceHub( | |
repo_id=config['model_name'], | |
model_kwargs={"temperature": config['temperature']}, | |
huggingfacehub_api_token=config['api_key'] | |
) | |
self.agent = initialize_agent( | |
llm=self.llm, | |
tools=self.tools, | |
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
verbose=config['verbose'] | |
) | |
except Exception as e: | |
logger.error(f"Error initializing agent: {e}") | |
raise AgentInitializationError(f"Failed to initialize agent: {e}") | |
async def act(self, prompt: str, context: str) -> str: | |
"""Perform an action based on the given prompt and context.""" | |
self.memory.append((prompt, context)) | |
try: | |
action = await self.agent.arun(prompt, context) | |
return action | |
except Exception as e: | |
logger.error(f"Error during agent action: {e}") | |
raise | |
def __str__(self) -> str: | |
return f"Agent: {self.name} (Role: {self.role})" | |
# Main application functions | |
async def run(message: str, history: List[Tuple[str, str]]) -> str: | |
"""Process user input and generate a response using the agent system.""" | |
agent = Agent( | |
name="CodeFusion", | |
role="AI Coding Assistant", | |
tools=[CodeGenerationTool(), CodeExplanationTool(), DebuggingTool()] | |
) | |
context = "\n".join([f"Human: {h[0]}\nAI: {h[1]}" for h in history]) | |
try: | |
response = await agent.act(message, context) | |
return response | |
except Exception as e: | |
logger.error(f"Error processing request: {e}") | |
return "I apologize, but an error occurred while processing your request. Please try again." | |
async def main(): | |
"""Main function to run the Gradio interface.""" | |
examples = [ | |
["What is the purpose of this AI agent?", "I am an AI coding assistant designed to help with various programming tasks."], | |
["Can you help me generate a Python function to calculate the factorial of a number?", "Certainly! Here's a Python function to calculate the factorial of a number:"], | |
["Explain the concept of recursion in programming.", "Recursion is a programming concept where a function calls itself to solve a problem by breaking it down into smaller, similar subproblems."], | |
] | |
gr.ChatInterface( | |
fn=run, | |
title="CodeFusion: Your AI Coding Assistant", | |
description="Ask me about code generation, explanation, debugging, or any other coding task!", | |
examples=examples, | |
theme="default" | |
).launch() | |
# Simple testing framework | |
def run_tests(): | |
"""Run basic tests for the CodeFusion components.""" | |
async def test_code_generation(): | |
tool = CodeGenerationTool() | |
result = await tool.run({"language": "python", "code_description": "function to add two numbers"}) | |
assert "def" in result["output"], "Code generation failed to produce a function" | |
print("Code Generation Test: Passed") | |
async def test_code_explanation(): | |
tool = CodeExplanationTool() | |
result = await tool.run({"code": "def factorial(n):\n return 1 if n == 0 else n * factorial(n-1)"}) | |
assert "recursive" in result["output"].lower(), "Code explanation failed to mention recursion" | |
print("Code Explanation Test: Passed") | |
async def test_debugging(): | |
tool = DebuggingTool() | |
result = await tool.run({"code": "def divide(a, b):\n return a / b", "error_message": "ZeroDivisionError"}) | |
assert "zero" in result["output"].lower(), "Debugging failed to address division by zero" | |
print("Debugging Test: Passed") | |
async def test_agent(): | |
agent = Agent("TestAgent", "Tester", [CodeGenerationTool(), CodeExplanationTool(), DebuggingTool()]) | |
result = await agent.act("Generate a Python function to calculate the square of a number", "") | |
assert "def" in result and "return" in result, "Agent failed to generate a proper function" | |
print("Agent Test: Passed") | |
async def run_all_tests(): | |
await test_code_generation() | |
await test_code_explanation() | |
await test_debugging() | |
await test_agent() | |
asyncio.run(run_all_tests()) | |
if __name__ == "__main__": | |
import sys | |
if len(sys.argv) > 1 and sys.argv[1] == "--test": | |
run_tests() | |
else: | |
asyncio.run(main()) | |