# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from typing import List, Optional, Tuple
import numpy as np
[docs]class HistoryBuffer:
"""
Track a series of scalar values and provide access to smoothed values over a
window or the global average of the series.
"""
[docs] def __init__(self, max_length: int = 1000000) -> None:
"""
Args:
max_length: maximal number of values that can be stored in the
buffer. When the capacity of the buffer is exhausted, old
values will be removed.
"""
self._max_length: int = max_length
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
self._count: int = 0
self._global_avg: float = 0
[docs] def update(self, value: float, iteration: Optional[float] = None) -> None:
"""
Add a new scalar value produced at certain iteration. If the length
of the buffer exceeds self._max_length, the oldest element will be
removed from the buffer.
"""
if iteration is None:
iteration = self._count
if len(self._data) == self._max_length:
self._data.pop(0)
self._data.append((value, iteration))
self._count += 1
self._global_avg += (value - self._global_avg) / self._count
[docs] def latest(self) -> float:
"""
Return the latest scalar value added to the buffer.
"""
return self._data[-1][0]
[docs] def median(self, window_size: int) -> float:
"""
Return the median of the latest `window_size` values in the buffer.
"""
return np.median([x[0] for x in self._data[-window_size:]])
[docs] def avg(self, window_size: int) -> float:
"""
Return the mean of the latest `window_size` values in the buffer.
"""
return np.mean([x[0] for x in self._data[-window_size:]])
[docs] def global_avg(self) -> float:
"""
Return the mean of all the elements in the buffer. Note that this
includes those getting removed due to limited buffer storage.
"""
return self._global_avg
[docs] def values(self) -> List[Tuple[float, float]]:
"""
Returns:
list[(number, iteration)]: content of the current buffer.
"""
return self._data