發表日期 3/17/2022, 2:30:02 PM
用反嚮傳播(backpropagation)來計算優化目標函數的梯度,是當前機器學習領域的主流方法。近日,牛津與微軟等機構的多位學者聯閤提齣一種名為「正嚮梯度」(forward gradient)的自動微分模式,可以完全拋棄反嚮傳播進行梯度計算。實驗證明,在一些問題中,正嚮梯度的計算時間是反嚮傳播的二分之一。
編譯 | 張倩
編輯 | 陳彩嫻
反嚮傳播和基於梯度的優化是近年來機器學習(ML)取得重大突破的核心技術。
人們普遍認為,機器學習之所以能夠快速發展,是因為研究者們使用瞭第三方框架(如PyTorch、TensorFlow)來解析ML代碼。這些框架不僅具有自動微分(AD)功能,還為本地代碼提供瞭基礎的計算功能。而ML所依賴的這些軟件框架都是圍繞 AD 的反嚮模式所構建的。這主要是因為在ML中,當輸入的梯度為海量時,可以通過反嚮模式的單次評估進行精確有效的評估。
自動微分算法分為正嚮模式和反嚮模式。但正嚮模式的特點是隻需要對一個函數進行一次正嚮評估(即沒有用到任何反嚮傳播),計算成本明顯降低。為此,來自劍橋與微軟等機構的研究者們探索這種模式,展示瞭僅使用正嚮自動微分也能在一係列機器學習框架上實現穩定的梯度下降。
論文地址:https://arxiv.org/pdf/2202.08587v1.pdf
他們認為,正嚮梯度有利於改變經典機器學習訓練管道的計算復雜性,減少訓練的時間和精力成本,影響機器學習的硬件設計,甚至對大腦中反嚮傳播的生物學閤理性産生影響。
1
自動微分的兩種模式
首先,我們來簡要迴顧一下自動微分的兩種基本模式。
正嚮模式
給定一個函數 f: θ∈R n,v∈R n,正嚮模式的AD會計算 f(θ) 和雅可比嚮量乘積Jf (θ) v,其中Jf (θ) ∈R m×n是f在θ處評估的所有偏導數的雅可比矩陣,v是擾動嚮量。對於 f : R n R 的情況,在雅可比嚮量乘積對應的方嚮導數用 f(θ)- v錶示,即在θ處的梯度 f對方嚮嚮量v的映射,代錶沿著該方嚮的變化率。
值得注意的是,正嚮模式在一次正嚮運行中同時評估瞭函數 f 及其雅可比嚮量乘積 Jf v。此外,獲得 Jf v 不需要計算雅可比嚮量Jf,這一特點被稱為無矩陣計算。
反嚮模式
給定一個函數 f : R n R m,數值 θ∈R n,v∈R m,AD反嚮模式會計算f(θ)和雅可比嚮量乘積v |Jf (θ),其中Jf∈R m×n是f在θ處求值的所有偏導數的雅可比矩陣,v∈R m是一個鄰接的矢量。對於f : R n R和v = 1的情況,反嚮模式計算梯度,即f對所有n個輸入的偏導數 f(θ)=h f θ1,. . . , f θn i| 。
請注意,v |Jf 是在一次前嚮-後嚮評估中進行計算的,而不需要計算雅可比Jf 。
運行時間成本
兩種AD模式的運行時間以運行正在微分的函數 f 所需時間的恒定倍數為界。
反嚮模式的成本比正嚮模式高,因為它涉及到數據流的反轉,而且需要保留正嚮過程中所有操作結果的記錄,因為在接下來的反嚮過程中需要這些記錄來評估導數。內存和計算成本特徵最終取決於AD係統實現的功能,如利用稀疏性。
成本可以通過假設基本操作的計算復雜性來分析,如存儲、加法、乘法和非綫性操作。將評估原始函數 f 所需的時間錶示設為 runtime(f),我們可以將正嚮和反嚮模式所需的時間分彆錶示為 Rf×runtime(f) 和 Rb×runtime(f)。在實踐中,Rf 通常在1到3之間,Rb通常在5到10之間,不過這些結果都與程序高度相關。
2
方法
正嚮梯度
定義1
給定一個函數 f : R n R,他們將「正嚮梯度」 g : R n R n 定義為:
其中,θ∈R n 是評估梯度的關鍵點,v∈R n 是一個擾動嚮量,被視為一個多元隨機變量v p(v),這樣 v 的標量分量 vi 是獨立的,對所有 i 都有零均值和單位方差, f(θ)-v∈R 是 f 在在 v 方嚮上 θ 點的方嚮導數。
簡要地談一下這個定義的由來。
如前所述,正嚮模式直接給我們提供瞭方嚮導數 f(θ) - v = P i f θi vi,無需計算 f。將 f 正嚮評估 n 次,方嚮嚮量取為標準基(獨熱碼)嚮量ei∈R n,i=1 ... n,其中ei錶示在第i個坐標上為1、其他地方為0的嚮量,這時,隻用正嚮模式就可以計算 f。這樣就可以分彆評估f對每個輸入 f θi的敏感性,把所有結果閤並後就可以得到梯度 f。
為瞭獲得比反嚮傳播更優的運行時間優勢,我們需要在每個優化迭代中運行一次正嚮模式。在一次正嚮運行中,我們可以將方嚮v理解為敏感度加權和中的權重嚮量,即P i f θi vi,盡管這沒辦法區分每個θi在最終總數中的貢獻。因此,我們使用權重嚮量v將總體敏感度歸因於每個單獨的參數θi,與每個參數θi的權重vi成正比(例如,權重小的參數在總敏感度中的貢獻小,權重大的參數貢獻大)。
總之,每次評估正嚮梯度時,我們隻需做以下工作:
對一個隨機擾動嚮量v p(v)進行采樣,其大小與f的第一個參數相同。
通過AD正嚮模式運行f函數,在一次正嚮運行中同時評估f(θ)和 f(θ)-v,在此過程中無需計算 f。得到的方嚮導數( f(θ)-v)是一個標量,並且由AD精確計算(不是近似值)。
將標量方嚮導數 f(θ)-v與矢量v相乘,得到g(θ),即正嚮梯度。
圖 1 顯示瞭 Beale函數的幾個正嚮梯度的評估結果。我們可以看到擾動vk(橙色)如何在k∈[1,5]的情況下轉化為正嚮梯度( f-vk)vk(藍色),在受到指嚮限製時偶爾也會指嚮正確的梯度(紅色)。綠色箭頭錶示通過平均正嚮梯度來評估濛特卡洛梯度,即1 K PK k=1( f - vk)vk≈E[( f - v)v]。
正嚮梯度下降
他們構建瞭一個正嚮梯度下降(FGD)算法,用正嚮梯度g代替標準梯度下降中的梯度 f(算法1)。
在實踐中,他們使用小型隨機版本,其中 ft 在每次迭代中都會發生變化,因為它會被訓練中使用的每一小批數據影響。研究者注意到,算法 1 中的方嚮導數dt可以為正負數。如果為負數,正嚮梯度gt的方嚮會發生逆轉,指嚮預料中的真實梯度。圖1顯示的兩個vk樣本,證明瞭這種行為。
在本文中,他們將範圍限製在FGD上,單純研究瞭這一基礎算法,並將其與標準反嚮傳播進行比較,不考慮動量或自適應學習率等其他各種乾擾因素。筆者認為,正嚮梯度算法是可以應用到其他基於梯度算法的優化算法係列中的。
3
實驗
研究者在PyTorch中執行正嚮AD來進行實驗。他們發現,正嚮梯度與反嚮傳播這兩種方法在內存上沒有實際差異(每個實驗的差異都小於0.1%)。
邏輯迴歸
圖 3 給齣瞭多叉邏輯迴歸在MNIST數字分類上的幾次運行結果。我們觀察到,相比基本運行時間,正嚮梯度和反嚮傳播的運行時間成本分彆為 Rf=2.435 和 Rb=4.389,這與人們對典型AD係統的預期相符。
Rf/Rb=0.555和Tf/Tb=0.553的比率錶明,在運行時間和損失性能方麵,正嚮梯度大約比反嚮傳播快兩倍。
在簡單的模型中,這些比率是一緻的,因為這兩種技術在空間行為的迭代損失上幾乎相同,這意味著運行時收益幾乎直接反映在每個時間空間的損失上。
多層神經網絡
圖4顯示瞭用多層神經網絡在不同學習率下進行MNIST分類的兩個實驗。他們使用瞭三個架構大小分彆為1024、1024、10的全連接層。在這個模型架構中,他們觀察到正嚮梯度和反嚮傳播相對於基礎運行時間的運行成本為Rf=2.468和Rb=4.165,相對測量 Rf/Rb 平均為0.592,與邏輯迴歸的情況大緻相同。
有趣的是,在第二個實驗中(學習率為2×10-4),我們可以看到正嚮梯度在每個迭代損失圖中都實現瞭快速的下降。作者認為,這種行為是由於常規SGD(反嚮傳播)和正嚮SGD算法的隨機性不同所導緻的,因此他們推測:正嚮梯度引入的乾擾可能有利於探索損失平麵。
我們可以從時間麯綫圖看到,正嚮模式減少瞭運行時間。我們看到,損失性能指標Tf/Tb值為0.211,這錶明在驗證實驗損失的過程中,正嚮梯度的速度是反嚮傳播的四倍以上。
捲積神經網絡
圖 5 展示瞭一個捲積神經網絡對同一MNIST分類任務的正嚮梯度和反嚮傳播的比較。
在這個架構中,他們觀察到,相對於基本運行時間,正嚮AD的性能最好,其中正嚮模式的Rf=1.434,代錶瞭在基本運行時間之上的開銷隻有 43%。Rb=2.211 的反嚮傳播非常接近反嚮 AD 係統中所期待的理想情況。Rf/Rb=0.649 代錶瞭正嚮AD運行時間相對於反嚮傳播的一個顯著優勢。在損失空間,他們得到一個比率 Tf /Tb=0.514,這錶明在驗證損失的實驗中,正嚮梯度的速度比反嚮傳播的速度要快兩倍。
可擴展性
前麵的幾個結果錶明:
不用反嚮傳播也可以在一個典型的ML訓練管道中進行訓練,並且以一種競爭計算的方式來實現;
在相同參數(學習率和學習率衰減)的情況下,正嚮AD比反嚮傳播所消耗的時間要少很多。
相對於基礎運行時的成本,我們看到,對於大部分實驗,反嚮傳播在Rb∈[4,5]內,正嚮梯度在Rf∈[3,4]內。我們還觀察到,正嚮梯度算法在整個範圍內對運行都是有利的。Rf/Rb比率在10層以內保持在0.6以下,在100層時略高於0.8。重要的是,這兩種方法在內存消耗上幾乎沒有差彆。
4
結論
總的來說,這篇工作的幾點貢獻主要如下:
他們將「正嚮梯度」(forward gradient)定義為:一個無偏差的、基於正嚮自動微分且毫不涉及到反嚮傳播的梯度估算器。
他們在PyTorch中從零開始,實現瞭正嚮模式的自動微分係統,且完全不依賴PyTorch中已有的反嚮傳播。
他們把正嚮梯度模式應用在各類隨機梯度下降(SGD)優化中,最後的結果充分證明瞭:一個典型的現代機器學習訓練管道可以隻使用自動微分正嚮傳播來構建。
他們比較瞭正嚮梯度和反嚮傳播的運行時間和損失消耗等等,證明瞭在一些情況下,正嚮梯度算法的速度比反嚮傳播快兩倍。
雷峰網