본문 바로가기

boostcamp AI tech/boostcamp AI

Transformer - Multi-Head attention

728x90

http://jalammar.github.io/illustrated-transformer/

1. Multi-head attention

self-attention을 유연하게 더 확장한 multihead attention에 대해 알아본다.

그림에서 보면 중첩된 블럭들이 보인다. self-attention에서는 한번 Q, K, V쌍이 만들어지고 attention 블럭을 한번 통과하는데 반해 self-attention에서는 여러개의 Q, K, V를 만들고 여러번 attention 블럭을 통과한다.

즉 Wq, Wk, Wv가 여러개 존재한다. 서로 다른 attention 블럭마다 서로 다른 encoding vector가 만들어지고 이를 모두 concat하여 하나의 결론도출에 사용한다.

이를 여러개의 head를 갖는다고 하여 multi head attention이라고 하고 각 head마다 (Wq, Wk, Wv) 쌍을 한개씩 갖는다.

주어진 하나의 sequence에 대해 여러 관점으로 정보를 병렬적으로 추출하는 것으로 볼 수 있다.

https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb

its가 query를 날렸을 때 head별로 집중한 부분을 나타내는 그림이다.

http://jalammar.github.io/illustrated-transformer/

각 attention head마다 self attention을 수행한다.

http://jalammar.github.io/illustrated-transformer/

그럼 z가 head수만큼 나온다.

http://jalammar.github.io/illustrated-transformer/

이걸 모두 concat한 후에 W-Linear transform을 거치고 하나의 최종 결과 Z를 얻는다.

 

2. Complexity

n : sequence length
d : dimension of representation
k : kernel size
r : size of the neighborhodd in restricted self-attention

recurrent unit이 필요로하는 계산량은 n*d*d (n번의 timestep * d*d번의 weight와 hidden 행렬 계산)이다. 각 timestep마다 backpropagation을 위한 hidden state값들을 저장해야한다.

재귀적 연산으로 인해 sequence 길이에 따라 O(n)의 시간이 걸린다. (병렬화 불가)

또한 maximum path length는 최대 n번을 통과해야 정보가 전달된다. (long term dependency)

 

self-attention은 softmax(QK/scale)V 이다. 필요료하는 계산량은 n*n*d (n개의 q, n*d번의 계산)이다. 대신 gpu병렬처리로 인해 sequence길이가 얼마가 되는 한번에 계산할 수 있다. 

backpropagation을 위해 저장해야 하는 공간도 역시 아래 rnn보다 훤씬 크다. 학습(forward,backward)은 빨리 될 수 있지만 메모리를 훨씬 많이 잡아먹는다.

maximum path length는 O(1)이다. long term dependency를 해결한다.

 

3. Multi-head attention block

transformer는 muti-head attention을 활용하여 위와 같은 하나의 블럭을 사용한다.

Q, K, V가 multi-head attention으로 들어가고 있다.

한가지 더 주목할 점은 input이 바로 attention의 output에 바로 더해진다는 점이다. (Add, residual connection) 여기에 layer norm을 한번 더 거친다.

그 결과물은 feed forward, add & Norm을 마지막으로 거친다.

이것저것 더해지긴 했지만 결국엔 Muti-head attention을 거친 결과 Z1, Z2, ..., Zh와 최종 결과의 모양은 같다.

* Zi = {z_i0, z_i1, ..., z_in}

 

4. Residual connection

이런 residual connection을 예전에 resNet에서 본적이 있다. layer를 깊게 계속 쌓는 바람에 생기는 gradient vanishing을 해결하기 위해 고안되었다. 

결과 값은 x + sublayer(x)인데 x는 내가 그냥 더해주었으니 attention layer는 차이값(결과-입력)만을 만들어주면 된다. 이게 학습 안정화 효과가 있다.

이 덧셈을 수행하려면 input embedding vector와 attention output vector size가 같아야 한다.

 

5. Layernorm

batch norm : batch size가 64라고 하면 forward propagation 결과가 64개가 나올 것이다. 이 64개의 값은 나름대로 평균과 분산을 가질텐데 이를 버리고 N(0,1)로 정규화를 한다. 이 값들은 y=ax+b에 넣어서 새로운 y들로 변환한다. 이 말은 즉, 평균을 b, 분산을 a^2으로 만들겠다는 뜻이다. a, b는 역전파를 통해 학습한다. 

그럼 모델은 Layer의 output node가 특정 평균과 분산을 가지도록 학습할 수 있게된다.

layer norm : thinking, machine 두 단어로 구성된 sequence가 있다고 해보자. 각 단어가 갖는 vector를 N(0,1)로 만든다. 그 이후 각 node별로 평균 분산을 주입한다.

 

6. Positional Encoding

rnn과 달리 attention 연산은 Input sequence의 순서랑 무관하게 연산한다. 

I go home에서 {I, home}의 연산결과와 home go I에서 {home, I}의 연산결과는 동일하고 결국 순서정보를 고려하지 않는 셈이 된다.

 

다시 "I go home"이 있을 때 첫번째 단어 "I" = [3, -2, 4]  첫번째 원소에 1000을 더해준다. "I" = [1003, -2, 4]

세번째 단어 "home"=[1, 2, 3]의 경우는 home =[1,2,1003]이 된다.

즉 순서 정보를 주입해준다. 실제로는 위의 예시처럼 단순하게 상수를 더하지는 않는다.

실제로는 이렇게 sin, cos그래프를 활용한다.

0번으로 들어온 128차원의 입력 (단어)벡터에는 위에 초록 row를 더해준다. 1번째 단어에는 그 아래 row 벡터를 더해준다. 이렇게 위치 정보를 삽입해주면 똑같은 단어라도 위치에 따라 값이 달라지고 결국엔 attention 연산 결과도 달라진다.

 

7. learning rate scheduling

학습 전 과정 동안 하나의 lr만 사용하지 않는다. 동적으로 보폭을 조절해서 성능을 높였다고 한다.

 

8. Decoder

다시 transformer로 돌아와서, encoder에서는 N개의 동일한 MH-attention블럭을 거친다.

최종 결과만 decoder에게 전달된다.

그런데 연결 선을 자세히 보면 encoder의 최종 output이 V, K 계산을 위해서만 decoder의 MH-attention으로 들어온다.

모자란 Q는 이전 layer의 output을 가지고 구한다.

드디어 최종 결과를 target word dim으로 변환시켜서 예측을 수행한다.

 

9. Masked self-attention

학습시에는 ("I go home", "<sos> 나는 집에 간다.") 이 쌍을 encoder, decoder에게 전달한다. 

그런데 "나는"이 decoder에 들어갈 때 attention은 "나는" 뿐만 아니라 "간다"도 볼 수 있다. Q, K, V 연산 이후 attention score는아래처럼 생겼을 것이다.

뒤에 올 단어들을 보면 안되기 때문에 attention score에서 뒷부분을 다 가려버린다. 그리고 row별로 합이 1이 되도록 조정을 한다.(softmax) 그래야 나중에 가중합을 구할 때 가중치 합이 1이 될 수 있기 때문이다. 

이러면 attention distribution을 가지고 weighted sum을 구할 때, 보지 말아야 하는 단어들은 안 볼 수 있다.

 

728x90
반응형