Phi2-Fine-Tuning
/
phivenv
/Lib
/site-packages
/torch
/distributed
/pipelining
/_schedule_visualizer.py
| # Copyright (c) Meta Platforms, Inc. and affiliates | |
| """ | |
| This visualizer requires matplotlib to be installed. | |
| Example usage: | |
| ops = get_schedule_ops("InterleavedZeroBubble", 4, 8) | |
| visualize_schedule(ops, "test.png") | |
| """ | |
| from typing import Optional, Union | |
| from unittest import mock | |
| from torch.distributed.pipelining.schedules import ( | |
| _Action, | |
| _ComputationType, | |
| _PipelineSchedule, | |
| get_schedule_class, | |
| PipelineScheduleMulti, | |
| PipelineScheduleSingle, | |
| ) | |
| from torch.distributed.pipelining.stage import PipelineStage | |
| def get_schedule_ops( | |
| schedule: Union[str, _PipelineSchedule], | |
| pp_degree: int, | |
| num_microbatches: int, | |
| num_stages_per_rank: Optional[int] = None, | |
| ) -> list[list[Optional[_Action]]]: | |
| """ | |
| Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists | |
| where each inner list represents a rank and each element in the inner list represents an action. | |
| The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance. | |
| """ | |
| if isinstance(schedule, str): | |
| schedule_class = get_schedule_class(schedule) | |
| elif type(schedule) == _PipelineSchedule: | |
| schedule_class = schedule | |
| else: | |
| raise ValueError(f"Invalid schedule: {schedule}") | |
| # Create a mock of the PipelineStage class | |
| mock_pipeline_stage = mock.create_autospec(PipelineStage, instance=True) | |
| # Set the return values for group_rank and group_size methods | |
| mock_pipeline_stage.group_rank = 0 | |
| mock_pipeline_stage.group_size = pp_degree | |
| mock_pipeline_stage.submod = None | |
| # Check num_stages_per_rank is valid | |
| if issubclass(schedule_class, PipelineScheduleSingle): | |
| if num_stages_per_rank is None: | |
| num_stages_per_rank = 1 | |
| assert num_stages_per_rank == 1 | |
| stages = mock_pipeline_stage | |
| stages.num_stages = num_stages_per_rank * pp_degree | |
| elif issubclass(schedule_class, PipelineScheduleMulti): | |
| if num_stages_per_rank is None: | |
| num_stages_per_rank = 2 | |
| assert num_stages_per_rank >= 2 | |
| stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)] | |
| for stage in stages: | |
| stage.num_stages = num_stages_per_rank * pp_degree | |
| else: | |
| raise ValueError(f"Invalid schedule: {schedule_class}") | |
| # Instantiate the schedule class | |
| schedule_instance = schedule_class(stages, num_microbatches) | |
| # Convert to List[List[_Action]] | |
| all_actions = [] | |
| for rank in range(pp_degree): | |
| all_actions.append(schedule_instance.pipeline_order[rank]) | |
| # Return the pipeline order | |
| return all_actions | |
| class _ComputationTypeColor: | |
| def __init__( | |
| self, | |
| color: str, | |
| text: str = "", | |
| width: int = 1, | |
| ): | |
| self.color = color | |
| self.width = width | |
| self.text = text | |
| # Update the mapping to use _ComputationTypeColor instances | |
| action_type_to_color_mapping = { | |
| _ComputationType.FORWARD: _ComputationTypeColor("blue", "Forward"), | |
| _ComputationType.BACKWARD_INPUT: _ComputationTypeColor("teal", "Backward Input"), | |
| _ComputationType.BACKWARD_WEIGHT: _ComputationTypeColor("green", "Backward Weight"), | |
| _ComputationType.FULL_BACKWARD: _ComputationTypeColor("orange", "Full Backward", 2), | |
| } | |
| def visualize_schedule( | |
| schedule: list[list[Optional[_Action]]], filename: Optional[str] = None | |
| ) -> None: | |
| """ | |
| Visualize the schedule using matplotlib. | |
| The schedule is a list of lists where each inner list represents a rank and each element in the inner list represents an action. | |
| The actions are represented as rectangles with different colors based on their computation type. | |
| The filename is optional and if provided, the plot will be saved to that file. | |
| """ | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Rectangle | |
| plt.rcParams["font.family"] = ( | |
| "DejaVu Sans" # or any other font available on your system | |
| ) | |
| num_ranks = len(schedule) | |
| max_actions = max(len(rank) for rank in schedule) | |
| # Increase the figure size to provide more space for the legend | |
| fig, ax = plt.subplots(figsize=(max_actions + 2, num_ranks + 2)) | |
| max_draw_position = -1 | |
| # Calculate dynamic font size based on figure size | |
| font_size = min(max_actions, num_ranks) + 4 | |
| used_computation = set() | |
| for rank_idx, actions in enumerate(schedule): | |
| draw_position = 0 # Initialize drawing position for each rank | |
| for action in actions: | |
| if action is not None: | |
| comp_type_color = action_type_to_color_mapping.get( | |
| action.computation_type, _ComputationTypeColor("black") | |
| ) | |
| used_computation.add(action.computation_type) | |
| color = comp_type_color.color | |
| width = comp_type_color.width | |
| # Draw the rectangle to represent the action duration | |
| rect = Rectangle( | |
| (draw_position, num_ranks - rank_idx - 1), | |
| width, | |
| 1, | |
| facecolor=color, | |
| edgecolor="black", | |
| ) | |
| ax.add_patch(rect) | |
| # Draw the text centered within the rectangle | |
| ax.text( | |
| draw_position + width / 2, | |
| num_ranks - rank_idx - 1 + 0.5, | |
| str(action), | |
| ha="center", | |
| va="center", | |
| fontsize=font_size, | |
| color="white", | |
| ) | |
| # Increment the drawing position by the width of the current action | |
| draw_position += width | |
| else: | |
| draw_position += 1 # Move to the next | |
| max_draw_position = max(max_draw_position, draw_position) | |
| ax.set_xlim(-0.5, max_draw_position + 1) | |
| ax.set_ylim(-0.5, num_ranks + 0.5) # Add extra space at the top | |
| # Set y-ticks to be in the middle of each rank's row | |
| ax.set_yticks([num_ranks - rank_idx - 0.5 for rank_idx in range(num_ranks)]) | |
| ax.set_yticklabels([f"Rank {i}" for i in range(num_ranks)], fontsize=font_size) | |
| ax.set_xticklabels([]) | |
| # Remove grid lines and ticks | |
| ax.grid(False) | |
| # Add legend with larger font size | |
| legend_elements = [ | |
| Rectangle( | |
| (0, 0), | |
| 1, | |
| 1, | |
| facecolor=action_type_to_color_mapping[comp_type].color, | |
| edgecolor="black", | |
| label=action_type_to_color_mapping[comp_type].text, | |
| ) | |
| for comp_type in used_computation | |
| ] | |
| ax.legend(handles=legend_elements, loc="upper right", fontsize=font_size) | |
| # Save to file if filename is provided, otherwise display the plot | |
| if filename: | |
| plt.savefig(filename, bbox_inches="tight") | |
| else: | |
| plt.show() | |