aiexperiments-ai-duet/server/magenta/music/midi_io_test.py
Yotam Mann ff837cec16 server
2016-11-11 13:53:51 -05:00

234 lines
9.9 KiB
Python

# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test to ensure correct midi input and output."""
from collections import defaultdict
import os.path
import tempfile
# internal imports
import midi as py_midi
import pretty_midi
import tensorflow as tf
from magenta.music import midi_io
# self.midi_simple_filename contains a c-major scale of 8 quarter notes each
# with a sustain of .95 of the entire note. Here are the first two notes dumped
# using mididump.py:
# midi.NoteOnEvent(tick=0, channel=0, data=[60, 100]),
# midi.NoteOnEvent(tick=209, channel=0, data=[60, 0]),
# midi.NoteOnEvent(tick=11, channel=0, data=[62, 100]),
# midi.NoteOnEvent(tick=209, channel=0, data=[62, 0]),
_SIMPLE_MIDI_FILE_VELO = 100
_SIMPLE_MIDI_FILE_NUM_NOTES = 8
_SIMPLE_MIDI_FILE_SUSTAIN = .95
# self.midi_complex_filename contains many instruments including percussion as
# well as control change and pitch bend events.
# self.midi_is_drum_filename contains 41 tracks, two of which are on channel 9.
# self.midi_event_order_filename contains notes ordered
# non-monotonically by pitch. Here are relevent events as printed by
# mididump.py:
# midi.NoteOnEvent(tick=0, channel=0, data=[1, 100]),
# midi.NoteOnEvent(tick=0, channel=0, data=[3, 100]),
# midi.NoteOnEvent(tick=0, channel=0, data=[2, 100]),
# midi.NoteOnEvent(tick=4400, channel=0, data=[3, 0]),
# midi.NoteOnEvent(tick=0, channel=0, data=[1, 0]),
# midi.NoteOnEvent(tick=0, channel=0, data=[2, 0]),
class MidiIoTest(tf.test.TestCase):
def setUp(self):
self.midi_simple_filename = os.path.join(
tf.resource_loader.get_data_files_path(), '../testdata/example.mid')
self.midi_complex_filename = os.path.join(
tf.resource_loader.get_data_files_path(),
'../testdata/example_complex.mid')
self.midi_is_drum_filename = os.path.join(
tf.resource_loader.get_data_files_path(),
'../testdata/example_is_drum.mid')
self.midi_event_order_filename = os.path.join(
tf.resource_loader.get_data_files_path(),
'../testdata/example_event_order.mid')
def CheckPrettyMidiAndSequence(self, midi, sequence_proto):
"""Compares PrettyMIDI object against a sequence proto.
Args:
midi: A pretty_midi.PrettyMIDI object.
sequence_proto: A tensorflow.magenta.Sequence proto.
"""
# Test time signature changes.
self.assertEqual(len(midi.time_signature_changes),
len(sequence_proto.time_signatures))
for midi_time, sequence_time in zip(midi.time_signature_changes,
sequence_proto.time_signatures):
self.assertEqual(midi_time.numerator, sequence_time.numerator)
self.assertEqual(midi_time.denominator, sequence_time.denominator)
self.assertAlmostEqual(midi_time.time, sequence_time.time)
# Test key signature changes.
self.assertEqual(len(midi.key_signature_changes),
len(sequence_proto.key_signatures))
for midi_key, sequence_key in zip(midi.key_signature_changes,
sequence_proto.key_signatures):
self.assertEqual(midi_key.key_number % 12, sequence_key.key)
self.assertEqual(midi_key.key_number / 12, sequence_key.mode)
self.assertAlmostEqual(midi_key.time, sequence_key.time)
# Test tempos.
midi_times, midi_qpms = midi.get_tempo_changes()
self.assertEqual(len(midi_times),
len(sequence_proto.tempos))
self.assertEqual(len(midi_qpms),
len(sequence_proto.tempos))
for midi_time, midi_qpm, sequence_tempo in zip(
midi_times, midi_qpms, sequence_proto.tempos):
self.assertAlmostEqual(midi_qpm, sequence_tempo.qpm)
self.assertAlmostEqual(midi_time, sequence_tempo.time)
# Test instruments.
seq_instruments = defaultdict(lambda: defaultdict(list))
for seq_note in sequence_proto.notes:
seq_instruments[
(seq_note.instrument, seq_note.program, seq_note.is_drum)][
'notes'].append(seq_note)
for seq_bend in sequence_proto.pitch_bends:
seq_instruments[
(seq_bend.instrument, seq_bend.program, seq_bend.is_drum)][
'bends'].append(seq_bend)
for seq_control in sequence_proto.control_changes:
seq_instruments[
(seq_control.instrument, seq_control.program, seq_control.is_drum)][
'controls'].append(seq_control)
sorted_seq_instrument_keys = sorted(
seq_instruments.keys(),
key=lambda (instr, program, is_drum): (instr, program, is_drum))
self.assertEqual(len(midi.instruments), len(seq_instruments))
for midi_instrument, seq_instrument_key in zip(
midi.instruments, sorted_seq_instrument_keys):
seq_instrument_notes = seq_instruments[seq_instrument_key]['notes']
self.assertEqual(len(midi_instrument.notes), len(seq_instrument_notes))
for midi_note, sequence_note in zip(midi_instrument.notes,
seq_instrument_notes):
self.assertEqual(midi_note.pitch, sequence_note.pitch)
self.assertEqual(midi_note.velocity, sequence_note.velocity)
self.assertAlmostEqual(midi_note.start, sequence_note.start_time)
self.assertAlmostEqual(midi_note.end, sequence_note.end_time)
seq_instrument_pitch_bends = seq_instruments[seq_instrument_key]['bends']
self.assertEqual(len(midi_instrument.pitch_bends),
len(seq_instrument_pitch_bends))
for midi_pitch_bend, sequence_pitch_bend in zip(
midi_instrument.pitch_bends,
seq_instrument_pitch_bends):
self.assertEqual(midi_pitch_bend.pitch, sequence_pitch_bend.bend)
self.assertAlmostEqual(midi_pitch_bend.time, sequence_pitch_bend.time)
def CheckMidiToSequence(self, filename):
"""Test the translation from PrettyMIDI to Sequence proto."""
source_midi = pretty_midi.PrettyMIDI(filename)
sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
self.CheckPrettyMidiAndSequence(source_midi, sequence_proto)
def CheckSequenceToPrettyMidi(self, filename):
"""Test the translation from Sequence proto to PrettyMIDI."""
source_midi = pretty_midi.PrettyMIDI(filename)
sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
translated_midi = midi_io.sequence_proto_to_pretty_midi(sequence_proto)
self.CheckPrettyMidiAndSequence(translated_midi, sequence_proto)
def CheckReadWriteMidi(self, filename):
"""Test writing to a MIDI file and comparing it to the original Sequence."""
# TODO(deck): The input MIDI file is opened in pretty-midi and
# re-written to a temp file, sanitizing the MIDI data (reordering
# note ons, etc). Issue 85 in the pretty-midi GitHub
# (http://github.com/craffel/pretty-midi/issues/85) requests that
# this sanitization be available outside of the context of a file
# write. If that is implemented, this rewrite code should be
# modified or deleted.
with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as rewrite_file:
original_midi = pretty_midi.PrettyMIDI(filename)
original_midi.write(rewrite_file.name)
source_midi = pretty_midi.PrettyMIDI(rewrite_file.name)
sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
# Translate the NoteSequence to MIDI and write to a file.
with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as temp_file:
midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
# Read it back in and compare to source.
created_midi = pretty_midi.PrettyMIDI(temp_file.name)
self.CheckPrettyMidiAndSequence(created_midi, sequence_proto)
def testSimplePrettyMidiToSequence(self):
self.CheckMidiToSequence(self.midi_simple_filename)
def testSimpleSequenceToPrettyMidi(self):
self.CheckSequenceToPrettyMidi(self.midi_simple_filename)
def testSimpleReadWriteMidi(self):
self.CheckReadWriteMidi(self.midi_simple_filename)
def testComplexPrettyMidiToSequence(self):
self.CheckMidiToSequence(self.midi_complex_filename)
def testComplexSequenceToPrettyMidi(self):
self.CheckSequenceToPrettyMidi(self.midi_complex_filename)
def testIsDrumDetection(self):
"""Verify that is_drum instruments are properly tracked.
self.midi_is_drum_filename is a MIDI file containing two tracks
set to channel 9 (is_drum == True). Each contains one NoteOn. This
test is designed to catch a bug where the second track would lose
is_drum, remapping the drum track to an instrument track.
"""
sequence_proto = midi_io.midi_file_to_sequence_proto(
self.midi_is_drum_filename)
with tempfile.NamedTemporaryFile(prefix='MidiDrumTest') as temp_file:
midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
midi_data1 = py_midi.read_midifile(self.midi_is_drum_filename)
midi_data2 = py_midi.read_midifile(temp_file.name)
# Count number of channel 9 Note Ons.
channel_counts = [0, 0]
for index, midi_data in enumerate([midi_data1, midi_data2]):
for track in midi_data:
for event in track:
if (event.name == 'Note On' and
event.velocity > 0 and event.channel == 9):
channel_counts[index] += 1
self.assertEqual(channel_counts, [2, 2])
def testComplexReadWriteMidi(self):
self.CheckReadWriteMidi(self.midi_complex_filename)
def testEventOrdering(self):
self.CheckReadWriteMidi(self.midi_event_order_filename)
if __name__ == '__main__':
tf.test.main()