JAX/Flax로 딥러닝 레벨업
고급 모델링과 병렬 가속화로 무장한 차세대 딥러닝 라이브러리를 만나다
Regular price
$26.97
Sale price
Regular price
✈️
Estimated delivery date 예상 배송일
Standard Shipping
불러오는 중...
주문일로부터 8-12 영업일
Express Shipping
불러오는 중...
주문일로부터 6-8 영업일
LLM 시대를 선도하는 최적의 딥러닝 라이브러리 JAX/Flax
JAX(잭스)는 대규모 계산의 확장성을 염두에 두고 설계된 고성능 라이브러리로, LLM 시대 애물단지로 전락한 파이토치를 빠르게 대체하고 있다. 모두의연구소 JAX/Flax LAB이 집필한 이 책은 JAX, 그리고 JAX와 함께 쓰이는 Flax(플랙스)를 본격적으로 다루는 국내 최초의 책이다. JAX 기초와 함수형 프로그래밍, 병렬처리 등의 특장점을 살펴보고, JAX와 Flax를 조합해서 CNN, ResNet, DCGAN, CLIP 모델을 실제로 구현해본다. 새로운 시대, 새로운 딥러닝의 방식을 익혀보자.
JAX(잭스)는 대규모 계산의 확장성을 염두에 두고 설계된 고성능 라이브러리로, LLM 시대 애물단지로 전락한 파이토치를 빠르게 대체하고 있다. 모두의연구소 JAX/Flax LAB이 집필한 이 책은 JAX, 그리고 JAX와 함께 쓰이는 Flax(플랙스)를 본격적으로 다루는 국내 최초의 책이다. JAX 기초와 함수형 프로그래밍, 병렬처리 등의 특장점을 살펴보고, JAX와 Flax를 조합해서 CNN, ResNet, DCGAN, CLIP 모델을 실제로 구현해본다. 새로운 시대, 새로운 딥러닝의 방식을 익혀보자.
Couldn't load pickup availability
출판사 리뷰
출판사 리뷰
LLM 시대 모두가 기다려온 확장 가능한 고성능 딥러닝 라이브러리 JAX/Flax
파이토치는 쓸 만한 라이브러리였다. LLM 전성시대가 닥치기 전까지는. JAX는 파이토치가 부족했던 부분을 채우며 부상했다. 대규모 계산의 병렬처리 등 바로 이 시대에 무엇보다 우선시되는 '확장성'을 염두에 두고 설계되었기 때문이다.
이 책은 국내 JAX 전문가들이 모인 모두의연구소 JAX/Flax LAB이 집필한 최초의 JAX+Flax 도서다. JAX 기초부터 시작해 함수형 프로그래밍, JIT 컴파일, 병렬처리 등 JAX의 특장점을 살펴본 다음, 현재 JAX와 가장 많이 조합되는 신경망 라이브러리인 Flax를 설명한다. CNN, ResNet, DCGAN, CLIP, DistilGPT2 모델의 관련 이론을 간단히 살펴보며 JAX와 Flax를 조합해서 우아하게 구현하는 방법을 보여준다. 새 술은 새 부대에. 파이토치는 놓아주고 새로운 시대에 맞는 새로운 라이브러리를 익혀보자.
주요 내용함수형 프로그래밍, 파이썬 라이브러리 등 JAX 사용 시 알아야 할 기초JIT 컴파일, 자동 벡터화, pytree, 병렬처리 등 JAX의 주요 특징CNN 튜토리얼로 알아보는 Flax 기초ResNet, DCGAN, CLIP 모델을 구축하며 Flax에 익숙해지기코랩, 캐글에서 TPU 환경 설정하기
파이토치는 쓸 만한 라이브러리였다. LLM 전성시대가 닥치기 전까지는. JAX는 파이토치가 부족했던 부분을 채우며 부상했다. 대규모 계산의 병렬처리 등 바로 이 시대에 무엇보다 우선시되는 '확장성'을 염두에 두고 설계되었기 때문이다.
이 책은 국내 JAX 전문가들이 모인 모두의연구소 JAX/Flax LAB이 집필한 최초의 JAX+Flax 도서다. JAX 기초부터 시작해 함수형 프로그래밍, JIT 컴파일, 병렬처리 등 JAX의 특장점을 살펴본 다음, 현재 JAX와 가장 많이 조합되는 신경망 라이브러리인 Flax를 설명한다. CNN, ResNet, DCGAN, CLIP, DistilGPT2 모델의 관련 이론을 간단히 살펴보며 JAX와 Flax를 조합해서 우아하게 구현하는 방법을 보여준다. 새 술은 새 부대에. 파이토치는 놓아주고 새로운 시대에 맞는 새로운 라이브러리를 익혀보자.
주요 내용함수형 프로그래밍, 파이썬 라이브러리 등 JAX 사용 시 알아야 할 기초JIT 컴파일, 자동 벡터화, pytree, 병렬처리 등 JAX의 주요 특징CNN 튜토리얼로 알아보는 Flax 기초ResNet, DCGAN, CLIP 모델을 구축하며 Flax에 익숙해지기코랩, 캐글에서 TPU 환경 설정하기
목차
목차
베타리더 후기 viii
지은이 소개 x
JAX/Flax LAB 소개 xii
들어가며 xiii
이 책에 대하여 xiv
CHAPTER 1 JAX/Flax를 공부하기 전에 1
1.1 JAX/Flax에 대한 소개와 예시 1
__1.1.1 JAX란 1
__1.1.2 Flax란 2
__1.1.3 JAX로 이루어진 기타 프레임워크들 3
__1.1.4 JAX 프레임워크 사용 예시 3
1.2 함수형 프로그래밍에 대한 이해 5
__1.2.1 부수 효과와 순수 함수 5
__1.2.2 불변성과 순수 함수 7
__1.2.3 정리하며 8
1.3 JAX/Flax에서 자주 사용하는 파이썬 표준 라이브러리 9
__1.3.1 functools.partial() 10
__1.3.2 typing 모듈 12
__1.3.3 정리하며 13
1.4 JAX/Flax 설치 방법 14
__1.4.1 로컬에 JAX/Flax 설치하기 14
__1.4.2 코랩에서 TPU 사용하기 14
CHAPTER 2 JAX의 특징 17
2.1 NumPy에서부터 JAX 시작하기 18
__2.1.1 JAX와 NumPy 비교하기 18
__2.1.2 JAX에서 미분 계산하기 19
__2.1.3 손실 함수의 그레이디언트 계산하기 21
__2.1.4 손실 함수의 중간 과정 확인하기 22
__2.1.5 JAX의 함수형 언어적 특징 이해하기 23
__2.1.6 JAX로 간단한 학습 돌려보기 25
2.2 JAX의 JIT 컴파일 28
__2.2.1 JAX 변환 이해하기 29
__2.2.2 함수를 JIT 컴파일하기 32
__2.2.3 JIT 컴파일이 안 되는 경우 34
__2.2.4 JIT 컴파일과 캐싱 37
2.3 자동 벡터화 39
__2.3.1 수동으로 벡터화하기 39
__2.3.2 자동으로 벡터화하기 41
2.4 자동 미분 42
__2.4.1 고차 도함수 43
__2.4.2 그레이디언트 중지 46
__2.4.3 샘플당 그레이디언트 49
2.5 JAX의 난수 52
__2.5.1 NumPy의 난수 52
__2.5.2 JAX의 난수 56
2.6 pytree 사용하기 59
__2.6.1 pytree의 정의 60
__2.6.2 pytree 함수 사용법 61
2.7 JAX에서의 병렬처리 65
2.8 상태를 유지하는 연산 69
__2.8.1 상태에 대한 이해 69
__2.8.2 모델에 적용하기 72
CHAPTER 3 Flax 소개 77
3.1 Flax CNN 튜토리얼 79
__3.1.1 패키지 로드하기 79
__3.1.2 데이터 로드하기 80
__3.1.3 모델 정의와 초기화 81
__3.1.4 메트릭 정의하기 84
__3.1.5 TrainState 초기화 84
__3.1.6 훈련 스텝과 평가 스텝 정의하기 85
__3.1.7 모델 학습하기 87
__3.1.8 모델 추론하기 89
3.2 심화 튜토리얼 90
__3.2.1 배치 정규화 적용 91
__3.2.2 드롭아웃 적용 95
__3.2.3 학습률 스케줄링 98
__3.2.4 체크포인트 관리 103
CHAPTER 4 JAX/Flax를 활용한 딥러닝 모델 만들기 105
4.1 순수 JAX로 구현하는 CNN 106
__4.1.1 패키지 로드하기 107
__4.1.2 데이터 로드하기 108
__4.1.3 레이어 구현 108
__4.1.4 네트워크 정의하기 115
__4.1.5 학습 및 평가 준비 116
__4.1.6 학습 및 평가 118
__4.1.7 추론 121
4.2 ResNet 122
__4.2.1 패키지 로드하기 123
__4.2.2 데이터 로드하기 123
__4.2.3 모델 정의 및 초기화 124
__4.2.4 메트릭 정의하기 129
__4.2.5 TrainState 초기화 129
__4.2.6 훈련 스텝과 평가 스텝 정의하기 131
__4.2.7 모델 학습하기 132
__4.2.8 결과 시각화하기 135
4.3 DCGAN 136
__4.3.1 패키지 로드하기 136
__4.3.2 데이터 로드하기 137
__4.3.3 모델 정의 및 초기화 138
__4.3.4 학습 방법 정의하기 140
__4.3.5 TrainState 초기화 143
__4.3.6 모델 학습하기 145
__4.3.7 결과 시각화하기 146
4.4 CLIP 148
__4.4.1 CIFAR10 데이터셋으로 CLIP 미세조정 진행하기 150
__4.4.2 JAX로 만들어진 데이터셋 구축 클래스 150
__4.4.3 이미지 데이터 구축 함수 뜯어보기 151
__4.4.4 CLIP 모델 불러오기 154
__4.4.5 CLIP에 사용하기 위한 전처리 및 미세조정 155
__4.4.6 모델 학습에 필요한 함수 정의하기 156
__4.4.7 하이퍼파라미터 설정과 TrainState 구축하기 160
__4.4.8 모델 저장하고 체크포인트 만들기 161
__4.4.9 요약 클래스 만들기 162
__4.4.10 학습에 필요한 스텝 정의와 랜덤 인수 복제 163
__4.4.11 모델 학습하기와 모델 저장하기 163
4.5 DistilGPT2 미세조정 학습 166
__4.5.1 패키지 설치 167
__4.5.2 환경 설정 168
__4.5.3 토크나이저 학습 169
__4.5.4 데이터셋 전처리 171
__4.5.5 학습 및 평가 173
__4.5.6 추론 178
CHAPTER 5 TPU 환경 설정 181
5.1 코랩에서 TPU 설정하기 181
5.2 캐글에서 TPU 세팅하기 182
5.3 TRC 프로그램 신청하기 183
마무리하며 186
찾아보기 188
지은이 소개 x
JAX/Flax LAB 소개 xii
들어가며 xiii
이 책에 대하여 xiv
CHAPTER 1 JAX/Flax를 공부하기 전에 1
1.1 JAX/Flax에 대한 소개와 예시 1
__1.1.1 JAX란 1
__1.1.2 Flax란 2
__1.1.3 JAX로 이루어진 기타 프레임워크들 3
__1.1.4 JAX 프레임워크 사용 예시 3
1.2 함수형 프로그래밍에 대한 이해 5
__1.2.1 부수 효과와 순수 함수 5
__1.2.2 불변성과 순수 함수 7
__1.2.3 정리하며 8
1.3 JAX/Flax에서 자주 사용하는 파이썬 표준 라이브러리 9
__1.3.1 functools.partial() 10
__1.3.2 typing 모듈 12
__1.3.3 정리하며 13
1.4 JAX/Flax 설치 방법 14
__1.4.1 로컬에 JAX/Flax 설치하기 14
__1.4.2 코랩에서 TPU 사용하기 14
CHAPTER 2 JAX의 특징 17
2.1 NumPy에서부터 JAX 시작하기 18
__2.1.1 JAX와 NumPy 비교하기 18
__2.1.2 JAX에서 미분 계산하기 19
__2.1.3 손실 함수의 그레이디언트 계산하기 21
__2.1.4 손실 함수의 중간 과정 확인하기 22
__2.1.5 JAX의 함수형 언어적 특징 이해하기 23
__2.1.6 JAX로 간단한 학습 돌려보기 25
2.2 JAX의 JIT 컴파일 28
__2.2.1 JAX 변환 이해하기 29
__2.2.2 함수를 JIT 컴파일하기 32
__2.2.3 JIT 컴파일이 안 되는 경우 34
__2.2.4 JIT 컴파일과 캐싱 37
2.3 자동 벡터화 39
__2.3.1 수동으로 벡터화하기 39
__2.3.2 자동으로 벡터화하기 41
2.4 자동 미분 42
__2.4.1 고차 도함수 43
__2.4.2 그레이디언트 중지 46
__2.4.3 샘플당 그레이디언트 49
2.5 JAX의 난수 52
__2.5.1 NumPy의 난수 52
__2.5.2 JAX의 난수 56
2.6 pytree 사용하기 59
__2.6.1 pytree의 정의 60
__2.6.2 pytree 함수 사용법 61
2.7 JAX에서의 병렬처리 65
2.8 상태를 유지하는 연산 69
__2.8.1 상태에 대한 이해 69
__2.8.2 모델에 적용하기 72
CHAPTER 3 Flax 소개 77
3.1 Flax CNN 튜토리얼 79
__3.1.1 패키지 로드하기 79
__3.1.2 데이터 로드하기 80
__3.1.3 모델 정의와 초기화 81
__3.1.4 메트릭 정의하기 84
__3.1.5 TrainState 초기화 84
__3.1.6 훈련 스텝과 평가 스텝 정의하기 85
__3.1.7 모델 학습하기 87
__3.1.8 모델 추론하기 89
3.2 심화 튜토리얼 90
__3.2.1 배치 정규화 적용 91
__3.2.2 드롭아웃 적용 95
__3.2.3 학습률 스케줄링 98
__3.2.4 체크포인트 관리 103
CHAPTER 4 JAX/Flax를 활용한 딥러닝 모델 만들기 105
4.1 순수 JAX로 구현하는 CNN 106
__4.1.1 패키지 로드하기 107
__4.1.2 데이터 로드하기 108
__4.1.3 레이어 구현 108
__4.1.4 네트워크 정의하기 115
__4.1.5 학습 및 평가 준비 116
__4.1.6 학습 및 평가 118
__4.1.7 추론 121
4.2 ResNet 122
__4.2.1 패키지 로드하기 123
__4.2.2 데이터 로드하기 123
__4.2.3 모델 정의 및 초기화 124
__4.2.4 메트릭 정의하기 129
__4.2.5 TrainState 초기화 129
__4.2.6 훈련 스텝과 평가 스텝 정의하기 131
__4.2.7 모델 학습하기 132
__4.2.8 결과 시각화하기 135
4.3 DCGAN 136
__4.3.1 패키지 로드하기 136
__4.3.2 데이터 로드하기 137
__4.3.3 모델 정의 및 초기화 138
__4.3.4 학습 방법 정의하기 140
__4.3.5 TrainState 초기화 143
__4.3.6 모델 학습하기 145
__4.3.7 결과 시각화하기 146
4.4 CLIP 148
__4.4.1 CIFAR10 데이터셋으로 CLIP 미세조정 진행하기 150
__4.4.2 JAX로 만들어진 데이터셋 구축 클래스 150
__4.4.3 이미지 데이터 구축 함수 뜯어보기 151
__4.4.4 CLIP 모델 불러오기 154
__4.4.5 CLIP에 사용하기 위한 전처리 및 미세조정 155
__4.4.6 모델 학습에 필요한 함수 정의하기 156
__4.4.7 하이퍼파라미터 설정과 TrainState 구축하기 160
__4.4.8 모델 저장하고 체크포인트 만들기 161
__4.4.9 요약 클래스 만들기 162
__4.4.10 학습에 필요한 스텝 정의와 랜덤 인수 복제 163
__4.4.11 모델 학습하기와 모델 저장하기 163
4.5 DistilGPT2 미세조정 학습 166
__4.5.1 패키지 설치 167
__4.5.2 환경 설정 168
__4.5.3 토크나이저 학습 169
__4.5.4 데이터셋 전처리 171
__4.5.5 학습 및 평가 173
__4.5.6 추론 178
CHAPTER 5 TPU 환경 설정 181
5.1 코랩에서 TPU 설정하기 181
5.2 캐글에서 TPU 세팅하기 182
5.3 TRC 프로그램 신청하기 183
마무리하며 186
찾아보기 188
저자
저자
이영빈
모두의연구소에서 AI 교육을 진행하고 있으며 JAX/Flax LAB짱을 맡고 있다.
Payment & Security
Payment methods
Your payment information is processed securely. We do not store credit card details nor have access to your credit card information.
$99 이상 무료 배송
3% 리워드 크레딧 적립
Secure Payment

