Top-Level Functions¶
-
higher.innerloop_ctx(model, opt, device=None, copy_initial_weights=True, override=None, track_higher_grads=True)¶ A context manager for writing differentiable inner loops.
- Parameters
model – a
torch.nn.Modulesubclass instance.opt – an existing optimizer, assumed to be an instance of
torch.optim.Optimizer, of a supported type which is either defined intorch.optim, or a custom implemantation which has been added to higher at runtime by usinghigher.register_optim. We assume this optimizer tracks the parameters (or some subset thereof) of a singletorch.nn.Moduleinstance, with support for parameter groups.device (optional) – a device to cast the fast weights and state to. If not specified, the device used for corresponding weights of
modelwill be used.copy_initial_weights – if true, the weights of the patched module are copied to form the initial weights of the patched module, and thus are not part of the gradient tape when unrolling the patched module. If this is set to False, the actual module weights will be the initial weights of the patched module. This is useful when doing MAML, for example.
override (optional) – a dictionary mapping optimizer settings (i.e. those which would be passed to the optimizer constructor or provided within parameter groups) to either singleton lists of override values, or to a list of override values of length equal to the number of parameter groups. If a single override is provided for a keyword, it is used for all parameter groups. If a list is provided, the
ith element of the list overrides the corresponding setting in theith parameter group. This permits the passing of tensors requiring gradient to differentiable optimizers for use as optimizer settings.track_higher_grads – if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. Setting this to False allows
innerloop_ctxto be used in “test mode”, without potentially tracking higher order gradients. This can be useful when running the training loop at test time, e.g. in k-shot learning experiments, without incurring a significant memory overhead.
- Yields
A
(fmodule, diffopt)tuple. wherefmoduleis a “stateless” version of the original module, for which calls to forward take the additional kwarg-only parameterparams, which should be a list of torch tensors requiring gradients, ideally provided by this function (see below) or by an update step from one of the optimizers inhigher.optim. Anddiffoptis an initializedDifferentiableOptimizerinstance of the right subtype.