61 lines
2.2 KiB
Python
61 lines
2.2 KiB
Python
from text import text, tokens
|
|
|
|
number_of_tokens = 256 # Number of possible integer values in a byte.
|
|
|
|
def get_pairs(tokens):
|
|
pairs = {}
|
|
for pair in zip(tokens, tokens[1:]):
|
|
pairs[pair] = pairs.get(pair, 0) + 1
|
|
return pairs
|
|
|
|
def replace_pair(tokens, pair, new_token):
|
|
new_tokens = []
|
|
i = 0
|
|
while i < len(tokens):
|
|
if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair:
|
|
new_tokens.append(new_token)
|
|
i += 2
|
|
else:
|
|
new_tokens.append(tokens[i])
|
|
i += 1
|
|
return new_tokens
|
|
|
|
def merge(tokens, number_of_merges = 20):
|
|
"""Merge tokens according to the byte pair encoding algorithm."""
|
|
merged_tokens = list(tokens)
|
|
merges = {}
|
|
for i in range(number_of_merges):
|
|
pairs = get_pairs(merged_tokens)
|
|
most_frequent_pair = max(pairs, key = pairs.get)
|
|
merged_tokens = replace_pair(merged_tokens, most_frequent_pair, number_of_tokens + i)
|
|
merges[most_frequent_pair] = number_of_tokens + i
|
|
return merged_tokens, merges
|
|
|
|
def encode(text, merges):
|
|
"""Encode the text into a sequence of merged tokens."""
|
|
tokens = list(text.encode('utf-8'))
|
|
while len(tokens) > 1:
|
|
pairs = get_pairs(tokens)
|
|
pair = min(pairs, key = lambda pair: merges.get(pair, float('inf')))
|
|
if pair not in merges:
|
|
break # Nothing else to merge.
|
|
token = merges[pair]
|
|
tokens = replace_pair(tokens, pair, token)
|
|
return tokens
|
|
|
|
def decode(merged_tokens, merges):
|
|
"""Decode the merged tokens back into a UTF-8 string."""
|
|
vocabulary = {token: bytes([token]) for token in range(number_of_tokens)}
|
|
for (token1, token2), new_token in merges.items():
|
|
vocabulary[new_token] = vocabulary[token1] + vocabulary[token2]
|
|
tokens = b''.join([vocabulary[token] for token in merged_tokens])
|
|
return tokens.decode('utf-8', errors = 'replace')
|
|
|
|
if __name__ == "__main__":
|
|
merged_tokens, merges = merge(tokens)
|
|
print('Merges:', merges)
|
|
print('Compression Ratio:', len(tokens) / len(merged_tokens))
|
|
|
|
encoded_text = encode(text, merges)
|
|
decoded_text = decode(encoded_text, merges)
|
|
print('Encoded = Decoded?', text == decoded_text) |