SAMのヤバさ SoTAを総なめ!衝撃のオプティマイザー「SAM」爆誕&解説! ICLR2021に衝撃的な手法が登場しました。その名もSharpness-Aware Minimization、通称SAMです。どれくらい衝撃かというと、画像分類タスクにおいて、SAMがImageNet(88.61%)/CIFAR-10(99.70%)/CIFAR-100(96.08%)などを含む9つものデータセットでSoTAを更新したくらいです(カッコ内はSAMによる精度)。話題のVision Transformer(ViT)のImageNetの結果(88.55%)を早速超しました(SoTA更新早すぎます)。
簡単に要約すると、損失が最小かつ平坦なパラメータを探しに行くようなoptimizerです。そうすることで汎用性が高まります。
とにかく強いです。
公式実装は下にあります。
https://github.com/davda54/sam
見ればわかりますが、optimizerを2回呼び出す必要があるので計算量は少し増えてしまいます。(自分環境だと元と比較して約1.2倍、時間がかかる)
PyTorch Lightningで動かすときの問題 そもそもPyTorch Lightningはいちいち backward()やstep()とかを書くのが嫌で、kerasっぽい学習をしたいけどPyTorchの拡張性も使いたいっていう良い所どりしたい人が使うもの(偏見強め)なので、新しい手法だと問題が生じることが多々有ります。
今回のSAMはまさにそうで、first_step()とsecond_step()を呼び出す必要がありますが、PyTorch Lightningのデフォルトではもちろん1回しか呼ぶことができません。
(また、SAMの公式実装で呼ぶ回数が1回で済む関数が用意されてましたが、エラーを吐かれてしまいました)
結局良い書き方が分からず、Twitterで呼びかけた所、有識者が現れました。
https://twitter.com/kuto_bopro/status/1363406456469422083
solution 改良したSAMのコードとrunnerのコードを貼っておきます。ほとんどは先程のツイートのリプライに貼ってあるコードをコピペしたところです。(少し修正はしてあります)
import torch """ https://github.com/kuto5046/kaggle-rainforest/blob/main/src/sam.py#L16 """ class SAM(torch.optim.Optimizer): def __init__(self, params, base_optimizer, rho=0.05, **kwargs): assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" defaults = dict(rho=rho, **kwargs) super(SAM, self).__init__(params, defaults) self.base_optimizer = base_optimizer(self.param_groups, **kwargs) self.param_groups = self.base_optimizer.param_groups @torch.no_grad() def first_step(self, closure=None, zero_grad=False): loss = None if closure is not None: with torch.