284 lines
8.8 KiB
Python
284 lines
8.8 KiB
Python
# Copyright 2016 Google Inc. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Defines statistics objects for pipelines."""
|
|
|
|
import abc
|
|
import bisect
|
|
import copy
|
|
|
|
# internal imports
|
|
import tensorflow as tf
|
|
|
|
|
|
class MergeStatisticsException(Exception):
|
|
pass
|
|
|
|
|
|
class Statistic(object):
|
|
"""Holds statistics about a Pipeline run.
|
|
|
|
Pipelines produce statistics on each call to `transform`.
|
|
Statistic objects can be merged together to aggregate
|
|
statistics over the course of many calls to `transform`.
|
|
|
|
A `Statistic` also has a string name which is used during merging. Any two
|
|
`Statistic` instances with the same name may be merged together. The name
|
|
should also be informative about what the `Statistic` is measuring. Names
|
|
do not need to be unique globally (outside of the `Pipeline` objects that
|
|
produce them) because a `Pipeline` that returns statistics will prepend
|
|
its own name, effectively creating a namespace for each `Pipeline`.
|
|
"""
|
|
|
|
__metaclass__ = abc.ABCMeta
|
|
|
|
def __init__(self, name):
|
|
"""Constructs a `Statistic`.
|
|
|
|
Subclass constructors are expected to call this constructor.
|
|
|
|
Args:
|
|
name: The string name for this `Statistic`. Any two `Statistic` objects
|
|
with the same name will be merged together. The name should also
|
|
describe what this statistic is measuring.
|
|
"""
|
|
self._name = name
|
|
|
|
@abc.abstractmethod
|
|
def _merge_from(self, other):
|
|
"""Merge another Statistic into this instance.
|
|
|
|
Takes another Statistic of the same type, and merges its information into
|
|
this instance.
|
|
|
|
Args:
|
|
other: Another Statistic instance.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def _pretty_print(self, name):
|
|
"""Return a string representation of this instance using the given name.
|
|
|
|
Returns a human readable and nicely presented representation of this
|
|
instance. Since this instance does not know what its measuring, a string
|
|
name is given to use in the string representation.
|
|
|
|
For example, if this Statistic held a count, say 5, and the given name was
|
|
'error_count', then the string representation might be 'error_count: 5'.
|
|
|
|
Args:
|
|
name: A string name for this instance.
|
|
|
|
Returns:
|
|
A human readable and preferably a nicely presented string representation
|
|
of this instance.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def copy(self):
|
|
"""Returns a new copy of `self`."""
|
|
pass
|
|
|
|
def merge_from(self, other):
|
|
if not isinstance(other, Statistic):
|
|
raise MergeStatisticsException(
|
|
'Cannot merge with non-Statistic of type %s' % type(other))
|
|
if self.name != other.name:
|
|
raise MergeStatisticsException(
|
|
'Name "%s" does not match this name "%s"' % (other.name, self.name))
|
|
self._merge_from(other)
|
|
|
|
@property
|
|
def name(self):
|
|
"""String name of this statistic.
|
|
|
|
This name is used to uniquely identify a statistic.
|
|
|
|
Returns:
|
|
The string name of `self`.
|
|
"""
|
|
return self._name
|
|
|
|
@name.setter
|
|
def name(self, value):
|
|
assert isinstance(value, basestring) and value
|
|
self._name = value
|
|
|
|
def __str__(self):
|
|
return self._pretty_print(self._name)
|
|
|
|
|
|
def merge_statistics(stats_list):
|
|
"""Merge together statistics of the same name in the given list.
|
|
|
|
Any two statistics in the list with the same name will be merged into a
|
|
single statistic using the `merge_from` method.
|
|
|
|
Args:
|
|
stats_list: A list of `Statistic` objects.
|
|
|
|
Returns:
|
|
A list of merged statistics. Each name will appear only once.
|
|
"""
|
|
name_map = {}
|
|
for stat in stats_list:
|
|
if stat.name not in name_map:
|
|
name_map[stat.name] = stat
|
|
else:
|
|
name_map[stat.name].merge_from(stat)
|
|
return name_map.values()
|
|
|
|
|
|
def log_statistics_list(stats_list, logger_fn=tf.logging.info):
|
|
"""Calls the given logger function on each `Statistic` in the list.
|
|
|
|
Args:
|
|
stats_list: A list of `Statistic` objects.
|
|
logger_fn: The function which will be called on the string representation
|
|
of each `Statistic`.
|
|
"""
|
|
for stat in stats_list:
|
|
logger_fn(str(stat))
|
|
|
|
|
|
class Counter(Statistic):
|
|
"""Holds a count.
|
|
|
|
Use `Counter` to count occurrences or sum values together.
|
|
"""
|
|
|
|
def __init__(self, name, start_value=0):
|
|
"""Constructs a Counter.
|
|
|
|
Args:
|
|
name: String name of this counter.
|
|
start_value: What value to start the count at.
|
|
"""
|
|
super(Counter, self).__init__(name)
|
|
self.count = start_value
|
|
|
|
def increment(self, inc=1):
|
|
"""Increment the count.
|
|
|
|
Args:
|
|
inc: (defaults to 1) How much to increment the count by.
|
|
"""
|
|
self.count += inc
|
|
|
|
def _merge_from(self, other):
|
|
"""Adds the count of another Counter into this instance."""
|
|
if not isinstance(other, Counter):
|
|
raise MergeStatisticsException(
|
|
'Cannot merge %s into Counter' % other.__class__.__name__)
|
|
self.count += other.count
|
|
|
|
def _pretty_print(self, name):
|
|
return '%s: %d' % (name, self.count)
|
|
|
|
def copy(self):
|
|
return copy.copy(self)
|
|
|
|
|
|
class Histogram(Statistic):
|
|
"""Represents a histogram.
|
|
|
|
A histogram is a list of counts, each over a range of values.
|
|
For example, given this list of values [0.5, 0.0, 1.0, 0.6, 1.5, 2.4, 0.1],
|
|
a histogram over 3 ranges [0, 1), [1, 2), [2, 3) would be:
|
|
[0, 1): 4
|
|
[1, 2): 2
|
|
[2, 3): 1
|
|
Each range is inclusive in the lower bound and exclusive in the upper bound
|
|
(hence the square open bracket but curved close bracket).
|
|
"""
|
|
|
|
def __init__(self, name, buckets, verbose_pretty_print=False):
|
|
"""Initializes the histogram with the given ranges.
|
|
|
|
Args:
|
|
name: String name of this histogram.
|
|
buckets: The ranges the histogram counts over. This is a list of values,
|
|
where each value is the inclusive lower bound of the range. An extra
|
|
range will be implicitly defined which spans from negative infinity
|
|
to the lowest given lower bound. The highest given lower bound
|
|
defines a range spaning to positive infinity. This way any value will
|
|
be included in the histogram counts. For example, if `buckets` is
|
|
[4, 6, 10] the histogram will have ranges
|
|
[-inf, 4), [4, 6), [6, 10), [10, inf).
|
|
verbose_pretty_print: If True, self.pretty_print will print the count for
|
|
every bucket. If False, only buckets with positive counts will be
|
|
printed.
|
|
"""
|
|
super(Histogram, self).__init__(name)
|
|
|
|
# List of inclusive lowest values in each bucket.
|
|
self.buckets = [float('-inf')] + sorted(set(buckets))
|
|
self.counters = dict([(bucket_lower, 0)
|
|
for bucket_lower in self.buckets])
|
|
self.verbose_pretty_print = verbose_pretty_print
|
|
|
|
# https://docs.python.org/2/library/bisect.html#searching-sorted-lists
|
|
def _find_le(self, x):
|
|
"""Find rightmost bucket less than or equal to x."""
|
|
i = bisect.bisect_right(self.buckets, x)
|
|
if i:
|
|
return self.buckets[i-1]
|
|
raise ValueError
|
|
|
|
def increment(self, value, inc=1):
|
|
"""Increment the bucket containing the given value.
|
|
|
|
The bucket count for which ever range `value` falls in will be incremented.
|
|
|
|
Args:
|
|
value: Any number.
|
|
inc: An integer. How much to increment the bucket count by.
|
|
"""
|
|
bucket_lower = self._find_le(value)
|
|
self.counters[bucket_lower] += inc
|
|
|
|
def _merge_from(self, other):
|
|
"""Adds the counts of another Histogram into this instance.
|
|
|
|
`other` must have the same buckets as this instance. The counts
|
|
from `other` are added to the counts for this instance.
|
|
|
|
Args:
|
|
other: Another Histogram instance with the same buckets as this instance.
|
|
|
|
Raises:
|
|
MergeStatisticsException: If `other` is not a Histogram or the buckets
|
|
are not the same.
|
|
"""
|
|
if not isinstance(other, Histogram):
|
|
raise MergeStatisticsException(
|
|
'Cannot merge %s into Histogram' % other.__class__.__name__)
|
|
if self.buckets != other.buckets:
|
|
raise MergeStatisticsException(
|
|
'Histogram buckets do not match. Expected %s, got %s'
|
|
% (self.buckets, other.buckets))
|
|
for bucket_lower, count in other.counters.items():
|
|
self.counters[bucket_lower] += count
|
|
|
|
def _pretty_print(self, name):
|
|
b = self.buckets + [float('inf')]
|
|
return ('%s:\n' % name) + '\n'.join(
|
|
[' [%s,%s): %d' % (lower, b[i+1], self.counters[lower])
|
|
for i, lower in enumerate(self.buckets)
|
|
if self.verbose_pretty_print or self.counters[lower]])
|
|
|
|
def copy(self):
|
|
return copy.copy(self)
|