curobo.geom.transform module

Implements differentiable point and pose transformations leveraging Warp kernels. Most of these implementations are available through Pose.

transform_points(
position: Tensor,
quaternion: Tensor,
points: Tensor,
out_points: Tensor | None = None,
out_gp: Tensor | None = None,
out_gq: Tensor | None = None,
out_gpt: Tensor | None = None,
) Tensor

Transforms the given points using the provided position and quaternion.

Parameters:
  • position – The position tensor representing the translation of the transformation.

  • quaternion – The quaternion tensor representing the rotation of the transformation. Quaternion format is [w, x, y, z].

  • points – The points to be transformed.

  • out_points – If provided, the transformed points will be stored in this tensor. If not provided, a new tensor will be created.

  • out_gp – If provided, the gradient of the transformed points with respect to the position will be stored in this tensor. If not provided, a new tensor will be created.

  • out_gq – If provided, the gradient of the transformed points with respect to the quaternion will be stored in this tensor. If not provided, a new tensor will be created.

  • out_gpt – If provided, the gradient of the transformed points with respect to the original points will be stored in this tensor. If not provided, a new tensor will be created.

Returns:

The transformed points.

Return type:

torch.Tensor

batch_transform_points(
position: Tensor,
quaternion: Tensor,
points: Tensor,
out_points: Tensor | None = None,
out_gp: Tensor | None = None,
out_gq: Tensor | None = None,
out_gpt: Tensor | None = None,
) Tensor

Transforms the given points using the provided batch of position and quaternion.

Parameters:
  • position – The position tensor representing the translation of the transformation. Shape should be (batch_size, 3).

  • quaternion – The quaternion tensor representing the rotation of the transformation. Quaternion format is [w, x, y, z]. Shape should be (batch_size, 4).

  • points – The points to be transformed. Shape should be (batch_size, num_points, 3).

  • out_points – If provided, the transformed points will be stored in this tensor. If not provided, a new tensor will be created.

  • out_gp – If provided, the gradient of the transformed points with respect to the position will be stored in this tensor. If not provided, a new tensor will be created.

  • out_gq – If provided, the gradient of the transformed points with respect to the quaternion will be stored in this tensor. If not provided, a new tensor will be created.

  • out_gpt – If provided, the gradient of the transformed points with respect to the original points will be stored in this tensor. If not provided, a new tensor will be created.

Returns:

The transformed points with shape (batch_size, num_points, 3).

Return type:

torch.Tensor

get_inv_transform(
w_rot_c: Tensor,
w_trans_c: Tensor,
) Tuple[Tensor, Tensor]

Get the inverse of the given transformation.

Parameters:
  • w_rot_c – Rotation matrix in world frame.

  • w_trans_c – Translation vector in world frame.

Returns:

The inverse rotation matrix and translation vector.

Return type:

Tuple[torch.Tensor, torch.Tensor]

transform_point_inverse(
point: Tensor,
rot: Tensor,
trans: Tensor,
) Tensor

Transforms the given point using the inverse of the provided transformation.

Parameters:
  • point – Input point to be transformed.

  • rot – Rotation matrix.

  • trans – Translation vector.

Returns:

The transformed point.

Return type:

torch.Tensor

matrix_to_quaternion(
matrix: Tensor,
out_quat: Tensor | None = None,
adj_matrix: Tensor | None = None,
) Tensor

Converts the given rotation matrix to quaternion.

Parameters:
  • matrix – Rotation matrices as tensor of shape (…, 3, 3).

  • out_quat – Output tensor to store the quaternions. If not provided, a new tensor will be created.

  • adj_matrix – Gradient tensor, if not provided, a new tensor will be created.

Returns:

Quaternions with real part first, as tensor of shape (…, 4) [qw, qx,qy,qz].

Return type:

torch.Tensor

cuda_matrix_to_quaternion(
matrix: Tensor,
) Tensor

Convert rotations given as rotation matrices to quaternions.

This is not differentiable. Use matrix_to_quaternion for differentiable conversion. :param matrix: Rotation matrices as tensor of shape (…, 3, 3).

Returns:

quaternions with real part first, as tensor of shape (…, 4). [qw, qx,qy,qz]

quaternion_to_matrix(
quaternions: Tensor,
out_mat: Tensor | None = None,
adj_quaternion: Tensor | None = None,
) Tensor

Convert quaternion to rotation matrix.

Parameters:
  • quaternions – Input quaternions with real part first, as tensor of shape (…, 4).

  • out_mat – Output rotation matrices as tensor of shape (…, 3, 3). If not provided, a new tensor will be created.

  • adj_quaternion – Gradient tensor, if not provided, a new tensor will be created.

Returns:

Rotation matrices as tensor of shape (…, 3, 3).

Return type:

torch.Tensor

torch_quaternion_to_matrix(
quaternions: Tensor,
) Tensor

Convert rotations given as quaternions to rotation matrices.

Parameters:

quaternions – quaternions with real part first, as tensor of shape (…, 4).

Returns:

Rotation matrices as tensor of shape (…, 3, 3).

pose_to_matrix(
position: Tensor,
quaternion: Tensor,
out_matrix: Tensor | None = None,
) Tensor

Converts the given pose to a transformation matrix.

Parameters:
  • position – The position tensor representing the translation of the transformation.

  • quaternion – The quaternion tensor representing the rotation of the transformation. Quaternion format is [w, x, y, z].

  • out_matrix – If provided, the transformation matrix will be stored in this tensor. If not provided, a new tensor will be created.

Returns:

The transformation matrix.

Return type:

torch.Tensor

pose_multiply(
position: Tensor,
quaternion: Tensor,
position2: Tensor,
quaternion2: Tensor,
out_position: Tensor | None = None,
out_quaternion: Tensor | None = None,
adj_pos: Tensor | None = None,
adj_quat: Tensor | None = None,
adj_pos2: Tensor | None = None,
adj_quat2: Tensor | None = None,
) Tuple[Tensor, Tensor]

Multiplies two poses.

The input poses can either be of shape (3,) or (batch_size, 3).

Parameters:
  • position – The position tensor representing the translation of the first transformation.

  • quaternion – The quaternion tensor representing the rotation of the first transformation. The quaternion format is [w, x, y, z].

  • position2 – The position tensor representing the translation of the second transformation.

  • quaternion2 – The quaternion tensor representing the rotation of the second transformation.

  • out_position – If provided, the position tensor of the multiplied pose will be stored in this tensor. If not provided, a new tensor will be created.

  • out_quaternion – If provided, the quaternion tensor of the multiplied pose will be stored in this tensor. If not provided, a new tensor will be created.

  • adj_pos – Gradient tensor for the position of the first pose. If not provided, a new tensor will be created.

  • adj_quat – Gradient tensor for the quaternion of the first pose. If not provided, a new tensor will be created.

  • adj_pos2 – Gradient tensor for the position of the second pose. If not provided, a new tensor will be created.

  • adj_quat2 – Gradient tensor for the quaternion of the second pose. If not provided, a new tensor will be created.

Returns:

The position and quaternion tensors of the multiplied

pose.

Return type:

Tuple[torch.Tensor, torch.Tensor]

pose_inverse(
position: Tensor,
quaternion: Tensor,
out_position: Tensor | None = None,
out_quaternion: Tensor | None = None,
adj_pos: Tensor | None = None,
adj_quat: Tensor | None = None,
) Tuple[Tensor, Tensor]

Get the inverse of the given pose.

Parameters:
  • position – The position tensor representing the translation of the transformation.

  • quaternion – The quaternion tensor representing the rotation of the transformation.

  • out_position – If provided, the position tensor of the inverse pose will be stored in this tensor. If not provided, a new tensor will be created.

  • out_quaternion – If provided, the quaternion tensor of the inverse pose will be stored in this tensor. If not provided, a new tensor will be created.

  • adj_pos – Gradient tensor for the position of the pose. If not provided, a new tensor will be created.

  • adj_quat – Gradient tensor for the quaternion of the pose. If not provided, a new tensor will be created.

Returns:

The position and quaternion tensors of the inverse pose.

Return type:

Tuple[torch.Tensor, torch.Tensor]

class TransformPoint(*args, **kwargs)

Bases: Function

A differentiable function to transform batch of points by a pose.

static forward(
ctx,
position: Tensor,
quaternion: Tensor,
points: Tensor,
out_points: Tensor,
adj_position: Tensor,
adj_quaternion: Tensor,
adj_points: Tensor,
)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward if they are intended to be used for in jvp.

static backward(
ctx,
grad_output,
)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

_backward_cls

alias of TransformPointBackward

_compiled_autograd_backward_state
static _compiled_autograd_key(
ctx,
)
_get_compiled_autograd_symints()
_input_metadata
_is_compiled_autograd_tracing()
_materialize_non_diff_grads
_raw_saved_tensors
static _register_hook(
backward_hooks,
hook,
)
_register_hook_dict()
_sequence_nr()
_set_sequence_nr()
classmethod apply(*args, **kwargs)
dirty_tensors
generate_vmap_rule = False
static jvp(
ctx: Any,
*grad_inputs: Any,
) Any

Define a formula for differentiating the operation with forward mode automatic differentiation.

This function is to be overridden by all subclasses. It must accept a context ctx as the first argument, followed by as many inputs as the forward got (None will be passed in for non tensor inputs of the forward function), and it should return as many tensors as there were outputs to forward. Each argument is the gradient w.r.t the given input, and each returned value should be the gradient w.r.t. the corresponding output. If an output is not a Tensor or the function is not differentiable with respect to that output, you can just pass None as a gradient for that input.

You can use the ctx object to pass any value from the forward to this functions.

mark_dirty(
*args: Tensor,
)

Mark given tensors as modified in an in-place operation.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be inputs.

Every tensor that’s been modified in-place in a call to forward should be given to this function, to ensure correctness of our checks. It doesn’t matter whether the function is called before or after modification.

Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Inplace(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         x_npy = x.numpy() # x_npy shares storage with x
>>>         x_npy += 1
>>>         ctx.mark_dirty(x)
>>>         return x
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_output):
>>>         return grad_output
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
>>> b = a * a
>>> Inplace.apply(a)  # This would lead to wrong gradients!
>>>                   # but the engine would not know unless we mark_dirty
>>> # xdoctest: +SKIP
>>> b.backward() # RuntimeError: one of the variables needed for gradient
>>>              # computation has been modified by an inplace operation
mark_non_differentiable(
*args: Tensor,
)

Mark outputs as non-differentiable.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be tensor outputs.

This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient for each output in backward, but it’s always going to be a zero tensor with the same shape as the shape of a corresponding output.

This is used e.g. for indices returned from a sort. See example::
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # still need to accept g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input
mark_shared_storage(
*pairs,
)
materialize_grads
maybe_clear_saved_tensors()
metadata
name()
needs_input_grad
next_functions
non_differentiable
register_hook()
register_prehook()
requires_grad
save_for_backward(
*tensors: Tensor,
)

Save given tensors for a future call to backward.

save_for_backward should be called at most once, in either the setup_context or forward methods, and only with tensors.

All tensors intended to be used in the backward pass should be saved with save_for_backward (as opposed to directly on ctx) to prevent incorrect gradients and memory leaks, and enable the application of saved tensor hooks. See torch.autograd.graph.saved_tensors_hooks.

Note that if intermediary tensors, tensors that are neither inputs nor outputs of forward, are saved for backward, your custom Function may not support double backward. Custom Functions that do not support double backward should decorate their backward method with @once_differentiable so that performing double backward raises an error. If you’d like to support double backward, you can either recompute intermediaries based on the inputs during backward or return the intermediaries as the outputs of the custom Function. See the double backward tutorial for more details.

In backward, saved tensors can be accessed through the saved_tensors attribute. Before returning them to the user, a check is made to ensure they weren’t used in any in-place operation that modified their content.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         w = x * z
>>>         out = x * y + y * z + w * y
>>>         ctx.save_for_backward(x, y, w, out)
>>>         ctx.z = z  # z is not a tensor
>>>         return out
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_out):
>>>         x, y, w, out = ctx.saved_tensors
>>>         z = ctx.z
>>>         gx = grad_out * (y + y * z)
>>>         gy = grad_out * (x + z + w)
>>>         gz = None
>>>         return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)
save_for_forward(
*tensors: Tensor,
)

Save given tensors for a future call to jvp.

save_for_forward should be called at most once, in either the setup_context or forward methods, and all arguments should be tensors.

In jvp, saved objects can be accessed through the saved_tensors attribute.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +SKIP
>>> class Func(torch.autograd.Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         ctx.save_for_backward(x, y)
>>>         ctx.save_for_forward(x, y)
>>>         ctx.z = z
>>>         return x * y * z
>>>
>>>     @staticmethod
>>>     def jvp(ctx, x_t, y_t, _):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * (y * x_t + x * y_t)
>>>
>>>     @staticmethod
>>>     def vjp(ctx, grad_out):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * grad_out * y, z * grad_out * x, None
>>>
>>>     a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>>     t = torch.tensor(1., dtype=torch.double)
>>>     b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>>     c = 4
>>>
>>>     with fwAD.dual_level():
>>>         a_dual = fwAD.make_dual(a, t)
>>>         d = Func.apply(a_dual, b, c)
saved_for_forward
saved_tensors
saved_variables
set_materialize_grads(
value: bool,
)

Set whether to materialize grad tensors. Default is True.

This should be called only from either the setup_context or forward methods.

If True, undefined grad tensors will be expanded to tensors full of zeros prior to calling the backward and jvp methods.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # We must check for None now
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # induces g2 to be undefined
static setup_context(
ctx: Any,
inputs: Tuple[Any, ...],
output: Any,
) Any

There are two ways to define the forward pass of an autograd.Function.

Either:

  1. Override forward with the signature forward(ctx, *args, **kwargs). setup_context is not overridden. Setting up the ctx for backward happens inside the forward.

  2. Override forward with the signature forward(*args, **kwargs) and override setup_context. Setting up the ctx for backward happens inside setup_context (as opposed to inside the forward)

See torch.autograd.Function.forward and Extending torch.autograd for more details.

to_save
static vjp(
ctx: Any,
*grad_outputs: Any,
) Any

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

static vmap(
info,
in_dims,
*args,
)

Define the behavior for this autograd.Function underneath torch.vmap.

For a torch.autograd.Function to support torch.vmap, you must either override this static method, or set generate_vmap_rule to True (you may not do both).

If you choose to override this staticmethod: it must accept

  • an info object as the first argument. info.batch_size specifies the size of the dimension being vmapped over, while info.randomness is the randomness option passed to torch.vmap.

  • an in_dims tuple as the second argument. For each arg in args, in_dims has a corresponding Optional[int]. It is None if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over.

  • *args, which is the same as the args to forward.

The return of the vmap staticmethod is a tuple of (output, out_dims). Similar to in_dims, out_dims should be of the same structure as output and contain one out_dim per output that specifies if the output has the vmapped dimension and what index it is in.

Please see Extending torch.func with autograd.Function for more details.

class BatchTransformPoint(*args, **kwargs)

Bases: Function

A differentiable function to transform batch of points by a batch of poses.

static forward(
ctx,
position: Tensor,
quaternion: Tensor,
points: Tensor,
out_points: Tensor,
adj_position: Tensor,
adj_quaternion: Tensor,
adj_points: Tensor,
)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward if they are intended to be used for in jvp.

static backward(
ctx,
grad_output,
)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

_backward_cls

alias of BatchTransformPointBackward

_compiled_autograd_backward_state
static _compiled_autograd_key(
ctx,
)
_get_compiled_autograd_symints()
_input_metadata
_is_compiled_autograd_tracing()
_materialize_non_diff_grads
_raw_saved_tensors
static _register_hook(
backward_hooks,
hook,
)
_register_hook_dict()
_sequence_nr()
_set_sequence_nr()
classmethod apply(
*args,
**kwargs,
)
dirty_tensors
generate_vmap_rule = False
static jvp(
ctx: Any,
*grad_inputs: Any,
) Any

Define a formula for differentiating the operation with forward mode automatic differentiation.

This function is to be overridden by all subclasses. It must accept a context ctx as the first argument, followed by as many inputs as the forward got (None will be passed in for non tensor inputs of the forward function), and it should return as many tensors as there were outputs to forward. Each argument is the gradient w.r.t the given input, and each returned value should be the gradient w.r.t. the corresponding output. If an output is not a Tensor or the function is not differentiable with respect to that output, you can just pass None as a gradient for that input.

You can use the ctx object to pass any value from the forward to this functions.

mark_dirty(
*args: Tensor,
)

Mark given tensors as modified in an in-place operation.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be inputs.

Every tensor that’s been modified in-place in a call to forward should be given to this function, to ensure correctness of our checks. It doesn’t matter whether the function is called before or after modification.

Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Inplace(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         x_npy = x.numpy() # x_npy shares storage with x
>>>         x_npy += 1
>>>         ctx.mark_dirty(x)
>>>         return x
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_output):
>>>         return grad_output
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
>>> b = a * a
>>> Inplace.apply(a)  # This would lead to wrong gradients!
>>>                   # but the engine would not know unless we mark_dirty
>>> # xdoctest: +SKIP
>>> b.backward() # RuntimeError: one of the variables needed for gradient
>>>              # computation has been modified by an inplace operation
mark_non_differentiable(
*args: Tensor,
)

Mark outputs as non-differentiable.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be tensor outputs.

This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient for each output in backward, but it’s always going to be a zero tensor with the same shape as the shape of a corresponding output.

This is used e.g. for indices returned from a sort. See example::
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # still need to accept g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input
mark_shared_storage(
*pairs,
)
materialize_grads
maybe_clear_saved_tensors()
metadata
name()
needs_input_grad
next_functions
non_differentiable
register_hook()
register_prehook()
requires_grad
save_for_backward(
*tensors: Tensor,
)

Save given tensors for a future call to backward.

save_for_backward should be called at most once, in either the setup_context or forward methods, and only with tensors.

All tensors intended to be used in the backward pass should be saved with save_for_backward (as opposed to directly on ctx) to prevent incorrect gradients and memory leaks, and enable the application of saved tensor hooks. See torch.autograd.graph.saved_tensors_hooks.

Note that if intermediary tensors, tensors that are neither inputs nor outputs of forward, are saved for backward, your custom Function may not support double backward. Custom Functions that do not support double backward should decorate their backward method with @once_differentiable so that performing double backward raises an error. If you’d like to support double backward, you can either recompute intermediaries based on the inputs during backward or return the intermediaries as the outputs of the custom Function. See the double backward tutorial for more details.

In backward, saved tensors can be accessed through the saved_tensors attribute. Before returning them to the user, a check is made to ensure they weren’t used in any in-place operation that modified their content.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         w = x * z
>>>         out = x * y + y * z + w * y
>>>         ctx.save_for_backward(x, y, w, out)
>>>         ctx.z = z  # z is not a tensor
>>>         return out
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_out):
>>>         x, y, w, out = ctx.saved_tensors
>>>         z = ctx.z
>>>         gx = grad_out * (y + y * z)
>>>         gy = grad_out * (x + z + w)
>>>         gz = None
>>>         return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)
save_for_forward(
*tensors: Tensor,
)

Save given tensors for a future call to jvp.

save_for_forward should be called at most once, in either the setup_context or forward methods, and all arguments should be tensors.

In jvp, saved objects can be accessed through the saved_tensors attribute.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +SKIP
>>> class Func(torch.autograd.Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         ctx.save_for_backward(x, y)
>>>         ctx.save_for_forward(x, y)
>>>         ctx.z = z
>>>         return x * y * z
>>>
>>>     @staticmethod
>>>     def jvp(ctx, x_t, y_t, _):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * (y * x_t + x * y_t)
>>>
>>>     @staticmethod
>>>     def vjp(ctx, grad_out):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * grad_out * y, z * grad_out * x, None
>>>
>>>     a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>>     t = torch.tensor(1., dtype=torch.double)
>>>     b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>>     c = 4
>>>
>>>     with fwAD.dual_level():
>>>         a_dual = fwAD.make_dual(a, t)
>>>         d = Func.apply(a_dual, b, c)
saved_for_forward
saved_tensors
saved_variables
set_materialize_grads(
value: bool,
)

Set whether to materialize grad tensors. Default is True.

This should be called only from either the setup_context or forward methods.

If True, undefined grad tensors will be expanded to tensors full of zeros prior to calling the backward and jvp methods.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # We must check for None now
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # induces g2 to be undefined
static setup_context(
ctx: Any,
inputs: Tuple[Any, ...],
output: Any,
) Any

There are two ways to define the forward pass of an autograd.Function.

Either:

  1. Override forward with the signature forward(ctx, *args, **kwargs). setup_context is not overridden. Setting up the ctx for backward happens inside the forward.

  2. Override forward with the signature forward(*args, **kwargs) and override setup_context. Setting up the ctx for backward happens inside setup_context (as opposed to inside the forward)

See torch.autograd.Function.forward and Extending torch.autograd for more details.

to_save
static vjp(
ctx: Any,
*grad_outputs: Any,
) Any

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

static vmap(
info,
in_dims,
*args,
)

Define the behavior for this autograd.Function underneath torch.vmap.

For a torch.autograd.Function to support torch.vmap, you must either override this static method, or set generate_vmap_rule to True (you may not do both).

If you choose to override this staticmethod: it must accept

  • an info object as the first argument. info.batch_size specifies the size of the dimension being vmapped over, while info.randomness is the randomness option passed to torch.vmap.

  • an in_dims tuple as the second argument. For each arg in args, in_dims has a corresponding Optional[int]. It is None if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over.

  • *args, which is the same as the args to forward.

The return of the vmap staticmethod is a tuple of (output, out_dims). Similar to in_dims, out_dims should be of the same structure as output and contain one out_dim per output that specifies if the output has the vmapped dimension and what index it is in.

Please see Extending torch.func with autograd.Function for more details.

class BatchTransformPose(*args, **kwargs)

Bases: Function

A differentiable function to transform batch of poses by a pose.

static forward(
ctx,
position: Tensor,
quaternion: Tensor,
position2: Tensor,
quaternion2: Tensor,
out_position: Tensor,
out_quaternion: Tensor,
adj_position: Tensor,
adj_quaternion: Tensor,
adj_position2: Tensor,
adj_quaternion2: Tensor,
)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward if they are intended to be used for in jvp.

static backward(
ctx,
grad_out_position,
grad_out_quaternion,
)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

_backward_cls

alias of BatchTransformPoseBackward

_compiled_autograd_backward_state
static _compiled_autograd_key(
ctx,
)
_get_compiled_autograd_symints()
_input_metadata
_is_compiled_autograd_tracing()
_materialize_non_diff_grads
_raw_saved_tensors
static _register_hook(
backward_hooks,
hook,
)
_register_hook_dict()
_sequence_nr()
_set_sequence_nr()
classmethod apply(
*args,
**kwargs,
)
dirty_tensors
generate_vmap_rule = False
static jvp(
ctx: Any,
*grad_inputs: Any,
) Any

Define a formula for differentiating the operation with forward mode automatic differentiation.

This function is to be overridden by all subclasses. It must accept a context ctx as the first argument, followed by as many inputs as the forward got (None will be passed in for non tensor inputs of the forward function), and it should return as many tensors as there were outputs to forward. Each argument is the gradient w.r.t the given input, and each returned value should be the gradient w.r.t. the corresponding output. If an output is not a Tensor or the function is not differentiable with respect to that output, you can just pass None as a gradient for that input.

You can use the ctx object to pass any value from the forward to this functions.

mark_dirty(
*args: Tensor,
)

Mark given tensors as modified in an in-place operation.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be inputs.

Every tensor that’s been modified in-place in a call to forward should be given to this function, to ensure correctness of our checks. It doesn’t matter whether the function is called before or after modification.

Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Inplace(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         x_npy = x.numpy() # x_npy shares storage with x
>>>         x_npy += 1
>>>         ctx.mark_dirty(x)
>>>         return x
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_output):
>>>         return grad_output
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
>>> b = a * a
>>> Inplace.apply(a)  # This would lead to wrong gradients!
>>>                   # but the engine would not know unless we mark_dirty
>>> # xdoctest: +SKIP
>>> b.backward() # RuntimeError: one of the variables needed for gradient
>>>              # computation has been modified by an inplace operation
mark_non_differentiable(
*args: Tensor,
)

Mark outputs as non-differentiable.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be tensor outputs.

This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient for each output in backward, but it’s always going to be a zero tensor with the same shape as the shape of a corresponding output.

This is used e.g. for indices returned from a sort. See example::
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # still need to accept g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input
mark_shared_storage(
*pairs,
)
materialize_grads
maybe_clear_saved_tensors()
metadata
name()
needs_input_grad
next_functions
non_differentiable
register_hook()
register_prehook()
requires_grad
save_for_backward(
*tensors: Tensor,
)

Save given tensors for a future call to backward.

save_for_backward should be called at most once, in either the setup_context or forward methods, and only with tensors.

All tensors intended to be used in the backward pass should be saved with save_for_backward (as opposed to directly on ctx) to prevent incorrect gradients and memory leaks, and enable the application of saved tensor hooks. See torch.autograd.graph.saved_tensors_hooks.

Note that if intermediary tensors, tensors that are neither inputs nor outputs of forward, are saved for backward, your custom Function may not support double backward. Custom Functions that do not support double backward should decorate their backward method with @once_differentiable so that performing double backward raises an error. If you’d like to support double backward, you can either recompute intermediaries based on the inputs during backward or return the intermediaries as the outputs of the custom Function. See the double backward tutorial for more details.

In backward, saved tensors can be accessed through the saved_tensors attribute. Before returning them to the user, a check is made to ensure they weren’t used in any in-place operation that modified their content.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         w = x * z
>>>         out = x * y + y * z + w * y
>>>         ctx.save_for_backward(x, y, w, out)
>>>         ctx.z = z  # z is not a tensor
>>>         return out
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_out):
>>>         x, y, w, out = ctx.saved_tensors
>>>         z = ctx.z
>>>         gx = grad_out * (y + y * z)
>>>         gy = grad_out * (x + z + w)
>>>         gz = None
>>>         return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)
save_for_forward(
*tensors: Tensor,
)

Save given tensors for a future call to jvp.

save_for_forward should be called at most once, in either the setup_context or forward methods, and all arguments should be tensors.

In jvp, saved objects can be accessed through the saved_tensors attribute.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +SKIP
>>> class Func(torch.autograd.Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         ctx.save_for_backward(x, y)
>>>         ctx.save_for_forward(x, y)
>>>         ctx.z = z
>>>         return x * y * z
>>>
>>>     @staticmethod
>>>     def jvp(ctx, x_t, y_t, _):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * (y * x_t + x * y_t)
>>>
>>>     @staticmethod
>>>     def vjp(ctx, grad_out):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * grad_out * y, z * grad_out * x, None
>>>
>>>     a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>>     t = torch.tensor(1., dtype=torch.double)
>>>     b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>>     c = 4
>>>
>>>     with fwAD.dual_level():
>>>         a_dual = fwAD.make_dual(a, t)
>>>         d = Func.apply(a_dual, b, c)
saved_for_forward
saved_tensors
saved_variables
set_materialize_grads(
value: bool,
)

Set whether to materialize grad tensors. Default is True.

This should be called only from either the setup_context or forward methods.

If True, undefined grad tensors will be expanded to tensors full of zeros prior to calling the backward and jvp methods.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # We must check for None now
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # induces g2 to be undefined
static setup_context(
ctx: Any,
inputs: Tuple[Any, ...],
output: Any,
) Any

There are two ways to define the forward pass of an autograd.Function.

Either:

  1. Override forward with the signature forward(ctx, *args, **kwargs). setup_context is not overridden. Setting up the ctx for backward happens inside the forward.

  2. Override forward with the signature forward(*args, **kwargs) and override setup_context. Setting up the ctx for backward happens inside setup_context (as opposed to inside the forward)

See torch.autograd.Function.forward and Extending torch.autograd for more details.

to_save
static vjp(
ctx: Any,
*grad_outputs: Any,
) Any

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

static vmap(
info,
in_dims,
*args,
)

Define the behavior for this autograd.Function underneath torch.vmap.

For a torch.autograd.Function to support torch.vmap, you must either override this static method, or set generate_vmap_rule to True (you may not do both).

If you choose to override this staticmethod: it must accept

  • an info object as the first argument. info.batch_size specifies the size of the dimension being vmapped over, while info.randomness is the randomness option passed to torch.vmap.

  • an in_dims tuple as the second argument. For each arg in args, in_dims has a corresponding Optional[int]. It is None if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over.

  • *args, which is the same as the args to forward.

The return of the vmap staticmethod is a tuple of (output, out_dims). Similar to in_dims, out_dims should be of the same structure as output and contain one out_dim per output that specifies if the output has the vmapped dimension and what index it is in.

Please see Extending torch.func with autograd.Function for more details.

class TransformPose(*args, **kwargs)

Bases: Function

A differentiable function to transform a batch of poses by another batch of poses.

static forward(
ctx,
position: Tensor,
quaternion: Tensor,
position2: Tensor,
quaternion2: Tensor,
out_position: Tensor,
out_quaternion: Tensor,
adj_position: Tensor,
adj_quaternion: Tensor,
adj_position2: Tensor,
adj_quaternion2: Tensor,
)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward if they are intended to be used for in jvp.

static backward(
ctx,
grad_out_position,
grad_out_quaternion,
)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

_backward_cls

alias of TransformPoseBackward

_compiled_autograd_backward_state
static _compiled_autograd_key(
ctx,
)
_get_compiled_autograd_symints()
_input_metadata
_is_compiled_autograd_tracing()
_materialize_non_diff_grads
_raw_saved_tensors
static _register_hook(
backward_hooks,
hook,
)
_register_hook_dict()
_sequence_nr()
_set_sequence_nr()
classmethod apply(*args, **kwargs)
dirty_tensors
generate_vmap_rule = False
static jvp(
ctx: Any,
*grad_inputs: Any,
) Any

Define a formula for differentiating the operation with forward mode automatic differentiation.

This function is to be overridden by all subclasses. It must accept a context ctx as the first argument, followed by as many inputs as the forward got (None will be passed in for non tensor inputs of the forward function), and it should return as many tensors as there were outputs to forward. Each argument is the gradient w.r.t the given input, and each returned value should be the gradient w.r.t. the corresponding output. If an output is not a Tensor or the function is not differentiable with respect to that output, you can just pass None as a gradient for that input.

You can use the ctx object to pass any value from the forward to this functions.

mark_dirty(
*args: Tensor,
)

Mark given tensors as modified in an in-place operation.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be inputs.

Every tensor that’s been modified in-place in a call to forward should be given to this function, to ensure correctness of our checks. It doesn’t matter whether the function is called before or after modification.

Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Inplace(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         x_npy = x.numpy() # x_npy shares storage with x
>>>         x_npy += 1
>>>         ctx.mark_dirty(x)
>>>         return x
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_output):
>>>         return grad_output
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
>>> b = a * a
>>> Inplace.apply(a)  # This would lead to wrong gradients!
>>>                   # but the engine would not know unless we mark_dirty
>>> # xdoctest: +SKIP
>>> b.backward() # RuntimeError: one of the variables needed for gradient
>>>              # computation has been modified by an inplace operation
mark_non_differentiable(
*args: Tensor,
)

Mark outputs as non-differentiable.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be tensor outputs.

This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient for each output in backward, but it’s always going to be a zero tensor with the same shape as the shape of a corresponding output.

This is used e.g. for indices returned from a sort. See example::
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # still need to accept g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input
mark_shared_storage(
*pairs,
)
materialize_grads
maybe_clear_saved_tensors()
metadata
name()
needs_input_grad
next_functions
non_differentiable
register_hook()
register_prehook()
requires_grad
save_for_backward(
*tensors: Tensor,
)

Save given tensors for a future call to backward.

save_for_backward should be called at most once, in either the setup_context or forward methods, and only with tensors.

All tensors intended to be used in the backward pass should be saved with save_for_backward (as opposed to directly on ctx) to prevent incorrect gradients and memory leaks, and enable the application of saved tensor hooks. See torch.autograd.graph.saved_tensors_hooks.

Note that if intermediary tensors, tensors that are neither inputs nor outputs of forward, are saved for backward, your custom Function may not support double backward. Custom Functions that do not support double backward should decorate their backward method with @once_differentiable so that performing double backward raises an error. If you’d like to support double backward, you can either recompute intermediaries based on the inputs during backward or return the intermediaries as the outputs of the custom Function. See the double backward tutorial for more details.

In backward, saved tensors can be accessed through the saved_tensors attribute. Before returning them to the user, a check is made to ensure they weren’t used in any in-place operation that modified their content.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         w = x * z
>>>         out = x * y + y * z + w * y
>>>         ctx.save_for_backward(x, y, w, out)
>>>         ctx.z = z  # z is not a tensor
>>>         return out
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_out):
>>>         x, y, w, out = ctx.saved_tensors
>>>         z = ctx.z
>>>         gx = grad_out * (y + y * z)
>>>         gy = grad_out * (x + z + w)
>>>         gz = None
>>>         return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)
save_for_forward(
*tensors: Tensor,
)

Save given tensors for a future call to jvp.

save_for_forward should be called at most once, in either the setup_context or forward methods, and all arguments should be tensors.

In jvp, saved objects can be accessed through the saved_tensors attribute.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +SKIP
>>> class Func(torch.autograd.Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         ctx.save_for_backward(x, y)
>>>         ctx.save_for_forward(x, y)
>>>         ctx.z = z
>>>         return x * y * z
>>>
>>>     @staticmethod
>>>     def jvp(ctx, x_t, y_t, _):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * (y * x_t + x * y_t)
>>>
>>>     @staticmethod
>>>     def vjp(ctx, grad_out):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * grad_out * y, z * grad_out * x, None
>>>
>>>     a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>>     t = torch.tensor(1., dtype=torch.double)
>>>     b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>>     c = 4
>>>
>>>     with fwAD.dual_level():
>>>         a_dual = fwAD.make_dual(a, t)
>>>         d = Func.apply(a_dual, b, c)
saved_for_forward
saved_tensors
saved_variables
set_materialize_grads(
value: bool,
)

Set whether to materialize grad tensors. Default is True.

This should be called only from either the setup_context or forward methods.

If True, undefined grad tensors will be expanded to tensors full of zeros prior to calling the backward and jvp methods.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # We must check for None now
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # induces g2 to be undefined
static setup_context(
ctx: Any,
inputs: Tuple[Any, ...],
output: Any,
) Any

There are two ways to define the forward pass of an autograd.Function.

Either:

  1. Override forward with the signature forward(ctx, *args, **kwargs). setup_context is not overridden. Setting up the ctx for backward happens inside the forward.

  2. Override forward with the signature forward(*args, **kwargs) and override setup_context. Setting up the ctx for backward happens inside setup_context (as opposed to inside the forward)

See torch.autograd.Function.forward and Extending torch.autograd for more details.

to_save
static vjp(
ctx: Any,
*grad_outputs: Any,
) Any

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

static vmap(info, in_dims, *args)

Define the behavior for this autograd.Function underneath torch.vmap.

For a torch.autograd.Function to support torch.vmap, you must either override this static method, or set generate_vmap_rule to True (you may not do both).

If you choose to override this staticmethod: it must accept

  • an info object as the first argument. info.batch_size specifies the size of the dimension being vmapped over, while info.randomness is the randomness option passed to torch.vmap.

  • an in_dims tuple as the second argument. For each arg in args, in_dims has a corresponding Optional[int]. It is None if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over.

  • *args, which is the same as the args to forward.

The return of the vmap staticmethod is a tuple of (output, out_dims). Similar to in_dims, out_dims should be of the same structure as output and contain one out_dim per output that specifies if the output has the vmapped dimension and what index it is in.

Please see Extending torch.func with autograd.Function for more details.

class PoseInverse(*args, **kwargs)

Bases: Function

A differentiable function to get the inverse of a pose (also supports batch).

static forward(
ctx,
position: Tensor,
quaternion: Tensor,
out_position: Tensor,
out_quaternion: Tensor,
adj_position: Tensor,
adj_quaternion: Tensor,
)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward if they are intended to be used for in jvp.

static backward(
ctx,
grad_out_position,
grad_out_quaternion,
)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

_backward_cls

alias of PoseInverseBackward

_compiled_autograd_backward_state
static _compiled_autograd_key(ctx)
_get_compiled_autograd_symints()
_input_metadata
_is_compiled_autograd_tracing()
_materialize_non_diff_grads
_raw_saved_tensors
static _register_hook(
backward_hooks,
hook,
)
_register_hook_dict()
_sequence_nr()
_set_sequence_nr()
classmethod apply(*args, **kwargs)
dirty_tensors
generate_vmap_rule = False
static jvp(
ctx: Any,
*grad_inputs: Any,
) Any

Define a formula for differentiating the operation with forward mode automatic differentiation.

This function is to be overridden by all subclasses. It must accept a context ctx as the first argument, followed by as many inputs as the forward got (None will be passed in for non tensor inputs of the forward function), and it should return as many tensors as there were outputs to forward. Each argument is the gradient w.r.t the given input, and each returned value should be the gradient w.r.t. the corresponding output. If an output is not a Tensor or the function is not differentiable with respect to that output, you can just pass None as a gradient for that input.

You can use the ctx object to pass any value from the forward to this functions.

mark_dirty(
*args: Tensor,
)

Mark given tensors as modified in an in-place operation.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be inputs.

Every tensor that’s been modified in-place in a call to forward should be given to this function, to ensure correctness of our checks. It doesn’t matter whether the function is called before or after modification.

Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Inplace(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         x_npy = x.numpy() # x_npy shares storage with x
>>>         x_npy += 1
>>>         ctx.mark_dirty(x)
>>>         return x
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_output):
>>>         return grad_output
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
>>> b = a * a
>>> Inplace.apply(a)  # This would lead to wrong gradients!
>>>                   # but the engine would not know unless we mark_dirty
>>> # xdoctest: +SKIP
>>> b.backward() # RuntimeError: one of the variables needed for gradient
>>>              # computation has been modified by an inplace operation
mark_non_differentiable(
*args: Tensor,
)

Mark outputs as non-differentiable.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be tensor outputs.

This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient for each output in backward, but it’s always going to be a zero tensor with the same shape as the shape of a corresponding output.

This is used e.g. for indices returned from a sort. See example::
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # still need to accept g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input
mark_shared_storage(*pairs)
materialize_grads
maybe_clear_saved_tensors()
metadata
name()
needs_input_grad
next_functions
non_differentiable
register_hook()
register_prehook()
requires_grad
save_for_backward(
*tensors: Tensor,
)

Save given tensors for a future call to backward.

save_for_backward should be called at most once, in either the setup_context or forward methods, and only with tensors.

All tensors intended to be used in the backward pass should be saved with save_for_backward (as opposed to directly on ctx) to prevent incorrect gradients and memory leaks, and enable the application of saved tensor hooks. See torch.autograd.graph.saved_tensors_hooks.

Note that if intermediary tensors, tensors that are neither inputs nor outputs of forward, are saved for backward, your custom Function may not support double backward. Custom Functions that do not support double backward should decorate their backward method with @once_differentiable so that performing double backward raises an error. If you’d like to support double backward, you can either recompute intermediaries based on the inputs during backward or return the intermediaries as the outputs of the custom Function. See the double backward tutorial for more details.

In backward, saved tensors can be accessed through the saved_tensors attribute. Before returning them to the user, a check is made to ensure they weren’t used in any in-place operation that modified their content.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         w = x * z
>>>         out = x * y + y * z + w * y
>>>         ctx.save_for_backward(x, y, w, out)
>>>         ctx.z = z  # z is not a tensor
>>>         return out
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_out):
>>>         x, y, w, out = ctx.saved_tensors
>>>         z = ctx.z
>>>         gx = grad_out * (y + y * z)
>>>         gy = grad_out * (x + z + w)
>>>         gz = None
>>>         return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)
save_for_forward(
*tensors: Tensor,
)

Save given tensors for a future call to jvp.

save_for_forward should be called at most once, in either the setup_context or forward methods, and all arguments should be tensors.

In jvp, saved objects can be accessed through the saved_tensors attribute.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +SKIP
>>> class Func(torch.autograd.Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         ctx.save_for_backward(x, y)
>>>         ctx.save_for_forward(x, y)
>>>         ctx.z = z
>>>         return x * y * z
>>>
>>>     @staticmethod
>>>     def jvp(ctx, x_t, y_t, _):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * (y * x_t + x * y_t)
>>>
>>>     @staticmethod
>>>     def vjp(ctx, grad_out):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * grad_out * y, z * grad_out * x, None
>>>
>>>     a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>>     t = torch.tensor(1., dtype=torch.double)
>>>     b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>>     c = 4
>>>
>>>     with fwAD.dual_level():
>>>         a_dual = fwAD.make_dual(a, t)
>>>         d = Func.apply(a_dual, b, c)
saved_for_forward
saved_tensors
saved_variables
set_materialize_grads(
value: bool,
)

Set whether to materialize grad tensors. Default is True.

This should be called only from either the setup_context or forward methods.

If True, undefined grad tensors will be expanded to tensors full of zeros prior to calling the backward and jvp methods.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # We must check for None now
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # induces g2 to be undefined
static setup_context(
ctx: Any,
inputs: Tuple[Any, ...],
output: Any,
) Any

There are two ways to define the forward pass of an autograd.Function.

Either:

  1. Override forward with the signature forward(ctx, *args, **kwargs). setup_context is not overridden. Setting up the ctx for backward happens inside the forward.

  2. Override forward with the signature forward(*args, **kwargs) and override setup_context. Setting up the ctx for backward happens inside setup_context (as opposed to inside the forward)

See torch.autograd.Function.forward and Extending torch.autograd for more details.

to_save
static vjp(
ctx: Any,
*grad_outputs: Any,
) Any

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

static vmap(info, in_dims, *args)

Define the behavior for this autograd.Function underneath torch.vmap.

For a torch.autograd.Function to support torch.vmap, you must either override this static method, or set generate_vmap_rule to True (you may not do both).

If you choose to override this staticmethod: it must accept

  • an info object as the first argument. info.batch_size specifies the size of the dimension being vmapped over, while info.randomness is the randomness option passed to torch.vmap.

  • an in_dims tuple as the second argument. For each arg in args, in_dims has a corresponding Optional[int]. It is None if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over.

  • *args, which is the same as the args to forward.

The return of the vmap staticmethod is a tuple of (output, out_dims). Similar to in_dims, out_dims should be of the same structure as output and contain one out_dim per output that specifies if the output has the vmapped dimension and what index it is in.

Please see Extending torch.func with autograd.Function for more details.

class QuatToMatrix(*args, **kwargs)

Bases: Function

A differentiable function for converting quaternions to rotation matrices.

static forward(
ctx,
quaternion: Tensor,
out_mat: Tensor,
adj_quaternion: Tensor,
)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward if they are intended to be used for in jvp.

static backward(ctx, grad_out_mat)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

_backward_cls

alias of QuatToMatrixBackward

_compiled_autograd_backward_state
static _compiled_autograd_key(ctx)
_get_compiled_autograd_symints()
_input_metadata
_is_compiled_autograd_tracing()
_materialize_non_diff_grads
_raw_saved_tensors
static _register_hook(
backward_hooks,
hook,
)
_register_hook_dict()
_sequence_nr()
_set_sequence_nr()
classmethod apply(*args, **kwargs)
dirty_tensors
generate_vmap_rule = False
static jvp(
ctx: Any,
*grad_inputs: Any,
) Any

Define a formula for differentiating the operation with forward mode automatic differentiation.

This function is to be overridden by all subclasses. It must accept a context ctx as the first argument, followed by as many inputs as the forward got (None will be passed in for non tensor inputs of the forward function), and it should return as many tensors as there were outputs to forward. Each argument is the gradient w.r.t the given input, and each returned value should be the gradient w.r.t. the corresponding output. If an output is not a Tensor or the function is not differentiable with respect to that output, you can just pass None as a gradient for that input.

You can use the ctx object to pass any value from the forward to this functions.

mark_dirty(
*args: Tensor,
)

Mark given tensors as modified in an in-place operation.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be inputs.

Every tensor that’s been modified in-place in a call to forward should be given to this function, to ensure correctness of our checks. It doesn’t matter whether the function is called before or after modification.

Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Inplace(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         x_npy = x.numpy() # x_npy shares storage with x
>>>         x_npy += 1
>>>         ctx.mark_dirty(x)
>>>         return x
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_output):
>>>         return grad_output
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
>>> b = a * a
>>> Inplace.apply(a)  # This would lead to wrong gradients!
>>>                   # but the engine would not know unless we mark_dirty
>>> # xdoctest: +SKIP
>>> b.backward() # RuntimeError: one of the variables needed for gradient
>>>              # computation has been modified by an inplace operation
mark_non_differentiable(
*args: Tensor,
)

Mark outputs as non-differentiable.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be tensor outputs.

This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient for each output in backward, but it’s always going to be a zero tensor with the same shape as the shape of a corresponding output.

This is used e.g. for indices returned from a sort. See example::
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # still need to accept g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input
mark_shared_storage(*pairs)
materialize_grads
maybe_clear_saved_tensors()
metadata
name()
needs_input_grad
next_functions
non_differentiable
register_hook()
register_prehook()
requires_grad
save_for_backward(
*tensors: Tensor,
)

Save given tensors for a future call to backward.

save_for_backward should be called at most once, in either the setup_context or forward methods, and only with tensors.

All tensors intended to be used in the backward pass should be saved with save_for_backward (as opposed to directly on ctx) to prevent incorrect gradients and memory leaks, and enable the application of saved tensor hooks. See torch.autograd.graph.saved_tensors_hooks.

Note that if intermediary tensors, tensors that are neither inputs nor outputs of forward, are saved for backward, your custom Function may not support double backward. Custom Functions that do not support double backward should decorate their backward method with @once_differentiable so that performing double backward raises an error. If you’d like to support double backward, you can either recompute intermediaries based on the inputs during backward or return the intermediaries as the outputs of the custom Function. See the double backward tutorial for more details.

In backward, saved tensors can be accessed through the saved_tensors attribute. Before returning them to the user, a check is made to ensure they weren’t used in any in-place operation that modified their content.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         w = x * z
>>>         out = x * y + y * z + w * y
>>>         ctx.save_for_backward(x, y, w, out)
>>>         ctx.z = z  # z is not a tensor
>>>         return out
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_out):
>>>         x, y, w, out = ctx.saved_tensors
>>>         z = ctx.z
>>>         gx = grad_out * (y + y * z)
>>>         gy = grad_out * (x + z + w)
>>>         gz = None
>>>         return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)
save_for_forward(
*tensors: Tensor,
)

Save given tensors for a future call to jvp.

save_for_forward should be called at most once, in either the setup_context or forward methods, and all arguments should be tensors.

