Fix bug for dynamic expansion in CURL.

PiperOrigin-RevId: 356227182
This commit is contained in:
Dushyant Rao
2021-02-08 11:21:21 +00:00
committed by Diego de Las Casas
parent 91f4e2d2e6
commit 0909ded4c7

View File

@@ -591,7 +591,7 @@ def run_training(
exp_wait_steps = 100 # Steps to wait after expansion before eligible again
exp_burn_in = 100 # Steps to wait at start of learning before eligible
exp_buffer_size = 100 # Size of the buffer of poorly explained data
num_buffer_train_steps = 20 # Num steps to train component on buffer
num_buffer_train_steps = 10 # Num steps to train component on buffer
# Define a global tf variable for the number of active components.
n_y_active_np = n_y_active
@@ -749,6 +749,21 @@ def run_training(
train_step = optimizer.minimize(train_ops.elbo)
train_step_supervised = optimizer.minimize(train_ops.elbo_supervised)
# For dynamic expansion, we want to train only new-component-related params
cat_params = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES,
'cluster_encoder/mlp_cluster_encoder_final')
component_params = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES,
'latent_encoder/mlp_latent_encoder_*')
prior_params = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES,
'latent_decoder/latent_prior*')
train_step_expansion = optimizer.minimize(
train_ops.elbo_supervised,
var_list=cat_params+component_params+prior_params)
# Set up ops for generative replay
if gen_every_n > 0:
# How many generative batches will we use each period?
@@ -1078,11 +1093,10 @@ def run_training(
for bs in range(n_poor_batches):
x_batch = poor_data_buffer[bs * batch_size:(bs + 1) *
batch_size]
label_batch = poor_data_labels[bs * batch_size:(bs + 1) *
batch_size]
label_batch = [new_cluster] * batch_size
label_onehot_batch = np.eye(n_y)[label_batch]
_ = sess.run(
train_step_supervised,
train_step_expansion,
feed_dict={
x_train_raw: x_batch,
model_train.y_label: label_onehot_batch