# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests to ensure correct reading and writing of NoteSequence record files."""

import tempfile

# internal imports
import tensorflow as tf

from magenta.music import note_sequence_io
from magenta.protobuf import music_pb2


class NoteSequenceIoTest(tf.test.TestCase):

  def testGenerateId(self):
    sequence_id_1 = note_sequence_io.generate_id(
        '/my/file/name', 'my_collection', 'midi')
    self.assertEquals('/id/midi/my_collection/', sequence_id_1[0:23])
    sequence_id_2 = note_sequence_io.generate_id(
        '/my/file/name', 'your_collection', 'abc')
    self.assertEquals('/id/abc/your_collection/', sequence_id_2[0:24])
    self.assertEquals(sequence_id_1[23:], sequence_id_2[24:])

    sequence_id_3 = note_sequence_io.generate_id(
        '/your/file/name', 'my_collection', 'abc')
    self.assertNotEquals(sequence_id_3[22:], sequence_id_1[23:])
    self.assertNotEquals(sequence_id_3[22:], sequence_id_2[24:])

  def testNoteSequenceRecordWriterAndIterator(self):
    sequences = []
    for i in xrange(4):
      sequence = music_pb2.NoteSequence()
      sequence.id = str(i)
      sequence.notes.add().pitch = i
      sequences.append(sequence)

    with tempfile.NamedTemporaryFile(prefix='NoteSequenceIoTest') as temp_file:
      with note_sequence_io.NoteSequenceRecordWriter(temp_file.name) as writer:
        for sequence in sequences:
          writer.write(sequence)

      for i, sequence in enumerate(
          note_sequence_io.note_sequence_record_iterator(temp_file.name)):
        self.assertEquals(sequence, sequences[i])

if __name__ == '__main__':
  tf.test.main()