사과나무심기

0513 - Tokenizer (BPE)

리하이 2024. 5. 13. 23:15

일하다 모르는 개념들을 

집에서 조금씩 다지는거? 매우 괜찮네

 

일할때는 듀데잇도 있고, 

뭔가 원론적인 내용들을 머리에 잘 안들어오는데

집에서 차분히 보니 쉽게 들어오는구만.

(회사에서 애매하게 알고 있던것을 지금 공부한다는 말을 돌려돌려말한것임)

 

1. BPE (Bytepair Encoding)

BPE는 사실 압축기법이다.

가장 많이 반복되는 character-pair (2개)를 선택한다. ("aa")

aaabdaaabac

"aa"를 Z로 치환한다.

ZabdZabac
Z=aa

그다음 자주나오는 character-pair를 선택한다. (ab)

치환한다.

ZYdZYac
Y=ab
Z=aa

이걸 더이상 치환할게 없을 때까지 반복하는 로직이다.

매우 간단하고, 직관적이다.

 

이걸 자연어처리에서 사용할때는 다음과 같다.

Training data vocab(BPE에 사용할 텍스트 시퀀스)를 준비하고,

byte-pair 빈도수를 기준으로 자른다.

Input이 들어오면, BPE토크나이징 기준으로 단어를 쪼갠다. 

 

내가 궁금한 부분은 BPE encoding부분이다.

어떤 단어가 들어왔을때 그냥 greedy하게 단어를 tokenize하진 않을것이다. 그럼 어떻게 할것인가?

def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as a tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def encode(orig):
    """Encode word based on list of BPE merge operations, which are applied consecutively"""

    word = tuple(orig) + ('</w>',)
    display(Markdown("__word split into characters:__ <tt>{}</tt>".format(word)))

    pairs = get_pairs(word)    

    if not pairs:
        return orig

    iteration = 0
    while True:
        iteration += 1
        display(Markdown("__Iteration {}:__".format(iteration)))

        print("bigrams in the word: {}".format(pairs))
        bigram = min(pairs, key = lambda pair: bpe_codes.get(pair, float('inf')))
        print("candidate for merging: {}".format(bigram))
        if bigram not in bpe_codes:
            display(Markdown("__Candidate not in BPE merges, algorithm stops.__"))
            break
        first, second = bigram
        new_word = []
        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)
                new_word.extend(word[i:j])
                i = j
            except:
                new_word.extend(word[i:])
                break

            if word[i] == first and i < len(word)-1 and word[i+1] == second:
                new_word.append(first+second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_word = tuple(new_word)
        word = new_word
        print("word after merging: {}".format(word))
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)

    # 특별 토큰인 </w>는 출력하지 않는다.
    if word[-1] == '</w>':
        word = word[:-1]
    elif word[-1].endswith('</w>'):
        word = word[:-1] + (word[-1].replace('</w>',''),)

    return word

 

아하 !

일단 단어가 들어오면 그걸 다- 캐릭터로 쪼갠다음에 byte-pair로 묶는다(2개씩 묶는다)

BPE vocab에 들어있으면, 치환하고,

다시 묶고, 치환하고 이걸 반복한다.

word split into characters: ('l', 'o', 'k', 'i', '')

Iteration 1:
bigrams in the word: {('i', '</w>'), ('o', 'k'), ('l', 'o'), ('k', 'i')}
candidate for merging: ('l', 'o')
word after merging: ('lo', 'k', 'i', '</w>')

Iteration 2:
bigrams in the word: {('i', '</w>'), ('k', 'i'), ('lo', 'k')}
candidate for merging: ('i', '</w>')
Candidate not in BPE merges, algorithm stops.

('lo', 'k', 'i')

 

BPE는 병합 규칙도 저장해둔다. (왜?)