# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# pyre-ignore-all-errors[2,3,58]
import logging
import os
from collections import defaultdict
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from iopath.common.file_io import HTTPURLHandler, PathManager
from termcolor import colored
from torch.nn.parallel import DataParallel, DistributedDataParallel
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 11):
from torch.ao import quantization
from torch.ao.quantization import FakeQuantizeBase, ObserverBase
elif (
TORCH_VERSION >= (1, 8)
and hasattr(torch.quantization, "FakeQuantizeBase")
and hasattr(torch.quantization, "ObserverBase")
):
from torch import quantization
from torch.quantization import FakeQuantizeBase, ObserverBase
__all__ = ["Checkpointer", "PeriodicCheckpointer"]
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
class _IncompatibleKeys(
NamedTuple(
"IncompatibleKeys",
[
("missing_keys", List[str]),
("unexpected_keys", List[str]),
("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]),
],
)
):
pass
[docs]class Checkpointer:
"""
A checkpointer that can save/load model as well as extra checkpointable
objects.
"""
[docs] def __init__(
self,
model: nn.Module,
save_dir: str = "",
*,
save_to_disk: bool = True,
**checkpointables: Any,
) -> None:
"""
Args:
model (nn.Module): model.
save_dir (str): a directory to save and find checkpoints.
save_to_disk (bool): if True, save checkpoint to disk, otherwise
disable saving for this checkpointer.
checkpointables (object): any checkpointable objects, i.e., objects
that have the ``state_dict()`` and ``load_state_dict()`` method. For
example, it can be used like
`Checkpointer(model, "dir", optimizer=optimizer)`.
"""
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = model.module
self.model = model
self.checkpointables: Dict[str, Any] = {}
for k, v in checkpointables.items():
self.add_checkpointable(k, v)
self.logger: logging.Logger = logging.getLogger(__name__)
self.save_dir = save_dir
self.save_to_disk = save_to_disk
# Default PathManager, support HTTP URLs (for backward compatibility in open source).
# A user may want to use a different project-specific PathManager
self.path_manager: PathManager = PathManager()
self.path_manager.register_handler(HTTPURLHandler())
[docs] def add_checkpointable(self, key: str, checkpointable: Any) -> None:
"""
Add checkpointable object for this checkpointer to track.
Args:
key (str): the key used to save the object
checkpointable: any object with ``state_dict()`` and
``load_state_dict()`` method
"""
if key in self.checkpointables:
raise KeyError(f"Key {key} already used in the Checkpointer")
if not hasattr(checkpointable, "state_dict"):
raise TypeError(
"add_checkpointable needs an object with 'state_dict()' method."
)
self.checkpointables[key] = checkpointable
[docs] def save(self, name: str, **kwargs: Any) -> None:
"""
Dump model and checkpointables to a file.
Args:
name (str): name of the file.
kwargs (dict): extra arbitrary data to save.
"""
if not self.save_dir or not self.save_to_disk:
return
data = {}
data["model"] = self.model.state_dict()
for key, obj in self.checkpointables.items():
data[key] = obj.state_dict()
data.update(kwargs)
basename = "{}.pth".format(name)
save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self.logger.info("Saving checkpoint to {}".format(save_file))
with self.path_manager.open(save_file, "wb") as f:
# pyre-fixme[6]: For 2nd param expected `Union[PathLike[typing.Any],
# IO[bytes], str, BinaryIO]` but got `Union[IO[bytes], IO[str]]`.
torch.save(data, f)
self.tag_last_checkpoint(basename)
[docs] def load(
self, path: str, checkpointables: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Load from the given checkpoint.
Args:
path (str): path or url to the checkpoint. If empty, will not load
anything.
checkpointables (list): List of checkpointable names to load. If not
specified (None), will load all the possible checkpointables.
Returns:
dict:
extra data loaded from the checkpoint that has not been
processed. For example, those saved with
:meth:`.save(**extra_data)`.
"""
if not path:
# no checkpoint provided
self.logger.info("No checkpoint found. Initializing model from scratch")
return {}
self.logger.info("[Checkpointer] Loading from {} ...".format(path))
if not os.path.isfile(path):
path = self.path_manager.get_local_path(path)
assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
checkpoint = self._load_file(path)
incompatible = self._load_model(checkpoint)
if (
incompatible is not None
): # handle some existing subclasses that returns None
self._log_incompatible_keys(incompatible)
for key in self.checkpointables if checkpointables is None else checkpointables:
if key in checkpoint:
self.logger.info("Loading {} from {} ...".format(key, path))
obj = self.checkpointables[key]
obj.load_state_dict(checkpoint.pop(key))
# return any further checkpoint data
return checkpoint
[docs] def has_checkpoint(self) -> bool:
"""
Returns:
bool: whether a checkpoint exists in the target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
return self.path_manager.exists(save_file)
[docs] def get_checkpoint_file(self) -> str:
"""
Returns:
str: The latest checkpoint file in target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
try:
with self.path_manager.open(save_file, "r") as f:
last_saved = f.read().strip()
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
return ""
# pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got
# `Union[bytes, str]`.
return os.path.join(self.save_dir, last_saved)
[docs] def get_all_checkpoint_files(self) -> List[str]:
"""
Returns:
list: All available checkpoint files (.pth files) in target
directory.
"""
all_model_checkpoints = [
os.path.join(self.save_dir, file)
for file in self.path_manager.ls(self.save_dir)
if self.path_manager.isfile(os.path.join(self.save_dir, file))
and file.endswith(".pth")
]
return all_model_checkpoints
[docs] def resume_or_load(self, path: str, *, resume: bool = True) -> Dict[str, Any]:
"""
If `resume` is True, this method attempts to resume from the last
checkpoint, if exists. Otherwise, load checkpoint from the given path.
This is useful when restarting an interrupted training job.
Args:
path (str): path to the checkpoint.
resume (bool): if True, resume from the last checkpoint if it exists
and load the model together with all the checkpointables. Otherwise
only load the model without loading any checkpointables.
Returns:
same as :meth:`load`.
"""
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
else:
return self.load(path, checkpointables=[])
[docs] def tag_last_checkpoint(self, last_filename_basename: str) -> None:
"""
Tag the last checkpoint.
Args:
last_filename_basename (str): the basename of the last filename.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
with self.path_manager.open(save_file, "w") as f:
f.write(last_filename_basename) # pyre-ignore
def _load_file(self, f: str) -> Dict[str, Any]:
"""
Load a checkpoint file. Can be overwritten by subclasses to support
different formats.
Args:
f (str): a locally mounted file path.
Returns:
dict: with keys "model" and optionally others that are saved by
the checkpointer dict["model"] must be a dict which maps strings
to torch.Tensor or numpy arrays.
"""
return torch.load(f, map_location=torch.device("cpu"))
def _load_model(self, checkpoint: Any) -> _IncompatibleKeys:
"""
Load weights from a checkpoint.
Args:
checkpoint (Any): checkpoint contains the weights.
Returns:
``NamedTuple`` with ``missing_keys``, ``unexpected_keys``,
and ``incorrect_shapes`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
* **incorrect_shapes** is a list of (key, shape in checkpoint, shape in model)
This is just like the return value of
:func:`torch.nn.Module.load_state_dict`, but with extra support
for ``incorrect_shapes``.
"""
checkpoint_state_dict = checkpoint.pop("model")
self._convert_ndarray_to_tensor(checkpoint_state_dict)
# if the state_dict comes from a model that was wrapped in a
# DataParallel or DistributedDataParallel during serialization,
# remove the "module" prefix before performing the matching.
_strip_prefix_if_present(checkpoint_state_dict, "module.")
# workaround https://github.com/pytorch/pytorch/issues/24139
model_state_dict = self.model.state_dict()
incorrect_shapes = []
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
model_param = model_state_dict[k]
# Allow mismatch for uninitialized parameters
if TORCH_VERSION >= (1, 8) and isinstance(
model_param, nn.parameter.UninitializedParameter
):
continue
shape_model = tuple(model_param.shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
has_observer_base_classes = (
TORCH_VERSION >= (1, 8)
and hasattr(quantization, "ObserverBase")
and hasattr(quantization, "FakeQuantizeBase")
)
if has_observer_base_classes:
# Handle the special case of quantization per channel observers,
# where buffer shape mismatches are expected.
def _get_module_for_key(
model: torch.nn.Module, key: str
) -> torch.nn.Module:
# foo.bar.param_or_buffer_name -> [foo, bar]
key_parts = key.split(".")[:-1]
cur_module = model
for key_part in key_parts:
cur_module = getattr(cur_module, key_part)
return cur_module
cls_to_skip = (
ObserverBase,
FakeQuantizeBase,
)
target_module = _get_module_for_key(self.model, k)
if isinstance(target_module, cls_to_skip):
# Do not remove modules with expected shape mismatches
# them from the state_dict loading. They have special logic
# in _load_from_state_dict to handle the mismatches.
continue
incorrect_shapes.append((k, shape_checkpoint, shape_model))
checkpoint_state_dict.pop(k)
incompatible = self.model.load_state_dict(checkpoint_state_dict, strict=False)
return _IncompatibleKeys(
missing_keys=incompatible.missing_keys,
unexpected_keys=incompatible.unexpected_keys,
incorrect_shapes=incorrect_shapes,
)
def _log_incompatible_keys(self, incompatible: _IncompatibleKeys) -> None:
"""
Log information about the incompatible keys returned by ``_load_model``.
"""
for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes:
self.logger.warning(
"Skip loading parameter '{}' to the model due to incompatible "
"shapes: {} in the checkpoint but {} in the "
"model! You might want to double check if this is expected.".format(
k, shape_checkpoint, shape_model
)
)
if incompatible.missing_keys:
missing_keys = _filter_reused_missing_keys(
self.model, incompatible.missing_keys
)
if missing_keys:
self.logger.warning(get_missing_parameters_message(missing_keys))
if incompatible.unexpected_keys:
self.logger.warning(
get_unexpected_parameters_message(incompatible.unexpected_keys)
)
def _convert_ndarray_to_tensor(self, state_dict: Dict[str, Any]) -> None:
"""
In-place convert all numpy arrays in the state_dict to torch tensor.
Args:
state_dict (dict): a state-dict to be loaded to the model.
Will be modified.
"""
# model could be an OrderedDict with _metadata attribute
# (as returned by Pytorch's state_dict()). We should preserve these
# properties.
for k in list(state_dict.keys()):
v = state_dict[k]
if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
raise ValueError(
"Unsupported type found in checkpoint! {}: {}".format(k, type(v))
)
if not isinstance(v, torch.Tensor):
state_dict[k] = torch.from_numpy(v)
[docs]class PeriodicCheckpointer:
"""
Save checkpoints periodically. When `.step(iteration)` is called, it will
execute `checkpointer.save` on the given checkpointer, if iteration is a
multiple of period or if `max_iter` is reached.
Attributes:
checkpointer (Checkpointer): the underlying checkpointer object
"""
[docs] def __init__(
self,
checkpointer: Checkpointer,
period: int,
max_iter: Optional[int] = None,
max_to_keep: Optional[int] = None,
file_prefix: str = "model",
) -> None:
"""
Args:
checkpointer: the checkpointer object used to save checkpoints.
period (int): the period to save checkpoint.
max_iter (int): maximum number of iterations. When it is reached,
a checkpoint named "{file_prefix}_final" will be saved.
max_to_keep (int): maximum number of most current checkpoints to keep,
previous checkpoints will be deleted
file_prefix (str): the prefix of checkpoint's filename
"""
self.checkpointer = checkpointer
self.period = int(period)
self.max_iter = max_iter
if max_to_keep is not None:
assert max_to_keep > 0
self.max_to_keep = max_to_keep
self.recent_checkpoints: List[str] = []
self.path_manager: PathManager = checkpointer.path_manager
self.file_prefix = file_prefix
[docs] def step(self, iteration: int, **kwargs: Any) -> None:
"""
Perform the appropriate action at the given iteration.
Args:
iteration (int): the current iteration, ranged in [0, max_iter-1].
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
iteration = int(iteration)
additional_state = {"iteration": iteration}
additional_state.update(kwargs)
if (iteration + 1) % self.period == 0:
self.checkpointer.save(
"{}_{:07d}".format(self.file_prefix, iteration), **additional_state
)
if self.max_to_keep is not None:
self.recent_checkpoints.append(self.checkpointer.get_checkpoint_file())
if len(self.recent_checkpoints) > self.max_to_keep:
file_to_delete = self.recent_checkpoints.pop(0)
if self.path_manager.exists(
file_to_delete
) and not file_to_delete.endswith(f"{self.file_prefix}_final.pth"):
self.path_manager.rm(file_to_delete)
if self.max_iter is not None:
if iteration >= self.max_iter - 1:
self.checkpointer.save(f"{self.file_prefix}_final", **additional_state)
[docs] def save(self, name: str, **kwargs: Any) -> None:
"""
Same argument as :meth:`Checkpointer.save`.
Use this method to manually save checkpoints outside the schedule.
Args:
name (str): file name.
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
self.checkpointer.save(name, **kwargs)
def _filter_reused_missing_keys(model: nn.Module, keys: List[str]) -> List[str]:
"""
Filter "missing keys" to not include keys that have been loaded with another name.
"""
keyset = set(keys)
param_to_names = defaultdict(set) # param -> names that points to it
for module_prefix, module in _named_modules_with_dup(model):
for name, param in list(module.named_parameters(recurse=False)) + list(
module.named_buffers(recurse=False)
):
full_name = (module_prefix + "." if module_prefix else "") + name
param_to_names[param].add(full_name)
for names in param_to_names.values():
# if one name appears missing but its alias exists, then this
# name is not considered missing
if any(n in keyset for n in names) and not all(n in keyset for n in names):
[keyset.remove(n) for n in names if n in keyset]
return list(keyset)
def get_missing_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the model but not found in a checkpoint.
Args:
keys (list[str]): List of keys that were not found in the checkpoint.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg_per_group = sorted(k + _group_to_str(v) for k, v in groups.items())
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
msg += "\n".join([colored(x, "blue") for x in msg_per_group])
return msg
def get_unexpected_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the checkpoint but not found in the model.
Args:
keys (list[str]): List of keys that were not found in the model.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
)
return msg
def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
"""
Strip the prefix in metadata, if any.
Args:
state_dict (OrderedDict): a state-dict to be loaded to the model.
prefix (str): prefix.
"""
keys = sorted(state_dict.keys())
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
return
for key in keys:
newkey = key[len(prefix) :]
state_dict[newkey] = state_dict.pop(key)
# also strip the prefix in metadata, if any..
try:
metadata = state_dict._metadata # pyre-ignore
except AttributeError:
pass
else:
for key in list(metadata.keys()):
# for the metadata dict, the key can be:
# '': for the DDP module, which we want to remove.
# 'module': for the actual model.
# 'module.xx.xx': for the rest.
if len(key) == 0:
continue
newkey = key[len(prefix) :]
metadata[newkey] = metadata.pop(key)
def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
"""
Group keys based on common prefixes. A prefix is the string up to the final
"." in each key.
Args:
keys (list[str]): list of parameter names, i.e. keys in the model
checkpoint dict.
Returns:
dict[list]: keys with common prefixes are grouped into lists.
"""
groups = defaultdict(list)
for key in keys:
pos = key.rfind(".")
if pos >= 0:
head, tail = key[:pos], [key[pos + 1 :]]
else:
head, tail = key, []
groups[head].extend(tail)
return groups
def _group_to_str(group: List[str]) -> str:
"""
Format a group of parameter name suffixes into a loggable string.
Args:
group (list[str]): list of parameter name suffixes.
Returns:
str: formated string.
"""
if len(group) == 0:
return ""
if len(group) == 1:
return "." + group[0]
return ".{" + ", ".join(sorted(group)) + "}"
def _named_modules_with_dup(
model: nn.Module, prefix: str = ""
) -> Iterable[Tuple[str, nn.Module]]:
"""
The same as `model.named_modules()`, except that it includes
duplicated modules that have more than one name.
"""
yield prefix, model
for name, module in model._modules.items():
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
yield from _named_modules_with_dup(module, submodule_prefix)