Autoregressive Image Generation without Vector Quantization
Autoregressive Image Generation without Vector Quantization 논문 리뷰
Aug 13, 2024
읽은 이유 : https://arxiv.org/pdf/2406.11838
diffusion으로 patch by patch generation하는 프로젝트를 해보고 싶다고 생각하던 와중 발견
요약
이미지를 auto-regressive하게 생성하는데 VQ를 써서 discrete token을 배운다음 categorical 분포 예측으로 하는게 아니라, 바로 continuous vector를 예측하게하고 diffusion loss를 사용해서 학습한다. google scholar타고 들어가다 정말 우연히 발견한건데 제시하는 개념과 성능이 너무 좋아서 놀랐다.
language가 아닌 이미지를 왜 굳이 discrete하게 바꾼다음 학습/생성해야해?라는 간단한 질문으로 시작한 연구이고 bidirectional attention, set-of-tokens prediction 같은 디테일한 개념도 잘 제시했다.
논문 제목에서 알 수 있듯이, natural language processing의 경우 언어라는 데이터의 자연적 특성상 discrete한 value로 표현하는게 맞는데, visual generation도 auto-regressive로 하고싶다고 해서 LM처럼 vector quantization해서 discrete token 만들어서 학습해야하나?라는 의문에서 시작한 연구이다.
auto-regressive하게 generation하는게 꼭 categorical distribution을 예측하는 태스크가 될 필요없다는게 주장인데 사실 지극히 맞는 말이다.
AR의 속성은 “predicting next tokens based on previous ones”일 뿐이고 그게 이산값인지 연속값인지 중요하진 않다.
결론적으로는 모든 토큰(토큰보다는 패치의 개념으로 이해해도 될 것 같다.)을 continuous한 vector로 다루고, 앞의 토큰에 기반해서 뒤의 토큰도 continuous한 값을 예측한다. 기존에는 Vector quantized tokenizer를 사용하는 것에 비해.
기존의 discrete-valued token 방식들을 보자
기존의 discrete한 접근들에서는 다음 위치로 예측할 토큰은 0 ~ K(vocab size) 사이의 정수값 x로 표현될 수 있다.
AR model은 D dimenstion의 벡터를 출력값으로 뱉고, 이게 K-way classifier를 타서 K차원의 categorical 확률 분포(0~K까지의 값 각각의 확률 분포)가 된다.
generative modeling의 관점에서 이 확률 분포(출력값)는 두가지 중요한 속성을 가져야한다.
- GT 분포와 예측 분포 사이의 차이를 측정하는 loss function이 있어야 한다. categorical 분포에서는 단순하게 cross-entropy loss를 사용한다.
- inference 타임에 데이터 분포에서 샘플링을 할 수 있는 샘플러가 있어야 한다. categorical 분포에서는 softmax를 적용해서 얻는다.
→ 다시 말해 이거 두개만 만족하면 discrete 아니어도 된다. 물론 이산값이 loss func와 sampler가 매우매우 단순하다는 장점이 있긴함.
본 연구에서는 Diffusion loss를 사용한다.
다음 위치에 예측해야한 GT token이 x라고 치자. 여기서 AR model은 앞의 토큰들을 입력으로 받아 z라는 벡터를 뱉는다. 그리고 이 z를 diffusion process의 condition으로 이용해서 noise를 예측할 distribution(x)으로 denoising 한다.
그러면 여기서 loss는 기존의 Diffusion model에서의 noise prediction loss가 된다.
대신 condition으로 쓰이는 z 자체가 x를 위해 네트워크에서 출력된 값이기 때문에 noise estimator 는 작은 MLP network다.
사설 : 근데 왜 x를 바로 예측하도록 하는게 아니라 이렇게 condition z를 뱉고 denoising해서 얻도록 했을까?
→ 이런식의 Diffusion loss는 diffmae라는 연구에서 먼저 제시했는데 직접 x를 모델링하는건 어렵지만 더 의미론적 공간에 위치한 z를 예측하는게 쉽고, x의 모델링은 그거 기막히게 잘하는 diffusion model에게 하도록하는 식의 접근이 아닐까 생각했다. 위 논문을 보지않아서 정확하지 않다.
자연스럽게 sampler는 reverse diffusion 과정이 된다.
+기존의 LM에서 사용하는 temperature 조절을 통한 diversity 컨트롤이 매우 중요하다는걸 알고있다. 여기서는 뒤의 σtδ 를 temperature로 나누어서 그 컨트롤을 구현했다.
결론적으로 Auto Regressive model은 아래 조건부 분포를 예측해서 next token prediction을 수행한다.
Unifying Autoregressive and Masked Generative Models
보통 AR model들은 transformer를 사용해서 causal attention 방식으로 만들지만, 우리는 causal이 아니라 bidirectional attention으로도 할 수 있다는걸 보인다.
중요한건 이전 토큰을 사용해서 다음 토큰을 예측하는 거지, 이전 토큰(ex. 3번째)을 처리할 때 1,2번째 토큰만을 사용해야한다는걸 의미하는건 아니다.
여기서는 MAE에서 한 것처럼 bidirectional attention을 도입한다. positional embedding을 주고 mask token을 사용해서 decoder가 어떤 position이 예측해야하는 토큰인지 알도록하는 방법.
→ 이렇게하면 알고있는 토큰들끼리는 전부 서로를 참고하도록 하게된다.
여기서 끝이 아니라!
기존처럼 순서대로 예측이 아니라 랜덤한 순서대로 예측하는 방식도 구현해봤다. 이건 아래 그림처럼 causal attention이어도 positional embedding을 통구현 적용가능하다.
이렇게하면 masked generative modeling처럼 된다.
근데 Masked 방법에서는 한 토큰이 아니라 토큰 뭉치를 예측하는 방법도 쓰는데, 본 연구에서도 그 방법을 도입했다. predicting multiple tokens based on previous tokens
→ next set-of-tokens prediction이 된다. Masked Autoregressive(MAR) 방법이라고 이름붙임. 이전의 MAGE라는 논문과 비슷하다.
구현 디테일들.
결과가 너무 중요한데, 아래에서 볼 수 있듯이 MAR + bidirectional을 합쳤을 때 결과가 꽤 좋다.
diffusion loss는 매우 flexible하다는 장점이 있다(그럴거 같음)
noise predict MLP는 진짜 작다.
속도는 AR인만큼 diffusion보다 빠르다.
다른 생성 모델들과 비교한 결과. 왜이렇게 좋지..??
정성적 결과
Share article
Subscribe to our newsletter