| from solverforge_legacy.solver.score import ( |
| constraint_provider, |
| ConstraintFactory, |
| Joiners, |
| HardSoftDecimalScore, |
| ConstraintCollectors, |
| ) |
| from datetime import datetime, date, time |
|
|
| from .domain import Employee, Shift |
|
|
|
|
| def get_minute_overlap(shift1: Shift, shift2: Shift) -> int: |
| return ( |
| min(shift1.end, shift2.end) - max(shift1.start, shift2.start) |
| ).total_seconds() // 60 |
|
|
|
|
| def is_overlapping_with_date(shift: Shift, dt: date) -> bool: |
| return shift.start.date() == dt or shift.end.date() == dt |
|
|
|
|
| def overlapping_in_minutes( |
| first_start_datetime: datetime, |
| first_end_datetime: datetime, |
| second_start_datetime: datetime, |
| second_end_datetime: datetime, |
| ) -> int: |
| latest_start = max(first_start_datetime, second_start_datetime) |
| earliest_end = min(first_end_datetime, second_end_datetime) |
| delta = (earliest_end - latest_start).total_seconds() / 60 |
| return max(0, delta) |
|
|
|
|
| def get_shift_overlapping_duration_in_minutes(shift: Shift, dt: date) -> int: |
| start_date_time = datetime.combine(dt, datetime.min.time()) |
| end_date_time = datetime.combine(dt, datetime.max.time()) |
| overlap = overlapping_in_minutes( |
| start_date_time, end_date_time, shift.start, shift.end |
| ) |
| return int(overlap) |
|
|
|
|
| @constraint_provider |
| def define_constraints(constraint_factory: ConstraintFactory): |
| return [ |
| |
| required_skill(constraint_factory), |
| no_overlapping_shifts(constraint_factory), |
| at_least_10_hours_between_two_shifts(constraint_factory), |
| one_shift_per_day(constraint_factory), |
| unavailable_employee(constraint_factory), |
| |
| |
| undesired_day_for_employee(constraint_factory), |
| desired_day_for_employee(constraint_factory), |
| balance_employee_shift_assignments(constraint_factory), |
| ] |
|
|
|
|
| def required_skill(constraint_factory: ConstraintFactory): |
| return ( |
| constraint_factory.for_each(Shift) |
| .filter(lambda shift: not shift.has_required_skill()) |
| .penalize(HardSoftDecimalScore.ONE_HARD) |
| .as_constraint("Missing required skill") |
| ) |
|
|
|
|
| def no_overlapping_shifts(constraint_factory: ConstraintFactory): |
| return ( |
| constraint_factory.for_each_unique_pair( |
| Shift, |
| Joiners.equal(lambda shift: shift.employee.name), |
| Joiners.overlapping(lambda shift: shift.start, lambda shift: shift.end), |
| ) |
| .penalize(HardSoftDecimalScore.ONE_HARD, get_minute_overlap) |
| .as_constraint("Overlapping shift") |
| ) |
|
|
|
|
| def at_least_10_hours_between_two_shifts(constraint_factory: ConstraintFactory): |
| return ( |
| constraint_factory.for_each(Shift) |
| .join( |
| Shift, |
| Joiners.equal(lambda shift: shift.employee.name), |
| Joiners.less_than_or_equal( |
| lambda shift: shift.end, lambda shift: shift.start |
| ), |
| ) |
| .filter( |
| lambda first_shift, second_shift: ( |
| second_shift.start - first_shift.end |
| ).total_seconds() |
| // (60 * 60) |
| < 10 |
| ) |
| .penalize( |
| HardSoftDecimalScore.ONE_HARD, |
| lambda first_shift, second_shift: 600 |
| - ((second_shift.start - first_shift.end).total_seconds() // 60), |
| ) |
| .as_constraint("At least 10 hours between 2 shifts") |
| ) |
|
|
|
|
| def one_shift_per_day(constraint_factory: ConstraintFactory): |
| return ( |
| constraint_factory.for_each_unique_pair( |
| Shift, |
| Joiners.equal(lambda shift: shift.employee.name), |
| Joiners.equal(lambda shift: shift.start.date()), |
| ) |
| .penalize(HardSoftDecimalScore.ONE_HARD) |
| .as_constraint("Max one shift per day") |
| ) |
|
|
|
|
| def unavailable_employee(constraint_factory: ConstraintFactory): |
| return ( |
| constraint_factory.for_each(Shift) |
| .join( |
| Employee, |
| Joiners.equal(lambda shift: shift.employee, lambda employee: employee), |
| ) |
| .flatten_last(lambda employee: employee.unavailable_dates) |
| .filter(lambda shift, unavailable_date: is_overlapping_with_date(shift, unavailable_date)) |
| .penalize( |
| HardSoftDecimalScore.ONE_HARD, |
| lambda shift, unavailable_date: int((min(shift.end, datetime.combine(unavailable_date, time(23, 59, 59))) - max(shift.start, datetime.combine(unavailable_date, time(0, 0, 0)))).total_seconds() / 60), |
| ) |
| .as_constraint("Unavailable employee") |
| ) |
|
|
|
|
| def max_shifts_per_employee(constraint_factory: ConstraintFactory): |
| """ |
| Hard constraint: No employee can have more than 12 shifts. |
| |
| The limit of 12 is chosen based on the demo data dimensions: |
| - SMALL dataset: 139 shifts / 15 employees = ~9.3 average |
| - This provides headroom while preventing extreme imbalance |
| |
| Note: A limit that's too low (e.g., 5) would make the problem infeasible. |
| Always ensure your constraints are compatible with your data dimensions. |
| """ |
| return ( |
| constraint_factory.for_each(Shift) |
| .group_by(lambda shift: shift.employee, ConstraintCollectors.count()) |
| .filter(lambda employee, shift_count: shift_count > 12) |
| .penalize( |
| HardSoftDecimalScore.ONE_HARD, |
| lambda employee, shift_count: shift_count - 12, |
| ) |
| .as_constraint("Max 12 shifts per employee") |
| ) |
|
|
|
|
| def undesired_day_for_employee(constraint_factory: ConstraintFactory): |
| return ( |
| constraint_factory.for_each(Shift) |
| .join( |
| Employee, |
| Joiners.equal(lambda shift: shift.employee, lambda employee: employee), |
| ) |
| .flatten_last(lambda employee: employee.undesired_dates) |
| .filter(lambda shift, undesired_date: shift.is_overlapping_with_date(undesired_date)) |
| .penalize( |
| HardSoftDecimalScore.ONE_SOFT, |
| lambda shift, undesired_date: int((min(shift.end, datetime.combine(undesired_date, time(23, 59, 59))) - max(shift.start, datetime.combine(undesired_date, time(0, 0, 0)))).total_seconds() / 60), |
| ) |
| .as_constraint("Undesired day for employee") |
| ) |
|
|
|
|
| def desired_day_for_employee(constraint_factory: ConstraintFactory): |
| return ( |
| constraint_factory.for_each(Shift) |
| .join( |
| Employee, |
| Joiners.equal(lambda shift: shift.employee, lambda employee: employee), |
| ) |
| .flatten_last(lambda employee: employee.desired_dates) |
| .filter(lambda shift, desired_date: shift.is_overlapping_with_date(desired_date)) |
| .reward( |
| HardSoftDecimalScore.ONE_SOFT, |
| lambda shift, desired_date: int((min(shift.end, datetime.combine(desired_date, time(23, 59, 59))) - max(shift.start, datetime.combine(desired_date, time(0, 0, 0)))).total_seconds() / 60), |
| ) |
| .as_constraint("Desired day for employee") |
| ) |
|
|
|
|
| def balance_employee_shift_assignments(constraint_factory: ConstraintFactory): |
| return ( |
| constraint_factory.for_each(Shift) |
| .group_by(lambda shift: shift.employee, ConstraintCollectors.count()) |
| .complement( |
| Employee, lambda e: 0 |
| ) |
| .group_by( |
| ConstraintCollectors.load_balance( |
| lambda employee, shift_count: employee, |
| lambda employee, shift_count: shift_count, |
| ) |
| ) |
| .penalize_decimal( |
| HardSoftDecimalScore.ONE_SOFT, |
| lambda load_balance: load_balance.unfairness(), |
| ) |
| .as_constraint("Balance employee shift assignments") |
| ) |
|
|