Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions src/__tests__/embeddings/ollama.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ describe('OllamaBackend', () => {
body: JSON.stringify({
model: DEFAULT_OLLAMA_MODEL,
input: ['test text'],
keep_alive: '10m',
}),
})
);
Expand Down Expand Up @@ -211,6 +212,7 @@ describe('OllamaBackend', () => {
body: JSON.stringify({
model: DEFAULT_OLLAMA_MODEL,
input: ['text1', 'text2', 'text3'],
keep_alive: '10m',
}),
})
);
Expand Down Expand Up @@ -248,22 +250,24 @@ describe('OllamaBackend', () => {
expect(result).toEqual([[0.1], [0.2], [0.3], [0.4], [0.5]]);
});

it('should use default batch size of 100', async () => {
const embeddings = Array.from({ length: 150 }, (_, i) => [i * 0.01]);
const mockFetch = vi
.fn()
.mockResolvedValueOnce(createOllamaBatchEmbeddingResponse(embeddings.slice(0, 100)))
.mockResolvedValueOnce(createOllamaBatchEmbeddingResponse(embeddings.slice(100)));
it('should use default batch size of 10', async () => {
const mockFetch = vi.fn().mockImplementation(async (_url, options) => {
const body = JSON.parse(options.body as string);
const inputLen = body.input.length;
return createOllamaBatchEmbeddingResponse(
Array.from({ length: inputLen }, (_, i) => [i * 0.01])
);
});
vi.stubGlobal('fetch', mockFetch);

// Create backend without custom batchSize (should use default 100)
// Create backend without custom batchSize (should use default 10)
const backend = new OllamaBackend({ backend: 'ollama' });
const texts = Array.from({ length: 150 }, (_, i) => `text${i}`);
const texts = Array.from({ length: 25 }, (_, i) => `text${i}`);
const result = await backend.embedBatch(texts);

// Should make 2 batch requests (100 + 50)
expect(mockFetch).toHaveBeenCalledTimes(2);
expect(result).toHaveLength(150);
// Should make 3 batch requests (10+10+5)
expect(mockFetch).toHaveBeenCalledTimes(3);
expect(result).toHaveLength(25);
});

it('should process batches in parallel based on concurrency', async () => {
Expand Down
11 changes: 11 additions & 0 deletions src/__tests__/mocks/fetch.mock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export interface MockFetchResponse {
ok: boolean;
status: number;
statusText?: string;
headers?: { get: (name: string) => string | null };
json?: () => Promise<unknown>;
text?: () => Promise<string>;
}
Expand All @@ -17,13 +18,19 @@ export function createMockFetch(handler: FetchHandler) {
return vi.fn().mockImplementation(handler);
}

/** Helper to create mock headers */
function createMockHeaders(): { get: (name: string) => string | null } {
return { get: () => null };
}

/**
* Creates a mock fetch that returns a successful JSON response
*/
export function createSuccessFetch(data: unknown): ReturnType<typeof createMockFetch> {
return createMockFetch(async () => ({
ok: true,
status: 200,
headers: createMockHeaders(),
json: async () => data,
text: async () => JSON.stringify(data),
}));
Expand All @@ -41,6 +48,7 @@ export function createErrorFetch(
ok: false,
status,
statusText,
headers: createMockHeaders(),
json: async () => ({ error: body }),
text: async () => body,
}));
Expand Down Expand Up @@ -76,6 +84,7 @@ export function createJinaEmbeddingResponse(embeddings: number[][]): MockFetchRe
return {
ok: true,
status: 200,
headers: createMockHeaders(),
json: async () => ({
data: embeddings.map((embedding) => ({ embedding })),
}),
Expand All @@ -93,6 +102,7 @@ export function createOllamaEmbeddingResponse(embedding: number[]): MockFetchRes
return {
ok: true,
status: 200,
headers: createMockHeaders(),
json: async () => ({ embedding }),
text: async () => JSON.stringify({ embedding }),
};
Expand All @@ -105,6 +115,7 @@ export function createOllamaBatchEmbeddingResponse(embeddings: number[][]): Mock
return {
ok: true,
status: 200,
headers: createMockHeaders(),
json: async () => ({ embeddings }),
text: async () => JSON.stringify({ embeddings }),
};
Expand Down
Loading