234 lines
9.9 KiB
Python
234 lines
9.9 KiB
Python
|
# Copyright 2016 Google Inc. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Test to ensure correct midi input and output."""
|
||
|
|
||
|
from collections import defaultdict
|
||
|
import os.path
|
||
|
import tempfile
|
||
|
|
||
|
# internal imports
|
||
|
import midi as py_midi
|
||
|
import pretty_midi
|
||
|
import tensorflow as tf
|
||
|
|
||
|
from magenta.music import midi_io
|
||
|
|
||
|
# self.midi_simple_filename contains a c-major scale of 8 quarter notes each
|
||
|
# with a sustain of .95 of the entire note. Here are the first two notes dumped
|
||
|
# using mididump.py:
|
||
|
# midi.NoteOnEvent(tick=0, channel=0, data=[60, 100]),
|
||
|
# midi.NoteOnEvent(tick=209, channel=0, data=[60, 0]),
|
||
|
# midi.NoteOnEvent(tick=11, channel=0, data=[62, 100]),
|
||
|
# midi.NoteOnEvent(tick=209, channel=0, data=[62, 0]),
|
||
|
_SIMPLE_MIDI_FILE_VELO = 100
|
||
|
_SIMPLE_MIDI_FILE_NUM_NOTES = 8
|
||
|
_SIMPLE_MIDI_FILE_SUSTAIN = .95
|
||
|
|
||
|
# self.midi_complex_filename contains many instruments including percussion as
|
||
|
# well as control change and pitch bend events.
|
||
|
|
||
|
# self.midi_is_drum_filename contains 41 tracks, two of which are on channel 9.
|
||
|
|
||
|
# self.midi_event_order_filename contains notes ordered
|
||
|
# non-monotonically by pitch. Here are relevent events as printed by
|
||
|
# mididump.py:
|
||
|
# midi.NoteOnEvent(tick=0, channel=0, data=[1, 100]),
|
||
|
# midi.NoteOnEvent(tick=0, channel=0, data=[3, 100]),
|
||
|
# midi.NoteOnEvent(tick=0, channel=0, data=[2, 100]),
|
||
|
# midi.NoteOnEvent(tick=4400, channel=0, data=[3, 0]),
|
||
|
# midi.NoteOnEvent(tick=0, channel=0, data=[1, 0]),
|
||
|
# midi.NoteOnEvent(tick=0, channel=0, data=[2, 0]),
|
||
|
|
||
|
|
||
|
class MidiIoTest(tf.test.TestCase):
|
||
|
|
||
|
def setUp(self):
|
||
|
self.midi_simple_filename = os.path.join(
|
||
|
tf.resource_loader.get_data_files_path(), '../testdata/example.mid')
|
||
|
self.midi_complex_filename = os.path.join(
|
||
|
tf.resource_loader.get_data_files_path(),
|
||
|
'../testdata/example_complex.mid')
|
||
|
self.midi_is_drum_filename = os.path.join(
|
||
|
tf.resource_loader.get_data_files_path(),
|
||
|
'../testdata/example_is_drum.mid')
|
||
|
self.midi_event_order_filename = os.path.join(
|
||
|
tf.resource_loader.get_data_files_path(),
|
||
|
'../testdata/example_event_order.mid')
|
||
|
|
||
|
def CheckPrettyMidiAndSequence(self, midi, sequence_proto):
|
||
|
"""Compares PrettyMIDI object against a sequence proto.
|
||
|
|
||
|
Args:
|
||
|
midi: A pretty_midi.PrettyMIDI object.
|
||
|
sequence_proto: A tensorflow.magenta.Sequence proto.
|
||
|
"""
|
||
|
# Test time signature changes.
|
||
|
self.assertEqual(len(midi.time_signature_changes),
|
||
|
len(sequence_proto.time_signatures))
|
||
|
for midi_time, sequence_time in zip(midi.time_signature_changes,
|
||
|
sequence_proto.time_signatures):
|
||
|
self.assertEqual(midi_time.numerator, sequence_time.numerator)
|
||
|
self.assertEqual(midi_time.denominator, sequence_time.denominator)
|
||
|
self.assertAlmostEqual(midi_time.time, sequence_time.time)
|
||
|
|
||
|
# Test key signature changes.
|
||
|
self.assertEqual(len(midi.key_signature_changes),
|
||
|
len(sequence_proto.key_signatures))
|
||
|
for midi_key, sequence_key in zip(midi.key_signature_changes,
|
||
|
sequence_proto.key_signatures):
|
||
|
self.assertEqual(midi_key.key_number % 12, sequence_key.key)
|
||
|
self.assertEqual(midi_key.key_number / 12, sequence_key.mode)
|
||
|
self.assertAlmostEqual(midi_key.time, sequence_key.time)
|
||
|
|
||
|
# Test tempos.
|
||
|
midi_times, midi_qpms = midi.get_tempo_changes()
|
||
|
self.assertEqual(len(midi_times),
|
||
|
len(sequence_proto.tempos))
|
||
|
self.assertEqual(len(midi_qpms),
|
||
|
len(sequence_proto.tempos))
|
||
|
for midi_time, midi_qpm, sequence_tempo in zip(
|
||
|
midi_times, midi_qpms, sequence_proto.tempos):
|
||
|
self.assertAlmostEqual(midi_qpm, sequence_tempo.qpm)
|
||
|
self.assertAlmostEqual(midi_time, sequence_tempo.time)
|
||
|
|
||
|
# Test instruments.
|
||
|
seq_instruments = defaultdict(lambda: defaultdict(list))
|
||
|
for seq_note in sequence_proto.notes:
|
||
|
seq_instruments[
|
||
|
(seq_note.instrument, seq_note.program, seq_note.is_drum)][
|
||
|
'notes'].append(seq_note)
|
||
|
for seq_bend in sequence_proto.pitch_bends:
|
||
|
seq_instruments[
|
||
|
(seq_bend.instrument, seq_bend.program, seq_bend.is_drum)][
|
||
|
'bends'].append(seq_bend)
|
||
|
for seq_control in sequence_proto.control_changes:
|
||
|
seq_instruments[
|
||
|
(seq_control.instrument, seq_control.program, seq_control.is_drum)][
|
||
|
'controls'].append(seq_control)
|
||
|
|
||
|
sorted_seq_instrument_keys = sorted(
|
||
|
seq_instruments.keys(),
|
||
|
key=lambda (instr, program, is_drum): (instr, program, is_drum))
|
||
|
|
||
|
self.assertEqual(len(midi.instruments), len(seq_instruments))
|
||
|
for midi_instrument, seq_instrument_key in zip(
|
||
|
midi.instruments, sorted_seq_instrument_keys):
|
||
|
|
||
|
seq_instrument_notes = seq_instruments[seq_instrument_key]['notes']
|
||
|
|
||
|
self.assertEqual(len(midi_instrument.notes), len(seq_instrument_notes))
|
||
|
for midi_note, sequence_note in zip(midi_instrument.notes,
|
||
|
seq_instrument_notes):
|
||
|
self.assertEqual(midi_note.pitch, sequence_note.pitch)
|
||
|
self.assertEqual(midi_note.velocity, sequence_note.velocity)
|
||
|
self.assertAlmostEqual(midi_note.start, sequence_note.start_time)
|
||
|
self.assertAlmostEqual(midi_note.end, sequence_note.end_time)
|
||
|
|
||
|
seq_instrument_pitch_bends = seq_instruments[seq_instrument_key]['bends']
|
||
|
self.assertEqual(len(midi_instrument.pitch_bends),
|
||
|
len(seq_instrument_pitch_bends))
|
||
|
for midi_pitch_bend, sequence_pitch_bend in zip(
|
||
|
midi_instrument.pitch_bends,
|
||
|
seq_instrument_pitch_bends):
|
||
|
self.assertEqual(midi_pitch_bend.pitch, sequence_pitch_bend.bend)
|
||
|
self.assertAlmostEqual(midi_pitch_bend.time, sequence_pitch_bend.time)
|
||
|
|
||
|
def CheckMidiToSequence(self, filename):
|
||
|
"""Test the translation from PrettyMIDI to Sequence proto."""
|
||
|
source_midi = pretty_midi.PrettyMIDI(filename)
|
||
|
sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
|
||
|
self.CheckPrettyMidiAndSequence(source_midi, sequence_proto)
|
||
|
|
||
|
def CheckSequenceToPrettyMidi(self, filename):
|
||
|
"""Test the translation from Sequence proto to PrettyMIDI."""
|
||
|
source_midi = pretty_midi.PrettyMIDI(filename)
|
||
|
sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
|
||
|
translated_midi = midi_io.sequence_proto_to_pretty_midi(sequence_proto)
|
||
|
self.CheckPrettyMidiAndSequence(translated_midi, sequence_proto)
|
||
|
|
||
|
def CheckReadWriteMidi(self, filename):
|
||
|
"""Test writing to a MIDI file and comparing it to the original Sequence."""
|
||
|
|
||
|
# TODO(deck): The input MIDI file is opened in pretty-midi and
|
||
|
# re-written to a temp file, sanitizing the MIDI data (reordering
|
||
|
# note ons, etc). Issue 85 in the pretty-midi GitHub
|
||
|
# (http://github.com/craffel/pretty-midi/issues/85) requests that
|
||
|
# this sanitization be available outside of the context of a file
|
||
|
# write. If that is implemented, this rewrite code should be
|
||
|
# modified or deleted.
|
||
|
with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as rewrite_file:
|
||
|
original_midi = pretty_midi.PrettyMIDI(filename)
|
||
|
original_midi.write(rewrite_file.name)
|
||
|
source_midi = pretty_midi.PrettyMIDI(rewrite_file.name)
|
||
|
sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
|
||
|
|
||
|
# Translate the NoteSequence to MIDI and write to a file.
|
||
|
with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as temp_file:
|
||
|
midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
|
||
|
|
||
|
# Read it back in and compare to source.
|
||
|
created_midi = pretty_midi.PrettyMIDI(temp_file.name)
|
||
|
|
||
|
self.CheckPrettyMidiAndSequence(created_midi, sequence_proto)
|
||
|
|
||
|
def testSimplePrettyMidiToSequence(self):
|
||
|
self.CheckMidiToSequence(self.midi_simple_filename)
|
||
|
|
||
|
def testSimpleSequenceToPrettyMidi(self):
|
||
|
self.CheckSequenceToPrettyMidi(self.midi_simple_filename)
|
||
|
|
||
|
def testSimpleReadWriteMidi(self):
|
||
|
self.CheckReadWriteMidi(self.midi_simple_filename)
|
||
|
|
||
|
def testComplexPrettyMidiToSequence(self):
|
||
|
self.CheckMidiToSequence(self.midi_complex_filename)
|
||
|
|
||
|
def testComplexSequenceToPrettyMidi(self):
|
||
|
self.CheckSequenceToPrettyMidi(self.midi_complex_filename)
|
||
|
|
||
|
def testIsDrumDetection(self):
|
||
|
"""Verify that is_drum instruments are properly tracked.
|
||
|
|
||
|
self.midi_is_drum_filename is a MIDI file containing two tracks
|
||
|
set to channel 9 (is_drum == True). Each contains one NoteOn. This
|
||
|
test is designed to catch a bug where the second track would lose
|
||
|
is_drum, remapping the drum track to an instrument track.
|
||
|
"""
|
||
|
sequence_proto = midi_io.midi_file_to_sequence_proto(
|
||
|
self.midi_is_drum_filename)
|
||
|
with tempfile.NamedTemporaryFile(prefix='MidiDrumTest') as temp_file:
|
||
|
midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
|
||
|
midi_data1 = py_midi.read_midifile(self.midi_is_drum_filename)
|
||
|
midi_data2 = py_midi.read_midifile(temp_file.name)
|
||
|
|
||
|
# Count number of channel 9 Note Ons.
|
||
|
channel_counts = [0, 0]
|
||
|
for index, midi_data in enumerate([midi_data1, midi_data2]):
|
||
|
for track in midi_data:
|
||
|
for event in track:
|
||
|
if (event.name == 'Note On' and
|
||
|
event.velocity > 0 and event.channel == 9):
|
||
|
channel_counts[index] += 1
|
||
|
self.assertEqual(channel_counts, [2, 2])
|
||
|
|
||
|
def testComplexReadWriteMidi(self):
|
||
|
self.CheckReadWriteMidi(self.midi_complex_filename)
|
||
|
|
||
|
def testEventOrdering(self):
|
||
|
self.CheckReadWriteMidi(self.midi_event_order_filename)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
tf.test.main()
|