Upload model.py
Browse files
model.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import numpy as np
|
| 3 |
+
import miditoolkit
|
| 4 |
+
import modules
|
| 5 |
+
import pickle
|
| 6 |
+
import utils
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
class PopMusicTransformer(object):
|
| 10 |
+
########################################
|
| 11 |
+
# initialize
|
| 12 |
+
########################################
|
| 13 |
+
def __init__(self, checkpoint, is_training=False):
|
| 14 |
+
# load dictionary
|
| 15 |
+
self.dictionary_path = '{}/dictionary.pkl'.format(checkpoint)
|
| 16 |
+
self.event2word, self.word2event = pickle.load(open(self.dictionary_path, 'rb'))
|
| 17 |
+
# model settings
|
| 18 |
+
self.x_len = 512
|
| 19 |
+
self.mem_len = 512
|
| 20 |
+
self.n_layer = 12
|
| 21 |
+
self.d_embed = 512
|
| 22 |
+
self.d_model = 512
|
| 23 |
+
self.dropout = 0.1
|
| 24 |
+
self.n_head = 8
|
| 25 |
+
self.d_head = self.d_model // self.n_head
|
| 26 |
+
self.d_ff = 2048
|
| 27 |
+
self.n_token = len(self.event2word)
|
| 28 |
+
self.learning_rate = 0.0002
|
| 29 |
+
# load model
|
| 30 |
+
self.is_training = is_training
|
| 31 |
+
if self.is_training:
|
| 32 |
+
self.batch_size = 4
|
| 33 |
+
else:
|
| 34 |
+
self.batch_size = 1
|
| 35 |
+
self.checkpoint_path = '{}/model'.format(checkpoint)
|
| 36 |
+
self.load_model()
|
| 37 |
+
|
| 38 |
+
########################################
|
| 39 |
+
# load model
|
| 40 |
+
########################################
|
| 41 |
+
def load_model(self):
|
| 42 |
+
# placeholders
|
| 43 |
+
self.x = tf.compat.v1.placeholder(tf.int32, shape=[self.batch_size, None])
|
| 44 |
+
self.y = tf.compat.v1.placeholder(tf.int32, shape=[self.batch_size, None])
|
| 45 |
+
self.mems_i = [tf.compat.v1.placeholder(tf.float32, [self.mem_len, self.batch_size, self.d_model]) for _ in range(self.n_layer)]
|
| 46 |
+
# model
|
| 47 |
+
self.global_step = tf.compat.v1.train.get_or_create_global_step()
|
| 48 |
+
initializer = tf.compat.v1.initializers.random_normal(stddev=0.02, seed=None)
|
| 49 |
+
proj_initializer = tf.compat.v1.initializers.random_normal(stddev=0.01, seed=None)
|
| 50 |
+
with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope()):
|
| 51 |
+
xx = tf.transpose(self.x, [1, 0])
|
| 52 |
+
yy = tf.transpose(self.y, [1, 0])
|
| 53 |
+
loss, self.logits, self.new_mem = modules.transformer(
|
| 54 |
+
dec_inp=xx,
|
| 55 |
+
target=yy,
|
| 56 |
+
mems=self.mems_i,
|
| 57 |
+
n_token=self.n_token,
|
| 58 |
+
n_layer=self.n_layer,
|
| 59 |
+
d_model=self.d_model,
|
| 60 |
+
d_embed=self.d_embed,
|
| 61 |
+
n_head=self.n_head,
|
| 62 |
+
d_head=self.d_head,
|
| 63 |
+
d_inner=self.d_ff,
|
| 64 |
+
dropout=self.dropout,
|
| 65 |
+
dropatt=self.dropout,
|
| 66 |
+
initializer=initializer,
|
| 67 |
+
proj_initializer=proj_initializer,
|
| 68 |
+
is_training=self.is_training,
|
| 69 |
+
mem_len=self.mem_len,
|
| 70 |
+
cutoffs=[],
|
| 71 |
+
div_val=-1,
|
| 72 |
+
tie_projs=[],
|
| 73 |
+
same_length=False,
|
| 74 |
+
clamp_len=-1,
|
| 75 |
+
input_perms=None,
|
| 76 |
+
target_perms=None,
|
| 77 |
+
head_target=None,
|
| 78 |
+
untie_r=False,
|
| 79 |
+
proj_same_dim=True)
|
| 80 |
+
self.avg_loss = tf.reduce_mean(loss)
|
| 81 |
+
# vars
|
| 82 |
+
all_vars = tf.compat.v1.trainable_variables()
|
| 83 |
+
grads = tf.gradients(self.avg_loss, all_vars)
|
| 84 |
+
grads_and_vars = list(zip(grads, all_vars))
|
| 85 |
+
all_trainable_vars = tf.reduce_sum([tf.reduce_prod(v.shape) for v in tf.compat.v1.trainable_variables()])
|
| 86 |
+
# optimizer
|
| 87 |
+
decay_lr = tf.compat.v1.train.cosine_decay(
|
| 88 |
+
self.learning_rate,
|
| 89 |
+
global_step=self.global_step,
|
| 90 |
+
decay_steps=400000,
|
| 91 |
+
alpha=0.004)
|
| 92 |
+
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=decay_lr)
|
| 93 |
+
self.train_op = optimizer.apply_gradients(grads_and_vars, self.global_step)
|
| 94 |
+
# saver
|
| 95 |
+
self.saver = tf.compat.v1.train.Saver()
|
| 96 |
+
config = tf.compat.v1.ConfigProto(allow_soft_placement=True)
|
| 97 |
+
config.gpu_options.allow_growth = True
|
| 98 |
+
self.sess = tf.compat.v1.Session(config=config)
|
| 99 |
+
self.saver.restore(self.sess, self.checkpoint_path)
|
| 100 |
+
|
| 101 |
+
########################################
|
| 102 |
+
# temperature sampling
|
| 103 |
+
########################################
|
| 104 |
+
def temperature_sampling(self, logits, temperature, topk):
|
| 105 |
+
probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
|
| 106 |
+
if topk == 1:
|
| 107 |
+
prediction = np.argmax(probs)
|
| 108 |
+
else:
|
| 109 |
+
sorted_index = np.argsort(probs)[::-1]
|
| 110 |
+
candi_index = sorted_index[:topk]
|
| 111 |
+
candi_probs = [probs[i] for i in candi_index]
|
| 112 |
+
# normalize probs
|
| 113 |
+
candi_probs /= sum(candi_probs)
|
| 114 |
+
# choose by predicted probs
|
| 115 |
+
prediction = np.random.choice(candi_index, size=1, p=candi_probs)[0]
|
| 116 |
+
return prediction
|
| 117 |
+
|
| 118 |
+
########################################
|
| 119 |
+
# extract events for prompt continuation
|
| 120 |
+
########################################
|
| 121 |
+
def extract_events(self, input_path):
|
| 122 |
+
note_items, tempo_items = utils.read_items(input_path)
|
| 123 |
+
note_items = utils.quantize_items(note_items)
|
| 124 |
+
max_time = note_items[-1].end
|
| 125 |
+
if 'chord' in self.checkpoint_path:
|
| 126 |
+
chord_items = utils.extract_chords(note_items)
|
| 127 |
+
items = chord_items + tempo_items + note_items
|
| 128 |
+
else:
|
| 129 |
+
items = tempo_items + note_items
|
| 130 |
+
groups = utils.group_items(items, max_time)
|
| 131 |
+
events = utils.item2event(groups)
|
| 132 |
+
return events
|
| 133 |
+
|
| 134 |
+
########################################
|
| 135 |
+
# generate
|
| 136 |
+
########################################
|
| 137 |
+
def generate(self, n_target_bar, temperature, topk, output_path, prompt=None):
|
| 138 |
+
# if prompt, load it. Or, random start
|
| 139 |
+
if prompt:
|
| 140 |
+
events = self.extract_events(prompt)
|
| 141 |
+
words = [[self.event2word['{}_{}'.format(e.name, e.value)] for e in events]]
|
| 142 |
+
words[0].append(self.event2word['Bar_None'])
|
| 143 |
+
else:
|
| 144 |
+
words = []
|
| 145 |
+
for _ in range(self.batch_size):
|
| 146 |
+
ws = [self.event2word['Bar_None']]
|
| 147 |
+
if 'chord' in self.checkpoint_path:
|
| 148 |
+
tempo_classes = [v for k, v in self.event2word.items() if 'Tempo Class' in k]
|
| 149 |
+
tempo_values = [v for k, v in self.event2word.items() if 'Tempo Value' in k]
|
| 150 |
+
chords = [v for k, v in self.event2word.items() if 'Chord' in k]
|
| 151 |
+
ws.append(self.event2word['Position_1/16'])
|
| 152 |
+
ws.append(np.random.choice(chords))
|
| 153 |
+
ws.append(self.event2word['Position_1/16'])
|
| 154 |
+
ws.append(np.random.choice(tempo_classes))
|
| 155 |
+
ws.append(np.random.choice(tempo_values))
|
| 156 |
+
else:
|
| 157 |
+
tempo_classes = [v for k, v in self.event2word.items() if 'Tempo Class' in k]
|
| 158 |
+
tempo_values = [v for k, v in self.event2word.items() if 'Tempo Value' in k]
|
| 159 |
+
ws.append(self.event2word['Position_1/16'])
|
| 160 |
+
ws.append(np.random.choice(tempo_classes))
|
| 161 |
+
ws.append(np.random.choice(tempo_values))
|
| 162 |
+
words.append(ws)
|
| 163 |
+
# initialize mem
|
| 164 |
+
batch_m = [np.zeros((self.mem_len, self.batch_size, self.d_model), dtype=np.float32) for _ in range(self.n_layer)]
|
| 165 |
+
# generate
|
| 166 |
+
original_length = len(words[0])
|
| 167 |
+
initial_flag = 1
|
| 168 |
+
current_generated_bar = 0
|
| 169 |
+
while current_generated_bar < n_target_bar:
|
| 170 |
+
# input
|
| 171 |
+
if initial_flag:
|
| 172 |
+
temp_x = np.zeros((self.batch_size, original_length))
|
| 173 |
+
for b in range(self.batch_size):
|
| 174 |
+
for z, t in enumerate(words[b]):
|
| 175 |
+
temp_x[b][z] = t
|
| 176 |
+
initial_flag = 0
|
| 177 |
+
else:
|
| 178 |
+
temp_x = np.zeros((self.batch_size, 1))
|
| 179 |
+
for b in range(self.batch_size):
|
| 180 |
+
temp_x[b][0] = words[b][-1]
|
| 181 |
+
# prepare feed dict
|
| 182 |
+
feed_dict = {self.x: temp_x}
|
| 183 |
+
for m, m_np in zip(self.mems_i, batch_m):
|
| 184 |
+
feed_dict[m] = m_np
|
| 185 |
+
# model (prediction)
|
| 186 |
+
_logits, _new_mem = self.sess.run([self.logits, self.new_mem], feed_dict=feed_dict)
|
| 187 |
+
# sampling
|
| 188 |
+
_logit = _logits[-1, 0]
|
| 189 |
+
word = self.temperature_sampling(
|
| 190 |
+
logits=_logit,
|
| 191 |
+
temperature=temperature,
|
| 192 |
+
topk=topk)
|
| 193 |
+
words[0].append(word)
|
| 194 |
+
# if bar event (only work for batch_size=1)
|
| 195 |
+
if word == self.event2word['Bar_None']:
|
| 196 |
+
current_generated_bar += 1
|
| 197 |
+
# re-new mem
|
| 198 |
+
batch_m = _new_mem
|
| 199 |
+
# write
|
| 200 |
+
if prompt:
|
| 201 |
+
utils.write_midi(
|
| 202 |
+
words=words[0][original_length:],
|
| 203 |
+
word2event=self.word2event,
|
| 204 |
+
output_path=output_path,
|
| 205 |
+
prompt_path=prompt)
|
| 206 |
+
else:
|
| 207 |
+
utils.write_midi(
|
| 208 |
+
words=words[0],
|
| 209 |
+
word2event=self.word2event,
|
| 210 |
+
output_path=output_path,
|
| 211 |
+
prompt_path=None)
|
| 212 |
+
|
| 213 |
+
########################################
|
| 214 |
+
# prepare training data
|
| 215 |
+
########################################
|
| 216 |
+
def prepare_data(self, midi_paths):
|
| 217 |
+
# extract events
|
| 218 |
+
all_events = []
|
| 219 |
+
for path in midi_paths:
|
| 220 |
+
events = self.extract_events(path)
|
| 221 |
+
all_events.append(events)
|
| 222 |
+
# event to word
|
| 223 |
+
all_words = []
|
| 224 |
+
for events in all_events:
|
| 225 |
+
words = []
|
| 226 |
+
for event in events:
|
| 227 |
+
e = '{}_{}'.format(event.name, event.value)
|
| 228 |
+
if e in self.event2word:
|
| 229 |
+
words.append(self.event2word[e])
|
| 230 |
+
else:
|
| 231 |
+
# OOV
|
| 232 |
+
if event.name == 'Note Velocity':
|
| 233 |
+
# replace with max velocity based on our training data
|
| 234 |
+
words.append(self.event2word['Note Velocity_21'])
|
| 235 |
+
else:
|
| 236 |
+
# something is wrong
|
| 237 |
+
# you should handle it for your own purpose
|
| 238 |
+
print('something is wrong! {}'.format(e))
|
| 239 |
+
all_words.append(words)
|
| 240 |
+
# to training data
|
| 241 |
+
self.group_size = 5
|
| 242 |
+
segments = []
|
| 243 |
+
for words in all_words:
|
| 244 |
+
pairs = []
|
| 245 |
+
for i in range(0, len(words)-self.x_len-1, self.x_len):
|
| 246 |
+
x = words[i:i+self.x_len]
|
| 247 |
+
y = words[i+1:i+self.x_len+1]
|
| 248 |
+
pairs.append([x, y])
|
| 249 |
+
pairs = np.array(pairs)
|
| 250 |
+
# abandon the last
|
| 251 |
+
for i in np.arange(0, len(pairs)-self.group_size, self.group_size*2):
|
| 252 |
+
data = pairs[i:i+self.group_size]
|
| 253 |
+
if len(data) == self.group_size:
|
| 254 |
+
segments.append(data)
|
| 255 |
+
segments = np.array(segments)
|
| 256 |
+
return segments
|
| 257 |
+
|
| 258 |
+
########################################
|
| 259 |
+
# finetune
|
| 260 |
+
########################################
|
| 261 |
+
def finetune(self, training_data, output_checkpoint_folder):
|
| 262 |
+
# shuffle
|
| 263 |
+
index = np.arange(len(training_data))
|
| 264 |
+
np.random.shuffle(index)
|
| 265 |
+
training_data = training_data[index]
|
| 266 |
+
num_batches = len(training_data) // self.batch_size
|
| 267 |
+
st = time.time()
|
| 268 |
+
for e in range(200):
|
| 269 |
+
total_loss = []
|
| 270 |
+
for i in range(num_batches):
|
| 271 |
+
segments = training_data[self.batch_size*i:self.batch_size*(i+1)]
|
| 272 |
+
batch_m = [np.zeros((self.mem_len, self.batch_size, self.d_model), dtype=np.float32) for _ in range(self.n_layer)]
|
| 273 |
+
for j in range(self.group_size):
|
| 274 |
+
batch_x = segments[:, j, 0, :]
|
| 275 |
+
batch_y = segments[:, j, 1, :]
|
| 276 |
+
# prepare feed dict
|
| 277 |
+
feed_dict = {self.x: batch_x, self.y: batch_y}
|
| 278 |
+
for m, m_np in zip(self.mems_i, batch_m):
|
| 279 |
+
feed_dict[m] = m_np
|
| 280 |
+
# run
|
| 281 |
+
_, gs_, loss_, new_mem_ = self.sess.run([self.train_op, self.global_step, self.avg_loss, self.new_mem], feed_dict=feed_dict)
|
| 282 |
+
batch_m = new_mem_
|
| 283 |
+
total_loss.append(loss_)
|
| 284 |
+
print('>>> Epoch: {}, Step: {}, Loss: {:.5f}, Time: {:.2f}'.format(e, gs_, loss_, time.time()-st))
|
| 285 |
+
self.saver.save(self.sess, '{}/model-{:03d}-{:.3f}'.format(output_checkpoint_folder, e, np.mean(total_loss)))
|
| 286 |
+
# stop
|
| 287 |
+
if np.mean(total_loss) <= 0.1:
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
########################################
|
| 291 |
+
# close
|
| 292 |
+
########################################
|
| 293 |
+
def close(self):
|
| 294 |
+
self.sess.close()
|