2022-05-31 15:56:03 +00:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from typing import Type
|
|
|
|
|
|
|
|
import torch
|
2022-08-15 10:14:14 +00:00
|
|
|
from prototorch.models import BaseYArchitecture
|
2022-05-31 15:56:03 +00:00
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
|
|
|
|
|
|
class SingleLearningRateMixin(BaseYArchitecture):
|
|
|
|
"""
|
|
|
|
Single Learning Rate
|
|
|
|
|
|
|
|
All parameters are updated with a single learning rate.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# HyperParameters
|
|
|
|
# ----------------------------------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
|
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
|
|
"""
|
|
|
|
lr: The learning rate. Default: 0.1.
|
|
|
|
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
|
|
|
"""
|
|
|
|
lr: float = 0.1
|
|
|
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
|
|
|
|
|
|
|
# Hooks
|
|
|
|
# ----------------------------------------------------------------------------------------------------
|
|
|
|
def configure_optimizers(self):
|
2022-06-12 08:36:15 +00:00
|
|
|
return self.hparams.optimizer(self.parameters(),
|
|
|
|
lr=self.hparams.lr) # type: ignore
|
2022-05-31 15:56:03 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MultipleLearningRateMixin(BaseYArchitecture):
|
|
|
|
"""
|
|
|
|
Multiple Learning Rates
|
|
|
|
|
|
|
|
Define Different Learning Rates for different parameters.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# HyperParameters
|
|
|
|
# ----------------------------------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
|
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
|
|
"""
|
|
|
|
lr: The learning rate. Default: 0.1.
|
|
|
|
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
|
|
|
"""
|
|
|
|
lr: dict = field(default_factory=lambda: dict())
|
|
|
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
|
|
|
|
|
|
|
# Hooks
|
|
|
|
# ----------------------------------------------------------------------------------------------------
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizers = []
|
2022-06-12 08:36:15 +00:00
|
|
|
for name, lr in self.hparams.lr.items():
|
2022-05-31 15:56:03 +00:00
|
|
|
if not hasattr(self, name):
|
|
|
|
raise ValueError(f"{name} is not a parameter of {self}")
|
|
|
|
else:
|
|
|
|
model_part = getattr(self, name)
|
|
|
|
if isinstance(model_part, Parameter):
|
|
|
|
optimizers.append(
|
2022-06-12 08:36:15 +00:00
|
|
|
self.hparams.optimizer(
|
2022-05-31 15:56:03 +00:00
|
|
|
[model_part],
|
|
|
|
lr=lr, # type: ignore
|
|
|
|
))
|
|
|
|
elif hasattr(model_part, "parameters"):
|
|
|
|
optimizers.append(
|
2022-06-12 08:36:15 +00:00
|
|
|
self.hparams.optimizer(
|
2022-05-31 15:56:03 +00:00
|
|
|
model_part.parameters(),
|
|
|
|
lr=lr, # type: ignore
|
|
|
|
))
|
|
|
|
return optimizers
|