/** * Trigo AI Agent - Language Model-based Move Selection (Frontend/Backend Common) * * Platform-agnostic AI agent that accepts ONNX session from platform-specific code. * No direct dependency on onnxruntime packages - uses dependency injection pattern. * * Uses ONNX language model to score and select moves by: * 1. Getting all valid moves for current position * 2. Scoring each move by appending it to TGN and computing token probabilities * 3. Selecting move with highest probability (argmax) */ import { ModelInferencer } from "./modelInferencer"; import { TrigoGame, StoneType } from "./trigo/game"; import type { Move, Stone } from "./trigo/types"; /** * Configuration for the AI agent */ export interface TrigoAgentConfig { vocabSize?: number; seqLen?: number; temperature?: number; } /** * Move score result */ export interface MoveScore { move: Move; score: number; logProb: number; } /** * Trigo AI Agent for move generation * Compatible with both frontend (onnxruntime-web) and backend (onnxruntime-node) */ export class TrigoAgent { private inferencer: ModelInferencer; constructor(inferencer: ModelInferencer) { this.inferencer = inferencer; } /** * Check if agent is initialized (checks if inferencer has a session) */ isInitialized(): boolean { // Agent is initialized if the inferencer has been set up return this.inferencer !== null; } /** * Convert Stone type to player string */ private stoneToPlayer(stone: Stone): "black" | "white" { if (stone === StoneType.BLACK) return "black"; if (stone === StoneType.WHITE) return "white"; throw new Error(`Invalid stone type: ${stone}`); } /** * Convert string to token IDs (byte-level encoding with ASCII direct mapping) * For characters 32-127, token_id = ascii_value * Special tokens (0-3) and newline (10) are handled by the tokenizer */ private stringToTokens(text: string): number[] { return Array.from(text).map((char) => char.charCodeAt(0)); } /** * Compute softmax probabilities from logits */ private softmax(logits: Float32Array, vocabSize: number): Float32Array { const probs = new Float32Array(vocabSize); let maxLogit = -Infinity; // Find max for numerical stability for (let i = 0; i < vocabSize; i++) { if (logits[i] > maxLogit) { maxLogit = logits[i]; } } // Compute exp and sum let sum = 0; for (let i = 0; i < vocabSize; i++) { probs[i] = Math.exp(logits[i] - maxLogit); sum += probs[i]; } // Normalize for (let i = 0; i < vocabSize; i++) { probs[i] /= sum; } return probs; } /** * Score a candidate move by computing token probabilities * * Clones the game, applies the move, generates new TGN, and computes * the probability of the move tokens. */ async scoreMove(game: TrigoGame, move: Move): Promise { // Clone the game const clonedGame = game.clone(); // Apply the move to the cloned game let success: boolean; if (move.isPass) { success = clonedGame.pass(); } else if (move.x !== undefined && move.y !== undefined && move.z !== undefined) { success = clonedGame.drop({ x: move.x, y: move.y, z: move.z }); } else { // Invalid move format return -1000; } if (!success) { // Invalid move, return very low probability return -1000; } // Generate TGN from both original and cloned game const newTGN = clonedGame.toTGN().trim(); // Extract the move substring // The move should be the new content added after the current TGN const moveTokens = this.extractMoveTokens(newTGN); if (moveTokens.length === 0) { // Could not extract move, return low probability return -100; } // Convert new TGN to tokens const tokens = this.stringToTokens(newTGN); // Get configuration const config = this.inferencer.getConfig(); const seqLen = config.seqLen; const vocabSize = config.vocabSize; // Truncate if too long if (tokens.length > seqLen) { tokens.splice(0, tokens.length - seqLen); } // Run inference (START_TOKEN will be prepended by inferencer) const logits = await this.inferencer.runInference(tokens); // Compute probability for the move tokens // Note: inferencer prepends START_TOKEN, so positions are offset by +1 // Token sequence: [START_TOKEN, ...tokens, PAD, PAD, ...] // Position in output: token_i is at position i+1 in the padded sequence // Find where move tokens start in original token sequence const moveStartInTokens = tokens.length - moveTokens.length; let logProb = 0; for (let i = 0; i < moveTokens.length; i++) { // Position of this move token in the original tokens array const tokenPos = moveStartInTokens + i; // Skip if position is out of bounds // Logits at position tokenPos predict the token at tokenPos+1 if (tokenPos < 0 || tokenPos >= tokens.length) continue; const offset = tokenPos * vocabSize; const tokenLogits = logits.slice(offset, offset + vocabSize); const probs = this.softmax(tokenLogits, vocabSize); const tokenId = moveTokens[i]; const prob = probs[tokenId]; if (prob > 0) { logProb += Math.log(prob); } else { // If probability is zero, assign very low prob logProb += -100; } } return logProb; } /** * Extract move tokens from TGN difference * Returns the tokens that were added between currentTGN and newTGN */ private extractMoveTokens(tgn: string): number[] { const moveCapture = tgn.match(/[Pa-z0]+$/); return this.stringToTokens(moveCapture ? moveCapture[0] : ""); } /** * Select the best move using the language model * * Scores all valid moves and returns the one with highest probability (argmax). */ async selectBestMove(game: TrigoGame): Promise { if (!this.isInitialized()) { throw new Error("Agent not initialized. Pass initialized inferencer to constructor."); } console.log("[TrigoAgent] Selecting move..."); // Get current player as string const currentPlayer = this.stoneToPlayer(game.getCurrentPlayer()); // Get all valid moves const validMoves: Move[] = game.validMovePositions().map((pos) => ({ x: pos.x, y: pos.y, z: pos.z, player: currentPlayer })); validMoves.push({ player: currentPlayer, isPass: true }); // Add pass move if (validMoves.length === 0) { console.log("[TrigoAgent] No valid moves available"); return null; } console.log(`[TrigoAgent] Evaluating ${validMoves.length} valid moves...`); // Score each move const scores: MoveScore[] = []; for (const move of validMoves) { const logProb = await this.scoreMove(game, move); scores.push({ move, score: Math.exp(logProb), // Convert log prob to probability logProb }); } // Find best move (argmax) scores.sort((a, b) => b.logProb - a.logProb); const bestMove = scores[0]; console.debug("scores:", scores); console.log("[TrigoAgent] Best move:", bestMove.move, "score:", bestMove.score.toFixed(6)); console.log("[TrigoAgent] Top 5 moves:"); for (let i = 0; i < Math.min(5, scores.length); i++) { console.log(` ${i + 1}. ${scores[i].move}: ${scores[i].score.toFixed(6)}`); } return bestMove.move; } /** * Clean up resources */ destroy(): void { this.inferencer.destroy(); console.log("[TrigoAgent] Destroyed"); } }