torch.nn.Parameter继承torch.Tensor类,其功能为将不可训练的Tensor类参数转化为可训练的Parameter类参数,并将这个参数绑定到module中,成为module中可训练的参数。
输入包括:
- data为传入Tensor类型参数。
- requires_grad为是否训练。 True表示可训练,False表示不可训练,默认值为True。
它与torch.Tensor的区别就是nn.Parameter会自动被认为是module的可训练参数,即加入到parameter()这个迭代器中去;而module中非nn.Parameter()的普通Tensor是不在parameter()中的。
注意,nn.Parameter的对象的requires_grad属性的默认值是True,即是可被训练的,这与torch.Tensor对象的默认值相反。
在nn.Module类中,pytorch也是使用nn.Parameter来对每一个module的参数进行初始化的。
test