三行代码  ›  专栏  ›  技术社区  ›  ilovewt

我可以将方法定义为属性吗?

  •  0
  • ilovewt  · 技术社区  · 5 天前

    class Trainer:
        """Object used to facilitate training."""
    
        def __init__(
            self,
            # params: Namespace,
            params,
            model,
            device=torch.device("cpu"),
            optimizer=None,
            scheduler=None,
            wandb_run=None,
            early_stopping: callbacks.EarlyStopping = None,
        ):
            # Set params
            self.params = params
            self.model = model
            self.device = device
    
            # self.optimizer = optimizer
            self.optimizer = self.get_optimizer()
            self.scheduler = scheduler
            self.wandb_run = wandb_run
            self.early_stopping = early_stopping
    
            # list to contain various train metrics
            # TODO: how to add more metrics? wandb log too. Maybe save to model artifacts?
    
            self.history = DefaultDict(list)
    
        @staticmethod
        def get_optimizer(
            model: models.CustomNeuralNet,
            optimizer_params: global_params.OptimizerParams(),
        ):
            """Get the optimizer for the model.
    
            Args:
                model (models.CustomNeuralNet): [description]
                optimizer_params (global_params.OptimizerParams): [description]
    
            Returns:
                [type]: [description]
            """
            return getattr(torch.optim, optimizer_params.optimizer_name)(
                model.parameters(), **optimizer_params.optimizer_params
            )
    

    注意,最初我是通过的 optimizer 在构造函数中,我将在这个类之外调用它。不过,我现在提出 get_optimizer self.optimizer = self.get_optimizer() 或者只是使用 self.get_optimizer() 在班上指定的地点?前者对我来说有一定的可读性。


    附录:我现在将实例放在 .fit() 方法,我将调用say 5次来训练模型5次。在这个场景中,即使在每次调用使用一次优化器时不会出现任何明显的问题,但不定义优化器是否更好 self.optimizer

        def fit(
            self,
            train_loader: torch.utils.data.DataLoader,
            valid_loader: torch.utils.data.DataLoader,
            fold: int = None,
        ):
            """[summary]
    
            Args:
                train_loader (torch.utils.data.DataLoader): [description]
                val_loader (torch.utils.data.DataLoader): [description]
                fold (int, optional): [description]. Defaults to None.
    
            Returns:
                [type]: [description]
            """
            self.optimizer = self.get_optimizer(
                model=self.model, optimizer_params=OPTIMIZER_PARAMS
            )
            self.scheduler = self.get_scheduler(
                optimizer=self.optimizer, scheduler_params=SCHEDULER_PARAMS
            )
    
    1 回复  |  直到 5 天前
        1
  •  2
  •   Shai    5 天前

    两者之间有一个区别:打电话给你的 get_optimizer 将实例化 刚出现的 torch.optim.<optimizer> 每一次。相比之下,设置 self.optimizer 优化器实例。