VAE 논문 개인적인 리뷰
Sep 24, 2024
VAE가 무엇이고, 무슨 역할을 하는지는 알고 있다. 궁금한건 디테일과, 이 때 당시에 이 연구 가설을 설정하게된 배경.
Abstract
intractable한 사후 확률 분포를 가지는 연속적인 잠재변수의 directed probabilistic model에서 효율적인 추론과 학습을 어떻게 할 수 있을까?(그리고 대규모 데이터셋일 때)
우리는 확률적 variational inference와 학습 알고리즘을 소개한다.
근사를 최적화한다.
We show how a reparameterization of the variational lower bound yields a simple differentiable unbiased estimator of the lower bound; this SGVB (Stochastic Gradient Variational Bayes) estimator can be used for efficient approximate posterior inference in almost any model with continuous latent variables and/or
parameters, and is straightforward to optimize using standard stochastic gradient ascent techniques.
Lower bound estimator(stochastic objective function)
variational inference : 복잡한 확률 분포를 근사하여 추론하는 방법 - 사후 확률 분포를 정확하게 계산하기 어려울 때, 이를 더 간단한 분포로 근사하는 것.
Problem scenario
- 연속적이거나 discrete한 독립변수 x에서 샘플링한 데이터셋 X
- 데이터는 관찰되지 않은(뭔지 모르는) 랜덤 변수 z에서 생성된다고 가정
이 때 x가 생성되는 과정은
- value z가 어떤 prior distribution 에서 생성된다(샘플링된다).
- value x가 조건부 분포 에 의해 생성된다.
실제로는 무언가 분포가 있겠지만 우리는 알지 못하고, 파라미터를 사용해서 모델링한 세타와 z에 대해서 미분 가능한 분포 2개라고 가정했다.
하지만 실제 분포를 모델링하기위한 파라미터 세타와 latent variable z를 우리가 모르기 때문에 어렵다.
가장 중요한데 - marginal or 사후 확률에 일반적인 단순화한 가정을 사용하지 않는다. 거꾸로 아래의 상황에서도 동작하는 제너럴한 알고리즘에 관심이 있다.
- Intractability
marginal likelihood 의 적분이 난해하여 평가와 미분이 어려운 경우
실제 사후 밀도 가 난해하여 EM 알고리즘이 사용될 수 없는 경우
합리적인 Mean-field 변분 베이즈(VB) 알고리즘을 위해 필요한 적분들이 모두 난해한 경우. 이런 계산 불가능성(난해함)은 꽤 흔하고, 뉴럴넷의 비선형 히든 레이어같은 복잡한 Likelihood function 에서 나타난다.
- A large dataset
너무 많은 데이터가 있어 배치 최적화가 지나치게 비용이 많이 드는 경우, 우리는 작은 미니배치나 심지어 개별 데이터 포인트를 사용해 파라미터 업데이트를 하고 싶다.
샘플링 기반 솔루션, 예를 들어 몬테카를로 EM 같은 방법은 데이터 포인트마다 비용이 많이 드는 샘플링 루프를 포함하기 때문에 일반적으로 너무 느리다.
위의 문제들을 해결하기 위해 recognition model 를 도입한다 : intractable한 실제의 사후 분포 의 근사
여기서 recognition model의 파라미터 ϕ를, generative model 파라미터 θ와 함께 학습하는 방법을 제안할 것입니다
→ 코딩의 관점에서는 관측되지 않은 변수 z는 latent representation으로 해석될 수 있다. 따라서 본 논문에서 recogniion model 을 probabilistic
encoder
라고도 부를거다! 왜냐하면 데이터포인트 x가 주어졌을 때 x를 생성했을 수 있는 코드 z의 가능한 분포를 생성해내는 모델이기 때문이다.비슷한 이유로 를 확률적 decoder라고 부를거다. z가 주어졌을 때 그에 상응하는 x의 분포를 생성하기 때문이다.
여기서 중요한 것(나에게)
현실에서 관측한 data point들로 현실의 데이터의 확률 분포를 알아내는 것이 확률모델의 목표이다. 따라서 목적은 자연스럽게 를 최대화 하는 것이다.
ϕ와 θ의 KL Divergence를 계산하는 식으로 부터 위의 값을 전개할 수 있다. 그럼 아래처럼 된다
여기서 L은
이다. KL divergence 값은 양수이기 때문에 L이 (variational) lower bound라고 불린다. 베이즈 정리에서 x가 evidence에 해당하기 때문에 ELBO라는 이름이 됨. 최소한 이 값보다는 크다는 뜻! 전개하면
로 적을 수도 있다.
- 두번째 텀은 reconstruction loss가 된다
q로 샘플링한 z에 대해서 조건부 확률 분포를 계산해서 x. 이 값이 클수록 z로부터 입력 데이터 x를 잘 모델링했다고 볼 수 있다!
근데 이렇게만 했을 때는 z에 최대한 정보가 많이 담기게 복잡하게 해서 이걸 달성할 수도 있다. 그래서 아래가 그걸 방지하는 역할을 함
- 첫번째 텀은 Regularization
data x를 기반으로 z 분포를 모델링(사후 분포를 근사하여 표현)하는데 그게 prior z 분포와 비슷하도록 파라미터 를 정규화한다.
우리는 lower bound L을 variational parameter와 generative 파라미터 둘 모두에 대해 미분하고 최적화하고 싶다. → 하지만 인코더의 파라미터에 대한 lower bound의 기울기를 계산하는건 문제가 있다.
reconstruction error를 보면 기댓값계산인데 z에 대해 적분해서 얻을 수 있다. 하지만 그건 너무 어려우니 무한번 샘플링한 값들로 기댓값을 계산하는 Monte Carlo 방법을 사용해서 계산한다.
Monte Carlo gradient estimator는 복잡한 확률분포에서 직접 기댓값을 계산하는 대신 여러 번의 샘플링을 통해 얻어낸 샘플의 평균을 사용해서 기댓값을 계산하는 방법.
근데 그럼 무한개의 샘플링? 시간이 너무 오래 걸린다. 근데 샘플링 수가 1이어도 괜찮다고 한다.
In our experiments we found that the number of samples L
per datapoint can be set to 1 as long as the minibatch size M was large enough, e.g. M = 100.
- 논문을 보다보면 reparameterization하기 전에 z를 샘플링해서 얻는 방법으로는 z의 분산이 매우 커져서 impractical하다고 말한다.
하지만 rep~를 적용하고 나면 무작위성을 반영하는 요소()가 가우시안 분포에서의 샘플링 값이기 때문에 각 샘플이 잠재공간에서 충분히 유의미한 값일 확률이 높다. 그래서 샘플 수가 적어도 학습을 할 수 있다?
- 각 데이터는 하나의 샘플만으로 재구성한 값으로 학습하지만 한번 학습할 때 minibatch단위로 하기 때문에 = 여러 데이터로 학습한걸 반영하기 때문에 모델이 안정적으로 수렴할 수 있다. 미니배치가 크긴 해야한다.
이걸 사용해서 ϕ의 기울기를 계산해야하는데 샘플링이라는게, 기울기가 역전파될 수 없는 과정이다.
The reparameterization trick
에서 샘플을 생성하는 문제를 해결하기 위해 reparameterization trick을 사용한다.
이 독립적인 분포 p()의 변수고, g가 벡터값 함수일 때, 랜덤 변수 z를 deterministic variable z = 로 표현하는게 가능한 경우가 있다. 이 reparameterization은 몬테 카를로 기댓값 추정을 해서 의 기댓값을 얻어낸 뒤에 그걸 phi에 대해 미분가능하게 해주기 때문에 중요하다.
다시 말해
어떤 함수 g는 x와 e를 입력으로 받아서 z를 출력하는 결정론적 함수다(= 출력과정에 랜덤성이 없다) = 미분이 가능하다 = 기울기를 역전파해서 학습할 수 있다.이렇게 하기위해서, 입력으로 사용하는 값 e에 확률론적 과정을 부여한다(= 랜덤성을 준다). 이 방법으로 분포에서 z를 샘플링하는 process의 의미를 유지하면서도 기울기 에 대해 미분이 가능하게 된다.
그럼 z가 가우시안 분포에서 샘플링되는 랜덤변수로 예를 들어보면
라고 얻을 수 있음.
모든 상황에서 이 재매개변수화 트릭이 가능한건 아니고 아래 조건? 접근법을 사용해서 계산해봐야한다.
The SGVB estimator and AEVB algorithm
우리는 의 형태로 사후 분포를 근사했다고 가정한다.
전체 데이터셋 X, 미니 배치의 수 M개, 각 미니 배치의 데이터 L개.
VAE 논문이지만 AE에 관한 이야기는 직접적으로 하지 않으면서 시작했는데, 7번 식을 보면 auto-encoder와의 연관성이 보인다.
datapoint x와 랜덤 노이즈 e를 사용해서 x를 사후분포로 매핑한 샘플 z를 만든다. 그리고 잠재 변수 z가 주어졌을 때 데이터 x가 발생할 확률을 최대화해서 학습 데이터 분포를 모델링한다.
최종 정리
- VAE의 최종 목적은 실제 데이터의 분포를 잘 모델링하는 것이다. 일단 확률 모델은 파라미터를 통해서 모델링한 분포의 실제데이터들이 나올 확률 총합을 최대화 하는걸 목적으로 시작한다 = x의 likelihood를 최대화하는 분포를 찾는 것
- = 로 표현된다.
- 문제는 우리가 아는건 주어진 데이터 x뿐이고, 잠재변수 z에 대해서는 아는게 없다. 따라서 모든 z와 그 때의 p(x|z)를 계산하는건 불가능하다.
- 근데 x는 알고 있으니까, 그럼 x에 대한 latent z의 사후분포 p(z|x)를 추정해보자.
잠재 변수 z의 prior 분포와 이 때의 likelihood로 표현됨.
그런데 p(z|x)도 계산하기가 매우 어렵다. 베이즈 정리로 풀었을 때 p(x)를 알아야 계산할 수 있기 때문에.
- 이걸 위해서 실제 사후 분포 를 근사하는 recognition model = Encoder를 도입한다. 그리고 와 를 jointly하게 학습한다.
- 그리고 다시 = 로 돌아와서, 우변을 기댓값 표현으로 바꿔준다.
여기서 변분 추론의 트릭을 적용해서 p(z|x)를 근사한 encoder q를 도입하여 사용했다. 그리고 베이즈 정리에 의해 식을 전개한다.
근데 여기서 제일 오른쪽 항을 우리는 계산할 수 없다. p(z|x)의 정답 값을 알 수 없기 때문에. 대신, KL Divergence이기 때문에 항상 0보다 크다는건 알고있다.
- 따라서 왼쪽 두 항을 계산한 최댓값이 이 식 전체의 하한(lower bound)이다. 따라서 그 최댓값에만 집중하고자, 왼쪽 두 항만 따로 이 하한을 최대화하도록 한다.
이는 베이즈 정리에서 Evidence에 해당되는 x에 대한 확률(marginal likelihood)을 구하는 것이기 때문에 Evidence LowerBOund 라고 해서 ELBO라고 부르고,
로 정리한다.
- 근데 여기서 계산을 위해 근사한 사후분포 q에서 z를 샘플링해내서 사용하게 되는데, 이게 확률론적 샘플링이기 때문에 q의 파라미터 로 gradient를 흘려주는게 불가능하다.
- 따라서 이 z 샘플링을, deterministic한 방법(= 미분가능하여 기울기를 흘려줄 수 있는)으로 바꾸어야 한다(= Reparameterization trick)
Share article
Subscribe to our newsletter