triAGI-Coder / supplemental.py
acecalisto3's picture
Update supplemental.py
898963b verified
import os
import json
import logging
from typing import List, Dict, Optional, Any
import re
from abc import ABC, abstractmethod
from huggingface_hub import HfApi, InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from dataclasses import dataclass
@dataclass
class ProjectConfig:
name: str
description: str
technologies: List[str]
structure: Dict[str, List[str]]
class WebDevelopmentTool(ABC):
def __init__(self, name: str, description: str):
self.name = name
self.description = description
@abstractmethod
def generate_code(self, *args, **kwargs):
pass
class HTMLGenerator(WebDevelopmentTool):
def __init__(self):
super().__init__("HTML Generator", "Generates HTML code for web pages")
def generate_code(self, structure: Dict[str, Any]) -> str:
html = "<html><body>"
for tag, content in structure.items():
html += f"<{tag}>{content}</{tag}>"
html += "</body></html>"
return html
class CSSGenerator(WebDevelopmentTool):
def __init__(self):
super().__init__("CSS Generator", "Generates CSS code for styling web pages")
def generate_code(self, styles: Dict[str, Dict[str, str]]) -> str:
css = ""
for selector, properties in styles.items():
css += f"{selector} {{\n"
for prop, value in properties.items():
css += f" {prop}: {value};\n"
css += "}\n"
return css
class JavaScriptGenerator(WebDevelopmentTool):
def __init__(self):
super().__init__("JavaScript Generator", "Generates JavaScript code for web functionality")
def generate_code(self, functions: List[Dict[str, Any]]) -> str:
js = ""
for func in functions:
js += f"function {func['name']}({', '.join(func['params'])}) {{\n"
js += f" {func['body']}\n"
js += "}\n\n"
return js
class ProjectConfig:
def __init__(self, name: str, description: str, technologies: List[str], structure: Dict[str, List[str]]):
self.name = name
self.description = description
self.technologies = technologies
self.structure = structure
class HTMLGenerator:
def generate(self, content: str) -> str:
return f"<html><body>{content}</body></html>"
class CSSGenerator:
def generate(self, styles: Dict[str, str]) -> str:
return "\n".join([f"{selector} {{ {'; '.join([f'{prop}: {value}' for prop, value in properties.items()])} }}" for selector, properties in styles.items()])
class JavaScriptGenerator:
def generate(self, functionality: str) -> str:
return f"function main() {{ {functionality} }}"
class EnhancedAIAgent:
def __init__(self, name: str, description: str, skills: List[str], model_name: str):
self.name = name
self.description = description
self.skills = skills
self.model_name = model_name
self.html_gen_tool = HTMLGenerator()
self.css_gen_tool = CSSGenerator()
self.js_gen_tool = JavaScriptGenerator()
self.hf_api = HfApi()
self.inference_client = InferenceClient(model=model_name, token=os.environ.get("HF_API_TOKEN"))
self.tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)
self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
self.text_generation = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, clean_up_tokenization_spaces=True)
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def generate_agent_response(self, prompt: str) -> str:
try:
response = self.inference_client.text_generation(prompt, max_new_tokens=100)
self.logger.info(f"Generated response for prompt: {prompt[:50]}...")
return response.generated_text
except Exception as e:
self.logger.error(f"Error generating response: {str(e)}", exc_info=True)
return f"Error: Unable to generate response. {str(e)}"
def generate_project_config(self, project_description: str) -> Optional[ProjectConfig]:
prompt = f"""
Based on the following project description, generate a ProjectConfig object:
Description: {project_description}
The ProjectConfig should include:
- name: A short, descriptive name for the project
- description: A brief summary of the project
- technologies: A list of technologies to be used (e.g., ["HTML", "CSS", "JavaScript", "React"])
- structure: A dictionary representing the file structure, where keys are directories and values are lists of files
Respond with a JSON object representing the ProjectConfig.
"""
response = self.generate_agent_response(prompt)
try:
json_start = response.find('{')
json_end = response.rfind('}') + 1
if json_start != -1 and json_end != -1:
json_str = response[json_start:json_end]
config_dict = json.loads(json_str)
return ProjectConfig(**config_dict)
else:
raise ValueError("No JSON object found in the response")
except (json.JSONDecodeError, ValueError) as e:
self.logger.error(f"Error parsing JSON from response: {str(e)}")
self.logger.error(f"Full response from model: {response}")
try:
partial_config = self.extract_partial_config(response)
if partial_config:
self.logger.warning("Extracted partial config from malformed response")
return partial_config
except Exception as ex:
self.logger.error(f"Failed to extract partial config: {str(ex)}")
return None
def extract_partial_config(self, response: str) -> Optional[ProjectConfig]:
name = self.extract_field(response, "name")
description = self.extract_field(response, "description")
technologies = self.extract_list(response, "technologies")
structure = self.extract_dict(response, "structure")
if name and description:
return ProjectConfig(
name=name,
description=description,
technologies=technologies or [],
structure=structure or {}
)
return None
def extract_field(self, text: str, field: str) -> Optional[str]:
match = re.search(rf'"{field}"\s*:\s*"([^"]*)"', text)
return match.group(1) if match else None
def extract_list(self, text: str, field: str) -> Optional[List[str]]:
match = re.search(rf'"{field}"\s*:\s*\[(.*?)\]', text, re.DOTALL)
if match:
items = re.findall(r'"([^"]*)"', match.group(1))
return items
return None
def extract_dict(self, text: str, field: str) -> Optional[Dict[str, List[str]]]:
match = re.search(rf'"{field}"\s*:\s*\{{(.*?)\}}', text, re.DOTALL)
if match:
dict_str = match.group(1)
result = {}
for item in re.finditer(r'"([^"]*)"\s*:\s*\[(.*?)\]', dict_str, re.DOTALL):
key = item.group(1)
values = re.findall(r'"([^"]*)"', item.group(2))
result[key] = values
return result
return None
def generate_html(self, content: str) -> str:
return self.html_gen_tool.generate(content)
def generate_css(self, styles: Dict[str, str]) -> str:
return self.css_gen_tool.generate(styles)
def generate_javascript(self, functionality: str) -> str:
return self.js_gen_tool.generate(functionality)
def create_project_files(self, config: ProjectConfig) -> Dict[str, str]:
files = {}
for directory, file_list in config.structure.items():
for file in file_list:
file_path = os.path.join(directory, file)
if file.endswith('.html'):
files[file_path] = self.generate_html(f"Content for {file}")
elif file.endswith('.css'):
files[file_path] = self.generate_css({"body": {"font-family": "Arial, sans-serif"}})
elif file.endswith('.js'):
files[file_path] = self.generate_javascript(f"console.log('Script for {file}');")
else:
files[file_path] = f"Content for {file}"
return files
def execute_project(self, project_description: str) -> Dict[str, str]:
config = self.generate_project_config(project_description)
if config:
return self.create_project_files(config)
else:
self.logger.error("Failed to generate project configuration")
return {}