computer/AI

[AI/ML] Attention is all you need

ketrewq 2024. 12. 3. 11:24

서론 

나는 ML/AI 라는 필드를 전혀 모른다. 내가 경험해본 것들은 다음과 같다 

  • 챗봇 파인튜닝하고 놀아봄 (친구가 없음)
  • stable diffusion LoRA 만들기 
  • 보이스 트레이닝 후 AI 목소리 만들기 

그럼에도 내 귀에 들어올만큼 자주 얘기되는 논문이 있다. Transformer를 소개한 Attention is all you need 라는 논문이다. 

 

[1706.03762] Attention Is All You Need

 

Attention Is All You Need

The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new

arxiv.org

 

나는 수학에 대한 사전지식 또한 전무하다. 고등학생 때 문과였기 때문이다. 하지만 이 논문만큼은 이해해보고 싶었다. 그래서 오늘은 이걸 이해하고, 문과 또한 알아들을 수 있게 정리해보려 한다. 이해하지 못한 부분들은 오독을 막기 위해 적당히 건너뛰고, 핵심 아이디어만 정리했다. 

 

간략한 소개 

영어에서 독일어로 텍스트를 번역한다고 쳐보자. RNN/CNN같은 전통적인 모델은 한쪽한쪽 책을 읽어내려간다. 하지만 책에서 중요한 정보를 잡아서 그것에만 집중할 수 있다면 얼마나 좋을까! 그래서 나온 것이 이 논문에서 소개하는 Transformer이다. 다른 말로, 트랜스포머는 기존의 AI모델에 대한 ADHD 치료제다. 

 

Encoder / Decoder 

 

대부분의 모델들은 인코딩 - 디코딩 구조를 가지고 있다. 인풋을 받으면 encoder는 매핑하고, decoder는 매핑된 심볼에 대한 아웃풋을 하나하나 출력한다. 이때 모델은 auto-regressive하다. 전에 나온 output을 input으로 다시 받는다는 뜻이다. 

 

트랜스포머의 구조이다. 

 

 

 

Attention 

Attention 함수는 query와 key-value 쌍의 집합을 output으로 매핑하는 것으로 설명할 수 있다. 여기서 query, keys, values, output은 모두 벡터이다. output은 value들의 가중 합으로 계산되며, 각 value에 할당된 가중치는 query와 해당 key 간의 compatibility function에 의해 계산된다.

 

여기서 소개하는 attention을, scaled dot-product attention이라고 부를 수 있다. 

 

Scaled Dot-product attention 

다음의 수식을 이해하려 해보자. 

 

\( \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \)

 

  • Queries (Q): 도서관에서 정보를 검색한다고 상상해보자. query는 "내가 찾고 싶은 정보가 뭘까?" 이다. 
  • Keys (K): 도서관의 각 책에는 key(제목이나 키워드 등)가 있어서 해당 책이 query와 관련이 있는지를 판단하는 데 도움이 된다.
  • Values (V): value는 읽고 싶어하는 책의 실제 내용이다.

Dot Product는 각 query가 각 key와 얼마나 유사하거나 관련이 있는지를 측정하는 방법이다. query가 key와 밀접하게 일치할수록 dot product 값이 커져 높은 관련성을 나타낸다.

 

값이 너무 커지는 것을 방지하기 위해(이로 인해 계산에 문제가 생길 수 있음) keys의 차원의 제곱근으로 나눈다. 이는 노이즈를 조절하여 다루기 쉬운 뭉텅이로 맞추는 것에 비유할 수 있다.

 

스케일링 후, softmax function을 적용하여 원시 점수를 1로 합산되는 확률로 변환한다. 이는 query에 비해 각 key의 가중치나 중요성을 결정하는 데 도움이 된다.

 

마지막으로, 이러한 확률을 values (V)와 곱한다. 이 단계는 query에 대한 관련성에 따라 가중된 values의 정보를 결합하는 과정이다.

 

Multi-Head attention 

 

Multi-Head Attention은 모델이 서로 다른 표현 서브스페이스(representation subspaces)에서 다양한 위치의 정보에 동시에 집중(attend)할 수 있게 한다. 단일 attention head의 경우, averaging(평균화)이 이를 억제하는 반면, muti-head attention 모델에서는 다중 head를 사용함으로써 다양한 관계와 정보를 포착할 수 있다.

 

\(\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O\)

\(\text{where head}_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i)\)

 

이 수식을 이해해보자. 

 

MultiHead(Q, K, V): Multi-Head Attention 함수는 여러 개의 attention head를 병렬로 연결한 후, linear transformation \( W^O \)을 적용하여 최종 output을 생성한다.

 

\( head_i \): 각 attention head는 query, keys, values 에 대해 각각 다른 학습된 linear transformation 행렬 \( (QW^Q_i, KW^K_i, VW^V_i)\) 을 곱한 후, Scaled Dot-Product Attention을 수행한다. 이게 뭔소리냐면 각 헤드가 서로 다른 linear transformation을 통해 queries, keys, values를 투영함으로써, 입력 데이터의 다양한 측면을 동시에 학습하고 포착할 수 있다는 소리이다. 단일 attention 메커니즘이 모든 정보를 하나의 관점에서만 처리하는 것과 차이가 난다.

 

Self-attention layer in the decoder

 

  • auto-regressive (그니까 전의 output을 input으로 또 받는) 을 유지하기 위해, decoder에서는 왼쪽(이전) 위치로의 정보 흐름을 방지해야 한다.
  • 이는 scaled dot-product attention 내부에서 illegal connections에 해당하는 softmax 입력 값을 -∞으로 마스킹(masking)함으로써 구현된다. 

 

Position-wise Feed-Forward Networks

 

각 encoder와 decoder의 각 층에는 attention sub-layer 외에도 완전히 연결된 (fully connected) feed-forward network가 포함되어 있다.

 

\( FFN(x) = max(0, xW1 + b1)W2 + b2 \)

 

 

Positional Encoding 

 

RNN이나 CNN의 방식을 쓰지 않기 떄문에, 순서 정보를 활용하기 위해 상대적/절대적 위치에 대한 정보 또한 줘야한다. 이에는 learned 와 fixed 방식이 있는데, 다음의 함수를 이용했다고 한다. 

 

\( PE(pos, 2i) = sin(pos / 10000^{2i / d_{\text{model}}}) \)
\( PE(pos, 2i+1) = cos(pos / 10000^{2i / d_{\text{model}}}) \)

 

Why self attention?

이러한 개념을 설계하며 self-attention 레이어와 recurrent / convolutional 레이어 간을 비교한 기준은 다음의 셋과 같다. 

 

  • 첫째, 층(layer)당 전체 계산 복잡성(total computational complexity)이다.
  • 둘째, 병렬화할 수 있는 계산의 양(amount of computation that can be parallelized)으로, 이는 필요한 최소한의 순차 연산 수(minimum number of sequential operations)로 측정된다.
  • 셋째, 네트워크 내 장거리 의존성(long-range dependencies) 간의 경로 길이(path length)이다.

 

결론 

 

대충은 이해했다. 나중에 AI/ML쪽을 더 파게된다면 다시 읽고 싶다. 모든 리서치가 그렇듯이, 기존의 리서치 (CNN과 RNN)에 대한 이해 또한 필요해보인다.