mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-06 09:02:05 +08:00
Fix bug for dynamic expansion in CURL.
PiperOrigin-RevId: 356227182
This commit is contained in:
committed by
Diego de Las Casas
parent
91f4e2d2e6
commit
0909ded4c7
@@ -591,7 +591,7 @@ def run_training(
|
||||
exp_wait_steps = 100 # Steps to wait after expansion before eligible again
|
||||
exp_burn_in = 100 # Steps to wait at start of learning before eligible
|
||||
exp_buffer_size = 100 # Size of the buffer of poorly explained data
|
||||
num_buffer_train_steps = 20 # Num steps to train component on buffer
|
||||
num_buffer_train_steps = 10 # Num steps to train component on buffer
|
||||
|
||||
# Define a global tf variable for the number of active components.
|
||||
n_y_active_np = n_y_active
|
||||
@@ -749,6 +749,21 @@ def run_training(
|
||||
train_step = optimizer.minimize(train_ops.elbo)
|
||||
train_step_supervised = optimizer.minimize(train_ops.elbo_supervised)
|
||||
|
||||
# For dynamic expansion, we want to train only new-component-related params
|
||||
cat_params = tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
'cluster_encoder/mlp_cluster_encoder_final')
|
||||
component_params = tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
'latent_encoder/mlp_latent_encoder_*')
|
||||
prior_params = tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
'latent_decoder/latent_prior*')
|
||||
|
||||
train_step_expansion = optimizer.minimize(
|
||||
train_ops.elbo_supervised,
|
||||
var_list=cat_params+component_params+prior_params)
|
||||
|
||||
# Set up ops for generative replay
|
||||
if gen_every_n > 0:
|
||||
# How many generative batches will we use each period?
|
||||
@@ -1078,11 +1093,10 @@ def run_training(
|
||||
for bs in range(n_poor_batches):
|
||||
x_batch = poor_data_buffer[bs * batch_size:(bs + 1) *
|
||||
batch_size]
|
||||
label_batch = poor_data_labels[bs * batch_size:(bs + 1) *
|
||||
batch_size]
|
||||
label_batch = [new_cluster] * batch_size
|
||||
label_onehot_batch = np.eye(n_y)[label_batch]
|
||||
_ = sess.run(
|
||||
train_step_supervised,
|
||||
train_step_expansion,
|
||||
feed_dict={
|
||||
x_train_raw: x_batch,
|
||||
model_train.y_label: label_onehot_batch
|
||||
|
||||
Reference in New Issue
Block a user