87 lines
3.1 KiB
Python
87 lines
3.1 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import Type
|
|
|
|
import torch
|
|
from prototorch.y import BaseYArchitecture
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
class SingleLearningRateMixin(BaseYArchitecture):
|
|
"""
|
|
Single Learning Rate
|
|
|
|
All parameters are updated with a single learning rate.
|
|
"""
|
|
|
|
# HyperParameters
|
|
# ----------------------------------------------------------------------------------------------------
|
|
@dataclass
|
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
"""
|
|
lr: The learning rate. Default: 0.1.
|
|
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
|
"""
|
|
lr: float = 0.1
|
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
|
|
|
# Steps
|
|
# ----------------------------------------------------------------------------------------------------
|
|
def __init__(self, hparams: HyperParameters) -> None:
|
|
super().__init__(hparams)
|
|
self.lr = hparams.lr
|
|
self.optimizer = hparams.optimizer
|
|
|
|
# Hooks
|
|
# ----------------------------------------------------------------------------------------------------
|
|
def configure_optimizers(self):
|
|
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
|
|
|
|
|
|
class MultipleLearningRateMixin(BaseYArchitecture):
|
|
"""
|
|
Multiple Learning Rates
|
|
|
|
Define Different Learning Rates for different parameters.
|
|
"""
|
|
|
|
# HyperParameters
|
|
# ----------------------------------------------------------------------------------------------------
|
|
@dataclass
|
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
"""
|
|
lr: The learning rate. Default: 0.1.
|
|
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
|
"""
|
|
lr: dict = field(default_factory=lambda: dict())
|
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
|
|
|
# Steps
|
|
# ----------------------------------------------------------------------------------------------------
|
|
def __init__(self, hparams: HyperParameters) -> None:
|
|
super().__init__(hparams)
|
|
self.lr = hparams.lr
|
|
self.optimizer = hparams.optimizer
|
|
|
|
# Hooks
|
|
# ----------------------------------------------------------------------------------------------------
|
|
def configure_optimizers(self):
|
|
optimizers = []
|
|
for name, lr in self.lr.items():
|
|
if not hasattr(self, name):
|
|
raise ValueError(f"{name} is not a parameter of {self}")
|
|
else:
|
|
model_part = getattr(self, name)
|
|
if isinstance(model_part, Parameter):
|
|
optimizers.append(
|
|
self.optimizer(
|
|
[model_part],
|
|
lr=lr, # type: ignore
|
|
))
|
|
elif hasattr(model_part, "parameters"):
|
|
optimizers.append(
|
|
self.optimizer(
|
|
model_part.parameters(),
|
|
lr=lr, # type: ignore
|
|
))
|
|
return optimizers
|