Как я могу установить скорость обучения для каждого конкретного параметра (веса и смещения) в сети?

В документах PyTorch я обнаружил следующее:

optim.SGD([{'params': model.base.parameters()}, 
           {'params': model.classifier.parameters(), 'lr': 1e-3}], 
           lr=1e-2, momentum=0.9)

Где model.classifier.parameters(), который определяет группу параметров, получает конкретную скорость обучения 1e-3.

Но как я могу перевести это на уровень параметров?

1
oezguensi 24 Ноя 2019 в 04:28

1 ответ

Лучший ответ

Вы можете установить скорость обучения для конкретного параметра, используя имена параметров для установки скорости обучения, например

Для данной сети, взятой из Форум PyTorch:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer1 = nn.Linear(1, 1)
        self.layer1.weight.data.fill_(1)
        self.layer1.bias.data.fill_(1)
        self.layer2 = nn.Linear(1, 1)
        self.layer2.weight.data.fill_(1)
        self.layer2.bias.data.fill_(1)

    def forward(self, x):
        x = self.layer1(x)
        return self.layer2(x)

net = Net()
for name, param in net.named_parameters():
    print(name)

Параметры:

layer1.weight
layer1.bias
layer2.weight
layer2.bias

Затем вы можете использовать имена параметров, чтобы установить их конкретную скорость обучения следующим образом:

optimizer = optim.Adam([
            {'params': net.layer1.weight},
            {'params': net.layer1.bias, 'lr': 0.01},
            {'params': net.layer2.weight, 'lr': 0.001}
        ], lr=0.1, weight_decay=0.0001)

out = net(torch.Tensor([[1]]))
out.backward()
optimizer.step()
print("weight", net.layer1.weight.data.numpy(), "grad", net.layer1.weight.grad.data.numpy())
print("bias", net.layer1.bias.data.numpy(), "grad", net.layer1.bias.grad.data.numpy())
print("weight", net.layer2.weight.data.numpy(), "grad", net.layer2.weight.grad.data.numpy())
print("bias", net.layer2.bias.data.numpy(), "grad", net.layer2.bias.grad.data.numpy())

Выход:

weight [[0.9]] grad [[1.0001]]
bias [0.99] grad [1.0001]
weight [[0.999]] grad [[2.0001]]
bias [1.] grad [1.]
2
kHarshit 24 Ноя 2019 в 06:59