senooken JP Social
  • FAQ
  • Login
senooken JP Socialはsenookenの専用分散SNSです。
  • Public

    • Public
    • Network
    • Groups
    • Popular
    • People

Conversation

Notices

  1. Akionux (akionux@status.akionux.net)'s status on Wednesday, 23-Jun-2021 12:38:11 JST Akionux Akionux
    PyTorch LightningでSAMを動かす - I'm chizuchizu - https://chizuchizu.com/blog/sam_lightning/
    In conversation Wednesday, 23-Jun-2021 12:38:11 JST from status.akionux.net permalink

    Attachments

    1. Domain not in remote thumbnail source whitelist: raw.githubusercontent.com
      PyTorch LightningでSAMを動かす
      SAMのヤバさ SoTAを総なめ!衝撃のオプティマイザー「SAM」爆誕&解説! ICLR2021に衝撃的な手法が登場しました。その名もSharpness-Aware Minimization、通称SAMです。どれくらい衝撃かというと、画像分類タスクにおいて、SAMがImageNet(88.61%)/CIFAR-10(99.70%)/CIFAR-100(96.08%)などを含む9つものデータセットでSoTAを更新したくらいです(カッコ内はSAMによる精度)。話題のVision Transformer(ViT)のImageNetの結果(88.55%)を早速超しました(SoTA更新早すぎます)。 簡単に要約すると、損失が最小かつ平坦なパラメータを探しに行くようなoptimizerです。そうすることで汎用性が高まります。 とにかく強いです。 公式実装は下にあります。 https://github.com/davda54/sam 見ればわかりますが、optimizerを2回呼び出す必要があるので計算量は少し増えてしまいます。(自分環境だと元と比較して約1.2倍、時間がかかる) PyTorch Lightningで動かすときの問題 そもそもPyTorch Lightningはいちいち backward()やstep()とかを書くのが嫌で、kerasっぽい学習をしたいけどPyTorchの拡張性も使いたいっていう良い所どりしたい人が使うもの(偏見強め)なので、新しい手法だと問題が生じることが多々有ります。 今回のSAMはまさにそうで、first_step()とsecond_step()を呼び出す必要がありますが、PyTorch Lightningのデフォルトではもちろん1回しか呼ぶことができません。 (また、SAMの公式実装で呼ぶ回数が1回で済む関数が用意されてましたが、エラーを吐かれてしまいました) 結局良い書き方が分からず、Twitterで呼びかけた所、有識者が現れました。 https://twitter.com/kuto_bopro/status/1363406456469422083 solution 改良したSAMのコードとrunnerのコードを貼っておきます。ほとんどは先程のツイートのリプライに貼ってあるコードをコピペしたところです。(少し修正はしてあります) import torch """ https://github.com/kuto5046/kaggle-rainforest/blob/main/src/sam.py#L16 """ class SAM(torch.optim.Optimizer): def __init__(self, params, base_optimizer, rho=0.05, **kwargs): assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" defaults = dict(rho=rho, **kwargs) super(SAM, self).__init__(params, defaults) self.base_optimizer = base_optimizer(self.param_groups, **kwargs) self.param_groups = self.base_optimizer.param_groups @torch.no_grad() def first_step(self, closure=None, zero_grad=False): loss = None if closure is not None: with torch.

    Feeds

    • Activity Streams
    • RSS 2.0
    • Atom
    • Help
    • About
    • FAQ
    • TOS
    • Privacy
    • Source
    • Version
    • Contact

    senooken JP Social is a social network, courtesy of senooken. It runs on GNU social, version 2.0.2-beta0, available under the GNU Affero General Public License.

    Creative Commons Attribution 3.0 All senooken JP Social content and data are available under the Creative Commons Attribution 3.0 license.