<dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><small id="yhprb"></small><dfn id="yhprb"></dfn><small id="yhprb"><delect id="yhprb"></delect></small><small id="yhprb"></small><small id="yhprb"></small> <delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"></dfn><dfn id="yhprb"></dfn><s id="yhprb"><noframes id="yhprb"><small id="yhprb"><dfn id="yhprb"></dfn></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><small id="yhprb"></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn> <small id="yhprb"></small><delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn>
"); //-->

博客專(zhuān)欄

EEPW首頁(yè) > 博客 > Github1.3萬(wàn)星,迅猛發(fā)展的JAX對比TensorFlow、PyTorch

Github1.3萬(wàn)星,迅猛發(fā)展的JAX對比TensorFlow、PyTorch

發(fā)布人:機器之心 時(shí)間:2021-08-15 來(lái)源:工程師 發(fā)布文章

JAX 是機器學(xué)習 (ML) 領(lǐng)域的新生力量,它有望使 ML 編程更加直觀(guān)、結構化和簡(jiǎn)潔。

在機器學(xué)習領(lǐng)域,大家可能對 TensorFlow 和 PyTorch 已經(jīng)耳熟能詳,但除了這兩個(gè)框架,一些新生力量也不容小覷,它就是谷歌推出的 JAX。很對研究者對其寄予厚望,希望它可以取代 TensorFlow 等眾多機器學(xué)習框架。

JAX 最初由谷歌大腦團隊的 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人發(fā)起。

目前,JAX 在 GitHub 上已累積 13.7K 星。

1.png

項目地址:https://github.com/google/jax

迅速發(fā)展的 JAX

JAX 的前身是 Autograd,其借助 Autograd 的更新版本,并且結合了 XLA,可對 Python 程序與 NumPy 運算執行自動(dòng)微分,支持循環(huán)、分支、遞歸、閉包函數求導,也可以求三階導數;依賴(lài)于 XLA,JAX 可以在 GPU 和 TPU 上編譯和運行 NumPy 程序;通過(guò) grad,可以支持自動(dòng)模式反向傳播和正向傳播,且二者可以任意組合成任何順序。

2.png

開(kāi)發(fā) JAX 的出發(fā)點(diǎn)是什么?說(shuō)到這,就不得不提 NumPy。NumPy 是 Python 中的一個(gè)基礎數值運算庫,被廣泛使用。但是 numpy 不支持 GPU 或其他硬件加速器,也沒(méi)有對反向傳播的內置支持,此外,Python 本身的速度限制阻礙了 NumPy 使用,所以少有研究者在生產(chǎn)環(huán)境下直接用 numpy 訓練或部署深度學(xué)習模型。

在此情況下,出現了眾多的深度學(xué)習框架,如 PyTorch、TensorFlow 等。但是 numpy 具有靈活、調試方便、API 穩定等獨特的優(yōu)勢。而 JAX 的主要出發(fā)點(diǎn)就是將 numpy 的以上優(yōu)勢與硬件加速結合。

目前,基于 JAX 已有很多優(yōu)秀的開(kāi)源項目,如谷歌的神經(jīng)網(wǎng)絡(luò )庫團隊開(kāi)發(fā)了 Haiku,這是一個(gè)面向 Jax 的深度學(xué)習代碼庫,通過(guò) Haiku,用戶(hù)可以在 Jax 上進(jìn)行面向對象開(kāi)發(fā);又比如 RLax,這是一個(gè)基于 Jax 的強化學(xué)習庫,用戶(hù)使用 RLax 就能進(jìn)行 Q-learning 模型的搭建和訓練;此外還包括基于 JAX 的深度學(xué)習庫 JAXnet,該庫一行代碼就能定義計算圖、可進(jìn)行 GPU 加速??梢哉f(shuō),在過(guò)去幾年中,JAX 掀起了深度學(xué)習研究的風(fēng)暴,推動(dòng)了科學(xué)研究迅速發(fā)展。

JAX 的安裝

如何使用 JAX 呢?首先你需要在 Python 環(huán)境或 Google colab 中安裝 JAX,使用 pip 進(jìn)行安裝:

$ pip install --upgrade jax jaxlib

注意,上述安裝方式只是支持在 CPU 上運行,如果你想在 GPU 執行程序,首先你需要有 CUDA、cuDNN ,然后運行以下命令(確保將 jaxlib 版本映射到 CUDA 版本):

$ pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

現在將 JAX 與 Numpy 一起導入:

import jax
import jax.numpy as jnp
import numpy as np

JAX 的一些特性

