Replace js-tiktoken BPE merge algorithm with faster heap based algorithm#101
Replace js-tiktoken BPE merge algorithm with faster heap based algorithm#101mikolalysenko wants to merge 3 commits intodqbd:mainfrom
Conversation
|
@dqbd are we considering merging this, we notice some performance issues if the input is large as well |
|
@mikolalysenko have you benchmarked this vs baseline vs wasm? curious for 1, 1000, 1M, 100M tokens. |
If there is a perf difference at different scales, we should consider toggling algorithms based on length. |
|
It would be lovely if this got merged - this module has been causing issues for my application in prod. |
|
+1 |
|
I am not sure if this fully works, but modify import base64 from "base64-js";
import type { TiktokenModel } from "./ranks/ranks";
import { never } from "./utils";
type BPEMergeNode = {
listNext: BPEMergeNode | null;
listPrev: BPEMergeNode | null;
deleted: boolean;
updated: boolean;
updatedRank: number;
removed: boolean;
rank: number;
start: number;
end: number;
};
function compareNode(a: BPEMergeNode, b: BPEMergeNode) {
return a.rank - b.rank || a.start - b.start;
}
// Helper function to swap elements at two indices
function swap(heap: BPEMergeNode[], i: number, j: number) {
const temp = heap[i];
heap[i] = heap[j];
heap[j] = temp;
}
// standard binary heap push, generated by gpt4
function heapPush(heap: BPEMergeNode[], part: BPEMergeNode) {
heap.push(part); // Add the new element to the end
let currentIndex = heap.length - 1;
let parentIndex = Math.floor((currentIndex - 1) / 2);
// Bubble the new element up to its correct position
while (
currentIndex > 0 &&
compareNode(heap[currentIndex], heap[parentIndex]) < 0
) {
swap(heap, currentIndex, parentIndex);
currentIndex = parentIndex;
parentIndex = Math.floor((currentIndex - 1) / 2);
}
}
// standard heap pop, also ai generated
function heapPop(heap: BPEMergeNode[]) {
if (heap.length === 0) {
return undefined; // Return undefined if the heap is empty
}
const rootValue = heap[0]; // The root element to return
const lastValue = heap.pop(); // Remove the last element
if (heap.length > 0 && lastValue) {
heap[0] = lastValue; // Move the last element to the root
let currentIndex = 0;
// Bubble down the new root element to its correct position
while (true) {
let leftChildIndex = 2 * currentIndex + 1;
let rightChildIndex = 2 * currentIndex + 2;
let smallestIndex = currentIndex;
if (
leftChildIndex < heap.length &&
compareNode(heap[leftChildIndex], heap[smallestIndex]) < 0
) {
smallestIndex = leftChildIndex;
}
if (
rightChildIndex < heap.length &&
compareNode(heap[rightChildIndex], heap[smallestIndex]) < 0
) {
smallestIndex = rightChildIndex;
}
if (smallestIndex !== currentIndex) {
swap(heap, currentIndex, smallestIndex);
currentIndex = smallestIndex;
} else {
break;
}
}
}
return rootValue;
}
function bytePairMerge(
piece: Uint8Array,
ranks: Map<string, number>
): Array<{ start: number; end: number }> {
const parts: BPEMergeNode[] = Array.from(
{ length: piece.length },
(_, i) => ({
start: i,
end: i + 1,
rank: Infinity,
deleted: false,
updated: false,
updatedRank: 0,
removed: true,
listNext: null,
listPrev: null,
})
);
if (parts.length === 0) {
return [];
}
// Initialize linked list
const head = parts[0];
for (let i = 0; i < parts.length; ++i) {
parts[i].listPrev = parts[i - 1] ?? null;
parts[i].listNext = parts[i + 1] ?? null;
}
// Initialize heap with valid merges
const heap: BPEMergeNode[] = [];
for (let i = 0; i < parts.length - 1; ++i) {
const slice = piece.slice(parts[i].start, parts[i + 1].end);
const rank = ranks.get(slice.join(","));
if (rank == null) continue;
const part = parts[i];
part.removed = false;
part.rank = rank;
heapPush(heap, part);
}
while (heap.length > 0) {
const part = heapPop(heap);
if (!part) break;
if (part.deleted || !part.listNext) {
continue;
}
if (part.updated) {
part.rank = part.updatedRank;
part.updated = false;
heapPush(heap, part);
continue;
}
// Verify the merge is still valid
const currentSlice = piece.slice(part.start, part.listNext.end);
const currentRank = ranks.get(currentSlice.join(","));
if (currentRank !== part.rank) {
continue;
}
// Perform merge
part.end = part.listNext.end;
part.listNext.deleted = true;
part.listNext = part.listNext.listNext;
if (part.listNext) {
part.listNext.listPrev = part;
}
// Check for new possible merges
let addedNewMerge = false;
if (part.listNext) {
const slice = piece.slice(part.start, part.listNext.end);
const rank = ranks.get(slice.join(","));
if (rank != null) {
part.rank = rank;
part.removed = false;
heapPush(heap, part);
addedNewMerge = true;
}
}
if (part.listPrev && !part.listPrev.deleted) {
const slice = piece.slice(part.listPrev.start, part.end);
const rank = ranks.get(slice.join(","));
if (rank != null) {
if (!part.listPrev.removed) {
part.listPrev.updated = true;
part.listPrev.updatedRank = rank;
} else {
part.listPrev.removed = false;
part.listPrev.rank = rank;
heapPush(heap, part.listPrev);
}
addedNewMerge = true;
}
}
if (!addedNewMerge) {
part.removed = true;
}
}
const result: Array<{ start: number; end: number }> = [];
let current: BPEMergeNode | null = head;
while (current) {
if (!current.deleted) {
result.push({ start: current.start, end: current.end });
}
current = current.listNext;
}
return result;
}
// rest of code unchangedIt will then pass tests since in the current PR there is a decoding error with non-latin chars FAIL test/compatibility.test.ts > LiteTokenizer matches the behavior of tiktoken > Emojis and non-latin characters
js-tiktoken:test: AssertionError: expected [ 9468, 239, 102, 378, 235, …(109) ] to deeply equal [ 9468, 239, 102, 378, 235, …(111) ]
js-tiktoken:test: ❯ test/compatibility.test.ts:50:38
js-tiktoken:test: 48|
js-tiktoken:test: 49| for (const text of fixtures) {
js-tiktoken:test: 50| expect([...lite.encode(text)]).toEqual([...full.encode(text)]);
js-tiktoken:test: | ^
js-tiktoken:test: 51| }
js-tiktoken:test: 52| });With this import { test, expect, describe, } from "vitest";
import { getEncoding } from "../src/index";
const TARGET_TIME = 30_000;
const TARGET_STRING_LENGTH = 1_000_000; // Crazy high number to test the limits of the lite tokenizer
const EVIL_STRING = Array.from({ length: TARGET_STRING_LENGTH }, () => {
return String.fromCharCode(Math.floor(Math.random() * 256));
}).join("");
// This test will be flaky - so perhaps we should run it externally
// from the main CI pipeline since it depends on the machine it's run on
describe(`Lite tokenizer resolves ${EVIL_STRING.length / 1000}K string in acceptable time (${TARGET_TIME}ms)`, () => {
const lite = getEncoding("cl100k_base");
test("Test lite performance", () => {
const start = Date.now();
const result = lite.encode(EVIL_STRING);
const end = Date.now();
console.log(`Lite encoding time: ${end - start}ms`);
expect(end - start).toBeLessThanOrEqual(TARGET_TIME);
});
test("Test encoding/decoding", () => {
const result = lite.encode(EVIL_STRING);
const decoded = lite.decode(result);
expect(decoded).toEqual(EVIL_STRING);
});
});With a crazy length of |
There are several open issues noting that in the worst case BPE merge algorithm in js-tiktoken takes quadratic time in the number of input characters for certain pathological inputs.
This PR fixes this problem using a heap to avoid recalculating the ranks of all tokens at each character. This technique should also work for the rust/wasm tokenizer but it seems less important in those cases since the native parsers are already pretty fast.
I also added a new test fixture and an example string which causes pathological behavior.
Related issues:
Note: This should be a mild CVE since an attacker may use this behavior to cause a denial of service against services that check user input with js-tiktoken.