110 lines
4.3 KiB
Python
110 lines
4.3 KiB
Python
# Copyright 2016 Google Inc. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for pipelines_common."""
|
|
|
|
# internal imports
|
|
import tensorflow as tf
|
|
|
|
from magenta.common import testing_lib as common_testing_lib
|
|
from magenta.music import constants
|
|
from magenta.music import melodies_lib
|
|
from magenta.music import sequences_lib
|
|
from magenta.music import testing_lib
|
|
from magenta.pipelines import pipelines_common
|
|
from magenta.protobuf import music_pb2
|
|
|
|
|
|
NOTE_OFF = constants.MELODY_NOTE_OFF
|
|
NO_EVENT = constants.MELODY_NO_EVENT
|
|
|
|
|
|
class PipelineUnitsCommonTest(tf.test.TestCase):
|
|
|
|
def _unit_transform_test(self, unit, input_instance,
|
|
expected_outputs):
|
|
outputs = unit.transform(input_instance)
|
|
self.assertTrue(isinstance(outputs, list))
|
|
common_testing_lib.assert_set_equality(self, expected_outputs, outputs)
|
|
self.assertEqual(unit.input_type, type(input_instance))
|
|
if outputs:
|
|
self.assertEqual(unit.output_type, type(outputs[0]))
|
|
|
|
def testQuantizer(self):
|
|
steps_per_quarter = 4
|
|
note_sequence = common_testing_lib.parse_test_proto(
|
|
music_pb2.NoteSequence,
|
|
"""
|
|
time_signatures: {
|
|
numerator: 4
|
|
denominator: 4}
|
|
tempos: {
|
|
qpm: 60}""")
|
|
testing_lib.add_track(
|
|
note_sequence, 0,
|
|
[(12, 100, 0.01, 10.0), (11, 55, 0.22, 0.50), (40, 45, 2.50, 3.50),
|
|
(55, 120, 4.0, 4.01), (52, 99, 4.75, 5.0)])
|
|
expected_quantized_sequence = sequences_lib.QuantizedSequence()
|
|
expected_quantized_sequence.qpm = 60.0
|
|
expected_quantized_sequence.steps_per_quarter = steps_per_quarter
|
|
testing_lib.add_quantized_track(
|
|
expected_quantized_sequence, 0,
|
|
[(12, 100, 0, 40), (11, 55, 1, 2), (40, 45, 10, 14),
|
|
(55, 120, 16, 17), (52, 99, 19, 20)])
|
|
|
|
unit = pipelines_common.Quantizer(steps_per_quarter)
|
|
self._unit_transform_test(unit, note_sequence,
|
|
[expected_quantized_sequence])
|
|
|
|
def testMonophonicMelodyExtractor(self):
|
|
quantized_sequence = sequences_lib.QuantizedSequence()
|
|
quantized_sequence.steps_per_quarter = 1
|
|
testing_lib.add_quantized_track(
|
|
quantized_sequence, 0,
|
|
[(12, 100, 2, 4), (11, 1, 6, 7)])
|
|
testing_lib.add_quantized_track(
|
|
quantized_sequence, 1,
|
|
[(12, 127, 2, 4), (14, 50, 6, 8)])
|
|
expected_events = [
|
|
[NO_EVENT, NO_EVENT, 12, NO_EVENT, NOTE_OFF, NO_EVENT, 11],
|
|
[NO_EVENT, NO_EVENT, 12, NO_EVENT, NOTE_OFF, NO_EVENT, 14, NO_EVENT]]
|
|
expected_melodies = []
|
|
for events_list in expected_events:
|
|
melody = melodies_lib.MonophonicMelody()
|
|
melody.from_event_list(events_list, steps_per_quarter=1, steps_per_bar=4)
|
|
expected_melodies.append(melody)
|
|
unit = pipelines_common.MonophonicMelodyExtractor(
|
|
min_bars=1, min_unique_pitches=1, gap_bars=1)
|
|
self._unit_transform_test(unit, quantized_sequence, expected_melodies)
|
|
|
|
def testRandomPartition(self):
|
|
random_partition = pipelines_common.RandomPartition(
|
|
str, ['a', 'b', 'c'], [0.1, 0.4])
|
|
random_nums = [0.55, 0.05, 0.34, 0.99]
|
|
choices = ['c', 'a', 'b', 'c']
|
|
random_partition.rand_func = iter(random_nums).next
|
|
self.assertEqual(random_partition.input_type, str)
|
|
self.assertEqual(random_partition.output_type,
|
|
{'a': str, 'b': str, 'c': str})
|
|
for i, s in enumerate(['hello', 'qwerty', '1234567890', 'zxcvbnm']):
|
|
results = random_partition.transform(s)
|
|
self.assertTrue(isinstance(results, dict))
|
|
self.assertEqual(set(results.keys()), set(['a', 'b', 'c']))
|
|
self.assertEqual(len(results.values()), 3)
|
|
self.assertEqual(len([l for l in results.values() if l == []]), 2) # pylint: disable=g-explicit-bool-comparison
|
|
self.assertEqual(results[choices[i]], [s])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
tf.test.main()
|