60 lines
2.2 KiB
Python
60 lines
2.2 KiB
Python
|
# Copyright 2016 Google Inc. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Tests to ensure correct reading and writing of NoteSequence record files."""
|
||
|
|
||
|
import tempfile
|
||
|
|
||
|
# internal imports
|
||
|
import tensorflow as tf
|
||
|
|
||
|
from magenta.music import note_sequence_io
|
||
|
from magenta.protobuf import music_pb2
|
||
|
|
||
|
|
||
|
class NoteSequenceIoTest(tf.test.TestCase):
|
||
|
|
||
|
def testGenerateId(self):
|
||
|
sequence_id_1 = note_sequence_io.generate_id(
|
||
|
'/my/file/name', 'my_collection', 'midi')
|
||
|
self.assertEquals('/id/midi/my_collection/', sequence_id_1[0:23])
|
||
|
sequence_id_2 = note_sequence_io.generate_id(
|
||
|
'/my/file/name', 'your_collection', 'abc')
|
||
|
self.assertEquals('/id/abc/your_collection/', sequence_id_2[0:24])
|
||
|
self.assertEquals(sequence_id_1[23:], sequence_id_2[24:])
|
||
|
|
||
|
sequence_id_3 = note_sequence_io.generate_id(
|
||
|
'/your/file/name', 'my_collection', 'abc')
|
||
|
self.assertNotEquals(sequence_id_3[22:], sequence_id_1[23:])
|
||
|
self.assertNotEquals(sequence_id_3[22:], sequence_id_2[24:])
|
||
|
|
||
|
def testNoteSequenceRecordWriterAndIterator(self):
|
||
|
sequences = []
|
||
|
for i in xrange(4):
|
||
|
sequence = music_pb2.NoteSequence()
|
||
|
sequence.id = str(i)
|
||
|
sequence.notes.add().pitch = i
|
||
|
sequences.append(sequence)
|
||
|
|
||
|
with tempfile.NamedTemporaryFile(prefix='NoteSequenceIoTest') as temp_file:
|
||
|
with note_sequence_io.NoteSequenceRecordWriter(temp_file.name) as writer:
|
||
|
for sequence in sequences:
|
||
|
writer.write(sequence)
|
||
|
|
||
|
for i, sequence in enumerate(
|
||
|
note_sequence_io.note_sequence_record_iterator(temp_file.name)):
|
||
|
self.assertEquals(sequence, sequences[i])
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
tf.test.main()
|