innerloop_ctx(model, opt, device=None, copy_initial_weights=True, override=None, track_higher_grads=True)¶
A context manager for writing differentiable inner loops.
model – a
opt – an existing optimizer, assumed to be an instance of
torch.optim.Optimizer, of a supported type which is either defined in
torch.optim, or a custom implemantation which has been added to higher at runtime by using
higher.register_optim. We assume this optimizer tracks the parameters (or some subset thereof) of a single
torch.nn.Moduleinstance, with support for parameter groups.
device (optional) – a device to cast the fast weights and state to. If not specified, the device used for corresponding weights of
modelwill be used.
copy_initial_weights – if true, the weights of the patched module are copied to form the initial weights of the patched module, and thus are not part of the gradient tape when unrolling the patched module. If this is set to False, the actual module weights will be the initial weights of the patched module. This is useful when doing MAML, for example.
override (optional) – a dictionary mapping optimizer settings (i.e. those which would be passed to the optimizer constructor or provided within parameter groups) to either singleton lists of override values, or to a list of override values of length equal to the number of parameter groups. If a single override is provided for a keyword, it is used for all parameter groups. If a list is provided, the
ith element of the list overrides the corresponding setting in the
ith parameter group. This permits the passing of tensors requiring gradient to differentiable optimizers for use as optimizer settings.
track_higher_grads – if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. Setting this to False allows
innerloop_ctxto be used in “test mode”, without potentially tracking higher order gradients. This can be useful when running the training loop at test time, e.g. in k-shot learning experiments, without incurring a significant memory overhead.
(fmodule, diffopt)tuple. where
fmoduleis a “stateless” version of the original module, for which calls to forward take the additional kwarg-only parameter
params, which should be a list of torch tensors requiring gradients, ideally provided by this function (see below) or by an update step from one of the optimizers in
diffoptis an initialized
DifferentiableOptimizerinstance of the right subtype.