mirror of
https://github.com/QwenLM/Qwen.git
synced 2026-05-20 08:25:47 +08:00
update cache GC in demo and add vocab expansion example
This commit is contained in:
10
cli_demo.py
10
cli_demo.py
@@ -11,6 +11,7 @@ import platform
|
|||||||
import shutil
|
import shutil
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
@@ -64,6 +65,13 @@ def _load_model_tokenizer(args):
|
|||||||
return model, tokenizer, config
|
return model, tokenizer, config
|
||||||
|
|
||||||
|
|
||||||
|
def _gc():
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def _clear_screen():
|
def _clear_screen():
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
os.system("cls")
|
os.system("cls")
|
||||||
@@ -129,10 +137,12 @@ def main():
|
|||||||
elif command in ['clear', 'cl']:
|
elif command in ['clear', 'cl']:
|
||||||
_clear_screen()
|
_clear_screen()
|
||||||
print(_WELCOME_MSG)
|
print(_WELCOME_MSG)
|
||||||
|
_gc()
|
||||||
continue
|
continue
|
||||||
elif command in ['clear-history', 'clh']:
|
elif command in ['clear-history', 'clh']:
|
||||||
print(f'[INFO] All {len(history)} history cleared')
|
print(f'[INFO] All {len(history)} history cleared')
|
||||||
history.clear()
|
history.clear()
|
||||||
|
_gc()
|
||||||
continue
|
continue
|
||||||
elif command in ['help', 'h']:
|
elif command in ['help', 'h']:
|
||||||
print(_HELP_MSG)
|
print(_HELP_MSG)
|
||||||
|
|||||||
226
examples/add_merges.py
Normal file
226
examples/add_merges.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import collections
|
||||||
|
import logging
|
||||||
|
import unicodedata
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
from tqdm.contrib.logging import tqdm_logging_redirect
|
||||||
|
|
||||||
|
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG, format="[%(asctime)s] %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> "dict[bytes, int]":
|
||||||
|
contents = open(tiktoken_bpe_file, "rb").read()
|
||||||
|
return {
|
||||||
|
base64.b64decode(token): int(rank)
|
||||||
|
for token, rank in (line.split() for line in contents.splitlines() if line)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def dump_tiktoken_bpe(bpe_ranks: "dict[bytes, int]", tiktoken_bpe_file: str) -> None:
|
||||||
|
with open(tiktoken_bpe_file, "wb") as f:
|
||||||
|
for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]):
|
||||||
|
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_pieces(the_bytes: bytes) -> "tuple[bytes]":
|
||||||
|
return tuple(bytes([byte]) for byte in the_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pairs(pieces: "tuple[bytes]") -> "set[tuple[bytes, bytes]]":
|
||||||
|
return set(zip(pieces[:-1], pieces[1:]))
|
||||||
|
|
||||||
|
|
||||||
|
def get_stats(
|
||||||
|
vocab: "dict[tuple[bytes, ...], int]",
|
||||||
|
) -> "dict[tuple[bytes, bytes], int]":
|
||||||
|
pairs = collections.defaultdict(int)
|
||||||
|
for word, freq in vocab.items():
|
||||||
|
for i in range(len(word) - 1):
|
||||||
|
pairs[(word[i], word[i + 1])] += freq
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
def merge_vocab(
|
||||||
|
pair: "tuple[bytes, bytes]", vocab: "dict[tuple[bytes, ...], int]"
|
||||||
|
) -> "dict[tuple[bytes, ...], int]":
|
||||||
|
return {apply_bp(pieces, pair): freq for pieces, freq in vocab.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def apply_bp(
|
||||||
|
pieces: "tuple[bytes, ...]", pair: "tuple[bytes, bytes]"
|
||||||
|
) -> "tuple[bytes, ...]":
|
||||||
|
new_pieces = []
|
||||||
|
first, second = pair
|
||||||
|
i = 0
|
||||||
|
while i < len(pieces):
|
||||||
|
try:
|
||||||
|
j = pieces.index(first, i)
|
||||||
|
new_pieces.extend(pieces[i:j])
|
||||||
|
i = j
|
||||||
|
except:
|
||||||
|
new_pieces.extend(pieces[i:])
|
||||||
|
break
|
||||||
|
|
||||||
|
if pieces[i] == first and i < len(pieces) - 1 and pieces[i + 1] == second:
|
||||||
|
new_pieces.append(first + second)
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
new_pieces.append(pieces[i])
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return tuple(new_pieces)
|
||||||
|
|
||||||
|
|
||||||
|
def bpe(word: bytes, merges: "dict[bytes,int]") -> "tuple[bytes, ...]":
|
||||||
|
pieces = bytes_to_pieces(word)
|
||||||
|
while len(pieces) > 1:
|
||||||
|
pairs = get_pairs(pieces)
|
||||||
|
pair = min(pairs, key=lambda pair: merges.get(pair[0] + pair[1], float("inf")))
|
||||||
|
|
||||||
|
if pair[0] + pair[1] not in merges:
|
||||||
|
break
|
||||||
|
pieces = apply_bp(pieces, pair)
|
||||||
|
# logger.debug(f"{[(p, p.decode('utf8', errors='replace')) for p in pieces]} {pair} {pieces}")
|
||||||
|
return pieces
|
||||||
|
|
||||||
|
|
||||||
|
def best_pair_sort_key(
|
||||||
|
item: "tuple[dict[bytes, bytes], int]",
|
||||||
|
) -> "tuple[int, int, int, str, bytes]":
|
||||||
|
# prefer to use the highest frequency or shortest length or lexi sort, sligtly slower
|
||||||
|
pair, freq = item
|
||||||
|
pair_bytes = pair[0] + pair[1]
|
||||||
|
pair_byte_length = len(pair_bytes)
|
||||||
|
pair_str = pair_bytes.decode("utf-8", errors="replace")
|
||||||
|
pair_str_length = len(pair_str)
|
||||||
|
return -freq, pair_str_length, pair_byte_length, pair_str, pair_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def learn_bpe(
|
||||||
|
freqs: "dict[str,int]", existing: "dict[bytes, int]"
|
||||||
|
) -> "tuple[bytes, bytes]":
|
||||||
|
vocab = {bpe(k.encode("utf-8"), existing): v for k, v in freqs.items()}
|
||||||
|
vocab = {key: value for key, value in vocab.items() if len(key) > 1}
|
||||||
|
new_merges = []
|
||||||
|
with tqdm_logging_redirect() as bar:
|
||||||
|
while vocab:
|
||||||
|
pairs = get_stats(vocab)
|
||||||
|
|
||||||
|
best, freq = min(pairs.items(), key=best_pair_sort_key)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f'{best} ({(best[0]+best[1]).decode("utf-8", errors="replace")}) is selected as the next merge with freq {freq}'
|
||||||
|
)
|
||||||
|
new_merges.append(best)
|
||||||
|
|
||||||
|
vocab = merge_vocab(best, vocab)
|
||||||
|
vocab = {key: value for key, value in vocab.items() if len(key) > 1}
|
||||||
|
bar.update()
|
||||||
|
|
||||||
|
return new_merges
|
||||||
|
|
||||||
|
|
||||||
|
def load_expand_vocab(path: Path) -> "dict[str, int]":
|
||||||
|
freqs = {}
|
||||||
|
with open(path, "r", encoding="utf8") as fin:
|
||||||
|
for line in fin:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
word, freq = line.strip().split("\t")
|
||||||
|
word = unicodedata.normalize("NFC", word)
|
||||||
|
parts = re.findall(PAT_STR, word)
|
||||||
|
if len(parts) > 1:
|
||||||
|
logger.warning(
|
||||||
|
f"{word} would be pre-tokenized to {parts}, and thus cannot be added to vocabulary"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
freq = int(freq)
|
||||||
|
except ValueError as _:
|
||||||
|
freq = 1
|
||||||
|
if word in freqs:
|
||||||
|
logger.warning(
|
||||||
|
f"{word} is repeated, the frequency is increased by this much"
|
||||||
|
)
|
||||||
|
freqs[word] += freq
|
||||||
|
else:
|
||||||
|
freqs[word] = freq
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
def make_new_merges_by_bpe(
|
||||||
|
input_path: Path, output_path: Path, expand_path: Path, start_id: int
|
||||||
|
) -> None:
|
||||||
|
mergeable_ranks = load_tiktoken_bpe(input_path)
|
||||||
|
|
||||||
|
if not start_id or start_id == -1:
|
||||||
|
start_id = len(mergeable_ranks)
|
||||||
|
elif start_id < len(mergeable_ranks):
|
||||||
|
logger.warning(
|
||||||
|
f"start_id {start_id} is too small, existing merges will be overridden, DONOT DO THIS. changed to {len(mergeable_ranks)}"
|
||||||
|
)
|
||||||
|
start_id = len(mergeable_ranks)
|
||||||
|
else:
|
||||||
|
start_id = start_id
|
||||||
|
|
||||||
|
expand_vocab_freqs = load_expand_vocab(expand_path)
|
||||||
|
for word in list(expand_vocab_freqs):
|
||||||
|
token = word.encode("utf-8")
|
||||||
|
if token in mergeable_ranks:
|
||||||
|
logger.warning(f"word {word} is already a token {token}, skipping")
|
||||||
|
del expand_vocab_freqs[word]
|
||||||
|
|
||||||
|
logger.info(f"number of existing merges: {len(mergeable_ranks)}")
|
||||||
|
logger.info(f"number of words for expanding: {len(expand_vocab_freqs)}")
|
||||||
|
|
||||||
|
new_merges = learn_bpe(expand_vocab_freqs, mergeable_ranks)
|
||||||
|
logger.info(f"number of newly learned merges: {len(new_merges)}")
|
||||||
|
|
||||||
|
extra_merges = {p[0] + p[1]: i for i, p in enumerate(new_merges, start=start_id)}
|
||||||
|
|
||||||
|
dump_tiktoken_bpe(extra_merges, output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
parser.add_argument("input_path", type=str, help="Path for input tiktoken file")
|
||||||
|
parser.add_argument(
|
||||||
|
"output_path",
|
||||||
|
type=str,
|
||||||
|
help="Path for output tiktoken file, containing only the new merges",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"vocab_path",
|
||||||
|
type=str,
|
||||||
|
help="Path for words needed adding, each line is a word and its frequency separated by \\t",
|
||||||
|
)
|
||||||
|
# if the extended vocabulary is for fine-tuning, you better set those correctly (the default is for qwen.tiktoken)
|
||||||
|
# if the extended vocabulary is for pretraining from the start, no need
|
||||||
|
parser.add_argument(
|
||||||
|
"--start_id",
|
||||||
|
type=int,
|
||||||
|
default=151851,
|
||||||
|
help="The start id for new merges. For Qwen tokenizer, this should be 151851 (skipping the existing special tokens)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
make_new_merges_by_bpe(
|
||||||
|
args.input_path, args.output_path, args.vocab_path, args.start_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
6
examples/qwen_extra.tiktoken
Normal file
6
examples/qwen_extra.tiktoken
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
5LiA5Y+q54yr 151851
|
||||||
|
5Y+q54yr 151852
|
||||||
|
5piv5LiA5Y+q54yr 151853
|
||||||
|
5oiR5piv5LiA5Y+q54yr 151854
|
||||||
|
5L2g5piv5LiA5Y+q54yr 151855
|
||||||
|
5LuW5piv5LiA5Y+q54yr 151856
|
||||||
6
examples/qwen_extra_vocab.txt
Normal file
6
examples/qwen_extra_vocab.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
我是一只猫 20
|
||||||
|
你是一只猫 10
|
||||||
|
他是一只猫 5
|
||||||
|
一只 200
|
||||||
|
一只猫 100
|
||||||
|
夸张的 比喻手法 20
|
||||||
@@ -9,8 +9,6 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
||||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
|
||||||
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
|
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -414,6 +412,142 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"tokenizer._convert_token_to_id('<|extra_204|>')"
|
"tokenizer._convert_token_to_id('<|extra_204|>')"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Vocabulary Expansion"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'input_ids': [35946, 99639, 91680, 100472], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 19,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"tokenizer(\"我是一只猫\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[99639, 91680, 100472]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 20,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"tokenizer.encode(\"是一只猫\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 21,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen-7B', trust_remote_code=True, extra_vocab_file=\"qwen_extra.tiktoken\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 22,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"151857"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 22,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"len(tokenizer)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'input_ids': [151854], 'token_type_ids': [0], 'attention_mask': [1]}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"tokenizer(\"我是一只猫\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'我是一只猫'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"tokenizer.decode(tokenizer.encode(\"我是一只猫\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 25,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[151853]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 25,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"tokenizer.encode(\"是一只猫\")"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@@ -432,7 +566,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.9"
|
"version": "3.10.12"
|
||||||
},
|
},
|
||||||
"orig_nbformat": 4
|
"orig_nbformat": 4
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -21,12 +21,21 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|||||||
from transformers.generation import GenerationConfig
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _gc(forced: bool = False):
|
||||||
|
global args
|
||||||
|
if args.disable_gc and not forced:
|
||||||
|
return
|
||||||
|
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI): # collects GPU memory
|
async def lifespan(app: FastAPI): # collects GPU memory
|
||||||
yield
|
yield
|
||||||
if torch.cuda.is_available():
|
_gc(forced=True)
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
@@ -392,6 +401,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||||||
**gen_kwargs
|
**gen_kwargs
|
||||||
)
|
)
|
||||||
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
|
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
|
||||||
|
_gc()
|
||||||
|
|
||||||
response = trim_stop_words(response, stop_words)
|
response = trim_stop_words(response, stop_words)
|
||||||
if request.functions:
|
if request.functions:
|
||||||
choice_data = parse_response(response)
|
choice_data = parse_response(response)
|
||||||
@@ -453,6 +464,8 @@ async def predict(
|
|||||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
|
||||||
|
_gc()
|
||||||
|
|
||||||
|
|
||||||
def _get_args():
|
def _get_args():
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
@@ -476,6 +489,8 @@ def _get_args():
|
|||||||
help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
|
help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
|
||||||
" If you want other computers to access your server, use 0.0.0.0 instead.",
|
" If you want other computers to access your server, use 0.0.0.0 instead.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--disable-gc", action="store_true",
|
||||||
|
help="Disable GC after each response generated.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|||||||
@@ -125,3 +125,122 @@ The new default is the same as
|
|||||||
{'input_ids': [1350, 445, 151643, 899], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}
|
{'input_ids': [1350, 445, 151643, 899], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Vocabulary Expansion
|
||||||
|
|
||||||
|
> WARNING: Read carefully, be aware of what you are doing, and use at your own risk.
|
||||||
|
> There are certain caveats regarding how your vocabulary is produced.
|
||||||
|
|
||||||
|
The tokenizer of Qwen models are based on BPE and you cannot directly expand the vocabulary by adding words to the vocabulary.
|
||||||
|
The intermediate merges are needed for tokenization.
|
||||||
|
Please follow the steps to obtain such information.
|
||||||
|
|
||||||
|
1. Prepare a plain text file `qwen_extra_vocab.txt`, where each line contains a token and its frequency separated by `\t`.
|
||||||
|
|
||||||
|
An example is given below:
|
||||||
|
```
|
||||||
|
我是一只猫 20
|
||||||
|
你是一只猫 10
|
||||||
|
他是一只猫 5
|
||||||
|
一只 200
|
||||||
|
一只猫 100
|
||||||
|
夸张的 比喻手法 20
|
||||||
|
```
|
||||||
|
The frequencies are needed to compute the BPE.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Prepare the base vocabulary file, e.g., `qwen.tiktoken`, and determine the start index for new tokens.
|
||||||
|
|
||||||
|
There are 151,643 regular tokens and 208 control tokens in the vocabulary for Qwen models.
|
||||||
|
For simplicity, the start index can be set as 151,851, which is the default value.
|
||||||
|
You can, of course, override the many inactive control tokens, but you will need to modify the tokenizer code.
|
||||||
|
|
||||||
|
3. Run the following command:
|
||||||
|
```
|
||||||
|
python add_merges.py qwen.tiktoken qwen_extra.tiktoken qwen_extra_vocab.txt
|
||||||
|
```
|
||||||
|
`add_merges.py` can be found [here](examples/add_merges.py).
|
||||||
|
It will learn the new merges based on the provided `qwen_extra_vocab.txt`.
|
||||||
|
The new tokens and their indices will be stored in `qwen_extra.tiktoken`.
|
||||||
|
Modify the paths as you wish.
|
||||||
|
|
||||||
|
It is a pure Python implementation, so please expect it to be slow if you are adding a lot of words.
|
||||||
|
|
||||||
|
Please note that not all words can be added due to pre-tokenization.
|
||||||
|
You will get warnings if you try to add such word:
|
||||||
|
```
|
||||||
|
WARNING - 夸张的 比喻手法 would be pre-tokenized to ['夸张的', ' 比喻手法'], and thus cannot be added to vocabulary
|
||||||
|
WARNING - word 一只 is already a token b'\xe4\xb8\x80\xe5\x8f\xaa', skipping
|
||||||
|
INFO - number of existing merges: 151643
|
||||||
|
INFO - number of words for expanding: 4
|
||||||
|
DEBUG - (b'\xe4\xb8\x80\xe5\x8f\xaa', b'\xe7\x8c\xab') (一只猫) is selected as the next merge with freq 100
|
||||||
|
DEBUG - (b'\xe5\x8f\xaa', b'\xe7\x8c\xab') (只猫) is selected as the next merge with freq 35
|
||||||
|
DEBUG - (b'\xe6\x98\xaf\xe4\xb8\x80', b'\xe5\x8f\xaa\xe7\x8c\xab') (是一只猫) is selected as the next merge with freq 35
|
||||||
|
DEBUG - (b'\xe6\x88\x91', b'\xe6\x98\xaf\xe4\xb8\x80\xe5\x8f\xaa\xe7\x8c\xab') (我是一只猫) is selected as the next merge with freq 20
|
||||||
|
DEBUG - (b'\xe4\xbd\xa0', b'\xe6\x98\xaf\xe4\xb8\x80\xe5\x8f\xaa\xe7\x8c\xab') (你是一只猫) is selected as the next merge with freq 10
|
||||||
|
DEBUG - (b'\xe4\xbb\x96', b'\xe6\x98\xaf\xe4\xb8\x80\xe5\x8f\xaa\xe7\x8c\xab') (他是一只猫) is selected as the next merge with freq 5
|
||||||
|
INFO - number of newly learned merges: 6
|
||||||
|
```
|
||||||
|
|
||||||
|
The `qwen_extra.tiktoken` will contain the following lines:
|
||||||
|
```
|
||||||
|
5LiA5Y+q54yr 151851
|
||||||
|
5Y+q54yr 151852
|
||||||
|
5piv5LiA5Y+q54yr 151853
|
||||||
|
5oiR5piv5LiA5Y+q54yr 151854
|
||||||
|
5L2g5piv5LiA5Y+q54yr 151855
|
||||||
|
5LuW5piv5LiA5Y+q54yr 151856
|
||||||
|
```
|
||||||
|
|
||||||
|
You may use the file as follows in your code:
|
||||||
|
``` python
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True, extra_vocab_file="qwen_extra.tiktoken")
|
||||||
|
|
||||||
|
>>> len(tokenizer)
|
||||||
|
151857
|
||||||
|
|
||||||
|
>>> tokenizer("我是一只猫")
|
||||||
|
{'input_ids': [151854], 'token_type_ids': [0], 'attention_mask': [1]}
|
||||||
|
```
|
||||||
|
Note: You need the latest tokenizer code, i.e., after 2013-10-08, to use the `extra_vocab_file` argument.
|
||||||
|
Otherwise, you need to manually append `qwen.tiktoken` (of which path varies with your configuration) with the content from `qwen_extra.tiktoken`.
|
||||||
|
|
||||||
|
Certainly, you will need to finetune the model for the new tokens to work.
|
||||||
|
|
||||||
|
|
||||||
|
### Caveats
|
||||||
|
|
||||||
|
|
||||||
|
The tokenizer of Qwen operates directly on UTF-8 byte sequences, unlike others, e.g., SentencePiece that operates on UTF-8 codepoints/characters and falls back to UTF-8 byte sequences for the unknown (IIRC).
|
||||||
|
The thing is if the frequencies are computed on limited data, the UTF-8 codepoint boundary may not be correctly recognized.
|
||||||
|
In theory, it could be a problem for fine-tuned models using the expanded vocabulary with limited data.
|
||||||
|
|
||||||
|
For example, it could happen that `b'\x80\xe5'` might be merged first for the UTF-8 byte sequence `b'\xe4\xb8\x80\xe5\x8f\xaa'` of the string `一只`, across the UTF-8 codepoint of `一`(`b'\xe4\xb8\x80'`) and `只` (`b'\xe5\x8f\xaa'`).
|
||||||
|
Normally, this would work just fine for known words, but for actually unknown words, unusual merges may happen, which may not be well understood for the pre-trained model.
|
||||||
|
|
||||||
|
Our advice is that to be safe, you should gather the UTF-8 codepoints from all the words you need to add, and also add them to the file with frequencies higher than the sum of the frequencies of the corresponding words.
|
||||||
|
But since Qwen has most of the Chinese words, it could be okay to just add the Chinese words alone.
|
||||||
|
|
||||||
|
For curious minds, you will also notice that in the given example, `一只` is a token and `只猫` is also learned as a new token.
|
||||||
|
The reason is that `是一` is also a token in Qwen and has higher merging priority than `一只`, such that the merging path for `是|一|只|猫` is `是一|只|猫 -> 是一|只猫 -> 是一只猫` (omitting the UTF-8 byte merges).
|
||||||
|
|
||||||
|
This is the characteristic for plain BPE: it is based solely on distribution, meaning it does not have knowledge of which bytes can form a valid UTF-8 codepoint, character, or meaningful word.
|
||||||
|
|
||||||
|
The byproduct is that text may be sub-tokenized differently in different contexts, even for words containing only ASCII characters.
|
||||||
|
```python
|
||||||
|
>>> tokenizer.tokenize("Panda")
|
||||||
|
[b'P', b'anda']
|
||||||
|
|
||||||
|
>>> tokenizer.tokenize(" Panda")
|
||||||
|
[b' Panda']
|
||||||
|
|
||||||
|
>>> tokenizer.tokenize("Pandas")
|
||||||
|
[b'P', b'andas']
|
||||||
|
|
||||||
|
>>> tokenizer.tokenize(" Pandas")
|
||||||
|
[b' Pand', b'as']
|
||||||
|
```
|
||||||
|
This simply suggests that those combinations occur more frequently in the data.
|
||||||
|
If you have vast amount of training data, it should not be a problem.
|
||||||
11
web_demo.py
11
web_demo.py
@@ -107,6 +107,13 @@ def _parse_text(text):
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _gc():
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def _launch_demo(args, model, tokenizer, config):
|
def _launch_demo(args, model, tokenizer, config):
|
||||||
|
|
||||||
def predict(_query, _chatbot, _task_history):
|
def predict(_query, _chatbot, _task_history):
|
||||||
@@ -138,9 +145,7 @@ def _launch_demo(args, model, tokenizer, config):
|
|||||||
def reset_state(_chatbot, _task_history):
|
def reset_state(_chatbot, _task_history):
|
||||||
_task_history.clear()
|
_task_history.clear()
|
||||||
_chatbot.clear()
|
_chatbot.clear()
|
||||||
import gc
|
_gc()
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return _chatbot
|
return _chatbot
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
|
|||||||
Reference in New Issue
Block a user