Generate-Knowledge-Graphs / src /llm_extractor.py
CultriX's picture
First commit
e86199a
import json
import requests
from typing import List, Dict, Any, Optional
from config.settings import Config
class LLMExtractor:
def __init__(self):
self.config = Config()
self.headers = {
"Authorization": f"Bearer {self.config.OPENROUTER_API_KEY}",
"Content-Type": "application/json"
}
def extract_entities_and_relationships(self, text: str) -> Dict[str, Any]:
"""Extract entities and relationships from text using LLM."""
prompt = self._create_extraction_prompt(text)
try:
response = self._call_openrouter_api(prompt, self.config.EXTRACTION_MODEL)
result = self._parse_extraction_response(response)
return result
except Exception as e:
# Try backup model
try:
response = self._call_openrouter_api(prompt, self.config.BACKUP_MODEL)
result = self._parse_extraction_response(response)
return result
except Exception as backup_e:
return {
"entities": [],
"relationships": [],
"error": f"Primary: {str(e)}, Backup: {str(backup_e)}"
}
def _create_extraction_prompt(self, text: str) -> str:
"""Create prompt for entity and relationship extraction."""
return f"""
You are an expert knowledge graph extraction system. Analyze the following text and extract:
1. ENTITIES: Important people, organizations, locations, concepts, events, objects, etc.
2. RELATIONSHIPS: How these entities relate to each other
3. IMPORTANCE SCORES: Rate each entity's importance from 0.0 to 1.0 based on how central it is to the text
For each entity, provide:
- name: The entity name (standardized/canonical form)
- type: The entity type (PERSON, ORGANIZATION, LOCATION, CONCEPT, EVENT, OBJECT, etc.)
- importance: Score from 0.0 to 1.0
- description: Brief description of the entity's role/significance
For each relationship, provide:
- source: Source entity name
- target: Target entity name
- relationship: Type of relationship (works_at, located_in, part_of, causes, etc.)
- description: Brief description of the relationship
Only respond with a valid JSON object with this structure and nothing else. Your response must be valid, parsable JSON!!
=== JSON STRUCTURE FOR RESPONSE / RESPONSE FORMAT ===
{{
"entities": [
{{
"name": "entity_name",
"type": "ENTITY_TYPE",
"importance": 0.8,
"description": "Brief description"
}}
],
"relationships": [
{{
"source": "entity1",
"target": "entity2",
"relationship": "relationship_type",
"description": "Brief description"
}}
]
}}
=== END OF JSON STRUCTURE FOR RESPONSE / END OF RESPONSE FORMAT ===
TEXT TO ANALYZE:
{text}
Reply in valid json using the format above!
JSON OUTPUT:
"""
def _call_openrouter_api(self, prompt: str, model: str) -> str:
"""Make API call to OpenRouter."""
if not self.config.OPENROUTER_API_KEY:
raise ValueError("OpenRouter API key not configured")
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": prompt
}
],
"max_tokens": 2048,
"temperature": 0.1
}
response = requests.post(
f"{self.config.OPENROUTER_BASE_URL}/chat/completions",
headers=self.headers,
json=payload,
timeout=60
)
if response.status_code != 200:
raise Exception(f"API call failed: {response.status_code} - {response.text}")
result = response.json()
if "choices" not in result or not result["choices"]:
raise Exception("Invalid API response format")
return result["choices"][0]["message"]["content"]
def _parse_extraction_response(self, response: str) -> Dict[str, Any]:
"""Parse the LLM response into structured data."""
try:
# Try to find JSON in the response
start_idx = response.find("{")
end_idx = response.rfind("}") + 1
if start_idx == -1 or end_idx == 0:
raise ValueError("No JSON found in response")
json_str = response[start_idx:end_idx]
data = json.loads(json_str)
# Validate structure
if "entities" not in data:
data["entities"] = []
if "relationships" not in data:
data["relationships"] = []
# Filter entities by importance threshold
filtered_entities = [
entity for entity in data["entities"]
if entity.get("importance", 0) >= self.config.ENTITY_IMPORTANCE_THRESHOLD
]
# Limit number of entities and relationships
data["entities"] = filtered_entities[:self.config.MAX_ENTITIES]
data["relationships"] = data["relationships"][:self.config.MAX_RELATIONSHIPS]
return data
except json.JSONDecodeError as e:
return {
"entities": [],
"relationships": [],
"error": f"JSON parsing error: {str(e)}"
}
except Exception as e:
return {
"entities": [],
"relationships": [],
"error": f"Response parsing error: {str(e)}"
}
def process_chunks(self, chunks: List[str]) -> Dict[str, Any]:
"""Process multiple text chunks and combine results."""
all_entities = []
all_relationships = []
errors = []
for i, chunk in enumerate(chunks):
try:
result = self.extract_entities_and_relationships(chunk)
if "error" in result:
errors.append(f"Chunk {i+1}: {result['error']}")
continue
all_entities.extend(result.get("entities", []))
all_relationships.extend(result.get("relationships", []))
except Exception as e:
errors.append(f"Chunk {i+1}: {str(e)}")
# Deduplicate and standardize entities
unique_entities = self._deduplicate_entities(all_entities)
# Validate relationships against existing entities
valid_relationships = self._validate_relationships(all_relationships, unique_entities)
return {
"entities": unique_entities,
"relationships": valid_relationships,
"errors": errors if errors else None
}
def _deduplicate_entities(self, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Remove duplicate entities and merge similar ones."""
seen_names = set()
unique_entities = []
for entity in entities:
name = entity.get("name", "").lower().strip()
if name and name not in seen_names:
seen_names.add(name)
unique_entities.append(entity)
# Sort by importance
unique_entities.sort(key=lambda x: x.get("importance", 0), reverse=True)
return unique_entities[:self.config.MAX_ENTITIES]
def _validate_relationships(self, relationships: List[Dict[str, Any]], entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Validate that relationships reference existing entities."""
entity_names = {entity.get("name", "").lower() for entity in entities}
valid_relationships = []
for rel in relationships:
source = rel.get("source", "").lower()
target = rel.get("target", "").lower()
if source in entity_names and target in entity_names:
valid_relationships.append(rel)
return valid_relationships[:self.config.MAX_RELATIONSHIPS]