3.4 自定义层
上一节简要地提到了nn.Module
在PyTorch中是所有NN构建块的基础父类。它不仅仅是现存层的统一父类,它远不止于此。通过将nn.Module
子类化,可以创建自己的构建块,它们可以组合在一起,后续可以复用,并且可以完美地集成到PyTorch框架中。
作为核心,nn.Module
为其子类提供了相当丰富的功能:
- 它记录当前模块的所有子模块。例如,构建块可以具有两个前馈层,可以以某种方式使用它们来执行代码块的转换。
- 提供处理已注册子模块的所有参数的函数。可以获取模块参数的完整列表(
parameters()
方法)将其梯度置零(zero_grads()
方法),将其移至CPU或GPU(to(device)
方法),序列化和反序列化模块(state_dict()
和load_state_dict()
),甚至可以用自己的callable执行通用的转换逻辑(apply()
方法)。 - 建立了
Module
针对数据的约定。每个模块都需要覆盖forward()
方法来执行数据的转换。 - 还有更多的函数,例如注册钩子函数以调整模块转换逻辑或梯度流,它们更加适合高级的使用场景。
这些功能允许我们通过统一的方式将子模型嵌套到更高层次的模型中,在处理复杂的情况时非常有用。它可以是简单的单层线性变换,也可以是1001层的residual NN(ResNet)
,但是如果它们遵循nn.Module
的约定,则可以用相同的方式处理它们。这对于代码的简洁性和可重用性非常有帮助。
为了简化工作,PyTorch的作者遵循上述约定,通过精心设计和大量Python魔术方法简化了模块的创建。因此,要创建自定义模块,通常只需要做两件事——注册子模块并实现forward()
方法。
我们来看上一节中Sequential
的例子是如何使用更加通用和可复用的方式做到这一点的(完整的示例见Chapter03/01_modules.py
):
这是继承了nn.Module
的模块。在构造函数中,我们传递了三个参数:输入大小、输出大小和可选的dropout概率。我们要做的第一件事就是调用父类的构造函数来初始化。
第二步,我们需要创建一个已经熟悉的nn.Sequential
,包含一些不同的层,并将其赋给类中名为pipe
的字段。通过为字段分配一个Sequential
实例,自动注册该模块(nn.Sequential
继承自nn.Module
,与nn
包中的其他类一样)。注册它不需要任何调用,只需将子模块分配给字段即可。构造函数完成后,所有字段会被自动注册(如果确实想要手动注册,nn.Module
中也有函数可用)。
在这里,我们必须覆写forward
函数并实现自己的数据转换逻辑。由于模块是对其他层的非常简单的包装,因此只需让它们转换数据即可。请注意,要将模块应用于数据,我们需要调用该模块(即假设模块实例为一个函数并使用参数调用它)而不使用nn.Module
类的forward()
方法。这是因为nn.Module
会覆盖__call__()
方法(将实例视为可调用实例时,会使用该方法)。该方法执行了nn.Module
中的一些神奇的操作,并调用forward()
方法。如果直接调用forward()
,则将干预nn.Module
的职责,这可能会导致错误的结果。
因此,这就是定义自己的模块所需要做的。现在,我们来使用它:
我们创建模块,为输入和输出赋值,然后创建张量,让模块对其进行转换(遵守约定,将其视为callable)。之后,打印网络结构(nn.Module
覆写了__str__()
和__repr__()
方法),以更好的方式来展示内部结构。最后,展示运行的结果。
代码输出应如下所示:
当然,之前说了PyTorch支持动态特性。每一批数据都会调用forward()
方法,因此如果要根据所需处理的数据进行一些复杂的转换,例如分层softmax或要应用网络随机选择,那么你也可以这样做。模块参数的数量也不只限于一个。因此,如果需要,可以编写一个带有多个必需参数和几十个可选参数的模块,这都是可以的。
接下来,我们需要熟悉PyTorch库的两个重要部分(损失函数和优化器),它们将简化我们的生活。