update cache GC in demo and add vocab expansion example

This commit is contained in:
yangapku
2023-10-09 14:59:57 +08:00
parent 343017c4ce
commit cbfaada8de
8 changed files with 530 additions and 9 deletions

226
examples/add_merges.py Normal file
View File

@@ -0,0 +1,226 @@
import argparse
import base64
import collections
import logging
import unicodedata
from pathlib import Path
import regex as re
from tqdm.contrib.logging import tqdm_logging_redirect
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.DEBUG, format="[%(asctime)s] %(levelname)s - %(message)s"
)
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> "dict[bytes, int]":
contents = open(tiktoken_bpe_file, "rb").read()
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
def dump_tiktoken_bpe(bpe_ranks: "dict[bytes, int]", tiktoken_bpe_file: str) -> None:
with open(tiktoken_bpe_file, "wb") as f:
for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]):
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
def bytes_to_pieces(the_bytes: bytes) -> "tuple[bytes]":
return tuple(bytes([byte]) for byte in the_bytes)
def get_pairs(pieces: "tuple[bytes]") -> "set[tuple[bytes, bytes]]":
return set(zip(pieces[:-1], pieces[1:]))
def get_stats(
vocab: "dict[tuple[bytes, ...], int]",
) -> "dict[tuple[bytes, bytes], int]":
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
for i in range(len(word) - 1):
pairs[(word[i], word[i + 1])] += freq
return pairs
def merge_vocab(
pair: "tuple[bytes, bytes]", vocab: "dict[tuple[bytes, ...], int]"
) -> "dict[tuple[bytes, ...], int]":
return {apply_bp(pieces, pair): freq for pieces, freq in vocab.items()}
def apply_bp(
pieces: "tuple[bytes, ...]", pair: "tuple[bytes, bytes]"
) -> "tuple[bytes, ...]":
new_pieces = []
first, second = pair
i = 0
while i < len(pieces):
try:
j = pieces.index(first, i)
new_pieces.extend(pieces[i:j])
i = j
except:
new_pieces.extend(pieces[i:])
break
if pieces[i] == first and i < len(pieces) - 1 and pieces[i + 1] == second:
new_pieces.append(first + second)
i += 2
else:
new_pieces.append(pieces[i])
i += 1
return tuple(new_pieces)
def bpe(word: bytes, merges: "dict[bytes,int]") -> "tuple[bytes, ...]":
pieces = bytes_to_pieces(word)
while len(pieces) > 1:
pairs = get_pairs(pieces)
pair = min(pairs, key=lambda pair: merges.get(pair[0] + pair[1], float("inf")))
if pair[0] + pair[1] not in merges:
break
pieces = apply_bp(pieces, pair)
# logger.debug(f"{[(p, p.decode('utf8', errors='replace')) for p in pieces]} {pair} {pieces}")
return pieces
def best_pair_sort_key(
item: "tuple[dict[bytes, bytes], int]",
) -> "tuple[int, int, int, str, bytes]":
# prefer to use the highest frequency or shortest length or lexi sort, sligtly slower
pair, freq = item
pair_bytes = pair[0] + pair[1]
pair_byte_length = len(pair_bytes)
pair_str = pair_bytes.decode("utf-8", errors="replace")
pair_str_length = len(pair_str)
return -freq, pair_str_length, pair_byte_length, pair_str, pair_bytes
def learn_bpe(
freqs: "dict[str,int]", existing: "dict[bytes, int]"
) -> "tuple[bytes, bytes]":
vocab = {bpe(k.encode("utf-8"), existing): v for k, v in freqs.items()}
vocab = {key: value for key, value in vocab.items() if len(key) > 1}
new_merges = []
with tqdm_logging_redirect() as bar:
while vocab:
pairs = get_stats(vocab)
best, freq = min(pairs.items(), key=best_pair_sort_key)
logger.debug(
f'{best} ({(best[0]+best[1]).decode("utf-8", errors="replace")}) is selected as the next merge with freq {freq}'
)
new_merges.append(best)
vocab = merge_vocab(best, vocab)
vocab = {key: value for key, value in vocab.items() if len(key) > 1}
bar.update()
return new_merges
def load_expand_vocab(path: Path) -> "dict[str, int]":
freqs = {}
with open(path, "r", encoding="utf8") as fin:
for line in fin:
if not line.strip():
continue
word, freq = line.strip().split("\t")
word = unicodedata.normalize("NFC", word)
parts = re.findall(PAT_STR, word)
if len(parts) > 1:
logger.warning(
f"{word} would be pre-tokenized to {parts}, and thus cannot be added to vocabulary"
)
continue
try:
freq = int(freq)
except ValueError as _:
freq = 1
if word in freqs:
logger.warning(
f"{word} is repeated, the frequency is increased by this much"
)
freqs[word] += freq
else:
freqs[word] = freq
return freqs
def make_new_merges_by_bpe(
input_path: Path, output_path: Path, expand_path: Path, start_id: int
) -> None:
mergeable_ranks = load_tiktoken_bpe(input_path)
if not start_id or start_id == -1:
start_id = len(mergeable_ranks)
elif start_id < len(mergeable_ranks):
logger.warning(
f"start_id {start_id} is too small, existing merges will be overridden, DONOT DO THIS. changed to {len(mergeable_ranks)}"
)
start_id = len(mergeable_ranks)
else:
start_id = start_id
expand_vocab_freqs = load_expand_vocab(expand_path)
for word in list(expand_vocab_freqs):
token = word.encode("utf-8")
if token in mergeable_ranks:
logger.warning(f"word {word} is already a token {token}, skipping")
del expand_vocab_freqs[word]
logger.info(f"number of existing merges: {len(mergeable_ranks)}")
logger.info(f"number of words for expanding: {len(expand_vocab_freqs)}")
new_merges = learn_bpe(expand_vocab_freqs, mergeable_ranks)
logger.info(f"number of newly learned merges: {len(new_merges)}")
extra_merges = {p[0] + p[1]: i for i, p in enumerate(new_merges, start=start_id)}
dump_tiktoken_bpe(extra_merges, output_path)
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("input_path", type=str, help="Path for input tiktoken file")
parser.add_argument(
"output_path",
type=str,
help="Path for output tiktoken file, containing only the new merges",
)
parser.add_argument(
"vocab_path",
type=str,
help="Path for words needed adding, each line is a word and its frequency separated by \\t",
)
# if the extended vocabulary is for fine-tuning, you better set those correctly (the default is for qwen.tiktoken)
# if the extended vocabulary is for pretraining from the start, no need
parser.add_argument(
"--start_id",
type=int,
default=151851,
help="The start id for new merges. For Qwen tokenizer, this should be 151851 (skipping the existing special tokens)",
)
args = parser.parse_args()
make_new_merges_by_bpe(
args.input_path, args.output_path, args.vocab_path, args.start_id
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,6 @@
5LiA5Y+q54yr 151851
5Y+q54yr 151852
5piv5LiA5Y+q54yr 151853
5oiR5piv5LiA5Y+q54yr 151854
5L2g5piv5LiA5Y+q54yr 151855
5LuW5piv5LiA5Y+q54yr 151856

View File

@@ -0,0 +1,6 @@
我是一只猫 20
你是一只猫 10
他是一只猫 5
一只 200
一只猫 100
夸张的 比喻手法 20

View File

@@ -9,8 +9,6 @@
"name": "stderr",
"output_type": "stream",
"text": [
"TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
]
}
@@ -414,6 +412,142 @@
"source": [
"tokenizer._convert_token_to_id('<|extra_204|>')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Vocabulary Expansion"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [35946, 99639, 91680, 100472], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer(\"我是一只猫\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[99639, 91680, 100472]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.encode(\"是一只猫\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen-7B', trust_remote_code=True, extra_vocab_file=\"qwen_extra.tiktoken\")\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"151857"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [151854], 'token_type_ids': [0], 'attention_mask': [1]}"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer(\"我是一只猫\")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'我是一只猫'"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.decode(tokenizer.encode(\"我是一只猫\"))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[151853]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.encode(\"是一只猫\")"
]
}
],
"metadata": {
@@ -432,7 +566,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.12"
},
"orig_nbformat": 4
},