[논문 리뷰] Efficient Diffusion Training via Min-SNR Weighting Strategy
Min-SNR-감마 : 논문 리뷰
Oct 09, 2023
Efficient Diffusion Training via Min-SNR Weighting Strategy
선결론 : 실제로 사용하기에도 너무 좋다. LoRA 학습에서 뛰어난 속도 향상의 결과를 보임.
- 디퓨전 모델 학습은 느린 수렴속도로 인해 문제를 겪는다.
- 우리는 본 연구를 통해서 느린 수렴속도가 부분적으로는 timestep간의 최적화 방향이 달라서 일어나는 충돌이 원인이라는 사실을 알아냈다.
- 이 문제를 해결하기 위해서 우리는 디퓨전 학습을 멀티태스크 학습으로 접근한다. 우리 방법을 Min-SNR- 라고 부른다.
이 방법은 clamped signal-to-noise ratios에 근거해서 타임스텝들의 loss weights를 조절한다.
clamped signal-to-noise ratios는 효과적으로 타임스텝간의 충돌을 밸런싱한다.
이 방법으로 3.4배 빨리 학습하고, (이미지넷 256*256 기준) 이전 SOTA보다 더 작은 모델을 써서 FID score 2.06으로 최고를 찍었다.
느린 수렴속도는 학습에서 각 여러 타임스텝간의 최적화 방향이 달라서 일어나는 문제다.
특정 noise level에만 맞게 denoising function을 최적화하면 다른 노이즈 레벨에서는 부작용을 끼칠 수 있다.
현재 DDPM이 다양한 Nosie 레벨에 대해 모델 가중치를 공유한다는 점을 감안했을 때 이런 noise timestep간의 균형을 신중하게 고려하지 않으면 충돌하는 가중치 기울기가 전체 수렴 속도를 방해할 것이다.
본 논문에서 제공하는 Min-SNR- 이라는 방법은 각 timestep의 디노이징 프로세스를 개별 task로 취급하므로 diffusion 학습이 multi-task 학습으로 간주된다. 다양한 task의 균형을 맞추기 위해서, 각 task의 난이도에 따라 다른 loss 가중치를 적용한다.
특히, clamped signal-to-noise ratio (SNR)를 loss 가중치로 적용해서 gradient 충돌 문제를 완화했다.
여러 타임스텝을 이 새로운 가중치 전략으로 조절해서, 수렴이 더 빨라졌다.
What is SNR?
SNR : 높을수록 고퀄리티 이미지인 값이다.
노란색이 중요한 시그널, 빨간색이 노이즈
내 신호가 무엇인지 파악하기 힘들게 만든다.
거의 대부분의 측정에서 일어나는 현상.
디퓨전 프로세서에서 이걸 생각해보자.
Reverse Diffusion process에서는 완전한 가우시안 노이즈 → 이미지로 간다.
그럼 여러 timestep 중
- timestep → T에서는 노이즈가 크다 = SNR이 작다.
- timestep → 0에서는 노이즈가 작다 = SNR이 크다.
이걸 가중치로 사용한다는 뜻은?
일반적인 multi-task 학습 방법은 기울기에 따라 각 task의 loss 가중치를 조정해서 task간의 충돌을 완화하려고 한다.
One classical 접근법인 Pareto 최적화는 모든 태스크를 향상시키는 방향으로 기울기를 감소시키는 것을 목표로 한다.
하지만 우리의 접근법은 이와는 3가지 측면에서 다르다.
Sparsity
. 일반적인 multi-task 학습 연구에서는 적은 수의 task 시나리오에 중점을 두었다. 그러나 diffusion의 timestep을 task로 여기면, 수천개의 태스크가 있다.
따라서 Pareto 최적해를 diffusion에 적용하면 대부분의 타임스텝에서의 가중치가 0이 된다. 이 방법으로는 많은 timestep에서 학습이 일어나지 않고, 전체 디노이징 프로세스에 해를 끼친다.
Instability
. 각 timestep에 대해 계산된 기울기는 타임스텝마다의 샘플 수가 적기때문에 noisy한 경우가 많다. 이는 Pareto 최적해의 정확한 계산을 어렵게한다.
Inefficiency
. Pareto 최적 해의 계산은 시간이 많이 걸리므로 전체 학습 속도가 상당히 느려진다.
Pareto는 런타임에 각 이터레이션마다 알맞은 loss 가중치를 설정한다.
우리의 방법은 사전에 global하게 step별로 적용될 loss를 세팅해두고 진행하는 방식이다. 따라서 sparisty문제가 해결된다.
그리고 보다 효율적이고 안정적이다.
그리고 이런 점도 좋을 수 있다. 1) 각 디노이징 태스크의 최적화 dynamics는 각 개별 샘플에 영향을 크게 받지 않고 주로 태스크의 노이즈 레벨에 의해 형성된다. 2) 적당한 수의 iteration 후에 대부분의 후속 학습 프로세스의 기울기가 더 안정적이게 되기 때문에 고정 가중치를 사용하는 방법으로 근사화할 수 있다.
Methods
다른 step은 각기 다른 가중치가 적용되어야한다.
예를 들어,
- 쉬운 디노이징 태스크(t → 0)은 더 낮은 디노이징 loss를 달성하기 위해서 입력을 간단하게 재구성할 필요가 있을 수 있다.
- 어려운 디노이징 태스크(t → T)에서는 불행하게도 위의 방법이 적용되지 않는다.
- 따라서 서로 다른 timestep간의 상관관계를 분석하는 것이 매우 중요하다.
이를 위해서 저자들은 간단한 실험을 수행했다.
전체 denoising process를 몇개의 분리된 bin으로 클러스터링 했다. 그리고 디퓨전 모델을 각 bin에 맞게 파인튜닝 했다.
마지막으로, 그 효율을 다른 bin에 어떤 영향을 끼쳤는지 위주로 확인했다.
위의 이미지에서 볼 수 있듯이 최적화한 그 스텝들에서는 좋은 영향을 끼쳤다. 하지만 다른 스텝에 대해서는 안좋은 영향을 끼쳤다.
이 결과가 우리가 모든 타입스텝에 동시에 긍정적으로 작용하는 효과적인 솔루션을 찾을 수 있을지에 대한 영감을 줬다.
저자는 멀티 태스크 학습의 관점에서 목적을 재구성했다.
디노이징 디퓨전 모델의 학습 프로세스는 T개의 각기다른 task를 포함한다.(라고 해석할 수 있다는 뜻.)
모델 파라미터 일 때, loss는 가 된다. t = (1,2, …, T)
우리의 목표는 아래 식을 만족하는 업데이트 방향을 찾아내는 것이다.
: 파라미터가 delta 방향으로 갔을 때, Loss가 감소한다.
1차 Taylor expansion을 적용하면 이렇게 된다.
따라서 이상적인 업데이트 방향은 다음과 같다.
Min-SNR- Loss Weight Strategy
각 이터레이션에서 iterative optimization에 의해 발생하는 비효율성과 불안정성을 피하기 위해서, 한가지 가능한 시도는 일관된 loss 가중치 전략을 사용하는 것이다.
discussion을 간단하게 하기 위해서 noise가 없는 상태 을 예측하기위해 네트워크가 re-parameterize 되었다고 가정한다.
하지만, 서로다른 prediction 목적함수들이 서로로 변환될 수 있다는 점은 주목할 가치가 있다.
이제, 저자들은 아래의 대체 training loss 가중치를 고려한다.
- Constant weighting :
- SNR weighting : , SNR(t) =
이건 가장 널리 사용되는 가중치 전략이다. 예측 대상이 완전한 노이즈일때는 위의 constant와 같은 수치이다.
- Max-SNR- weighting :
감마는 1을 디폴트로 둔다. 하지만, 여전히 가중치는 작은 noise level에 집중한다.
- Min-SNR- weighting :
저자는 작은 noise level에 너무 집중하는 것을 방지하기 위해서 이 방법을 사용한다.
- UGD optimization weighting : w_t는 각 타임스텝마다 최적화된다. 위와는 다르게 학습과정에서 가중치를 매번 계산해서 다르게 적용한다.
결과는 이렇다고 한다. 근데 UGD는 시간이 오래 걸리겠지?
구현을 확인해보자
def compute_snr(noise_scheduler, timesteps): alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) # Compute SNR. snr = (alpha / sigma) ** 2 return snr
snr = compute_snr(noise_scheduler, timesteps) base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr )
추가 생각
실제로 적용했을 때 3배의 속도 상승이 있었고, 이론과 개념에 대해서도 이해가 되었다. 근데 어떤 timestep에 가중치를 높게 둘지를 모델이 알아서 학습하고 결정하도록 해야하는건 아닌가? 머신러닝이라면. UGD optimization weighting이 이미지 자체를 입력으로 받아서 계산하기 때문에 오래걸림?
Share article
Subscribe to our newsletter