Monkey-Patching Functions

Functions for making torch.nn.Module subclass instances stateless.

higher.patch.buffer_sync(module, fmodule, device=None)

One off sync (copy) of buffers in fmodule with those from module.

Return type

None

higher.patch.make_functional(module, encapsulator=None)

Returns a stateless version of an nn.Module instance.

Return type

_MonkeyPatchBase

higher.patch.monkeypatch(module, device=None, copy_initial_weights=True)

Create a monkey-patched stateless version of a module.

This function produces a monkey-patched version of a module, and returns a copy of its parameters for use as fast weights. Where the original module or any of its submodules have state (e.g. batch norm), this will be copied too, but further updates (e.g. during inner loop training) will cause these to diverge without changing the state of the original module.

Parameters
  • module – a torch.nn.Module subclass instance.

  • device (optional) – a device to cast the fast weights and state to.

  • copy_initial_weights – if True, the weights of the patched module are copied to form the initial weights of the patched module, and thus are not part of the gradient tape when unrolling the patched module. If this is set to False, the actual module weights will be the initial weights of the patched module. This is useful when doing MAML, for example.

Returns

a “stateless” version of the original module, for which calls to forward take the additional kwarg-only parameter params, which should be a list of torch tensors requiring gradients, ideally provided by this function (see below) or by an update step from one of the optimizers in higher.optim.

Return type

fmodule

Return type

_MonkeyPatchBase