使用 grad() 函數自動(dòng)微分:這對深度學(xué)習應用非常有用,這樣就可以很容易地運行反向傳播,下面為一個(gè)簡(jiǎn)單的二次函數并在點(diǎn) 1.0 上求導的示例:

from jax import grad
def f(x):
  return 3*x**2 + 2*x + 5
def f_prime(x):
  return 6*x +2
grad(f)(1.0)
# DeviceArray(8., dtype=float32)
f_prime(1.0)
# 8.0

jit(Just in time) :為了利用 XLA 的強大功能,必須將代碼編譯到 XLA 內核中。這就是 jit 發(fā)揮作用的地方。要使用 XLA 和 jit,用戶(hù)可以使用 jit() 函數或 @jit 注釋。

from jax import jit
x = np.random.rand(1000,1000)
y = jnp.array(x)
def f(x):
  for _ in range(10):
      x = 0.5*x + 0.1* jnp.sin(x)
  return x
g = jit(f)
%timeit -n 5 -r 5 f(y).block_until_ready()
# 5 loops, best of 5: 10.8 ms per loop
%timeit -n 5 -r 5 g(y).block_until_ready()
# 5 loops, best of 5: 341 μs per loop

pmap:自動(dòng)將計算分配到所有當前設備,并處理它們之間的所有通信。JAX 通過(guò) pmap 轉換支持大規模的數據并行,從而將單個(gè)處理器無(wú)法處理的大數據進(jìn)行處理。要檢查可用設備,可以運行 jax.devices():

from jax import pmap
def f(x):
  return jnp.sin(x) + x**2
f(np.arange(4))
#DeviceArray([0.       , 1.841471 , 4.9092975, 9.14112  ], dtype=float32)
pmap(f)(np.arange(4))
#ShardedDeviceArray([0.       , 1.841471 , 4.9092975, 9.14112  ], dtype=float32)

vmap:是一種函數轉換,JAX 通過(guò) vmap 變換提供了自動(dòng)矢量化算法,大大簡(jiǎn)化了這種類(lèi)型的計算,這使得研究人員在處理新算法時(shí)無(wú)需再去處理批量化的問(wèn)題。示例如下:

from jax import vmap
def f(x):
  return jnp.square(x)
f(jnp.arange(10))
#DeviceArray([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81], dtype=int32)
vmap(f)(jnp.arange(10))
#DeviceArray([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81], dtype=int32)

TensorFlow vs PyTorch vs Jax

在深度學(xué)習領(lǐng)域有幾家巨頭公司,他們所提出的框架被廣大研究者使用。比如谷歌的 TensorFlow、Facebook 的 PyTorch、微軟的 CNTK、亞馬遜 AWS 的 MXnet 等。

每種框架都有其優(yōu)缺點(diǎn),選擇的時(shí)候需要根據自身需求進(jìn)行選擇。

3.png

我們以 Python 中的 3 個(gè)主要深度學(xué)習框架——TensorFlow、PyTorch 和 Jax 為例進(jìn)行比較。這些框架雖然不同,但有兩個(gè)共同點(diǎn):

它們是開(kāi)源的。這意味著(zhù)如果庫中存在錯誤,使用者可以在 GitHub 中發(fā)布問(wèn)題(并修復),此外你也可以在庫中添加自己的功能;

由于全局解釋器鎖,Python 在內部運行緩慢。所以這些框架使用 C/C++ 作為后端來(lái)處理所有的計算和并行過(guò)程。

那么它們的不同體現在哪些方面呢?如下表所示,為 TensorFlow、PyTorch、JAX 三個(gè)框架的比較。

4.png

TensorFlow

TensorFlow 由谷歌開(kāi)發(fā),最初版本可追溯到 2015 年開(kāi)源的 TensorFlow0.1,之后發(fā)展穩定,擁有強大的用戶(hù)群體,成為最受歡迎的深度學(xué)習框架。但是用戶(hù)在使用時(shí),也暴露了 TensorFlow 缺點(diǎn),例如 API 穩定性不足、靜態(tài)計算圖編程復雜等缺陷。因此在 TensorFlow2.0 版本,谷歌將 Keras 納入進(jìn)來(lái),成為 tf.keras。

目前 TensorFlow 主要特點(diǎn)包括以下:

這是一個(gè)非常友好的框架,高級 API-Keras 的可用性使得模型層定義、損失函數和模型創(chuàng )建變得非常容易;

TensorFlow2.0 帶有 Eager Execution(動(dòng)態(tài)圖機制),這使得該庫更加用戶(hù)友好,并且是對以前版本的重大升級;

