pytorch 为数值函数提供了自动求导包 torch.autograd,只需要在 Tensor 定义时增加 requires_grad=True 关键字。但是,有时我们需要在局部需要设置某些计算不进行梯度计算,这可以节省内存开销。本篇所有代码在 jupyterlab 中运行。

常用的局部计算管理器

如下的上下文管理器均是类。

管理器 说明
no_grad Context-manager that disabled gradient calculation.
enable_grad Context-manager that enables gradient calculation.
set_grad_enabled Context-manager that sets gradient calculation to on or off.
inference_mode Context-manager that enables or disables inference mode.

torch.no_grad

该上下文管理器用于禁止梯度计算,常采用如下两种方法。

with torch.no_grad()

代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
import torch

x = torch.tensor([1.], requires_grad=True)
x
# tensor([1.], requires_grad=True)

with torch.no_grad():
y = x * 2
y
# tensor([2.])
y.requires_grad
# False

@torch.no_grad()

作为装饰器,将函数中不进行梯度计算,接续上面代码:

1
2
3
4
5
6
7
8
9
@torch.no_grad()
def doubler(x):
return x * 2

z = doubler(x)
z
# tensor([2.])
z.requires_grad
# False

torch.enable_grad

该上下文管理器用于启用梯度计算,常采用如下两种方法。

with torch.enable_grad()

代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
x = torch.tensor([1.], requires_grad=True)
x
# tensor([1.], requires_grad=True)

with torch.no_grad():
with torch.enable_grad():
y = x * 2

y
# tensor([2.], grad_fn=<MulBackward0>)
y.requires_grad
# True

y.backward()
x.grad
# tensor([2.])

@torch.enable_grad()

作为装饰器,函数中进行梯度计算,接续上面代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
@torch.enable_grad()
def doubler(x):
return x * 2

with torch.no_grad():
z = doubler(x)
z
# tensor([2.], grad_fn=<MulBackward0>)
z.requires_grad
# True
z.backward()
x.grad
# tensor([4.]) # 这里因为是接续上面代码,所以x的梯度分别是y和z对x的梯度和。如果在代码块前增加对x的重新初始化,则只有z对x反向求导的梯度

torch.set_grad_enabled(mode)

将梯度计算设置为打开或关闭的上下文管理器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
x = torch.tensor([1.], requires_grad=True)
x
# tensor([1.], requires_grad=True)

is_train = False
with torch.set_grad_enabled(is_train):
y = x * 2
y
# tensor([2.])

y.requires_grad
# False

torch.set_grad_enabled(True)
y = x * 2
y
# tensor([2.], grad_fn=<MulBackward0>)

y.requires_grad
# True

就近原则

当混用前面的三个管理器时,采用就近原则。

1
2
3
4
5
6
7
8
9
10
11
12
x = torch.tensor([1.], requires_grad=True)
x
# tensor([1.], requires_grad=True)

with torch.enable_grad():
torch.set_grad_enabled(False)
y = x * 2
y
# tensor([2.])

y.requires_grad
# False
1
2
3
4
5
6
7
8
9
10
11
12
x = torch.tensor([1.], requires_grad=True)
x
# tensor([1.], requires_grad=True)

torch.set_anomaly_enabled(True)
with torch.no_grad():
z = x * 2
z
# tensor([2.])

z.requires_grad
# False

参考文献

  1. pytorch禁止/允许计算局部梯度