pytorch 设置梯度求导
pytorch 为数值函数提供了自动求导包 torch.autograd
,只需要在 Tensor 定义时增加 requires_grad=True
关键字。但是,有时我们需要在局部需要设置某些计算不进行梯度计算,这可以节省内存开销。本篇所有代码在 jupyterlab 中运行。
常用的局部计算管理器
如下的上下文管理器均是类。
管理器 | 说明 |
---|---|
no_grad | Context-manager that disabled gradient calculation. |
enable_grad | Context-manager that enables gradient calculation. |
set_grad_enabled | Context-manager that sets gradient calculation to on or off. |
inference_mode | Context-manager that enables or disables inference mode. |
torch.no_grad
该上下文管理器用于禁止梯度计算,常采用如下两种方法。
with torch.no_grad()
代码示例:
1 | import torch |
@torch.no_grad()
作为装饰器,将函数中不进行梯度计算,接续上面代码:
1 |
|
torch.enable_grad
该上下文管理器用于启用梯度计算,常采用如下两种方法。
with torch.enable_grad()
代码示例:
1 | x = torch.tensor([1.], requires_grad=True) |
@torch.enable_grad()
作为装饰器,函数中进行梯度计算,接续上面代码:
1 |
|
torch.set_grad_enabled(mode)
将梯度计算设置为打开或关闭的上下文管理器。
1 | x = torch.tensor([1.], requires_grad=True) |
就近原则
当混用前面的三个管理器时,采用就近原则。
1 | x = torch.tensor([1.], requires_grad=True) |
1 | x = torch.tensor([1.], requires_grad=True) |
参考文献
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 J. Xu!
评论