from text import tokens number_of_tokens = 256 # Number of possible integer values in a byte. def get_pairs(tokens): pairs = {} for pair in zip(tokens, tokens[1:]): pairs[pair] = pairs.get(pair, 0) + 1 return pairs def replace_pair(tokens, pair, new_token): new_tokens = [] i = 0 while i < len(tokens): if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair: new_tokens.append(new_token) i += 2 else: new_tokens.append(tokens[i]) i += 1 return new_tokens def merge(tokens, number_of_merges = 20): """Merge tokens according to the byte pair encoding algorithm.""" merged_tokens = list(tokens) merges = {} for i in range(number_of_merges): pairs = get_pairs(merged_tokens) most_frequent_pair = max(pairs, key = pairs.get) merged_tokens = replace_pair(merged_tokens, most_frequent_pair, number_of_tokens + i) merges[most_frequent_pair] = number_of_tokens + i return merged_tokens, merges if __name__ == "__main__": merged_tokens, merges = merge(tokens) print('Merges:', merges) print('Compression Ratio:', len(tokens) / len(merged_tokens))