Keras 這種高級接口有一定的缺點(diǎn),由于 TensorFlow 抽象了許多底層機制(只是為了方便最終用戶(hù)),這讓研究人員在處理模型方面的自由度更??;

Tensorflow 提供了 TensorBoard,它實(shí)際上是 Tensorflow 可視化工具包。它允許研究者可視化損失函數、模型圖、模型分析等。

PyTorch

PyTorch(Python-Torch) 是來(lái)自 Facebook 的機器學(xué)習庫。用 TensorFlow 還是 PyTorch?在一年前,這個(gè)問(wèn)題毫無(wú)爭議,研究者大部分會(huì )選擇 TensorFlow。但現在的情況大不一樣了,使用 PyTorch 的研究者越來(lái)越多。PyTorch 的一些最重要的特性包括:

5.png

與 TensorFlow 不同,PyTorch 使用動(dòng)態(tài)類(lèi)型圖,這意味著(zhù)執行圖是在運行中創(chuàng )建的。它允許我們隨時(shí)修改和檢查圖的內部結構;

除了用戶(hù)友好的高級 API 之外,PyTorch 還包括精心構建的低級 API,允許對機器學(xué)習模型進(jìn)行越來(lái)越多的控制。我們可以在訓練期間對模型的前向和后向傳遞進(jìn)行檢查和修改輸出。這被證明對于梯度裁剪和神經(jīng)風(fēng)格遷移非常有效;

PyTorch 允許用戶(hù)擴展代碼,可以輕松添加新的損失函數和用戶(hù)定義的層。PyTorch 的 Autograd 模塊實(shí)現了深度學(xué)習算法中的反向傳播求導數,在 Tensor 類(lèi)上的所有操作, Autograd 都能自動(dòng)提供微分,簡(jiǎn)化了手動(dòng)計算導數的復雜過(guò)程;

PyTorch 對數據并行和 GPU 的使用具有廣泛的支持;

PyTorch 比 TensorFlow 更 Python 化。PyTorch 非常適合 Python 生態(tài)系統,它允許使用 Python 類(lèi)調試器工具來(lái)調試 PyTorch 代碼。

JAX 

JAX 是來(lái)自 Google 的一個(gè)相對較新的機器學(xué)習庫。它更像是一個(gè) autograd 庫,可以區分原生的 python 和 NumPy 代碼。JAX 的一些特性主要包括:

正如官方網(wǎng)站所描述的那樣,JAX 能夠執行 Python+NumPy 程序的可組合轉換:向量化、JIT 到 GPU/TPU 等等;

與 PyTorch 相比,JAX 最重要的方面是如何計算梯度。在 Torch 中,圖是在前向傳遞期間創(chuàng )建的,梯度在后向傳遞期間計算, 另一方面,在 JAX 中,計算表示為函數。在函數上使用 grad() 返回一個(gè)梯度函數,該函數直接計算給定輸入的函數梯度;

JAX 是一個(gè) autograd 工具,不建議單獨使用。有各種基于 JAX 的機器學(xué)習庫,其中值得注意的是 ObJax、Flax 和 Elegy。由于它們都使用相同的核心并且接口只是 JAX 庫的 wrapper,因此可以將它們放在同一個(gè) bracket 下;

Flax 最初是在 PyTorch 生態(tài)系統下開(kāi)發(fā)的,更注重使用的靈活性。另一方面,Elegy 受 Keras 啟發(fā)。ObJAX 主要是為以研究為導向的目的而設計的,它更注重簡(jiǎn)單性和可理解性。 

參考鏈接:

https://www.askpython.com/python-modules/tensorflow-vs-pytorch-vs-jax

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

https://www.zhihu.com/question/306496943/answer/557876584

*博客內容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權請聯(lián)系工作人員刪除。



關(guān)鍵詞: 機器學(xué)習

相關(guān)推薦

技術(shù)專(zhuān)區

關(guān)閉
国产精品自在自线亚洲|国产精品无圣光一区二区|国产日产欧洲无码视频|久久久一本精品99久久K精品66|欧美人与动牲交片免费播放
<dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><small id="yhprb"></small><dfn id="yhprb"></dfn><small id="yhprb"><delect id="yhprb"></delect></small><small id="yhprb"></small><small id="yhprb"></small> <delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"></dfn><dfn id="yhprb"></dfn><s id="yhprb"><noframes id="yhprb"><small id="yhprb"><dfn id="yhprb"></dfn></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><small id="yhprb"></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn> <small id="yhprb"></small><delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn>