Spaces:
Sleeping
Sleeping
import os | |
import json | |
import logging | |
from typing import List, Dict, Optional, Any | |
import re | |
from abc import ABC, abstractmethod | |
from huggingface_hub import HfApi, InferenceClient | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from dataclasses import dataclass | |
class ProjectConfig: | |
name: str | |
description: str | |
technologies: List[str] | |
structure: Dict[str, List[str]] | |
class WebDevelopmentTool(ABC): | |
def __init__(self, name: str, description: str): | |
self.name = name | |
self.description = description | |
def generate_code(self, *args, **kwargs): | |
pass | |
class HTMLGenerator(WebDevelopmentTool): | |
def __init__(self): | |
super().__init__("HTML Generator", "Generates HTML code for web pages") | |
def generate_code(self, structure: Dict[str, Any]) -> str: | |
html = "<html><body>" | |
for tag, content in structure.items(): | |
html += f"<{tag}>{content}</{tag}>" | |
html += "</body></html>" | |
return html | |
class CSSGenerator(WebDevelopmentTool): | |
def __init__(self): | |
super().__init__("CSS Generator", "Generates CSS code for styling web pages") | |
def generate_code(self, styles: Dict[str, Dict[str, str]]) -> str: | |
css = "" | |
for selector, properties in styles.items(): | |
css += f"{selector} {{\n" | |
for prop, value in properties.items(): | |
css += f" {prop}: {value};\n" | |
css += "}\n" | |
return css | |
class JavaScriptGenerator(WebDevelopmentTool): | |
def __init__(self): | |
super().__init__("JavaScript Generator", "Generates JavaScript code for web functionality") | |
def generate_code(self, functions: List[Dict[str, Any]]) -> str: | |
js = "" | |
for func in functions: | |
js += f"function {func['name']}({', '.join(func['params'])}) {{\n" | |
js += f" {func['body']}\n" | |
js += "}\n\n" | |
return js | |
class ProjectConfig: | |
def __init__(self, name: str, description: str, technologies: List[str], structure: Dict[str, List[str]]): | |
self.name = name | |
self.description = description | |
self.technologies = technologies | |
self.structure = structure | |
class HTMLGenerator: | |
def generate(self, content: str) -> str: | |
return f"<html><body>{content}</body></html>" | |
class CSSGenerator: | |
def generate(self, styles: Dict[str, str]) -> str: | |
return "\n".join([f"{selector} {{ {'; '.join([f'{prop}: {value}' for prop, value in properties.items()])} }}" for selector, properties in styles.items()]) | |
class JavaScriptGenerator: | |
def generate(self, functionality: str) -> str: | |
return f"function main() {{ {functionality} }}" | |
class EnhancedAIAgent: | |
def __init__(self, name: str, description: str, skills: List[str], model_name: str): | |
self.name = name | |
self.description = description | |
self.skills = skills | |
self.model_name = model_name | |
self.html_gen_tool = HTMLGenerator() | |
self.css_gen_tool = CSSGenerator() | |
self.js_gen_tool = JavaScriptGenerator() | |
self.hf_api = HfApi() | |
self.inference_client = InferenceClient(model=model_name, token=os.environ.get("HF_API_TOKEN")) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True) | |
self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
self.text_generation = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, clean_up_tokenization_spaces=True) | |
self.logger = logging.getLogger(__name__) | |
self.logger.setLevel(logging.INFO) | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
handler.setFormatter(formatter) | |
self.logger.addHandler(handler) | |
def generate_agent_response(self, prompt: str) -> str: | |
try: | |
response = self.inference_client.text_generation(prompt, max_new_tokens=100) | |
self.logger.info(f"Generated response for prompt: {prompt[:50]}...") | |
return response.generated_text | |
except Exception as e: | |
self.logger.error(f"Error generating response: {str(e)}", exc_info=True) | |
return f"Error: Unable to generate response. {str(e)}" | |
def generate_project_config(self, project_description: str) -> Optional[ProjectConfig]: | |
prompt = f""" | |
Based on the following project description, generate a ProjectConfig object: | |
Description: {project_description} | |
The ProjectConfig should include: | |
- name: A short, descriptive name for the project | |
- description: A brief summary of the project | |
- technologies: A list of technologies to be used (e.g., ["HTML", "CSS", "JavaScript", "React"]) | |
- structure: A dictionary representing the file structure, where keys are directories and values are lists of files | |
Respond with a JSON object representing the ProjectConfig. | |
""" | |
response = self.generate_agent_response(prompt) | |
try: | |
json_start = response.find('{') | |
json_end = response.rfind('}') + 1 | |
if json_start != -1 and json_end != -1: | |
json_str = response[json_start:json_end] | |
config_dict = json.loads(json_str) | |
return ProjectConfig(**config_dict) | |
else: | |
raise ValueError("No JSON object found in the response") | |
except (json.JSONDecodeError, ValueError) as e: | |
self.logger.error(f"Error parsing JSON from response: {str(e)}") | |
self.logger.error(f"Full response from model: {response}") | |
try: | |
partial_config = self.extract_partial_config(response) | |
if partial_config: | |
self.logger.warning("Extracted partial config from malformed response") | |
return partial_config | |
except Exception as ex: | |
self.logger.error(f"Failed to extract partial config: {str(ex)}") | |
return None | |
def extract_partial_config(self, response: str) -> Optional[ProjectConfig]: | |
name = self.extract_field(response, "name") | |
description = self.extract_field(response, "description") | |
technologies = self.extract_list(response, "technologies") | |
structure = self.extract_dict(response, "structure") | |
if name and description: | |
return ProjectConfig( | |
name=name, | |
description=description, | |
technologies=technologies or [], | |
structure=structure or {} | |
) | |
return None | |
def extract_field(self, text: str, field: str) -> Optional[str]: | |
match = re.search(rf'"{field}"\s*:\s*"([^"]*)"', text) | |
return match.group(1) if match else None | |
def extract_list(self, text: str, field: str) -> Optional[List[str]]: | |
match = re.search(rf'"{field}"\s*:\s*\[(.*?)\]', text, re.DOTALL) | |
if match: | |
items = re.findall(r'"([^"]*)"', match.group(1)) | |
return items | |
return None | |
def extract_dict(self, text: str, field: str) -> Optional[Dict[str, List[str]]]: | |
match = re.search(rf'"{field}"\s*:\s*\{{(.*?)\}}', text, re.DOTALL) | |
if match: | |
dict_str = match.group(1) | |
result = {} | |
for item in re.finditer(r'"([^"]*)"\s*:\s*\[(.*?)\]', dict_str, re.DOTALL): | |
key = item.group(1) | |
values = re.findall(r'"([^"]*)"', item.group(2)) | |
result[key] = values | |
return result | |
return None | |
def generate_html(self, content: str) -> str: | |
return self.html_gen_tool.generate(content) | |
def generate_css(self, styles: Dict[str, str]) -> str: | |
return self.css_gen_tool.generate(styles) | |
def generate_javascript(self, functionality: str) -> str: | |
return self.js_gen_tool.generate(functionality) | |
def create_project_files(self, config: ProjectConfig) -> Dict[str, str]: | |
files = {} | |
for directory, file_list in config.structure.items(): | |
for file in file_list: | |
file_path = os.path.join(directory, file) | |
if file.endswith('.html'): | |
files[file_path] = self.generate_html(f"Content for {file}") | |
elif file.endswith('.css'): | |
files[file_path] = self.generate_css({"body": {"font-family": "Arial, sans-serif"}}) | |
elif file.endswith('.js'): | |
files[file_path] = self.generate_javascript(f"console.log('Script for {file}');") | |
else: | |
files[file_path] = f"Content for {file}" | |
return files | |
def execute_project(self, project_description: str) -> Dict[str, str]: | |
config = self.generate_project_config(project_description) | |
if config: | |
return self.create_project_files(config) | |
else: | |
self.logger.error("Failed to generate project configuration") | |
return {} |