Monkey-Patching Functions

Functions for making torch.nn.Module subclass instances stateless.

higher.patch.buffer_sync(module, fmodule, device=None)

One off sync (copy) of buffers in fmodule with those from module.

Return type

None

higher.patch.make_functional(module, encapsulator=None)

Returns a stateless version of an nn.Module instance.

Return type

_MonkeyPatchBase

higher.patch.monkeypatch(module, device=None, copy_initial_weights=True, track_higher_grads=True)

Create a monkey-patched stateless version of a module.

This function produces a monkey-patched version of a module, and returns a copy of its parameters for use as fast weights. Where the original module or any of its submodules have state (e.g. batch norm), this will be copied too, but further updates (e.g. during inner loop training) will cause these to diverge without changing the state of the original module.

Parameters
  • module – a torch.nn.Module subclass instance.

  • device (optional) – a device to cast the fast weights and state to.

  • copy_initial_weights – if True, the weights of the patched module are copied to form the initial weights of the patched module, and thus are not part of the gradient tape when unrolling the patched module. If this is set to False, the actual module weights will be the initial weights of the patched module. This is useful when doing MAML, for example.

  • track_higher_grads – if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. Setting this to False allows monkeypatch to be used in “test mode”, without potentially tracking higher order gradients. This can be useful when running the training loop at test time, e.g. in k-shot learning experiments, without incurring a significant memory overhead.

Returns

a “stateless” version of the original module, for which calls to forward take the additional kwarg-only parameter params, which should be a list of torch tensors requiring gradients, ideally provided by this function (see below) or by an update step from one of the optimizers in higher.optim.

Return type

fmodule

Return type

_MonkeyPatchBase