In [1]:
# Ignore a bunch of deprecation warnings
import sys
sys.path.append('../..')
import warnings
warnings.filterwarnings("ignore")

import copy
import os
import time
from tqdm import tqdm
import math

import ddsp
import ddsp.training

from data_handling.ddspdataset import DDSPDataset
from utils.training_utils import print_hparams, set_seed, save_results, str2bool
from hparams_midiae_interp_cond import hparams as hp
from midiae_interp_cond.get_model import get_model, get_fake_data

import librosa
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import pandas as pd
import qgrid

from notebook_utils import *

set_seed(1234)

# Helper Functions
sample_rate = 16000


print('Done!')
Done!
In [2]:
model_path = r'/data/ddsp-experiment/logs/5.13_samples/150000'
hp_dict = get_hp(os.path.join(os.path.dirname(model_path), 'train.log'))
for k, v in hp_dict.items():
    setattr(hp, k, v)
hp.sequence_length=1000
In [3]:
# from data_handling.urmp_tfrecord_dataloader import UrmpMidi
# from data_handling.get_tfrecord_length import get_tfrecord_length
# data_dir = r'/data/music_dataset/urmp_dataset/tfrecord_ddsp/batched/solo_instrument'
# test_data_loader = UrmpMidi(data_dir, instrument_key='vn', split='test')
# evaluation_data = test_data_loader.get_batch(batch_size=1, shuffle=True, repeats=1)

from data_handling.google_solo_inst_dataloader import GoogleSoloInstrument
test_data_loader = GoogleSoloInstrument(base_dir=r'/data/music_dataset/solo_performance_google/solo-inst_midi_features', instrument_key='sax', split='test')
evaluation_data = test_data_loader.get_batch(batch_size=1, shuffle=True, repeats=1)
In [4]:
evaluation_data = iter(evaluation_data)
In [5]:
model = get_model(hp)
_ = model._build(get_fake_data(hp))
model.load_weights(model_path)
Out[5]:
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f416445bf90>
In [6]:
sample = next(evaluation_data)
In [7]:
from midiae_interp_cond.interpretable_conditioning import midi_to_hz, get_interpretable_conditioning, extract_harm_controls
In [8]:
plot_spec(sample['audio'][0].numpy(), sr=16000)
In [9]:
synth_params, control_params, synth_audio = model.run_synth_coder(sample, training=False)
synth_params_normalized, midi_features, conditioning_dict = model.gen_cond_dict_from_feature(sample, training=False)
In [10]:
midi_audio, params = model.gen_audio_from_cond_dict(conditioning_dict, midi_features, instrument_id=sample['instrument_id'])

Synth-coder Prediction (ld, f0, mel -> synth params)

In [11]:
f0, amps, hd, noise = synth_params_normalized
f0_midi = ddsp.core.hz_to_midi(f0)
synth_params_normalized = (f0_midi, amps, hd, noise)
plot_pred_acoustic_feature(sample['audio'].numpy()[0], synth_audio.numpy()[0], get_synth_params(synth_params_normalized), mask_zero_f0=True)