이번 게시물에서는 Retrieval-Augmented Language Model 중 하나인 RAG model을 제시한 Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks 논문에 대해 다뤄보겠다.
해당 논문이 발표되었던 시기에는, REALM과 같은 Retrieval-Augmented Language Model이 이미 제안되었던 상황이었다. RAG model은 기존 REALM과 같은 model과는 달리, Encoder-Decoder 구조를 차용하면서 output을 산출하는 과정을 text generation task로 변형하였다는 특징이 있다.
원문 링크는 아래와 같다
Introduction
Pre-train 된 model들은 parameter안에 knowledge를 저장하는데, 해당 knowledge는 외부 지식에 접근하지 않고도 model 스스로 여러 task를 수행할 수 있는 능력을 가지게 하지만, knowledge에 대한 update나 확장이 쉽지 않다는 단점을 가진다. 또한, output을 생성할 때 관련된 knowledge를 직접적으로 활용하지는 못하여 hallucination과 같은 문제가 발생하곤 한다.
(hallucination이란, model이 사실과 다른 응답을 생성하는 문제이며, 이는 아래 링크와 연결된 REPLUG 논문 리뷰에서 다루었으니 참고 바란다.)
이러한 문제를 해결하기 위해, 기존 pre-trained model의 parametic knowledge와 함께 non-parametric knowledge를 같이 활용할 수 있는 hybrid model들이 연구되어 왔고, REALM이나 ORQA와 같은 retrieval-augmented language model도 그중의 하나이다.
(현재 글을 작성중인 2023년 초에는 더 많은 model들이 등장했지만, 본 논문이 처음 공개된 2020년 중반에는 REALM과 ORQA가 대표적인 retrieval-augmented language model이었다)
이러한 retrieval-augmented language model는 좋은 결과들을 많이 보여왔지만, open-domain extractive question answering task 관련해서만 주목해 왔었다.
(기존 model들은 Encoder-Only 구조의 BERT를 기반으로 만들어졌기에, 특정 corpus안에서 알맞은 부분만 추출하는 방식으로 Open-QA task를 수행하게 된다)
따라서, 본 논문에서는 sequence-to-sequence 구조를 채택함으로써 text-to-text로 task를 변형시켜, 기존 REALM, ORQA이 가졌던 한계를 극복할 수 있는 retrieval-augmented language model인 RAG를 제안한다.
(text-to-text는 RAG와 같은 Encoder-Decoder구조(sequence-to-sequence 구조)를 가진 T5 논문 리뷰에서 다루었다. 하단 링크 참고 바란다)
RAG model은 retriever로 DPR(Dense Passage Retriever)를 사용하고, 기본적인 Encoder-Decoder 구조 (지금부터는 "generator"라고 명명한다.)는 BART를 사용했다고 밝힌다. 또한 기존 retrieval-augmented language model과 동일하게, retriever과 generator는 따로 학습되는 것이 아닌, 동시에 학습된다.
Methods
RAG model은 기본적으로 input sequence $x$를 input으로 받게 되면, 이를 이용하여 external knowledge corpus로부터 관련성이 높은 document $d$를 $k$개 retrieve 한다. 이후 $x$와 $d$를 이용하여 target sequence인 $y$를 generate하게 된다.
아래는 이 과정을 시각화한 figure이다.
RAG model의 전반적인 흐름을 자세히 살펴보자.
앞서 언급했듯이 input sequence $x$를 input으로 받게 되면, 이를 이용하여 external knowledge corpus로부터 관련성이 높은 document $z$를 $k$개 retrieve한다. 이후 $x$와 $d$를 이용하여 target sequence인 $y$를 generate 하면서, 두 가지의 요소를 다루게 된다.
- retriever $p_{\eta}(z|x)$ with parameter $\eta$ : $x$에 대한 document의 distribution
- generator $p_{\theta}(y_i|x,z,y_{1:i-1})$ parametrized by $\theta$ : $x$와 $z$를 바탕으로 $y$생성
이 retriever와 generator는 훈련 과정에서 동시에 학습된다.
그리고, output을 산출하기 위해 retrieve 된 document에 대해 marginalize 하게 되는데, 저자들은 이 marginalize 방식을 다르게 한 두 가지 모델인 RAG-Sequence와 RAG-Token을 제안한다.
(marginalize에 대한 부연 설명을 하자면, 동전 A, B가 있으며, 동전 A를 먼저 던진 이후 동전 B를 던지며 B의 앞뒷면 결과는 A의 앞뒷면 결과에 영향을 받는다고 가정해 보자. (두 동전은 독립적이지 않음)
동전 A는 앞면이 나올 확률이 1/4, 뒷면이 나올 확률이 3/4이며, 동전 B는 A의 결과에 따라 다음과 같은 확률을 가진다고 가정해 보자.
동전 B - 앞면 | 동전 B - 뒷면 | ||
동전 A - 앞면 | 동전 A - 뒷면 | 동전 A - 앞면 | 동전 A - 뒷면 |
1/12 | 3/12 | 2/12 | 6/12 |
이러한 가정 속에서, 우리는 오로지 동전 B의 결과에만 관심이 있다고 해보자 (마치 위에서 $y$에만 관심이 있듯이)
그래서 동전 B의 앞면 뒷면 결과에 대한 확률을 추정하고 싶은데, 이를 어떻게 구할 수 있을까?
이는 동전 B가 앞면이 나왔을 때의 확률과, 뒷면이 나왔을 때의 확률을 각각 더해주면 구할 수 있다.
그러면 동전 B가 앞면이 나올 확률은 1/12 + 3/12 = 1/3, 뒷면이 나올 확률은 2/12 + 6/12 = 2/3이라는 것을 추정할 수 있다.
이와 같은 원리로 output $y$에 대한 확률만을 구하기 위해 각 document $z$에 대해 구해진 곱사건의 확률을 모두 더하는 것이 위 수식에서의 marginalize이다. 이는 아래의 REALM 리뷰에서도 다룬 적이 있으니 참고 바란다.
Models
그렇다면, RAG-Sequence model과 RAG-Token model은 어떻게 다를까?
우선 RAG-Sequence model부터 살펴보자. 아래는 RAG-Sequence model이 output을 산출하는 과정을 나타내는 수식이다.
수식을 보면, 하나의 document $z$에 대해 sequence 안의 모든 token에 대한 확률을 계산한 뒤, top-k document에 이 과정을 모두 적용하여 더하는 과정을 진행하는 것을 알 수 있다.
따라서, RAG-Sequence model은 각각의 document를 이용하여 output sequence 전체를 대상으로 값을 산출하고, document에 대해 marginalize 함으로써 최종 값을 산출하는 model이다.
다음으로는 RAG-Token model이다.
RAG-Token model은 RAG-Sequence와 다르게, 하나의 token을 생성할 때 모든 document에 대해 다루고 이후 document에 대해 marginalize 한 다음, 모든 token에 대해 동일한 과정을 진행함으로써 output sequence를 생성하는 model이다
즉, RAG-Sequence는 document에 대한 값을 sequence 단위로 고려한 다음 marginalize 하고,
RAG-Token은 document에 대한 값을 token단위로 고려한 다음 marginalize 하고, 다음 token을 생성하면서 sequence를 생성하는 model인 것이다.
Retriever: DPR
앞서 잠깐 언급한 것처럼, RAG에서 retriever $p_{\eta}(z|x)$로 DPR을 사용한다고 하였다. 그렇다면, DPR은 과연 무엇일까?
DPR은 기본적으로 bi-encoder 구조를 따르며, 아래와 같은 수식으로 나타낼 수 있다.
$d(z)$는 BERT-BASE 구조로 구성된 document encoder를 통해 산출되는 dense representation of a document이다.
$q(x)$는 마찬가지로, BERT-BASE 구조로 구성된 query encoder를 통해 산출되는 query representation이다.
(여기서의 query는 input을 뜻한다)
기본적으로, input $x$에 대한 document $z$의 분포는 위에서 산출한 $d(z)$와 $q(x)$의 내적 연산을 기반으로 하여 산출된다. 이 값들을 바탕으로, 내적 값이 높은 순서대로 top-k document를 골라 retrieve 하게 되는데, 이 과정은 REALM에서 사용되었던 MIPS 알고리즘을 사용하여 효율적인(sub-linear time) 탐색이 가능하게 하였다고 한다.
또한, 저자들은 해당 bi-encoder를 TriviaQA question과 Natural Question dataset을 바탕으로 document를 가져오도록 pre-train 된 model로 초기화하였다고 밝힌다.
Generator: BART
논문에서는 retriever로 retrieve 된 document $z$와 input $x$를 통해 output sequence를 생성하는 generator $p_{\theta}(y_i|x,z,y_{1:i-1})$는 어떠한 encoder-decoder 구조도 사용 가능하다고 밝힌다. 그러면서 본 연구에서는 BART-large를 사용하였다고 밝힌다.(400M parameters)
Training
DPR기반의 retriever와 BART-large 기반의 generator는 training 과정에서 동시에 학습된다. 이때, 어떠한 document가 retrieve 되어야 하는지에 대한 direct supervision은 주어지지 않는다. 오로지 output sequence에 대한 NLL(Negative Log-Likelihood)를 최소화하는 방향으로 학습되고, retriever는 이 과정에서 NLL을 최소화하는 방향으로 학습되는 것이다.
즉, retriever가 어떠한 document를 가져와 하는지에 대한 직접적인 supervision은 없지만, 결국 NLL을 최소화하는, 양질의 output을 산출할 수 있도록 학습되며 retriever는 generator가 더 좋은 output을 산출할 수 있는 document를 가져오게끔 학습이 진행되는 것이다.
또한, 저자들은 REALM에서 진행했던 방식인, pre-training에서 document encoder를 update 하는 방법은 cost가 많이 필요하며, 성능상으로도 큰 이점을 가져다주지 않는다고 판단하여 RAG model에서는 document encoder를 fix 하고, 오로지 query encoder와 generator만 학습한다고 밝힌다.
(REALM에서는 document encoder도 학습하였다. 그러나 document encoder를 학습한다는 것은 document embedding도 변한다는 것이고, 이에 대한 MIPS index값도 변경된다는 것을 의미한다.
Encoder의 parameter가 update 될 때마다 document embedding과 MIPS index값을 재계산한다는 것은 매우 큰 cost를 필요로 하기에, REALM 논문에서는 이를 asynchronous refreshing, 비동기적인 update로 해결하였다
그러나, RAG에서는 이마저도 채택하지 않고 초기 Encoder parameter로 계산된 document embedding과 MIPS index값을 사용한다는 것이다)
Decoding
우리는 앞서 RAG-Sequence와 RAG-Token에 대해 살펴보았다. 두 model은 output 산출 방법이 다르기에, $\text {argmax}_{y} p(y|x)$ 과정(output distribution을 바탕으로 token decoding 하는 과정)에서도 차이가 발생한다.
먼저, RAG-Token의 decoding 과정이다.
RAG-Token model의 경우, token별로 marginize 하여 개별 token에 대한 값을 구하는 구조이기에, 기존 seq2 seq generator와 크게 다른 점이 없다. 다만 token별 각 document에 대한 output distribution을 더하여 최종 output을 산출하기에 아래와 같은 과정으로 진행된다.
이러한 방식은 기존 beam search 방법론에도 그대로 활용이 가능하다.
다만, RAG-Sequence의 경우에는 상황이 좀 다르다. RAG-Sequence의 likelihood $p(y|x)$를 구하는 과정을 다시 한번 살펴보자.
RAG-Sequence의 경우, sequence를 끝까지 생성한 이후 document에 대해 marginalize를 진행하기에 기존 beam search 방법론을 적용할 수 없다는 문제가 생긴다.
그래서 논문에서는 document별로 beam search를 진행한다고 한다. 아래 예시를 봐보자.
input $x$가 주어질 때, 이에 대한 top-k document를 retrieve 하게 된다. 여기서는 top-k를 3으로 설정하였으니, 3개의 document가 활용되는 것을 확인할 수 있다.
그다음으로는, 각각의 document별로 document별로 beam search를 진행한다. 즉, 하나의 input $x$와 document $z$별로 logit이 높은 순서로 $k$ (여기서의 $k$는 앞선 top-k와 별개의 $k$이다.) sequence를 산출한다.
이 과정을 거치면 예시의 맨 오른쪽처럼 각각의 document와 output에 대한 $p(y|x,z)$가 구해지게 된다. 이를 document $z$에 대해 marginalize를 함으로써 최종 $p(y|x)$를 구하게 되는데 이는 아래와 같다.
각각의 output $y$에 대해, 모든 document의 확률을 더함으로써 marginalize 하는 과정이다.
그런데, 여기서 문제가 발생한다. 아래 수식에서 빨간색 확률들을 살펴보자. 해당 확률들은 beam search 과정에서 발견되지 않은 값들이다.
$p(y_2|x,z_2)$의 경우를 살펴보자. 해당 확률은 $x$와 $z_2$가 주어졌을 때 output으로 $y_2$가 나올 확률이다. 그러나 좌측 상단의 beam search 과정, 혹은 이전 예시의 사진에서 볼 수 있듯이 $y_2$는 document $z_2$로부터 산출된 적이 없다.
($z_2$로부터는 오로지 $y_1, y_4, y_5$만 산출되었다.)
따라서 이에 대한 확률 값을 모른다는 문제가 발생한다. $p(y_2|x,z_2)$값을 모르기 때문에 당연히 이를 marginalize 할 수 없고, 결과적으로는 $p(y_2|x)$값을 구할 수 없게 되는 것이다.
이러한 문제를 해결하기 위해서는 $p(y_2|x,z_2)$와 같이 발견되지 않은 값들에 대해 additional forward, 즉 추가적으로 model에 넣어서 해당 값을 산출한 뒤 marginalize와 같은 후속 계산을 진행해야 한다.
그러나, 이 과정은 비용적으로나, 시간적으로나 매우 비효율적이다.
따라서, 논문에서는 보다 효율적인 decoding을 위해 $p(y_2|x,z_2)$와 같이 발견되지 않은 값들을 0 값으로 처리한 다음 marginalize와 같은 후속 계산을 진행하는 방법인 "Fast Decoding"을 제시한다.
즉, 아래와 같은 과정을 거치게 되는 것이다.
결론적으로, RAG-Sequence model은 이러한 Fast Decoding을 거쳐서 최종 output sequence를 생성하게 된다
Experiment & Result
지금까지는 RAG model의 내부 작동 원리에 대해 살펴보았다. 그렇다면, 이러한 RAG model은 실제로 여러 knowledge-intensive task에서 잘 작동할까?
저자들은 다양한 knowledge-intensive task에 대해 RAG model을 실험해 보았다. 각 task별 간단한 설명과 함께 RAG와 그 외 model의 결과를 나열해 보도록 하겠다
먼저 Open-domain Question Answering이다. Open-domain Question Answering task는 많은 정보들을 포함하고 있는 corpus들로부터 주어진 질문에 대한 답변을 찾는 task이다.
RAG 이전의 REALM과 같은 model들은 question에 대한 answer를 주어진 knowledge corpus 내에서 extract, 추출해 내는 방법론을 사용하였다.
(REALM과 같은 기존 model들은 BERT를 기반으로 하였기 때문에 text-to-text framework로 task를 수행하지 않는다)
RAG는 이와 다르게 answer를 직접 generate 하고, 이에 대해 NLL loss를 최소화하는 방법으로 학습을 진행한다.
저자들은 이러한 기존 방법론들과 RAG의 성능을 실험을 통해 비교하였다. 또한 external knowledge 없이, 즉 retrieval 없이 parametic knowledge에만 의존하여 task를 수행하는 Closed-Book QA 방법론과도 RAG model의 성능을 비교하였다고 한다.
실험에 사용된 dataset은 아래와 같다.
- Natural Questions
- TriviaQA
- WebQuestions
- CuratedTrec
결과는 아래와 같다.
우선 external knowledge 없이 task를 수행한 Closed-Book setting에서는, 타 model들에 비해 몇 배 많은 parameter 수를 가진 T5-11B model의 성능이 뒤떨어지는 것을 확인할 수 있다. 이 결과를 통해 knowledge intensive task에서 retrieval이 유효한 효과를 가진다는 것을 확인할 수 있다.
그다음으로는 Open-Book setting(with retrieval)의 결과를 비교해 보겠다
우선, 전반적으로 기존 방법론들보다 RAG(RAG-Token and RAG-Seq)의 성능이 좋은 것을 확인할 수 있다.
논문에서는 QA task 이외에도 Jeopardy Question Generation task, fact verification task에 대해서도 실험을 진행하였다.
먼저, Jeopardy Question Generation task는 특정 entitiy가 주어지면 해당 entity에 알맞은 question을 생성하는 task이다.
예를 들어, "The World Cup"이라는 entity(혹은 answer)가 주어지면, 이에 알맞는 질문인 "In 1986 Mexico scored as the first country to host this international sports competition twice"라는 sequence를 생성해야 하는 것이다.
저자들은 본 실험에서 Jeopardy Question Generation task에 대한 metric으로 Q-BLEU-1 metric을 사용했다고 밝힌다.
Q-BLEU란, 기존 BLEU를 matching entity에 더 많은 가중치를 두는 방식으로 수정하여 human judgement에 조금 더 상관관계를 보이는 metric이다.
또한, generation factuality와 specificity 방면으로 human evaluation을 진행하였다고 한다.
Fact verification task는 input으로 받는 "claim"에 대해 wikipedia 정보를 바탕으로 해당 claim이 support 되는지, refute 되는지, 혹은 wikipedia에 충분한 정보가 없는지를 판별하는 task이다.
이를 측정하기 위한 dataset으로 FEVER dataset이 있으며, 이는 카이스트의 James Thorne 교수님께서 제안하신 dataset이기도 하다.
해당 task들에 대한 결과는 아래와 같다.
먼저 Jeopardy Question Generation task 결과부터 살펴보겠다. retrieval을 사용하지 않은 BART에 비해, RAG model이 Q-BLEU t상으로나, human assessment상으로나 더 좋은 성능을 내는 것을 확인할 수 있다.
Fact verification task에서는 SOTA를 달성하진 못하였다. 그러나 저자들은 그 당시 SOTA model의 경우 복잡한 구조를 가지고 있으며, domain-specific architecture를 가지고 있지만, RAG는 이와 다르게 범용적으로 설계되었음에도 불구하고 SOTA에 근접하는 성능을 가진다고 말한다.
댓글