綜合LSTM、transformer優(yōu)勢,DeepMind強化學(xué)習智能體提高數據效率
來(lái)自 DeepMind 的研究者提出了用于強化學(xué)習的 CoBERL 智能體,它結合了新的對比損失以及混合 LSTM-transformer 架構,可以提高處理數據效率。實(shí)驗表明,CoBERL 在整個(gè) Atari 套件、一組控制任務(wù)和具有挑戰性的 3D 環(huán)境中可以不斷提高性能。
近些年,多智能體強化學(xué)習取得了突破性進(jìn)展,例如 DeepMind 開(kāi)發(fā)的 AlphaStar 在星際爭霸 II 中擊敗了職業(yè)星際玩家,超過(guò)了 99.8% 的人類(lèi)玩家;OpenAI Five 在 DOTA2 中多次擊敗世界冠軍隊伍,是首個(gè)在電子競技比賽中擊敗冠軍的人工智能系統。然而,許多強化學(xué)習(RL)智能體需要大量的實(shí)驗才能解決任務(wù)。
最近,DeepMind 的研究者提出了 CoBERL(Contrastive BERT for RL)智能體,它結合了新的對比損失和混合 LSTM-transformer 架構,以提高處理數據效率。CoBERL 使得從更廣泛領(lǐng)域使用像素級信息進(jìn)行高效、穩健學(xué)習成為可能。
具體地,研究者使用雙向掩碼預測,并且結合最近的對比方法泛化,來(lái)學(xué)習 RL 中 transformer 更好的表征,而這一過(guò)程不需要手動(dòng)進(jìn)行數據擴充。實(shí)驗表明,CoBERL 在整個(gè) Atari 套件、一組控制任務(wù)和具有挑戰性的 3D 環(huán)境中可以不斷提高性能。
論文地址:https://arxiv.org/pdf/2107.05431.pdf
方法介紹
為了解決深度強化學(xué)習中的數據效率問(wèn)題,研究者對目前的研究提出了兩種修改:
首先提出了一種新的表征學(xué)習目標,旨在通過(guò)增強掩碼輸入預測中的自注意力一致性來(lái)學(xué)習更好的表征;
其次提出了一種架構改進(jìn),該架構可以結合 LSTM 以及 transformer 的優(yōu)勢。
CoBERL 整體架構圖。
表征學(xué)習
研究者將 BERT 與對比學(xué)習結合起來(lái)?;?BERT 方法,該研究將 transformer 的雙向處理機制與掩碼預測設置相結合。雙向處理機制一方面允許智能體根據時(shí)間環(huán)境來(lái)了解特定狀態(tài)的上下文。另一方面,位于掩碼位置處的預測輸入通過(guò)降低預測后續時(shí)間步長(cháng)的概率來(lái)緩解相關(guān)輸入問(wèn)題。
研究者還使用了對比學(xué)習,雖然許多對比損失(例如 SimCLR)依賴(lài)于數據擴充來(lái)創(chuàng )建可以進(jìn)行比較的數據分組,但該研究不需要利用這些手工數據擴充來(lái)構造代理任務(wù)。
相反地,該研究依賴(lài)輸入數據的順序性質(zhì)來(lái)創(chuàng )建對比學(xué)習所需的相似和不同點(diǎn)的必要分組,不需要僅依賴(lài)圖像觀(guān)測的數據增強(如裁剪和像素變化)。對于對比損失,研究者使用了 RELIC,該損失同樣適應于時(shí)間域;他們通過(guò)對齊 GTrXL transformer 輸入和輸出創(chuàng )建數據分組,并且使用 RELIC 作為 KL 正則化改進(jìn)所用方法的性能,例如 SimCLR 在圖像分類(lèi)領(lǐng)域以及 Atari 在 RL 領(lǐng)域性能都得到提高。
CoBERL 架構
在自然語(yǔ)言處理和計算機視覺(jué)任務(wù)當中,transformer 在連接長(cháng)范圍數據依賴(lài)性方面非常有效,但在 RL 設置中,transformer 難以訓練并且容易過(guò)擬合。相反,LSTM 在 RL 中已經(jīng)被證明非常有用。盡管 LSTM 不能很好地捕獲長(cháng)范圍的依賴(lài)關(guān)系,但卻可以高效地捕獲短范圍的依賴(lài)關(guān)系。
該研究提出了一個(gè)簡(jiǎn)單但強大的架構改變:在 GTrXL 頂部添加了一個(gè) LSTM 層,同時(shí)在 LSTM 和 GTrXL 之間有一個(gè)額外的門(mén)控殘差連接,由 GTrXL 的輸入進(jìn)行調制。此外,該架構還有一個(gè)包含從 transformer 輸入到 LSTM 輸出的跳躍連接。更具體地說(shuō),Y_t 在時(shí)間 t 時(shí)編碼器網(wǎng)絡(luò )的輸出,可以用下列方程定義附加模塊:
這些模塊是互補的,因為 transformer 沒(méi)有最近偏差,而 LSTM 的偏差可以表示最近的輸入——等式 6 中的 Gate 允許編碼器表征和 transformer 輸出混合。這種內存架構與 RL 機制的選擇無(wú)關(guān),研究者在開(kāi)啟和關(guān)閉策略(on and off-policy)設置中評估了這種架構。對于 on-policy 設置,該研究使用 V-MPO 作為 RL 算法。V-MPO 使用目標分布進(jìn)行策略更新,并在 KL 約束下將參數部分移向該目標。對于 off-policy 設置,研究者使用 R2D2。
R2D2 智能體:R2D2(Recurrent Replay Distributed DQN) 演示了如何調整 replay 和 RL 學(xué)習目標,以適用于具有循環(huán)架構的智能體。鑒于其在 Atari-57 和 DMLab-30 上的競爭性能,研究者在 R2D2 的背景下實(shí)現了 CoBERL 架構。他們用門(mén)控 transformer 和 LSTM 組合有效地替換了 LSTM,并添加了對比表示學(xué)習損失。因此,通過(guò) R2D2,以及分布式經(jīng)驗收集的益處,將循環(huán)智能體狀態(tài)存儲在 replay buffer 中,并在訓練期間「燒入」(burning in)具有 replay 序列展開(kāi)網(wǎng)絡(luò )的一部分。
V-MPO 智能體:鑒于 V-MPO 在 DMLab-30 上的強大性能,特別是與作為 CoBERL 關(guān)鍵組件的 GTrXL 架構相結合,該研究使用 V-MPO 和 DMLab30 來(lái)演示 CoBERL 與 on-policy 算法的使用。V-MPO 是一種基于最大后驗概率策略?xún)?yōu)化(MPO)的 on-policy 自適應算法。為了避免策略梯度方法中經(jīng)常出現的高方差,V-MPO 使用目標分布進(jìn)行策略更新,受基于樣本的 KL 約束,計算梯度將參數部分移向目標,這樣也同樣受 KL 約束。與 MPO 不同,V-MPO 使用可學(xué)習的狀態(tài) - 價(jià)值函數 V(s) 而不是狀態(tài) - 動(dòng)作價(jià)值函數。
實(shí)驗細節
研究者證明了 1) CoBERL 在更為廣泛的環(huán)境和任務(wù)中能夠提高性能,2)最大化性能還需要所有組件。實(shí)驗展示了 CoBERL 在 Atari57 、DeepMind Control Suite 和 DMLab-30 中的性能。
下表 1 為目前可獲得的不同智能體的結果。由結果可得,CoBERL 在大多數游戲中的表現高于人類(lèi)平均水平,并且顯著(zhù)高于同類(lèi)算法平均性能。R2D2-GTrXL 的中值(median)略?xún)?yōu)于 CoBERL,表明 R2D2-GTrXL 確實(shí)是 Atari 上的強大變體。研究者還觀(guān)察到在檢查「25th Pct 以及 5th Pct」時(shí) ,CoBERL 的性能和其他算法的差異更大, 這表明 CoBERL 提高了數據效率。
為了在具有挑戰性的 3D 環(huán)境中測試 CoBERL,該研究在 DmLab30 中運行,如下圖 2 所示:
下表 3 的結果表明與沒(méi)有對比損失的 CoBERL 相比,對比損失可以顯著(zhù)提高 Atari 和 DMLab-30 的性能。此外,在 DmLab-30 這樣具有挑戰性的環(huán)境中,沒(méi)有額外損失的 CoBERL 仍然優(yōu)于基線(xiàn)方法。
下表 4 為該研究提出的對比損失與 SimCLR、CURL 之間的比較:結果表明該對比損失雖然比 SimCLR、CURL 簡(jiǎn)單,但性能更好。
下表 5 為從 CoBERL 中刪除 LSTM 的效果(如 w/o LSTM 一列),以及移除門(mén)控及其相關(guān)的跳躍連接(如 w/o Gate 一列)。在這兩種情況下 CoBERL 的性能都要差很多,這表明 CoBERL 需要這兩個(gè)組件(LSTM 和 Gate)。
下表 6 根據參數的數量對模型進(jìn)行了比較。對于 Atari,CoBERL 在 R2D2(GTrXL) 基線(xiàn)上添加的參數數量有限;然而,CoBERL 仍然在性能上產(chǎn)生了顯著(zhù)的提升。該研究還試圖將 LSTM 移到 transformer 模塊之前,在這種情況下,對比損失表征取自 LSTM 之前。
*博客內容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權請聯(lián)系工作人員刪除。