132 lines
5.9 KiB
Python
132 lines
5.9 KiB
Python
|
# Copyright 2016 Google Inc. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Tests for consistency between PrettyMusic21 and NoteSequence proto."""
|
||
|
|
||
|
import os
|
||
|
|
||
|
# internal imports
|
||
|
import music21
|
||
|
import tensorflow as tf
|
||
|
|
||
|
from magenta.music import pretty_music21
|
||
|
from magenta.music.music21_to_note_sequence_io import _MUSIC21_TO_NOTE_SEQUENCE_MODE
|
||
|
from magenta.music.music21_to_note_sequence_io import _PRETTY_MUSIC21_TO_NOTE_SEQUENCE_KEY_NAME
|
||
|
from magenta.music.music21_to_note_sequence_io import music21_to_sequence_proto
|
||
|
from magenta.music.music21_to_note_sequence_io import pretty_music21_to_sequence_proto
|
||
|
from magenta.protobuf import music_pb2
|
||
|
|
||
|
|
||
|
class Music21ScoretoNoteSequenceTest(tf.test.TestCase):
|
||
|
|
||
|
def setUp(self):
|
||
|
"""Get the file path to the test MusicXML file."""
|
||
|
fname = 'bach-one_phrase-4_voices.xml'
|
||
|
self.source_fpath = os.path.join(tf.resource_loader.get_data_files_path(),
|
||
|
'testdata', fname)
|
||
|
|
||
|
def testMusic21ToSequenceFromMusicXML(self):
|
||
|
"""Test consistency between pretty_music21 and NoteSequence store of XML."""
|
||
|
parser = music21.musicxml.xmlToM21.MusicXMLImporter()
|
||
|
music21_score = parser.scoreFromFile(self.source_fpath)
|
||
|
simple_score = pretty_music21.PrettyMusic21(
|
||
|
music21_score, os.path.basename(self.source_fpath))
|
||
|
sequence_proto = music21_to_sequence_proto(
|
||
|
music21_score, os.path.basename(self.source_fpath))
|
||
|
self.CompareNoteSequenceAndMusic21Score(sequence_proto, simple_score)
|
||
|
|
||
|
def testPrettyMusic21ToSequenceFromMusicXML(self):
|
||
|
"""Test consistency between pretty_music21 and NoteSequence store of XML."""
|
||
|
parser = music21.musicxml.xmlToM21.MusicXMLImporter()
|
||
|
music21_score = parser.scoreFromFile(self.source_fpath)
|
||
|
simple_score = pretty_music21.PrettyMusic21(
|
||
|
music21_score, os.path.basename(self.source_fpath))
|
||
|
sequence_proto = pretty_music21_to_sequence_proto(
|
||
|
simple_score, os.path.basename(self.source_fpath))
|
||
|
self.CompareNoteSequenceAndMusic21Score(sequence_proto, simple_score)
|
||
|
|
||
|
def testPrettyMusic21ToSequenceFromMusicXMLWithSourceFnamePassedToFormer(
|
||
|
self):
|
||
|
"""Test consistency between pretty_music21 and NoteSequence store of XML."""
|
||
|
parser = music21.musicxml.xmlToM21.MusicXMLImporter()
|
||
|
music21_score = parser.scoreFromFile(self.source_fpath)
|
||
|
|
||
|
simple_score = pretty_music21.PrettyMusic21(
|
||
|
music21_score, os.path.basename(self.source_fpath))
|
||
|
|
||
|
sequence_proto = pretty_music21_to_sequence_proto(simple_score)
|
||
|
self.assertEqual(sequence_proto.filename, simple_score.filename)
|
||
|
|
||
|
def CompareNoteSequenceAndMusic21Score(self, sequence_proto, score):
|
||
|
"""Compares a NoteSequence proto to a PrettyMusic21 object.
|
||
|
|
||
|
Args:
|
||
|
sequence_proto: A tensorflow.magenta.Sequence proto.
|
||
|
score: A pretty_music21.PrettyMusic21 object.
|
||
|
"""
|
||
|
# Test score info.
|
||
|
self.assertEqual(sequence_proto.source_info.parser,
|
||
|
music_pb2.NoteSequence.SourceInfo.MUSIC21)
|
||
|
self.assertEqual(sequence_proto.filename, score.filename)
|
||
|
|
||
|
# Test time signature changes.
|
||
|
self.assertEqual(
|
||
|
len(score.time_signature_changes), len(sequence_proto.time_signatures))
|
||
|
for score_time, sequence_time in zip(score.time_signature_changes,
|
||
|
sequence_proto.time_signatures):
|
||
|
self.assertEqual(score_time.numerator, sequence_time.numerator)
|
||
|
self.assertEqual(score_time.denominator, sequence_time.denominator)
|
||
|
self.assertAlmostEqual(score_time.time, sequence_time.time)
|
||
|
|
||
|
# Test key signature changes.
|
||
|
self.assertEqual(
|
||
|
len(score.key_signature_changes), len(sequence_proto.key_signatures))
|
||
|
for score_key, sequence_key in zip(score.key_signature_changes,
|
||
|
sequence_proto.key_signatures):
|
||
|
key_pitch_idx = _PRETTY_MUSIC21_TO_NOTE_SEQUENCE_KEY_NAME.values().index(
|
||
|
sequence_key.key)
|
||
|
self.assertEqual(
|
||
|
score_key.key.upper(),
|
||
|
_PRETTY_MUSIC21_TO_NOTE_SEQUENCE_KEY_NAME.keys()[key_pitch_idx])
|
||
|
key_mode_idx = _MUSIC21_TO_NOTE_SEQUENCE_MODE.values().index(
|
||
|
sequence_key.mode)
|
||
|
self.assertEqual(score_key.mode,
|
||
|
_MUSIC21_TO_NOTE_SEQUENCE_MODE.keys()[key_mode_idx])
|
||
|
self.assertAlmostEqual(score_key.time, sequence_key.time)
|
||
|
|
||
|
# Test tempos.
|
||
|
self.assertEqual(len(score.tempo_changes), len(sequence_proto.tempos))
|
||
|
for score_tempo, sequence_tempo in zip(score.tempo_changes,
|
||
|
sequence_proto.tempos):
|
||
|
self.assertAlmostEqual(score_tempo.qpm, sequence_tempo.qpm)
|
||
|
self.assertAlmostEqual(score_tempo.time, sequence_tempo.time)
|
||
|
|
||
|
# Test part info.
|
||
|
self.assertEqual(len(score.part_infos), len(sequence_proto.part_infos))
|
||
|
for score_part_infos, sequence_part_infos in zip(
|
||
|
score.part_infos, sequence_proto.part_infos):
|
||
|
self.assertEqual(score_part_infos.index, sequence_part_infos.part)
|
||
|
self.assertEqual(score_part_infos.name, sequence_part_infos.name)
|
||
|
|
||
|
# Test parts and notes.
|
||
|
for score_note, sequence_note in zip(score.sorted_notes,
|
||
|
sequence_proto.notes):
|
||
|
self.assertAlmostEqual(score_note.pitch_midi, sequence_note.pitch)
|
||
|
self.assertAlmostEqual(score_note.start_time, sequence_note.start_time)
|
||
|
self.assertAlmostEqual(score_note.end_time, sequence_note.end_time)
|
||
|
self.assertEqual(score_note.part_index, sequence_note.part)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
tf.test.main()
|