updating to latest magenta and tf
This commit is contained in:
parent
9c4445abf9
commit
c94efe36b1
BIN
server/basic_rnn.mag
Normal file
BIN
server/basic_rnn.mag
Normal file
Binary file not shown.
@ -14,23 +14,40 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import magenta.models.basic_rnn.basic_rnn_generator as basic_rnn_generator
|
|
||||||
from magenta.music import sequence_generator_bundle
|
import magenta
|
||||||
|
|
||||||
|
from magenta.models.melody_rnn import melody_rnn_config_flags
|
||||||
|
from magenta.models.melody_rnn import melody_rnn_model
|
||||||
|
from magenta.models.melody_rnn import melody_rnn_sequence_generator
|
||||||
from magenta.protobuf import generator_pb2
|
from magenta.protobuf import generator_pb2
|
||||||
from magenta.music import midi_io
|
from magenta.protobuf import music_pb2
|
||||||
from magenta.models.shared.melody_rnn_generate import _steps_to_seconds
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import pretty_midi
|
||||||
|
|
||||||
|
BUNDLE_NAME = 'basic_rnn'
|
||||||
|
|
||||||
basic_generator = basic_rnn_generator.create_generator(
|
config = magenta.models.melody_rnn.melody_rnn_model.default_configs[BUNDLE_NAME]
|
||||||
None,
|
bundle_file = magenta.music.read_bundle_file(os.path.abspath(BUNDLE_NAME+'.mag'))
|
||||||
sequence_generator_bundle.read_bundle_file(os.path.abspath('./magenta/basic_rnn.mag')),
|
steps_per_quarter = 4
|
||||||
4)
|
|
||||||
|
generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator(
|
||||||
|
model=melody_rnn_model.MelodyRnnModel(config),
|
||||||
|
details=config.details,
|
||||||
|
steps_per_quarter=steps_per_quarter,
|
||||||
|
bundle=bundle_file)
|
||||||
|
|
||||||
|
def _steps_to_seconds(steps, qpm):
|
||||||
|
return steps * 60.0 / qpm / steps_per_quarter
|
||||||
|
|
||||||
def generate_midi(midi_data, total_seconds=10):
|
def generate_midi(midi_data, total_seconds=10):
|
||||||
primer_sequence = midi_io.midi_to_sequence_proto(midi_data)
|
primer_sequence = magenta.music.midi_io.midi_to_sequence_proto(midi_data)
|
||||||
generate_request = generator_pb2.GenerateSequenceRequest()
|
|
||||||
|
# predict the tempo
|
||||||
if len(primer_sequence.notes) > 4:
|
if len(primer_sequence.notes) > 4:
|
||||||
estimated_tempo = midi_data.estimate_tempo()
|
estimated_tempo = midi_data.estimate_tempo()
|
||||||
if estimated_tempo > 240:
|
if estimated_tempo > 240:
|
||||||
@ -40,18 +57,19 @@ def generate_midi(midi_data, total_seconds=10):
|
|||||||
else:
|
else:
|
||||||
qpm = 120
|
qpm = 120
|
||||||
primer_sequence.tempos[0].qpm = qpm
|
primer_sequence.tempos[0].qpm = qpm
|
||||||
generate_request.input_sequence.CopyFrom(primer_sequence)
|
|
||||||
generate_section = (generate_request.generator_options.generate_sections.add())
|
generator_options = generator_pb2.GeneratorOptions()
|
||||||
# Set the start time to begin on the next step after the last note ends.
|
# Set the start time to begin on the next step after the last note ends.
|
||||||
notes_by_end_time = sorted(primer_sequence.notes, key=lambda n: n.end_time)
|
last_end_time = (max(n.end_time for n in primer_sequence.notes)
|
||||||
last_end_time = notes_by_end_time[-1].end_time if notes_by_end_time else 0
|
if primer_sequence.notes else 0)
|
||||||
generate_section.start_time_seconds = last_end_time + _steps_to_seconds(
|
generator_options.generate_sections.add(
|
||||||
1, qpm)
|
start_time=last_end_time + _steps_to_seconds(1, qpm),
|
||||||
generate_section.end_time_seconds = total_seconds
|
end_time=total_seconds)
|
||||||
# generate_response = generator_map[generator_name].generate(generate_request)
|
|
||||||
generate_response = basic_generator.generate(generate_request)
|
# generate the output sequence
|
||||||
|
generated_sequence = generator.generate(primer_sequence, generator_options)
|
||||||
|
|
||||||
output = tempfile.NamedTemporaryFile()
|
output = tempfile.NamedTemporaryFile()
|
||||||
midi_io.sequence_proto_to_midi_file(
|
magenta.music.midi_io.sequence_proto_to_midi_file(generated_sequence, output.name)
|
||||||
generate_response.generated_sequence, output.name)
|
|
||||||
output.seek(0)
|
output.seek(0)
|
||||||
return output
|
return output
|
||||||
|
@ -1,28 +1,5 @@
|
|||||||
Flask==0.11.1
|
tensorflow==0.12.1
|
||||||
Jinja2==2.8
|
magenta==0.1.8
|
||||||
MarkupSafe==0.23
|
Flask==0.12
|
||||||
Werkzeug==0.11.11
|
|
||||||
argparse==1.2.1
|
|
||||||
boto==2.34.0
|
|
||||||
chardet==2.3.0
|
|
||||||
click==6.6
|
|
||||||
colorama==0.3.2
|
|
||||||
crcmod==1.7
|
|
||||||
funcsigs==1.0.2
|
|
||||||
gunicorn==19.6.0
|
gunicorn==19.6.0
|
||||||
html5lib==0.999
|
ipython==5.1.0
|
||||||
itsdangerous==0.24
|
|
||||||
meld3==1.0.0
|
|
||||||
mido==1.1.17
|
|
||||||
mock==2.0.0
|
|
||||||
numpy==1.11.2
|
|
||||||
pbr==1.10.0
|
|
||||||
pretty-midi==0.2.6
|
|
||||||
protobuf==3.0.0
|
|
||||||
requests==2.4.3
|
|
||||||
six==1.10.0
|
|
||||||
supervisor==3.0
|
|
||||||
urllib3==1.9.1
|
|
||||||
virtualenv==1.11.6
|
|
||||||
wheel==0.29.0
|
|
||||||
wsgiref==0.1.2
|
|
Loading…
Reference in New Issue
Block a user