parent
6a573d1b00
commit
f33312634a
@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.append('./third_party')
|
||||||
import third_party.magenta.models.basic_rnn.basic_rnn_generator as basic_rnn_generator
|
import third_party.magenta.models.basic_rnn.basic_rnn_generator as basic_rnn_generator
|
||||||
from third_party.magenta.music import sequence_generator_bundle
|
from third_party.magenta.music import sequence_generator_bundle
|
||||||
from third_party.magenta.protobuf import generator_pb2
|
from third_party.magenta.protobuf import generator_pb2
|
||||||
@ -49,9 +51,9 @@ def generate_midi(midi_data, total_seconds=10):
|
|||||||
1, qpm)
|
1, qpm)
|
||||||
generate_section.end_time_seconds = total_seconds
|
generate_section.end_time_seconds = total_seconds
|
||||||
# generate_response = generator_map[generator_name].generate(generate_request)
|
# generate_response = generator_map[generator_name].generate(generate_request)
|
||||||
generate_response = basic_generator.generate(generate_request.input_sequence, generate_request.generator_options)
|
generate_response = basic_generator.generate(generate_request)
|
||||||
output = tempfile.NamedTemporaryFile()
|
output = tempfile.NamedTemporaryFile()
|
||||||
midi_io.sequence_proto_to_midi_file(
|
midi_io.sequence_proto_to_midi_file(
|
||||||
generate_response, output.name)
|
generate_response.generated_sequence, output.name)
|
||||||
output.seek(0)
|
output.seek(0)
|
||||||
return output
|
return output
|
||||||
|
Loading…
Reference in New Issue
Block a user