diff --git a/discojs/src/models/onnx.ts b/discojs/src/models/onnx.ts index 8227b1ebd..449beb039 100644 --- a/discojs/src/models/onnx.ts +++ b/discojs/src/models/onnx.ts @@ -1,14 +1,17 @@ -import { AutoModelForCausalLM, PreTrainedModel, Tensor } from '@xenova/transformers'; -import { Model } from './index.js'; -import type { WeightsContainer } from '../index.js'; -import { List } from 'immutable'; -import type { CausalLMOutput} from '@xenova/transformers'; -import type { GenerationConfig as TFJSGenerationConfig } from './gpt/config.js'; -import { DefaultGenerationConfig } from './gpt/config.js'; +import { + AutoModelForCausalLM, + PreTrainedModel, + Tensor, +} from "@xenova/transformers"; +import { Model } from "./index.js"; +import type { WeightsContainer } from "../index.js"; +import { List } from "immutable"; +import type { CausalLMOutput } from "@xenova/transformers"; +import type { GenerationConfig as TFJSGenerationConfig } from "./gpt/config.js"; +import { DefaultGenerationConfig } from "./gpt/config.js"; import type { Batched, DataFormat } from "../index.js"; - -export class ONNXModel extends Model<'text'> { +export class ONNXModel extends Model<"text"> { private model: PreTrainedModel; private constructor(model: PreTrainedModel) { @@ -16,12 +19,12 @@ export class ONNXModel extends Model<'text'> { this.model = model; } - static async init_pretrained(modelName = 'Xenova/gpt2'): Promise { + static async init_pretrained(modelName = "Xenova/gpt2"): Promise { const model = await AutoModelForCausalLM.from_pretrained(modelName); return new ONNXModel(model); } - getConfig(): Record { + get config(): Record { return this.model.config as Record; } @@ -30,62 +33,76 @@ export class ONNXModel extends Model<'text'> { options?: Partial ): Promise> { const config = Object.assign({}, DefaultGenerationConfig, options); - + return List( await Promise.all( - batch.map(tokens => this.#predictSingle(tokens, config)) + batch.map((tokens) => this.#predictSingle(tokens, config)) ) ); } - async #predictSingle( tokens: DataFormat.ModelEncoded["text"][0], config: TFJSGenerationConfig ): Promise { - const contextLength = (this.model.config as { max_position_embeddings?: number }).max_position_embeddings ?? 1024; + const contextLength = + (this.model.config as { max_position_embeddings?: number }) + .max_position_embeddings ?? 1024; const truncated = tokens.slice(-contextLength).toArray(); - + if (truncated.length === 0) { - throw new Error('Token list is empty. Cannot run generate().'); + throw new Error("Token list is empty. Cannot run generate()."); } - - const input_ids = new Tensor('int64', truncated.map(BigInt), [1, truncated.length]); - - const output = await this.model.generate(input_ids, { + + const input_ids = new Tensor("int64", truncated.map(BigInt), [ + 1, + truncated.length, + ]); + + const output = (await this.model.generate(input_ids, { max_new_tokens: 1, temperature: config.temperature, do_sample: config.doSample, top_k: config.topk, - }) as number[][]; - - if (!Array.isArray(output) || output.length === 0 || !Array.isArray(output[0])) { - throw new Error('ONNX model.generate() did not return valid sequences.'); - } - + })) as number[][]; + + if ( + !Array.isArray(output) || + output.length === 0 || + !Array.isArray(output[0]) + ) { + throw new Error("ONNX model.generate() did not return valid sequences."); + } + const predicted_id = output[0].at(-1) as number; return Number(predicted_id); - } - - async getLogits(batch: List>): Promise { - const input_ids_array: number[][] = batch.toArray().map(seq => seq.toArray()); + const input_ids_array: number[][] = batch + .toArray() + .map((seq) => seq.toArray()); const attention_mask_array: number[][] = input_ids_array.map( (seq): number[] => new Array(seq.length).fill(1) ); - + const input_ids_flat = input_ids_array.flat(); const attention_mask_flat = attention_mask_array.flat(); const shape = [input_ids_array.length, input_ids_array[0].length]; - + // use BigInt for int64 compatibility - const input_ids = new Tensor('int64', input_ids_flat.map(BigInt), shape); - const attention_mask = new Tensor('int64', attention_mask_flat.map(BigInt), shape); + const input_ids = new Tensor("int64", input_ids_flat.map(BigInt), shape); + const attention_mask = new Tensor( + "int64", + attention_mask_flat.map(BigInt), + shape + ); // run model forward - const outputs = await this.model.forward({ input_ids, attention_mask }) as CausalLMOutput; + const outputs = (await this.model.forward({ + input_ids, + attention_mask, + })) as CausalLMOutput; return outputs.logits; } @@ -93,18 +110,19 @@ export class ONNXModel extends Model<'text'> { await Promise.resolve(); // dummy await const yieldFlag = false; if (yieldFlag) yield undefined as never; // satisfy 'require-yield' - throw new Error('Training not supported for ONNX models'); + throw new Error("Training not supported for ONNX models"); } get weights(): WeightsContainer { - throw new Error('Weights access not supported in ONNX models'); + throw new Error("Weights access not supported in ONNX models"); } set weights(_: WeightsContainer) { - throw new Error('Weights setting not supported in ONNX models'); + throw new Error("Weights setting not supported in ONNX models"); } [Symbol.dispose](): void { // Dispose of the model to free up memory - void this.model.dispose();} + void this.model.dispose(); + } } diff --git a/package-lock.json b/package-lock.json index 52b95cdc7..daa6993c8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -4789,9 +4789,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001727", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001727.tgz", - "integrity": "sha512-pB68nIHmbN6L/4C6MH1DokyR3bYqFwjaSs/sWDHGj4CTcFtQUQMuJftVwWkXq7mNWOybD3KhUv3oWHoGxgP14Q==", + "version": "1.0.30001764", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001764.tgz", + "integrity": "sha512-9JGuzl2M+vPL+pz70gtMF9sHdMFbY9FJaQBi186cHKH3pSzDvzoUJUPV6fqiKIMyXbud9ZLg4F3Yza1vJ1+93g==", "dev": true, "funding": [ { diff --git a/webapp/src/assets/svg/CleanIcon.vue b/webapp/src/assets/svg/CleanIcon.vue new file mode 100644 index 000000000..864c91f87 --- /dev/null +++ b/webapp/src/assets/svg/CleanIcon.vue @@ -0,0 +1,24 @@ + + diff --git a/webapp/src/assets/svg/MessageArrow.vue b/webapp/src/assets/svg/MessageArrow.vue new file mode 100644 index 000000000..9db1de518 --- /dev/null +++ b/webapp/src/assets/svg/MessageArrow.vue @@ -0,0 +1,23 @@ + + diff --git a/webapp/src/assets/svg/StopIcon.vue b/webapp/src/assets/svg/StopIcon.vue new file mode 100644 index 000000000..72148ac0c --- /dev/null +++ b/webapp/src/assets/svg/StopIcon.vue @@ -0,0 +1,23 @@ + + diff --git a/webapp/src/components/testing/Benchmarcks.vue b/webapp/src/components/testing/Benchmarcks.vue new file mode 100644 index 000000000..1eeb5c8b3 --- /dev/null +++ b/webapp/src/components/testing/Benchmarcks.vue @@ -0,0 +1,3 @@ + diff --git a/webapp/src/components/testing/Chat.vue b/webapp/src/components/testing/Chat.vue new file mode 100644 index 000000000..b8c364e06 --- /dev/null +++ b/webapp/src/components/testing/Chat.vue @@ -0,0 +1,589 @@ +