vsp-demo / src /vsp /app /classifiers /education_classifier.py
pquiggles's picture
fixed a lot of stuff to meet spencer's requirements
9347485
from enum import Enum
from pydantic import BaseModel, Field
from vsp.app import bindings
from vsp.app.model.linkedin.linkedin_formatters import format_education, format_profile_as_resume
from vsp.app.model.linkedin.linkedin_models import Education, LinkedinProfile
from vsp.app.prompts.prompt_loader import PromptLoader
from vsp.llm.llm_service import LLMService
class SchoolType(Enum):
"""
Enumeration of different school types for education classification.
This enum represents various levels and types of educational institutions.
"""
PRIMARY_SECONDARY = "Primary / Secondary School"
UNDERGRAD_INCOMPLETE = "Undergraduate (Incomplete)"
UNDERGRAD_COMPLETED = "Undergraduate (Completed)"
MBA = "MBA"
LAW_SCHOOL = "Law School"
GRAD_SCHOOL = "Graduate School"
PHD = "PhD"
OTHER = "Other"
_SCHOOL_TYPE_MAPPING: dict[str, SchoolType] = {
"PRIMARY_SECONDARY": SchoolType.PRIMARY_SECONDARY,
"UNDERGRAD_INCOMPLETE": SchoolType.UNDERGRAD_INCOMPLETE,
"UNDERGRAD_COMPLETED": SchoolType.UNDERGRAD_COMPLETED,
"MBA": SchoolType.MBA,
"LAW_SCHOOL": SchoolType.LAW_SCHOOL,
"GRAD_SCHOOL": SchoolType.GRAD_SCHOOL,
"PHD": SchoolType.PHD,
"OTHER": SchoolType.OTHER,
}
class EducationClassification(BaseModel):
"""
Pydantic model representing the classification result for an education item.
Attributes:
output (SchoolType): The classified school type.
reasoning (str): Explanation for the classification decision.
"""
output: SchoolType = Field(description="The classified school type")
reasoning: str = Field(description="Explanation for the classification")
model_config = {"frozen": True} # This makes the model immutable and hashable
class EducationClassifier:
"""
A class for classifying education items from Linkedin profiles.
This classifier uses a language model to determine the type of educational
institution and program based on the information provided in a Linkedin profile.
Attributes:
_llm_service (LLMService): The language model service used for classification.
_prompt_template (Any): The template for generating prompts for the language model.
_prompt_loader (PromptLoader): The loader for prompt templates.
"""
@staticmethod
def _parse_output(output: str) -> EducationClassification:
"""
Parse the output from the language model into an EducationClassification object.
Args:
output (str): The raw output string from the language model.
Returns:
EducationClassification: A structured representation of the classification result.
Raises:
ValueError: If the output contains an unknown school type.
"""
lines = output.strip().split("\n")
parsed = {key.strip(): value.strip() for line in lines for key, value in [line.split(":", 1)]}
match parsed["output"].upper():
case school_type if school_type in _SCHOOL_TYPE_MAPPING:
return EducationClassification(
output=_SCHOOL_TYPE_MAPPING[school_type],
reasoning=parsed["reasoning"],
)
case _:
raise ValueError(f"Unknown school type: {parsed['output']}")
def __init__(
self, llm_service: LLMService = bindings.open_ai_service, prompt_loader: PromptLoader = bindings.prompt_loader
):
"""
Initialize the EducationClassifier.
Args:
llm_service (LLMService, optional): The language model service to use.
Defaults to the OpenAI service defined in bindings.
prompt_loader (PromptLoader, optional): The prompt loader to use.
Defaults to the prompt loader defined in bindings.
"""
self._llm_service = llm_service
self._prompt_template = prompt_loader.load_template("education_classifier/1 - education_classifier")
self._prompt_loader = prompt_loader
async def classify_education(
self, linkedin_profile: LinkedinProfile, education: Education
) -> EducationClassification:
"""
Classify a single education item from a Linkedin profile.
This method prepares the input for the language model, sends the query,
and processes the result to classify the education item.
Args:
linkedin_profile (LinkedinProfile): The full Linkedin profile of the individual.
education (Education): The specific education item to classify.
Returns:
EducationClassification: The classification result for the education item.
Raises:
ValueError: If the prompt evaluation fails to produce a result.
"""
prompt = self._prompt_loader.create_prompt(
self._prompt_template,
llm_service=self._llm_service,
output_formatter=EducationClassifier._parse_output,
resume=format_profile_as_resume(linkedin_profile),
education=format_education(education),
)
return await prompt.evaluate() # type: ignore