aiexperiments-ai-duet/server/predict.py

76 lines
2.5 KiB
Python
Raw Normal View History

2016-11-11 18:53:51 +00:00
#
# Copyright 2016 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2017-01-13 21:50:56 +00:00
import magenta
from magenta.models.melody_rnn import melody_rnn_config_flags
from magenta.models.melody_rnn import melody_rnn_model
from magenta.models.melody_rnn import melody_rnn_sequence_generator
from magenta.protobuf import generator_pb2
2017-01-13 21:50:56 +00:00
from magenta.protobuf import music_pb2
2016-11-11 18:53:51 +00:00
import os
2017-01-13 21:50:56 +00:00
import time
2016-11-11 18:53:51 +00:00
import tempfile
2017-01-13 21:50:56 +00:00
import pretty_midi
BUNDLE_NAME = 'attention_rnn'
2016-11-11 18:53:51 +00:00
2017-01-13 21:50:56 +00:00
config = magenta.models.melody_rnn.melody_rnn_model.default_configs[BUNDLE_NAME]
bundle_file = magenta.music.read_bundle_file(os.path.abspath(BUNDLE_NAME+'.mag'))
steps_per_quarter = 4
2016-11-11 18:53:51 +00:00
2017-01-13 21:50:56 +00:00
generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator(
model=melody_rnn_model.MelodyRnnModel(config),
details=config.details,
steps_per_quarter=steps_per_quarter,
bundle=bundle_file)
def _steps_to_seconds(steps, qpm):
return steps * 60.0 / qpm / steps_per_quarter
2016-11-11 18:53:51 +00:00
def generate_midi(midi_data, total_seconds=10):
2017-01-13 21:50:56 +00:00
primer_sequence = magenta.music.midi_io.midi_to_sequence_proto(midi_data)
# predict the tempo
2016-11-11 18:53:51 +00:00
if len(primer_sequence.notes) > 4:
estimated_tempo = midi_data.estimate_tempo()
if estimated_tempo > 240:
qpm = estimated_tempo / 2
else:
qpm = estimated_tempo
else:
qpm = 120
primer_sequence.tempos[0].qpm = qpm
2017-01-13 21:50:56 +00:00
generator_options = generator_pb2.GeneratorOptions()
2016-11-11 18:53:51 +00:00
# Set the start time to begin on the next step after the last note ends.
2017-01-13 21:50:56 +00:00
last_end_time = (max(n.end_time for n in primer_sequence.notes)
if primer_sequence.notes else 0)
generator_options.generate_sections.add(
start_time=last_end_time + _steps_to_seconds(1, qpm),
end_time=total_seconds)
# generate the output sequence
generated_sequence = generator.generate(primer_sequence, generator_options)
2016-11-11 18:53:51 +00:00
output = tempfile.NamedTemporaryFile()
2017-01-13 21:50:56 +00:00
magenta.music.midi_io.sequence_proto_to_midi_file(generated_sequence, output.name)
2016-11-11 18:53:51 +00:00
output.seek(0)
return output