In jvp, saved objects can be accessed through the saved_tensors attribute.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +SKIP
>>> class Func(torch.autograd.Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         ctx.save_for_backward(x, y)
>>>         ctx.save_for_forward(x, y)
>>>         ctx.z = z
>>>         return x * y * z
>>>
>>>     @staticmethod
>>>     def jvp(ctx, x_t, y_t, _):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * (y * x_t + x * y_t)
>>>
>>>     @staticmethod
>>>     def vjp(ctx, grad_out):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * grad_out * y, z * grad_out * x, None
>>>
>>>     a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>>     t = torch.tensor(1., dtype=torch.double)
>>>     b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>>     c = 4
>>>
>>>     with fwAD.dual_level():
>>>         a_dual = fwAD.make_dual(a, t)
>>>         d = Func.apply(a_dual, b, c)
saved_for_forward
saved_tensors
saved_variables
set_materialize_grads(
value: bool,
)

Set whether to materialize grad tensors. Default is True.

This should be called only from either the setup_context or forward methods.

If True, undefined grad tensors will be expanded to tensors full of zeros prior to calling the backward and jvp methods.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # We must check for None now
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # induces g2 to be undefined
static setup_context(
ctx: Any,
inputs: Tuple[Any, ...],
output: Any,
) Any

There are two ways to define the forward pass of an autograd.Function.

Either:

  1. Override forward with the signature forward(ctx, *args, **kwargs). setup_context is not overridden. Setting up the ctx for backward happens inside the forward.

  2. Override forward with the signature forward(*args, **kwargs) and override setup_context. Setting up the ctx for backward happens inside setup_context (as opposed to inside the forward)

See torch.autograd.Function.forward and Extending torch.autograd for more details.

to_save
static vjp(
ctx: Any,
*grad_outputs: Any,
) Any

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

static vmap(info, in_dims, *args)

Define the behavior for this autograd.Function underneath torch.vmap.

For a torch.autograd.Function to support torch.vmap, you must either override this static method, or set generate_vmap_rule to True (you may not do both).

If you choose to override this staticmethod: it must accept

  • an info object as the first argument. info.batch_size specifies the size of the dimension being vmapped over, while info.randomness is the randomness option passed to torch.vmap.

  • an in_dims tuple as the second argument. For each arg in args, in_dims has a corresponding Optional[int]. It is None if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over.

  • *args, which is the same as the args to forward.

The return of the vmap staticmethod is a tuple of (output, out_dims). Similar to in_dims, out_dims should be of the same structure as output and contain one out_dim per output that specifies if the output has the vmapped dimension and what index it is in.

Please see Extending torch.func with autograd.Function for more details.

class MatrixToQuaternion(*args, **kwargs)

Bases: Function

A differentiable function for converting rotation matrices to quaternions.

static forward(
ctx,
in_mat: Tensor,
out_quaternion: Tensor,
adj_mat: Tensor,
)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward if they are intended to be used for in jvp.

static backward(
ctx,
grad_out_q,
)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

_backward_cls

alias of MatrixToQuaternionBackward

_compiled_autograd_backward_state
static _compiled_autograd_key(
ctx,
)
_get_compiled_autograd_symints()
_input_metadata
_is_compiled_autograd_tracing()
_materialize_non_diff_grads
_raw_saved_tensors
static _register_hook(
backward_hooks,
hook,
)
_register_hook_dict()
_sequence_nr()
_set_sequence_nr()
classmethod apply(
*args,
**kwargs,
)
dirty_tensors
generate_vmap_rule = False
static jvp(
ctx: Any,
*grad_inputs: Any,
) Any

Define a formula for differentiating the operation with forward mode automatic differentiation.

This function is to be overridden by all subclasses. It must accept a context ctx as the first argument, followed by as many inputs as the forward got (None will be passed in for non tensor inputs of the forward function), and it should return as many tensors as there were outputs to forward. Each argument is the gradient w.r.t the given input, and each returned value should be the gradient w.r.t. the corresponding output. If an output is not a Tensor or the function is not differentiable with respect to that output, you can just pass None as a gradient for that input.

You can use the ctx object to pass any value from the forward to this functions.

mark_dirty(
*args: Tensor,
)

Mark given tensors as modified in an in-place operation.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be inputs.

Every tensor that’s been modified in-place in a call to forward should be given to this function, to ensure correctness of our checks. It doesn’t matter whether the function is called before or after modification.

Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Inplace(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         x_npy = x.numpy() # x_npy shares storage with x
>>>         x_npy += 1
>>>         ctx.mark_dirty(x)
>>>         return x
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_output):
>>>         return grad_output
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
>>> b = a * a
>>> Inplace.apply(a)  # This would lead to wrong gradients!
>>>                   # but the engine would not know unless we mark_dirty
>>> # xdoctest: +SKIP
>>> b.backward() # RuntimeError: one of the variables needed for gradient
>>>              # computation has been modified by an inplace operation
mark_non_differentiable(
*args: Tensor,
)

Mark outputs as non-differentiable.

This should be called at most once, in either the setup_context or forward methods, and all arguments should be tensor outputs.

This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient for each output in backward, but it’s always going to be a zero tensor with the same shape as the shape of a corresponding output.

This is used e.g. for indices returned from a sort. See example::
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # still need to accept g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input
mark_shared_storage(
*pairs,
)
materialize_grads
maybe_clear_saved_tensors()
metadata
name()
needs_input_grad
next_functions
non_differentiable
register_hook()
register_prehook()
requires_grad
save_for_backward(
*tensors: Tensor,
)

Save given tensors for a future call to backward.

save_for_backward should be called at most once, in either the setup_context or forward methods, and only with tensors.

All tensors intended to be used in the backward pass should be saved with save_for_backward (as opposed to directly on ctx) to prevent incorrect gradients and memory leaks, and enable the application of saved tensor hooks. See torch.autograd.graph.saved_tensors_hooks.

Note that if intermediary tensors, tensors that are neither inputs nor outputs of forward, are saved for backward, your custom Function may not support double backward. Custom Functions that do not support double backward should decorate their backward method with @once_differentiable so that performing double backward raises an error. If you’d like to support double backward, you can either recompute intermediaries based on the inputs during backward or return the intermediaries as the outputs of the custom Function. See the double backward tutorial for more details.

In backward, saved tensors can be accessed through the saved_tensors attribute. Before returning them to the user, a check is made to ensure they weren’t used in any in-place operation that modified their content.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         w = x * z
>>>         out = x * y + y * z + w * y
>>>         ctx.save_for_backward(x, y, w, out)
>>>         ctx.z = z  # z is not a tensor
>>>         return out
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_out):
>>>         x, y, w, out = ctx.saved_tensors
>>>         z = ctx.z
>>>         gx = grad_out * (y + y * z)
>>>         gy = grad_out * (x + z + w)
>>>         gz = None
>>>         return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)
save_for_forward(
*tensors: Tensor,
)

Save given tensors for a future call to jvp.

save_for_forward should be called at most once, in either the setup_context or forward methods, and all arguments should be tensors.

In jvp, saved objects can be accessed through the saved_tensors attribute.

Arguments can also be None. This is a no-op.

See Extending torch.autograd for more details on how to use this method.

Example::
>>> # xdoctest: +SKIP
>>> class Func(torch.autograd.Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         ctx.save_for_backward(x, y)
>>>         ctx.save_for_forward(x, y)
>>>         ctx.z = z
>>>         return x * y * z
>>>
>>>     @staticmethod
>>>     def jvp(ctx, x_t, y_t, _):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * (y * x_t + x * y_t)
>>>
>>>     @staticmethod
>>>     def vjp(ctx, grad_out):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * grad_out * y, z * grad_out * x, None
>>>
>>>     a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>>     t = torch.tensor(1., dtype=torch.double)
>>>     b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>>     c = 4
>>>
>>>     with fwAD.dual_level():
>>>         a_dual = fwAD.make_dual(a, t)
>>>         d = Func.apply(a_dual, b, c)
saved_for_forward
saved_tensors
saved_variables
set_materialize_grads(
value: bool,
)

Set whether to materialize grad tensors. Default is True.

This should be called only from either the setup_context or forward methods.

If True, undefined grad tensors will be expanded to tensors full of zeros prior to calling the backward and jvp methods.

Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # We must check for None now
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # induces g2 to be undefined
static setup_context(
ctx: Any,
inputs: Tuple[Any, ...],
output: Any,
) Any

There are two ways to define the forward pass of an autograd.Function.

Either:

  1. Override forward with the signature forward(ctx, *args, **kwargs). setup_context is not overridden. Setting up the ctx for backward happens inside the forward.

  2. Override forward with the signature forward(*args, **kwargs) and override setup_context. Setting up the ctx for backward happens inside setup_context (as opposed to inside the forward)

See torch.autograd.Function.forward and Extending torch.autograd for more details.

to_save
static vjp(
ctx: Any,
*grad_outputs: Any,
) Any

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward will have ctx.needs_input_grad[0] = True if the first input to forward needs gradient computed w.r.t. the output.

static vmap(
info,
in_dims,
*args,
)

Define the behavior for this autograd.Function underneath torch.vmap.

For a torch.autograd.Function to support torch.vmap, you must either override this static method, or set generate_vmap_rule to True (you may not do both).

If you choose to override this staticmethod: it must accept

  • an info object as the first argument. info.batch_size specifies the size of the dimension being vmapped over, while info.randomness is the randomness option passed to torch.vmap.

  • an in_dims tuple as the second argument. For each arg in args, in_dims has a corresponding Optional[int]. It is None if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over.

  • *args, which is the same as the args to forward.

The return of the vmap staticmethod is a tuple of (output, out_dims). Similar to in_dims, out_dims should be of the same structure as output and contain one out_dim per output that specifies if the output has the vmapped dimension and what index it is in.

Please see Extending torch.func with autograd.Function for more details.