aiexperiments-ai-duet/server/third_party/magenta/pipelines/statistics.py
2016-11-11 15:34:34 -05:00

284 lines
8.8 KiB
Python

# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines statistics objects for pipelines."""
import abc
import bisect
import copy
# internal imports
import tensorflow as tf
class MergeStatisticsException(Exception):
pass
class Statistic(object):
"""Holds statistics about a Pipeline run.
Pipelines produce statistics on each call to `transform`.
Statistic objects can be merged together to aggregate
statistics over the course of many calls to `transform`.
A `Statistic` also has a string name which is used during merging. Any two
`Statistic` instances with the same name may be merged together. The name
should also be informative about what the `Statistic` is measuring. Names
do not need to be unique globally (outside of the `Pipeline` objects that
produce them) because a `Pipeline` that returns statistics will prepend
its own name, effectively creating a namespace for each `Pipeline`.
"""
__metaclass__ = abc.ABCMeta
def __init__(self, name):
"""Constructs a `Statistic`.
Subclass constructors are expected to call this constructor.
Args:
name: The string name for this `Statistic`. Any two `Statistic` objects
with the same name will be merged together. The name should also
describe what this statistic is measuring.
"""
self._name = name
@abc.abstractmethod
def _merge_from(self, other):
"""Merge another Statistic into this instance.
Takes another Statistic of the same type, and merges its information into
this instance.
Args:
other: Another Statistic instance.
"""
pass
@abc.abstractmethod
def _pretty_print(self, name):
"""Return a string representation of this instance using the given name.
Returns a human readable and nicely presented representation of this
instance. Since this instance does not know what its measuring, a string
name is given to use in the string representation.
For example, if this Statistic held a count, say 5, and the given name was
'error_count', then the string representation might be 'error_count: 5'.
Args:
name: A string name for this instance.
Returns:
A human readable and preferably a nicely presented string representation
of this instance.
"""
pass
@abc.abstractmethod
def copy(self):
"""Returns a new copy of `self`."""
pass
def merge_from(self, other):
if not isinstance(other, Statistic):
raise MergeStatisticsException(
'Cannot merge with non-Statistic of type %s' % type(other))
if self.name != other.name:
raise MergeStatisticsException(
'Name "%s" does not match this name "%s"' % (other.name, self.name))
self._merge_from(other)
@property
def name(self):
"""String name of this statistic.
This name is used to uniquely identify a statistic.
Returns:
The string name of `self`.
"""
return self._name
@name.setter
def name(self, value):
assert isinstance(value, basestring) and value
self._name = value
def __str__(self):
return self._pretty_print(self._name)
def merge_statistics(stats_list):
"""Merge together statistics of the same name in the given list.
Any two statistics in the list with the same name will be merged into a
single statistic using the `merge_from` method.
Args:
stats_list: A list of `Statistic` objects.
Returns:
A list of merged statistics. Each name will appear only once.
"""
name_map = {}
for stat in stats_list:
if stat.name not in name_map:
name_map[stat.name] = stat
else:
name_map[stat.name].merge_from(stat)
return name_map.values()
def log_statistics_list(stats_list, logger_fn=tf.logging.info):
"""Calls the given logger function on each `Statistic` in the list.
Args:
stats_list: A list of `Statistic` objects.
logger_fn: The function which will be called on the string representation
of each `Statistic`.
"""
for stat in stats_list:
logger_fn(str(stat))
class Counter(Statistic):
"""Holds a count.
Use `Counter` to count occurrences or sum values together.
"""
def __init__(self, name, start_value=0):
"""Constructs a Counter.
Args:
name: String name of this counter.
start_value: What value to start the count at.
"""
super(Counter, self).__init__(name)
self.count = start_value
def increment(self, inc=1):
"""Increment the count.
Args:
inc: (defaults to 1) How much to increment the count by.
"""
self.count += inc
def _merge_from(self, other):
"""Adds the count of another Counter into this instance."""
if not isinstance(other, Counter):
raise MergeStatisticsException(
'Cannot merge %s into Counter' % other.__class__.__name__)
self.count += other.count
def _pretty_print(self, name):
return '%s: %d' % (name, self.count)
def copy(self):
return copy.copy(self)
class Histogram(Statistic):
"""Represents a histogram.
A histogram is a list of counts, each over a range of values.
For example, given this list of values [0.5, 0.0, 1.0, 0.6, 1.5, 2.4, 0.1],
a histogram over 3 ranges [0, 1), [1, 2), [2, 3) would be:
[0, 1): 4
[1, 2): 2
[2, 3): 1
Each range is inclusive in the lower bound and exclusive in the upper bound
(hence the square open bracket but curved close bracket).
"""
def __init__(self, name, buckets, verbose_pretty_print=False):
"""Initializes the histogram with the given ranges.
Args:
name: String name of this histogram.
buckets: The ranges the histogram counts over. This is a list of values,
where each value is the inclusive lower bound of the range. An extra
range will be implicitly defined which spans from negative infinity
to the lowest given lower bound. The highest given lower bound
defines a range spaning to positive infinity. This way any value will
be included in the histogram counts. For example, if `buckets` is
[4, 6, 10] the histogram will have ranges
[-inf, 4), [4, 6), [6, 10), [10, inf).
verbose_pretty_print: If True, self.pretty_print will print the count for
every bucket. If False, only buckets with positive counts will be
printed.
"""
super(Histogram, self).__init__(name)
# List of inclusive lowest values in each bucket.
self.buckets = [float('-inf')] + sorted(set(buckets))
self.counters = dict([(bucket_lower, 0)
for bucket_lower in self.buckets])
self.verbose_pretty_print = verbose_pretty_print
# https://docs.python.org/2/library/bisect.html#searching-sorted-lists
def _find_le(self, x):
"""Find rightmost bucket less than or equal to x."""
i = bisect.bisect_right(self.buckets, x)
if i:
return self.buckets[i-1]
raise ValueError
def increment(self, value, inc=1):
"""Increment the bucket containing the given value.
The bucket count for which ever range `value` falls in will be incremented.
Args:
value: Any number.
inc: An integer. How much to increment the bucket count by.
"""
bucket_lower = self._find_le(value)
self.counters[bucket_lower] += inc
def _merge_from(self, other):
"""Adds the counts of another Histogram into this instance.
`other` must have the same buckets as this instance. The counts
from `other` are added to the counts for this instance.
Args:
other: Another Histogram instance with the same buckets as this instance.
Raises:
MergeStatisticsException: If `other` is not a Histogram or the buckets
are not the same.
"""
if not isinstance(other, Histogram):
raise MergeStatisticsException(
'Cannot merge %s into Histogram' % other.__class__.__name__)
if self.buckets != other.buckets:
raise MergeStatisticsException(
'Histogram buckets do not match. Expected %s, got %s'
% (self.buckets, other.buckets))
for bucket_lower, count in other.counters.items():
self.counters[bucket_lower] += count
def _pretty_print(self, name):
b = self.buckets + [float('inf')]
return ('%s:\n' % name) + '\n'.join(
[' [%s,%s): %d' % (lower, b[i+1], self.counters[lower])
for i, lower in enumerate(self.buckets)
if self.verbose_pretty_print or self.counters[lower]])
def copy(self):
return copy.copy(self)