From 4cab7df53c41e078c7a99dbb9be2306e6deabd64 Mon Sep 17 00:00:00 2001 From: Uspectacle Date: Sat, 8 Nov 2025 15:00:42 +0100 Subject: [PATCH] feat: :art: Add type annotations, docstrings and formating This PR introduces extensive type annotations, docstrings, and code improvements throughout the MAIA2 codebase. All changes have been made with minimal impact on runtime behavior and backward compatibility, though some evolving-type variable renames and functional clarifications may affect code that relies on previous names. --- maia2/dataset.py | 85 +++- maia2/inference.py | 325 ++++++++++---- maia2/main.py | 886 +++++++++++++++++++++++++++---------- maia2/model.py | 108 +++-- maia2/requirements-dev.txt | 5 + maia2/requirements.txt | 18 +- maia2/train.py | 196 +++++--- maia2/utils.py | 494 +++++++++++++++------ pyproject.toml | 7 +- 9 files changed, 1563 insertions(+), 561 deletions(-) create mode 100644 maia2/requirements-dev.txt diff --git a/maia2/dataset.py b/maia2/dataset.py index 4f82abf..3569643 100644 --- a/maia2/dataset.py +++ b/maia2/dataset.py @@ -1,37 +1,82 @@ -import gdown +"""Dataset loading utilities for MAIA2. + +Provides functions to download and load chess datasets containing +positions, moves, and Elo ratings for training and testing. +""" + import os +from typing import Final + +import gdown # type: ignore import pandas as pd -def load_example_test_dataset(save_root = "./maia2_data"): - - url = "https://drive.google.com/uc?id=1fSu4Yp8uYj7xocbHAbjBP6DthsgiJy9X" - if os.path.exists(save_root) == False: +# Constants +DEFAULT_SAVE_ROOT: Final[str] = "./maia2_data" +TEST_DATASET_URL: Final[str] = ( + "https://drive.google.com/uc?id=1fSu4Yp8uYj7xocbHAbjBP6DthsgiJy9X" +) +TRAIN_DATASET_URL: Final[str] = ( + "https://drive.google.com/uc?id=1XBeuhB17z50mFK4tDvPG9rQRbxLSzNqB" +) + + +def load_example_test_dataset( + save_root: str = DEFAULT_SAVE_ROOT, +) -> pd.DataFrame: + """Download and load example test dataset. + + Args: + save_root: Directory to save dataset. + + Returns: + DataFrame with columns [board, move, active_elo, opponent_elo]. + Filtered to positions after move 10. + + Raises: + OSError: If download or directory creation fails. + pd.errors.EmptyDataError: If dataset is empty or corrupted. + """ + + if not os.path.exists(save_root): os.makedirs(save_root) + output_path = os.path.join(save_root, "example_test_dataset.csv") - + if os.path.exists(output_path): print("Example test dataset already downloaded.") else: - gdown.download(url, output_path, quiet=False) + gdown.download(TEST_DATASET_URL, output_path, quiet=False) print("Example test dataset downloaded.") - + data = pd.read_csv(output_path) - data = data[data.move_ply > 10][['board', 'move', 'active_elo', 'opponent_elo']] - - return data - -def load_example_train_dataset(save_root = "./maia2_data"): - - url = "https://drive.google.com/uc?id=1XBeuhB17z50mFK4tDvPG9rQRbxLSzNqB" - if os.path.exists(save_root) == False: + filtered_data = data[data.move_ply > 10][ + ["board", "move", "active_elo", "opponent_elo"] + ] + + return filtered_data + + +def load_example_train_dataset(save_root: str = DEFAULT_SAVE_ROOT) -> str: + """Download example training dataset. + + Args: + save_root: Directory to save dataset. + + Returns: + Path to downloaded training dataset CSV file. + + Raises: + OSError: If download or directory creation fails. + """ + if not os.path.exists(save_root): os.makedirs(save_root) + output_path = os.path.join(save_root, "example_train_dataset.csv") - + if os.path.exists(output_path): print("Example train dataset already downloaded.") else: - gdown.download(url, output_path, quiet=False) + gdown.download(TRAIN_DATASET_URL, output_path, quiet=False) print("Example train dataset downloaded.") - + return output_path - \ No newline at end of file diff --git a/maia2/inference.py b/maia2/inference.py index f4f2cfa..8a52888 100644 --- a/maia2/inference.py +++ b/maia2/inference.py @@ -1,59 +1,147 @@ -from .utils import * -from .main import * +"""Inference functions for MAIA2 model. -def preprocessing(fen, elo_self, elo_oppo, elo_dict, all_moves_dict): - - if fen.split(' ')[1] == 'w': +Provides preprocessing, dataset creation, and inference utilities +for running predictions on chess positions. +""" + +from typing import Dict, List, Tuple, cast + +import chess +import pandas as pd +import torch +import torch.utils.data +import tqdm + +from .main import MAIA2Model +from .utils import ( + BoardPosition, + ChessMove, + EloRangeDict, + EloRating, + MovesDict, + ReverseMovesDict, + board_to_tensor, + create_elo_dict, + get_all_possible_moves, + map_to_category, + mirror_move, +) + +TestDatasetItem = Tuple[BoardPosition, torch.Tensor, + EloRating, EloRating, torch.Tensor] +DictMoveProb = Dict[ChessMove, float] +PreparedDicts = Tuple[MovesDict, EloRangeDict, ReverseMovesDict] +DeprecatedPreparedDicts = List[(MovesDict | EloRangeDict | ReverseMovesDict)] + + +def preprocessing( + fen: BoardPosition, + elo_self: EloRating, + elo_oppo: EloRating, + elo_dict: EloRangeDict, + all_moves_dict: MovesDict, +) -> Tuple[torch.Tensor, EloRating, EloRating, torch.Tensor]: + """Preprocess FEN and Elo ratings into model tensors. + + Args: + fen: FEN string of chess position. + elo_self: Elo rating of active player. + elo_oppo: Elo rating of opponent. + elo_dict: Mapping of Elo ratings to categories. + all_moves_dict: Mapping of moves to indices. + + Returns: + Tuple of (board_tensor, elo_self_cat, elo_oppo_cat, legal_moves_mask). + """ + if fen.split(" ")[1] == "w": board = chess.Board(fen) - elif fen.split(' ')[1] == 'b': + elif fen.split(" ")[1] == "b": board = chess.Board(fen).mirror() else: raise ValueError(f"Invalid fen: {fen}") - + board_input = board_to_tensor(board) - + elo_self = map_to_category(elo_self, elo_dict) elo_oppo = map_to_category(elo_oppo, elo_dict) - - legal_moves = torch.zeros(len(all_moves_dict)) - legal_moves_idx = torch.tensor([all_moves_dict[move.uci()] for move in board.legal_moves]) + + legal_moves = torch.zeros(len(all_moves_dict), dtype=torch.float32) + legal_moves_idx = torch.tensor( + [all_moves_dict[move.uci()] for move in board.legal_moves] + ) legal_moves[legal_moves_idx] = 1 - + return board_input, elo_self, elo_oppo, legal_moves -class TestDataset(torch.utils.data.Dataset): - - def __init__(self, data, all_moves_dict, elo_dict): - +class TestDataset(torch.utils.data.Dataset[TestDatasetItem]): + """PyTorch Dataset for MAIA2 test data.""" + + def __init__( + self, + data: pd.DataFrame, + all_moves_dict: MovesDict, + elo_dict: EloRangeDict, + ): + """Initialize dataset. + + Args: + data: DataFrame with [fen, move, elo_self, elo_oppo]. + all_moves_dict: UCI moves to model indices. + elo_dict: Raw Elo to binned categories. + """ self.all_moves_dict = all_moves_dict self.data = data.values.tolist() self.elo_dict = elo_dict - - def __len__(self): - + + def __len__(self) -> int: + """Return number of samples.""" return len(self.data) - - def __getitem__(self, idx): - + + def __getitem__(self, idx: int) -> TestDatasetItem: + """Get preprocessed tensors for position. + + Args: + idx: Sample index. + + Returns: + Tuple of (fen, board_tensor, elo_self_cat, elo_oppo_cat, legal_moves_mask). + """ fen, _, elo_self, elo_oppo = self.data[idx] - board_input, elo_self, elo_oppo, legal_moves = preprocessing(fen, elo_self, elo_oppo, self.elo_dict, self.all_moves_dict) - + board_input, elo_self, elo_oppo, legal_moves = preprocessing( + fen, elo_self, elo_oppo, self.elo_dict, self.all_moves_dict + ) + return fen, board_input, elo_self, elo_oppo, legal_moves -def get_preds(model, dataloader, all_moves_dict_reversed): - - move_probs = [] - win_probs = [] - + +def get_preds( + model: MAIA2Model, + dataloader: torch.utils.data.DataLoader, + all_moves_dict_reversed: ReverseMovesDict, +) -> Tuple[List[DictMoveProb], List[float]]: + """Compute move and win probabilities for dataset. + + Args: + model: Trained MAIA2 model. + dataloader: DataLoader yielding test data. + all_moves_dict_reversed: Move indices to UCI strings. + + Returns: + Tuple of (move_probs_list, win_probs_list). + move_probs_list: List of dicts mapping UCI moves to probabilities. + win_probs_list: List of win probabilities. + """ + move_probs: List[DictMoveProb] = [] + win_probs: List[float] = [] + device = next(model.parameters()).device - + model.eval() + with torch.no_grad(): - for fens, boards, elos_self, elos_oppo, legal_moves in dataloader: - boards = boards.to(device) elos_self = elos_self.to(device) elos_oppo = elos_oppo.to(device) @@ -62,110 +150,169 @@ def get_preds(model, dataloader, all_moves_dict_reversed): logits_maia, _, logits_value = model(boards, elos_self, elos_oppo) logits_maia_legal = logits_maia * legal_moves probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() - logits_value = (logits_value / 2 + 0.5).clamp(0, 1).cpu().tolist() - - for i in range(len(fens)): - - fen = fens[i] + + for i, fen in enumerate(fens): black_flag = False - + # calculate win probability logit_value = logits_value[i] if fen.split(" ")[1] == "b": logit_value = 1 - logit_value black_flag = True win_probs.append(round(logit_value, 4)) - + # calculate move probabilities move_probs_each = {} - legal_move_indices = legal_moves[i].nonzero().flatten().cpu().numpy().tolist() + legal_move_indices = ( + legal_moves[i].nonzero().flatten().cpu().numpy().tolist() + ) legal_moves_mirrored = [] for move_idx in legal_move_indices: move = all_moves_dict_reversed[move_idx] if black_flag: move = mirror_move(move) legal_moves_mirrored.append(move) - - for j in range(len(legal_move_indices)): - move_probs_each[legal_moves_mirrored[j]] = round(probs[i][legal_move_indices[j]], 4) - - move_probs_each = dict(sorted(move_probs_each.items(), key=lambda item: item[1], reverse=True)) + + for j, legal_move_index in enumerate(legal_move_indices): + move_probs_each[legal_moves_mirrored[j]] = round( + probs[i][legal_move_index], 4 + ) + + move_probs_each = dict( + sorted( + move_probs_each.items(), key=lambda item: item[1], reverse=True + ) + ) move_probs.append(move_probs_each) - + return move_probs, win_probs -def inference_batch(data, model, verbose, batch_size, num_workers): +def inference_batch( + data: pd.DataFrame, + model: MAIA2Model, + verbose: bool, + batch_size: int, + num_workers: int, +) -> Tuple[pd.DataFrame, float]: + """Run inference on batch of chess positions. + + Args: + data: DataFrame with [fen, move, elo_self, elo_oppo]. + model: Trained MAIA2 model. + verbose: Show progress bar if True. + batch_size: Batch size for DataLoader. + num_workers: Number of DataLoader workers. + Returns: + Tuple of (updated_dataframe, accuracy). + updated_dataframe: Input data with added win_probs and move_probs columns. + accuracy: Move prediction accuracy. + """ all_moves = get_all_possible_moves() all_moves_dict = {move: i for i, move in enumerate(all_moves)} elo_dict = create_elo_dict() - all_moves_dict_reversed = {v: k for k, v in all_moves_dict.items()} + dataset = TestDataset(data, all_moves_dict, elo_dict) - dataloader = torch.utils.data.DataLoader(dataset, - batch_size=batch_size, - shuffle=False, - drop_last=False, - num_workers=num_workers) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers, + ) + if verbose: - dataloader = tqdm.tqdm(dataloader) - - move_probs, win_probs = get_preds(model, dataloader, all_moves_dict_reversed) - + dataloader = cast( + torch.utils.data.DataLoader[TestDatasetItem], tqdm.tqdm(dataloader) + ) + + move_probs, win_probs = get_preds( + model, dataloader, all_moves_dict_reversed) + data["win_probs"] = win_probs data["move_probs"] = move_probs - + acc = 0 for i in range(len(data)): - highest_prob_move = max(move_probs[i], key=move_probs[i].get) + highest_prob_move, _highest_prob = max( + move_probs[i].items(), key=lambda item: item[1] + ) + if highest_prob_move == data.iloc[i]["move"]: acc += 1 - acc = round(acc / len(data), 4) - - return data, acc + accuracy = round(acc / len(data), 4) + + return data, accuracy -def prepare(): +def prepare() -> DeprecatedPreparedDicts: + """Initialize dictionaries for model inference. + Returns: + List of [all_moves_dict, elo_dict, all_moves_dict_reversed]. + all_moves_dict: UCI moves to indices. + elo_dict: Raw Elo to categories. + all_moves_dict_reversed: Indices to UCI moves. + """ all_moves = get_all_possible_moves() all_moves_dict = {move: i for i, move in enumerate(all_moves)} elo_dict = create_elo_dict() - all_moves_dict_reversed = {v: k for k, v in all_moves_dict.items()} - + return [all_moves_dict, elo_dict, all_moves_dict_reversed] -def inference_each(model, prepared, fen, elo_self, elo_oppo): - - all_moves_dict, elo_dict, all_moves_dict_reversed = prepared - - board_input, elo_self, elo_oppo, legal_moves = preprocessing(fen, elo_self, elo_oppo, elo_dict, all_moves_dict) - +def inference_each( + model: MAIA2Model, + prepared: PreparedDicts | DeprecatedPreparedDicts, + fen: BoardPosition, + elo_self: EloRating, + elo_oppo: EloRating, +) -> Tuple[DictMoveProb, float]: + """Analyze single chess position with MAIA2. + + Args: + model: Trained MAIA2 model. + prepared: Tuple from prepare() with mapping dicts. + fen: FEN string of position. + elo_self: Elo of player to move. + elo_oppo: Elo of opponent. + + Returns: + Tuple of (move_probs, win_prob). + move_probs: Dict mapping UCI moves to probabilities (sorted). + win_prob: Win probability (0-1). + """ + all_moves_dict, elo_dict, all_moves_dict_reversed = cast( + PreparedDicts, prepared) + board_input, elo_self, elo_oppo, legal_moves = preprocessing( + fen, elo_self, elo_oppo, elo_dict, all_moves_dict + ) + device = next(model.parameters()).device - model.eval() - + board_input = board_input.unsqueeze(dim=0).to(device) - elo_self = torch.tensor([elo_self]).to(device) - elo_oppo = torch.tensor([elo_oppo]).to(device) + elo_self_tensor = torch.tensor([elo_self]).to(device) + elo_oppo_tensor = torch.tensor([elo_oppo]).to(device) legal_moves = legal_moves.unsqueeze(dim=0).to(device) - - logits_maia, _, logits_value = model(board_input, elo_self, elo_oppo) + + logits_maia, _, logits_value = model( + board_input, elo_self_tensor, elo_oppo_tensor) logits_maia_legal = logits_maia * legal_moves probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() - logits_value = (logits_value / 2 + 0.5).clamp(0, 1).item() - + black_flag = False if fen.split(" ")[1] == "b": logits_value = 1 - logits_value black_flag = True win_prob = round(logits_value, 4) - - move_probs = {} + + move_probs: DictMoveProb = {} legal_move_indices = legal_moves.nonzero().flatten().cpu().numpy().tolist() legal_moves_mirrored = [] for move_idx in legal_move_indices: @@ -173,11 +320,13 @@ def inference_each(model, prepared, fen, elo_self, elo_oppo): if black_flag: move = mirror_move(move) legal_moves_mirrored.append(move) - - for j in range(len(legal_move_indices)): - move_probs[legal_moves_mirrored[j]] = round(probs[0][legal_move_indices[j]], 4) - - move_probs = dict(sorted(move_probs.items(), key=lambda item: item[1], reverse=True)) - - return move_probs, win_prob + for j, legal_move_index in enumerate(legal_move_indices): + move_probs[legal_moves_mirrored[j]] = round( + probs[0][legal_move_index], 4) + + move_probs = dict( + sorted(move_probs.items(), key=lambda item: item[1], reverse=True) + ) + + return move_probs, win_prob diff --git a/maia2/main.py b/maia2/main.py index 32b061d..5e21ec4 100644 --- a/maia2/main.py +++ b/maia2/main.py @@ -1,141 +1,239 @@ -import chess.pgn +# pylint: disable=too-many-lines +"""Main training and model implementation for MAIA2. + +Provides core functionality including data processing, model architecture +(ResNet + Transformer), training loop, and evaluation utilities. +""" + +import threading +from multiprocessing import Pool, Queue +from typing import Any, Dict, List, Optional, Tuple, cast + import chess -import pdb -from multiprocessing import Pool, cpu_count, Queue, Process +import chess.pgn +import pandas as pd import torch -import tqdm -from .utils import * -import torch.nn as nn import torch.nn.functional as F -from tqdm.contrib.concurrent import process_map -import os -import pandas as pd -import time +import tqdm from einops import rearrange +from torch import nn +from tqdm.contrib.concurrent import process_map - -def process_chunks(cfg, pgn_path, pgn_chunks, elo_dict): - +from .utils import ( + BoardPosition, + ChessMove, + Chunk, + Config, + EloRangeDict, + EloRating, + FileOffset, + MovesDict, + board_to_tensor, + extract_clock_time, + get_side_info, + map_to_category, + mirror_move, +) + +# Type aliases +ResultScore = int +EloPair = Tuple[EloRating, EloRating] +DictFrequency = Dict[EloPair, int] +TrainingPositionData = Tuple[ + BoardPosition, ChessMove, EloRating, EloRating, ResultScore +] +ProcessPosition = Tuple[List[TrainingPositionData], int, DictFrequency] +ProcessChunks = Tuple[List[ProcessPosition], int, int] +GameResult = Tuple[chess.pgn.Game, EloRating, EloRating, ResultScore] +ModelOutput = Tuple[torch.Tensor, torch.Tensor, torch.Tensor] +MAIA1DatasetItem = Tuple[torch.Tensor, int, + int, int, torch.Tensor, torch.Tensor] +MAIA2DatasetItem = Tuple[torch.Tensor, int, + int, int, torch.Tensor, torch.Tensor, int] + + +def process_chunks( + cfg: Config, + pgn_path: str, + pgn_chunks: List[Chunk], + elo_dict: EloRangeDict, +) -> ProcessChunks: + """Process PGN file chunks in parallel. + + Args: + cfg: Configuration with processing parameters. + pgn_path: Path to PGN file. + pgn_chunks: List of (start_pos, end_pos) byte positions. + elo_dict: Elo ratings to category indices. + + Returns: + Tuple of (processed_positions, valid_games_count, chunks_count). + """ # process_per_chunk((pgn_chunks[0][0], pgn_chunks[0][1], pgn_path, elo_dict, cfg)) - + if cfg.verbose: - results = process_map(process_per_chunk, [(start, end, pgn_path, elo_dict, cfg) for start, end in pgn_chunks], max_workers=len(pgn_chunks), chunksize=1) + results = process_map( + process_per_chunk, + [(start, end, pgn_path, elo_dict, cfg) + for start, end in pgn_chunks], + max_workers=len(pgn_chunks), + chunksize=1, + ) else: with Pool(processes=len(pgn_chunks)) as pool: - results = pool.map(process_per_chunk, [(start, end, pgn_path, elo_dict, cfg) for start, end in pgn_chunks]) - - ret = [] + results = pool.map( + process_per_chunk, + [(start, end, pgn_path, elo_dict, cfg) + for start, end in pgn_chunks], + ) + + ret: List[ProcessPosition] = [] count = 0 - list_of_dicts = [] + list_of_dicts: List[DictFrequency] = [] for result, game_count, frequency in results: ret.extend(result) count += game_count list_of_dicts.append(frequency) - - total_counts = {} + + total_counts: DictFrequency = {} for d in list_of_dicts: for key, value in d.items(): total_counts[key] = total_counts.get(key, 0) + value print(total_counts, flush=True) - + return ret, count, len(pgn_chunks) -def process_per_game(game, white_elo, black_elo, white_win, cfg): +def process_per_game( + game: chess.pgn.Game, + white_elo: EloRating, + black_elo: EloRating, + white_win: ResultScore, + cfg: Config, +) -> List[TrainingPositionData]: + """Extract training positions from single game. + + Args: + game: Chess game with move history. + white_elo: White's Elo category index. + black_elo: Black's Elo category index. + white_win: Result from white's perspective (+1/0/-1). + cfg: Configuration with first_n_moves, clock_threshold, max_ply. + + Returns: + List of (fen, move_uci, elo_self, elo_oppo, result) tuples. + """ + ret: List[TrainingPositionData] = [] - ret = [] - board = game.board() moves = list(game.mainline_moves()) - + for i, node in enumerate(game.mainline()): - move = moves[i] - + if i >= cfg.first_n_moves: - comment = node.comment clock_info = extract_clock_time(comment) - - if i % 2 == 0: + + if i % 2 == 0: # White to move board_input = board.fen() move_input = move.uci() elo_self = white_elo elo_oppo = black_elo active_win = white_win - - else: + else: # Black to move board_input = board.mirror().fen() move_input = mirror_move(move.uci()) elo_self = black_elo elo_oppo = white_elo - active_win = - white_win + active_win = -white_win + + if clock_info and clock_info > cfg.clock_threshold: + ret.append((board_input, move_input, + elo_self, elo_oppo, active_win)) - if clock_info > cfg.clock_threshold: - ret.append((board_input, move_input, elo_self, elo_oppo, active_win)) - board.push(move) if i == cfg.max_ply: break - + return ret -def game_filter(game): - - white_elo = game.headers.get("WhiteElo", "?") - black_elo = game.headers.get("BlackElo", "?") +def game_filter(game: chess.pgn.Game) -> Optional[GameResult]: + """Filter games based on metadata and format. + + Args: + game: Chess game with headers and moves. + + Returns: + Tuple of (game, white_elo, black_elo, white_win) if valid, else None. + Returns None if game fails any criteria. + """ + white_elo_str = game.headers.get("WhiteElo", "?") + black_elo_str = game.headers.get("BlackElo", "?") time_control = game.headers.get("TimeControl", "?") result = game.headers.get("Result", "?") event = game.headers.get("Event", "?") - - if white_elo == "?" or black_elo == "?" or time_control == "?" or result == "?" or event == "?": - return - - if 'Rated' not in event: - return - - if 'Rapid' not in event: - return - + + if ( + white_elo_str == "?" + or black_elo_str == "?" + or time_control == "?" + or result == "?" + or event == "?" + ): + return None + + if "Rated" not in event: + return None + + if "Rapid" not in event: + return None + for _, node in enumerate(game.mainline()): - if 'clk' not in node.comment: - return - - white_elo = int(white_elo) - black_elo = int(black_elo) + if "clk" not in node.comment: + return None + + white_elo = int(white_elo_str) + black_elo = int(black_elo_str) - if result == '1-0': + if result == "1-0": white_win = 1 - elif result == '0-1': + elif result == "0-1": white_win = -1 - elif result == '1/2-1/2': + elif result == "1/2-1/2": white_win = 0 else: - return - + return None + return game, white_elo, black_elo, white_win -def process_per_chunk(args): +def process_per_chunk( + args: Tuple[FileOffset, FileOffset, str, EloRangeDict, Config], +) -> ProcessPosition: + """Process chunk of games from PGN file. + + Args: + args: Tuple of (start_pos, end_pos, pgn_path, elo_dict, cfg). + Returns: + Tuple of (position_list, game_count, frequency_dict). + """ start_pos, end_pos, pgn_path, elo_dict, cfg = args - - ret = [] + + ret: List[TrainingPositionData] = [] game_count = 0 - - frequency = {} - - with open(pgn_path, 'r', encoding='utf-8') as pgn_file: - + frequency: DictFrequency = {} + + with open(pgn_path, "r", encoding="utf-8") as pgn_file: pgn_file.seek(start_pos) while pgn_file.tell() < end_pos: - game = chess.pgn.read_game(pgn_file) - + if game is None: break @@ -144,45 +242,68 @@ def process_per_chunk(args): game, white_elo, black_elo, white_win = filtered_game white_elo = map_to_category(white_elo, elo_dict) black_elo = map_to_category(black_elo, elo_dict) - + + # Ensure consistent Elo pair ordering if white_elo < black_elo: range_1, range_2 = black_elo, white_elo else: range_1, range_2 = white_elo, black_elo - + freq = frequency.get((range_1, range_2), 0) if freq >= cfg.max_games_per_elo_range: continue - - ret_per_game = process_per_game(game, white_elo, black_elo, white_win, cfg) + + ret_per_game = process_per_game( + game, white_elo, black_elo, white_win, cfg + ) ret.extend(ret_per_game) - if len(ret_per_game): - + if len(ret_per_game) > 0: if (range_1, range_2) in frequency: frequency[(range_1, range_2)] += 1 else: frequency[(range_1, range_2)] = 1 - + game_count += 1 - + return ret, game_count, frequency -class MAIA1Dataset(torch.utils.data.Dataset): - - def __init__(self, data, all_moves_dict, elo_dict, cfg): - +class MAIA1Dataset(torch.utils.data.Dataset[MAIA1DatasetItem]): + """Dataset for MAIA1 evaluation data.""" + + def __init__( + self, + data: pd.DataFrame, + all_moves_dict: MovesDict, + elo_dict: EloRangeDict, + cfg: Config, + ) -> None: + """Initialize dataset from DataFrame. + + Args: + data: DataFrame with [board, move, active_elo, opponent_elo, white_active]. + all_moves_dict: UCI moves to model indices. + elo_dict: Elo ratings to categories. + cfg: Configuration object. + """ self.all_moves_dict = all_moves_dict self.cfg = cfg self.data = data.values.tolist() self.elo_dict = elo_dict - - def __len__(self): - + + def __len__(self) -> int: + """Return number of positions.""" return len(self.data) - - def __getitem__(self, idx): - + + def __getitem__(self, idx: int) -> MAIA1DatasetItem: + """Get single training example. + + Args: + idx: Position index. + + Returns: + Tuple of (board, move_idx, elo_self, elo_oppo, legal_moves, side_info). + """ fen, move, elo_self, elo_oppo, white_active = self.data[idx] if white_active: @@ -190,60 +311,107 @@ def __getitem__(self, idx): else: board = chess.Board(fen).mirror() move = mirror_move(move) - + board_input = board_to_tensor(board) move_input = self.all_moves_dict[move] - + elo_self = map_to_category(elo_self, self.elo_dict) elo_oppo = map_to_category(elo_oppo, self.elo_dict) - - legal_moves, side_info = get_side_info(board, move, self.all_moves_dict) - + + legal_moves, side_info = get_side_info( + board, move, self.all_moves_dict) + return board_input, move_input, elo_self, elo_oppo, legal_moves, side_info -class MAIA2Dataset(torch.utils.data.Dataset): - - - def __init__(self, data, all_moves_dict, cfg): - +class MAIA2Dataset(torch.utils.data.Dataset[MAIA2DatasetItem]): + """Dataset for MAIA2 training data.""" + + def __init__( + self, + data: List[TrainingPositionData], + all_moves_dict: MovesDict, + cfg: Config, + ) -> None: + """Initialize dataset from processed games. + + Args: + data: List of (fen, move_uci, elo_self, elo_oppo, result) tuples. + all_moves_dict: UCI moves to model indices. + cfg: Configuration object. + """ self.all_moves_dict = all_moves_dict self.data = data self.cfg = cfg - - def __len__(self): - + + def __len__(self) -> int: + """Return number of positions.""" return len(self.data) - - def __getitem__(self, idx): - + + def __getitem__(self, idx: int) -> MAIA2DatasetItem: + """Get single training example. + + Args: + idx: Position index. + + Returns: + Tuple of (board, move_idx, elo_self, elo_oppo, legal_moves, side_info, result). + """ board_input, move_uci, elo_self, elo_oppo, active_win = self.data[idx] - + board = chess.Board(board_input) - board_input = board_to_tensor(board) - - legal_moves, side_info = get_side_info(board, move_uci, self.all_moves_dict) - + board_input_tensor = board_to_tensor(board) + + legal_moves, side_info = get_side_info( + board, move_uci, self.all_moves_dict) + move_input = self.all_moves_dict[move_uci] - - return board_input, move_input, elo_self, elo_oppo, legal_moves, side_info, active_win + + return ( + board_input_tensor, + move_input, + elo_self, + elo_oppo, + legal_moves, + side_info, + active_win, + ) class BasicBlock(torch.nn.Module): + """Basic residual block with dropout.""" + + def __init__(self, in_planes: int, planes: int, stride: int = 1) -> None: + """Initialize block. - def __init__(self, in_planes, planes, stride=1): + Args: + in_planes: Input channels. + planes: Output channels. + stride: Convolution stride. + """ super(BasicBlock, self).__init__() - + mid_planes = planes - - self.conv1 = torch.nn.Conv2d(in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + self.conv1 = torch.nn.Conv2d( + in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) self.bn1 = torch.nn.BatchNorm2d(mid_planes) - self.conv2 = torch.nn.Conv2d(mid_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.conv2 = torch.nn.Conv2d( + mid_planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) self.bn2 = torch.nn.BatchNorm2d(planes) self.dropout = nn.Dropout(p=0.5) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through block. + Args: + x: Input tensor [batch, channels, height, width]. + + Returns: + Output tensor with same shape. + """ out = self.conv1(x) out = self.bn1(out) out = F.relu(out) @@ -258,36 +426,80 @@ def forward(self, x): class ChessResNet(torch.nn.Module): - - def __init__(self, block, cfg): + """ResNet-based CNN for chess board processing.""" + + def __init__(self, block: type, cfg: Config) -> None: + """Initialize CNN. + + Args: + block: Residual block class. + cfg: Configuration with network parameters. + """ super(ChessResNet, self).__init__() - - self.conv1 = torch.nn.Conv2d(cfg.input_channels, cfg.dim_cnn, kernel_size=3, stride=1, padding=1, bias=False) + + self.conv1 = torch.nn.Conv2d( + cfg.input_channels, + cfg.dim_cnn, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) self.bn1 = torch.nn.BatchNorm2d(cfg.dim_cnn) self.layers = self._make_layer(block, cfg.dim_cnn, cfg.num_blocks_cnn) - self.conv_last = torch.nn.Conv2d(cfg.dim_cnn, cfg.vit_length, kernel_size=3, stride=1, padding=1, bias=False) + self.conv_last = torch.nn.Conv2d( + cfg.dim_cnn, cfg.vit_length, kernel_size=3, stride=1, padding=1, bias=False + ) self.bn_last = torch.nn.BatchNorm2d(cfg.vit_length) - def _make_layer(self, block, planes, num_blocks, stride=1): - + def _make_layer( + self, block: type, planes: int, num_blocks: int, stride: int = 1 + ) -> torch.nn.Sequential: + """Create layer of stacked residual blocks. + + Args: + block: Residual block class. + planes: Number of channels. + num_blocks: Number of blocks to stack. + stride: Convolution stride. + + Returns: + Sequential container of blocks. + """ layers = [] for _ in range(num_blocks): layers.append(block(planes, planes, stride)) - + return torch.nn.Sequential(*layers) - def forward(self, x): - + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input [batch, channels, 8, 8]. + + Returns: + Output [batch, vit_length, 8, 8]. + """ out = F.relu(self.bn1(self.conv1(x))) out = self.layers(out) out = self.conv_last(out) out = self.bn_last(out) - + return out class FeedForward(nn.Module): - def __init__(self, dim, hidden_dim, dropout = 0.): + """MLP with normalization and dropout.""" + + def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0) -> None: + """Initialize feed-forward network. + + Args: + dim: Input/output dimension. + hidden_dim: Hidden layer dimension. + dropout: Dropout probability. + """ super().__init__() self.net = nn.Sequential( nn.LayerNorm(dim), @@ -295,68 +507,145 @@ def __init__(self, dim, hidden_dim, dropout = 0.): nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), - nn.Dropout(dropout) + nn.Dropout(dropout), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input [batch, seq_len, dim]. + + Returns: + Output [batch, seq_len, dim]. + """ return self.net(x) class EloAwareAttention(nn.Module): - def __init__(self, dim, heads=8, dim_head=64, dropout=0., elo_dim=64): + """Multi-head attention with Elo conditioning.""" + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + elo_dim: int = 64, + ) -> None: + """Initialize attention layer. + + Args: + dim: Input dimension. + heads: Number of attention heads. + dim_head: Dimension per head. + dropout: Dropout probability. + elo_dim: Elo embedding dimension. + """ super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.norm = nn.LayerNorm(dim) - self.attend = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) - self.elo_query = nn.Linear(elo_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, dim), - nn.Dropout(dropout) - ) if project_out else nn.Identity() + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) - def forward(self, x, elo_emb): + def forward(self, x: torch.Tensor, elo_emb: torch.Tensor) -> torch.Tensor: + """Forward pass with Elo conditioning. + + Args: + x: Input sequence [batch, seq_len, dim]. + elo_emb: Elo embeddings [batch, elo_dim]. + + Returns: + Output [batch, seq_len, dim]. + """ x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) + q, k, v = map(lambda t: rearrange( + t, "b n (h d) -> b h n d", h=self.heads), qkv) + # Condition attention with Elo elo_effect = self.elo_query(elo_emb).view(x.size(0), self.heads, 1, -1) q = q + elo_effect + # Scaled dot-product attention dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale - attn = self.attend(dots) attn = self.dropout(attn) out = torch.matmul(attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., elo_dim=64): + """Transformer with Elo-aware attention.""" + + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + elo_dim: int = 64, + ) -> None: + """Initialize transformer. + + Args: + dim: Model dimension. + depth: Number of layers. + heads: Number of attention heads. + dim_head: Dimension per head. + mlp_dim: MLP hidden dimension. + dropout: Dropout probability. + elo_dim: Elo embedding dimension. + """ super().__init__() self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) self.elo_layers = nn.ModuleList([]) for _ in range(depth): - self.elo_layers.append(nn.ModuleList([ - EloAwareAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, elo_dim = elo_dim), - FeedForward(dim, mlp_dim, dropout = dropout) - ])) - - def forward(self, x, elo_emb): + self.elo_layers.append( + nn.ModuleList( + [ + EloAwareAttention( + dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + elo_dim=elo_dim, + ), + FeedForward(dim, mlp_dim, dropout=dropout), + ] + ) + ) + + def forward(self, x: torch.Tensor, elo_emb: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input [batch, seq_len, dim]. + elo_emb: Elo embeddings [batch, elo_dim]. + + Returns: + Output [batch, seq_len, dim]. + """ for attn, ff in self.elo_layers: x = attn(x, elo_emb) + x x = ff(x) + x @@ -365,108 +654,177 @@ def forward(self, x, elo_emb): class MAIA2Model(torch.nn.Module): - - def __init__(self, output_dim, elo_dict, cfg): + """MAIA2 chess move prediction model. + + Hybrid CNN-Transformer with Elo-aware attention for move prediction. + """ + + def __init__(self, output_dim: int, elo_dict: EloRangeDict, cfg: Config) -> None: + """Initialize MAIA2 model. + + Args: + output_dim: Number of possible moves. + elo_dict: Elo ranges to indices. + cfg: Configuration with model parameters. + """ super(MAIA2Model, self).__init__() - + self.cfg = cfg self.chess_cnn = ChessResNet(BasicBlock, cfg) - + heads = 16 dim_head = 64 self.to_patch_embedding = nn.Sequential( nn.Linear(8 * 8, cfg.dim_vit), nn.LayerNorm(cfg.dim_vit), ) - self.transformer = Transformer(cfg.dim_vit, cfg.num_blocks_vit, heads, dim_head, mlp_dim=cfg.dim_vit, dropout = 0.1, elo_dim = cfg.elo_dim * 2) - self.pos_embedding = nn.Parameter(torch.randn(1, cfg.vit_length, cfg.dim_vit)) - + self.transformer = Transformer( + cfg.dim_vit, + cfg.num_blocks_vit, + heads, + dim_head, + mlp_dim=cfg.dim_vit, + dropout=0.1, + elo_dim=cfg.elo_dim * 2, + ) + self.pos_embedding = nn.Parameter( + torch.randn(1, cfg.vit_length, cfg.dim_vit)) + + # Output heads self.fc_1 = nn.Linear(cfg.dim_vit, output_dim) # self.fc_1_1 = nn.Linear(cfg.dim_vit, cfg.dim_vit) self.fc_2 = nn.Linear(cfg.dim_vit, output_dim + 6 + 6 + 1 + 64 + 64) # self.fc_2_1 = nn.Linear(cfg.dim_vit, cfg.dim_vit) self.fc_3 = nn.Linear(128, 1) self.fc_3_1 = nn.Linear(cfg.dim_vit, 128) - + self.elo_embedding = torch.nn.Embedding(len(elo_dict), cfg.elo_dim) - self.dropout = nn.Dropout(p=0.1) self.last_ln = nn.LayerNorm(cfg.dim_vit) + def forward( + self, boards: torch.Tensor, elos_self: torch.Tensor, elos_oppo: torch.Tensor + ) -> ModelOutput: + """Forward pass. + + Args: + boards: Board tensors [batch, channels, 8, 8]. + elos_self: Player Elo indices [batch]. + elos_oppo: Opponent Elo indices [batch]. - def forward(self, boards, elos_self, elos_oppo): - + Returns: + Tuple of (move_logits, side_info_logits, value_logits). + """ batch_size = boards.size(0) boards = boards.view(batch_size, self.cfg.input_channels, 8, 8) + + # Process board with CNN embs = self.chess_cnn(boards) embs = embs.view(batch_size, embs.size(1), 8 * 8) x = self.to_patch_embedding(embs) x += self.pos_embedding x = self.dropout(x) - + + # Combine Elo embeddings and process elos_emb_self = self.elo_embedding(elos_self) elos_emb_oppo = self.elo_embedding(elos_oppo) elos_emb = torch.cat((elos_emb_self, elos_emb_oppo), dim=1) x = self.transformer(x, elos_emb).mean(dim=1) - x = self.last_ln(x) + # Generate predictions logits_maia = self.fc_1(x) logits_side_info = self.fc_2(x) - logits_value = self.fc_3(self.dropout(torch.relu(self.fc_3_1(x)))).squeeze(dim=-1) - + logits_value = self.fc_3(self.dropout(torch.relu(self.fc_3_1(x)))).squeeze( + dim=-1 + ) + return logits_maia, logits_side_info, logits_value -def read_monthly_data_path(cfg): - - print('Training Data:', flush=True) - pgn_paths = [] - +def read_monthly_data_path(cfg: Config) -> List[str]: + """Get paths to monthly PGN files in date range. + + Args: + cfg: Configuration with start_year, end_year, start_month, end_month, data_root. + + Returns: + List of PGN file paths. + """ + print("Training Data:", flush=True) + pgn_paths: List[str] = [] + for year in range(cfg.start_year, cfg.end_year + 1): start_month = cfg.start_month if year == cfg.start_year else 1 end_month = cfg.end_month if year == cfg.end_year else 12 for month in range(start_month, end_month + 1): formatted_month = f"{month:02d}" - pgn_path = cfg.data_root + f"/lichess_db_standard_rated_{year}-{formatted_month}.pgn" + pgn_path = ( + cfg.data_root + + f"/lichess_db_standard_rated_{year}-{formatted_month}.pgn" + ) # skip 2019-12 if year == 2019 and month == 12: continue print(pgn_path, flush=True) pgn_paths.append(pgn_path) - + return pgn_paths -def evaluate(model, dataloader): - +def evaluate( + model: MAIA2Model, dataloader: torch.utils.data.DataLoader[MAIA1DatasetItem] +) -> Tuple[int, int]: + """Evaluate model accuracy on dataset. + + Args: + model: MAIA2 model. + dataloader: DataLoader with evaluation data. + + Returns: + Tuple of (correct_predictions, total_positions). + """ counter = 0 correct_move = 0 - + model.eval() with torch.no_grad(): - - for boards, labels, elos_self, elos_oppo, legal_moves, side_info in dataloader: - + for boards, labels, elos_self, elos_oppo, legal_moves, _side_info in dataloader: boards = boards.cuda() labels = labels.cuda() elos_self = elos_self.cuda() elos_oppo = elos_oppo.cuda() legal_moves = legal_moves.cuda() - logits_maia, logits_side_info, logits_value = model(boards, elos_self, elos_oppo) + logits_maia, _logits_side_info, _logits_value = model( + boards, elos_self, elos_oppo + ) logits_maia_legal = logits_maia * legal_moves preds = logits_maia_legal.argmax(dim=-1) correct_move += (preds == labels).sum().item() - + counter += len(labels) return correct_move, counter -def evaluate_MAIA1_data(model, all_moves_dict, elo_dict, cfg, tiny=False): - +def evaluate_MAIA1_data( # pylint: disable=invalid-name + model: MAIA2Model, + all_moves_dict: MovesDict, + elo_dict: EloRangeDict, + cfg: Config, + tiny: bool = False, +) -> None: + """Evaluate model on MAIA1 test dataset. + + Args: + model: MAIA2 model. + all_moves_dict: Moves to indices mapping. + elo_dict: Elo rating binning. + cfg: Configuration object. + tiny: Test only first Elo range if True. + """ elo_list = range(1000, 2600, 100) for i in elo_list: @@ -474,40 +832,85 @@ def evaluate_MAIA1_data(model, all_moves_dict, elo_dict, cfg, tiny=False): end = i + 100 file_path = f"../data/test/KDDTest_{start}-{end}.csv" data = pd.read_csv(file_path) - data = data[data.type == 'Rapid'][['board', 'move', 'active_elo', 'opponent_elo', 'white_active']] + data = data[data.type == "Rapid"][ + ["board", "move", "active_elo", "opponent_elo", "white_active"] + ] dataset = MAIA1Dataset(data, all_moves_dict, elo_dict, cfg) - dataloader = torch.utils.data.DataLoader(dataset, - batch_size=cfg.batch_size, - shuffle=False, - drop_last=False, - num_workers=cfg.num_workers) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=cfg.batch_size, + shuffle=False, + drop_last=False, + num_workers=cfg.num_workers, + ) if cfg.verbose: - dataloader = tqdm.tqdm(dataloader) - print(f'Testing Elo Range {start}-{end} with MAIA 1 data:', flush=True) + dataloader = cast( + torch.utils.data.DataLoader[MAIA1DatasetItem], tqdm.tqdm( + dataloader) + ) + print(f"Testing Elo Range {start}-{end} with MAIA 1 data:", flush=True) correct_move, counter = evaluate(model, dataloader) - print(f'Accuracy Move Prediction: {round(correct_move / counter, 4)}', flush=True) + print( + f"Accuracy Move Prediction: {round(correct_move / counter, 4)}", flush=True + ) if tiny: break -def train_chunks(cfg, data, model, optimizer, all_moves_dict, criterion_maia, criterion_side_info, criterion_value): - +def train_chunks( + cfg: Config, + data: List[TrainingPositionData], + model: torch.nn.DataParallel[MAIA2Model], + optimizer: torch.optim.Optimizer, + all_moves_dict: MovesDict, + criterion_maia: torch.nn.Module, + criterion_side_info: torch.nn.Module, + criterion_value: torch.nn.Module, +) -> Tuple[float, float, float, float]: + """Train model on batch of game chunks. + + Args: + cfg: Configuration with training parameters. + data: List of (fen, move, elos, result) tuples. + model: MAIA2 model. + optimizer: Optimizer. + all_moves_dict: Moves to indices. + criterion_maia: Move prediction loss. + criterion_side_info: Side info loss. + criterion_value: Value prediction loss. + + Returns: + Tuple of (total_loss, move_loss, side_info_loss, value_loss). + """ dataset_train = MAIA2Dataset(data, all_moves_dict, cfg) - dataloader_train = torch.utils.data.DataLoader(dataset_train, - batch_size=cfg.batch_size, - shuffle=True, - drop_last=False, - num_workers=cfg.num_workers) + dataloader_train = torch.utils.data.DataLoader( + dataset_train, + batch_size=cfg.batch_size, + shuffle=True, + drop_last=False, + num_workers=cfg.num_workers, + ) if cfg.verbose: - dataloader_train = tqdm.tqdm(dataloader_train) - - avg_loss = 0 - avg_loss_maia = 0 - avg_loss_side_info = 0 - avg_loss_value = 0 + dataloader_train = cast( + torch.utils.data.DataLoader[MAIA2DatasetItem], tqdm.tqdm( + dataloader_train) + ) + + avg_loss: float = 0 + avg_loss_maia: float = 0 + avg_loss_side_info: float = 0 + avg_loss_value: float = 0 step = 0 - for boards, labels, elos_self, elos_oppo, legal_moves, side_info, wdl in dataloader_train: + for ( + boards, + labels, + elos_self, + elos_oppo, + _legal_moves, + side_info, + wdl, + ) in dataloader_train: model.train() boards = boards.cuda() labels = labels.cuda() @@ -515,26 +918,30 @@ def train_chunks(cfg, data, model, optimizer, all_moves_dict, criterion_maia, cr elos_oppo = elos_oppo.cuda() side_info = side_info.cuda() wdl = wdl.float().cuda() - - logits_maia, logits_side_info, logits_value = model(boards, elos_self, elos_oppo) - - loss = 0 + + logits_maia, logits_side_info, logits_value = model( + boards, elos_self, elos_oppo + ) + loss_maia = criterion_maia(logits_maia, labels) - loss += loss_maia - + loss = 0 + loss_maia + if cfg.side_info: - - loss_side_info = criterion_side_info(logits_side_info, side_info) * cfg.side_info_coefficient + loss_side_info = ( + criterion_side_info(logits_side_info, side_info) + * cfg.side_info_coefficient + ) loss += loss_side_info - + if cfg.value: - loss_value = criterion_value(logits_value, wdl) * cfg.value_coefficient + loss_value = criterion_value( + logits_value, wdl) * cfg.value_coefficient loss += loss_value optimizer.zero_grad() loss.backward() optimizer.step() - + avg_loss += loss.item() avg_loss_maia += loss_maia.item() if cfg.side_info: @@ -542,18 +949,45 @@ def train_chunks(cfg, data, model, optimizer, all_moves_dict, criterion_maia, cr if cfg.value: avg_loss_value += loss_value.item() step += 1 - - return round(avg_loss / step, 3), round(avg_loss_maia / step, 3), round(avg_loss_side_info / step, 3), round(avg_loss_value / step, 3) - -def preprocess_thread(queue, cfg, pgn_path, pgn_chunks_sublist, elo_dict): - - data, game_count, chunk_count = process_chunks(cfg, pgn_path, pgn_chunks_sublist, elo_dict) + return ( + round(avg_loss / step, 3), + round(avg_loss_maia / step, 3), + round(avg_loss_side_info / step, 3), + round(avg_loss_value / step, 3), + ) + + +def preprocess_thread( + queue: Queue, + cfg: Config, + pgn_path: str, + pgn_chunks_sublist: List[Chunk], + elo_dict: EloRangeDict, +) -> None: + """Process PGN chunks in separate thread. + + Args: + queue: Queue for storing results. + cfg: Configuration object. + pgn_path: Path to PGN file. + pgn_chunks_sublist: List of chunk positions. + elo_dict: Elo rating binning. + """ + data, game_count, chunk_count = process_chunks( + cfg, pgn_path, pgn_chunks_sublist, elo_dict + ) queue.put([data, game_count, chunk_count]) del data -def worker_wrapper(semaphore, *args, **kwargs): +def worker_wrapper(semaphore: threading.Semaphore, *args: Any, **kwargs: Any) -> None: + """Thread worker with semaphore protection. + + Args: + semaphore: Semaphore for controlling access. + args: Positional arguments for preprocess_thread. + kwargs: Keyword arguments for preprocess_thread. + """ with semaphore: preprocess_thread(*args, **kwargs) - diff --git a/maia2/model.py b/maia2/model.py index 0a65e94..fb6d4ac 100644 --- a/maia2/model.py +++ b/maia2/model.py @@ -1,61 +1,93 @@ -import gdown +"""Model loading utilities for MAIA2. + +Provides functions to load pre-trained MAIA2 models +for blitz and rapid time controls. +""" + import os -from .main import MAIA2Model -from .utils import get_all_possible_moves, create_elo_dict, parse_args +import warnings +from typing import Final, Literal + +import gdown # type: ignore import torch from torch import nn -import warnings + +from .main import MAIA2Model +from .utils import create_elo_dict, get_all_possible_moves, parse_args + warnings.filterwarnings("ignore") -import pdb -def from_pretrained(type, device, save_root = "./maia2_models"): - - if os.path.exists(save_root) == False: +# Constants +DEFAULT_SAVE_ROOT: Final[str] = "./maia2_models" +CONFIG_URL: Final[str] = ( + "https://drive.google.com/uc?id=1GQTskYMVMubNwZH2Bi6AmevI15CS6gk0" +) +MODEL_BLITZ_URL: Final[str] = ( + "https://drive.google.com/uc?id=1X-Z4J3PX3MQFJoa8gRt3aL8CIH0PWoyt" +) +MODEL_RAPID_URL: Final[str] = ( + "https://drive.google.com/uc?id=1gbC1-c7c0EQOPPAVpGWubezeEW8grVwc" +) + + +def from_pretrained( + model_type: Literal["blitz", "rapid"], + device: Literal["gpu", "cpu"], + save_root: str = DEFAULT_SAVE_ROOT, +) -> MAIA2Model: + """Load pre-trained MAIA2 model. + + Args: + model_type: Type of model ("blitz" or "rapid"). + device: Device to load on ("gpu" or "cpu"). + save_root: Directory to save model files. + + Returns: + Loaded MAIA2 model. + + Raises: + ValueError: If model_type is invalid. + OSError: If download or directory creation fails. + RuntimeError: If model loading fails. + """ + if not os.path.exists(save_root): os.makedirs(save_root) - - if type == "blitz": - url = "https://drive.google.com/uc?id=1X-Z4J3PX3MQFJoa8gRt3aL8CIH0PWoyt" + + if model_type == "blitz": + url = MODEL_BLITZ_URL output_path = os.path.join(save_root, "blitz_model.pt") - - elif type == "rapid": - url = "https://drive.google.com/uc?id=1gbC1-c7c0EQOPPAVpGWubezeEW8grVwc" + + elif model_type == "rapid": + url = MODEL_RAPID_URL output_path = os.path.join(save_root, "rapid_model.pt") - + else: raise ValueError("Invalid model type. Choose between 'blitz' and 'rapid'.") if os.path.exists(output_path): - print(f"Model for {type} games already downloaded.") + print(f"Model for {model_type} games already downloaded.") else: - print(f"Downloading model for {type} games.") + print(f"Downloading model for {model_type} games.") gdown.download(url, output_path, quiet=False) - cfg_url = "https://drive.google.com/uc?id=1GQTskYMVMubNwZH2Bi6AmevI15CS6gk0" cfg_path = os.path.join(save_root, "config.yaml") if not os.path.exists(cfg_path): - gdown.download(cfg_url, cfg_path, quiet=False) + gdown.download(CONFIG_URL, cfg_path, quiet=False) cfg = parse_args(cfg_path) - all_moves = get_all_possible_moves() elo_dict = create_elo_dict() - model = MAIA2Model(len(all_moves), elo_dict, cfg) - model = nn.DataParallel(model) - - checkpoint = torch.load(output_path, map_location='cpu') - model.load_state_dict(checkpoint['model_state_dict']) - model = model.module - + maia2_model = MAIA2Model(len(all_moves), elo_dict, cfg) + model = nn.DataParallel(maia2_model) + + checkpoint = torch.load(output_path, map_location="cpu") + model.load_state_dict(checkpoint["model_state_dict"]) + model_module = model.module + if device == "gpu": - model = model.cuda() - - print(f"Model for {type} games loaded to {device}.") - - return model - - - - - - + model_module = model_module.cuda() + + print(f"Model for {model_type} games loaded to {device}.") + + return model_module diff --git a/maia2/requirements-dev.txt b/maia2/requirements-dev.txt new file mode 100644 index 0000000..df1f38b --- /dev/null +++ b/maia2/requirements-dev.txt @@ -0,0 +1,5 @@ +types-requests==2.32.4.20250913 +types-PyYAML==6.0.12.20250915 +types-tqdm==4.67.0.20250809 +pandas-stubs==2.3.2.250926 +types-pytz==2025.2.0.20250809 \ No newline at end of file diff --git a/maia2/requirements.txt b/maia2/requirements.txt index a94f2cc..64890c0 100644 --- a/maia2/requirements.txt +++ b/maia2/requirements.txt @@ -1,10 +1,10 @@ -chess==1.10.0 -einops==0.8.0 -gdown==5.2.0 -numpy==2.1.3 -pandas==2.2.3 -pyyaml==6.0.2 -pyzstd==0.15.9 -Requests==2.32.3 -torch==2.4.0 +chess==1.10.0 +einops==0.8.0 +gdown==5.2.0 +numpy==2.1.3 +pandas==2.2.3 +PyYAML==6.0.2 +pyzstd==0.15.9 +requests==2.32.3 +torch==2.4.0 tqdm==4.65.0 \ No newline at end of file diff --git a/maia2/train.py b/maia2/train.py index 570b4f8..f6e55bf 100644 --- a/maia2/train.py +++ b/maia2/train.py @@ -1,109 +1,203 @@ -import argparse +"""Training script for MAIA2. + +Handles complete training pipeline including data preprocessing, +model initialization, training loop, and checkpointing. +""" + import os -from multiprocessing import Process, Queue, cpu_count import time -from .utils import seed_everything, readable_time, readable_num, count_parameters -from .utils import get_all_possible_moves, create_elo_dict -from .utils import decompress_zst, read_or_create_chunks -from .main import MAIA2Model, preprocess_thread, train_chunks, read_monthly_data_path +from multiprocessing import Process, Queue, cpu_count +from typing import List + import torch import torch.nn as nn -import pdb +from .main import ( + MAIA2Model, + preprocess_thread, + read_monthly_data_path, + train_chunks, +) +from .utils import ( + Chunk, + Config, + count_parameters, + create_elo_dict, + decompress_zst, + get_all_possible_moves, + read_or_create_chunks, + readable_num, + readable_time, + seed_everything, +) -def run(cfg): - - print('Configurations:', flush=True) + +def run(cfg: Config) -> None: + """Execute complete MAIA2 training pipeline. + + Args: + cfg: Configuration object with training parameters. + Required attributes: seed, num_cpu_left, lr, batch_size, wd, + from_checkpoint, checkpoint_*, max_epochs, queue_length. + """ + # Print configuration + print("Configurations:", flush=True) for arg in vars(cfg): - print(f'\t{arg}: {getattr(cfg, arg)}', flush=True) + print(f"\t{arg}: {getattr(cfg, arg)}", flush=True) seed_everything(cfg.seed) + + # Set up multiprocessing num_processes = cpu_count() - cfg.num_cpu_left - save_root = f'../saves/{cfg.lr}_{cfg.batch_size}_{cfg.wd}/' + # Create save directory + save_root = f"../saves/{cfg.lr}_{cfg.batch_size}_{cfg.wd}/" if not os.path.exists(save_root): os.makedirs(save_root) + # Initialize dictionaries all_moves = get_all_possible_moves() all_moves_dict = {move: i for i, move in enumerate(all_moves)} elo_dict = create_elo_dict() - model = MAIA2Model(len(all_moves), elo_dict, cfg) + # Initialize model + maia2_model = MAIA2Model(len(all_moves), elo_dict, cfg) + + # Setup model for training + print(maia2_model, flush=True) + maia2_model = maia2_model.cuda() + model = nn.DataParallel(maia2_model) - print(model, flush=True) - model = model.cuda() - model = nn.DataParallel(model) + # Initialize loss functions criterion_maia = nn.CrossEntropyLoss() criterion_side_info = nn.BCEWithLogitsLoss() criterion_value = nn.MSELoss() - optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) - N_params = count_parameters(model) - print(f'Trainable Parameters: {N_params}', flush=True) + # Initialize optimizer + optimizer = torch.optim.AdamW( + model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) + n_params = count_parameters(model) + print(f"Trainable Parameters: {n_params}", flush=True) + # Initialize counters accumulated_samples = 0 accumulated_games = 0 + # Load checkpoint if requested if cfg.from_checkpoint: formatted_month = f"{cfg.checkpoint_month:02d}" - checkpoint = torch.load(save_root + f'epoch_{cfg.checkpoint_epoch}_{cfg.checkpoint_year}-{formatted_month}.pgn.pt') - model.load_state_dict(checkpoint['model_state_dict']) - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - accumulated_samples = checkpoint['accumulated_samples'] - accumulated_games = checkpoint['accumulated_games'] + checkpoint_path = ( + save_root + + f"epoch_{cfg.checkpoint_epoch}_{cfg.checkpoint_year}-{formatted_month}.pgn.pt" + ) + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + accumulated_samples = checkpoint["accumulated_samples"] + accumulated_games = checkpoint["accumulated_games"] + # Training loop for epoch in range(cfg.max_epochs): - - print(f'Epoch {epoch + 1}', flush=True) + print(f"Epoch {epoch + 1}", flush=True) pgn_paths = read_monthly_data_path(cfg) - + num_file = 0 for pgn_path in pgn_paths: - + # Decompress and prepare data start_time = time.time() - decompress_zst(pgn_path + '.zst', pgn_path) - print(f'Decompressing {pgn_path} took {readable_time(time.time() - start_time)}', flush=True) + decompress_zst(pgn_path + ".zst", pgn_path) + print( + f"Decompressing {pgn_path} took {readable_time(time.time() - start_time)}", + flush=True, + ) + # Read/create chunks pgn_chunks = read_or_create_chunks(pgn_path, cfg) - print(f'Training {pgn_path} with {len(pgn_chunks)} chunks.', flush=True) - - queue = Queue(maxsize=cfg.queue_length) - - pgn_chunks_sublists = [] + print( + f"Training {pgn_path} with {len(pgn_chunks)} chunks.", flush=True) + + # Setup multiprocessing queue + queue: Queue = Queue(maxsize=cfg.queue_length) + + # Split chunks for parallel processing + pgn_chunks_sublists: List[List[Chunk]] = [] for i in range(0, len(pgn_chunks), num_processes): - pgn_chunks_sublists.append(pgn_chunks[i:i + num_processes]) - + pgn_chunks_sublists.append(pgn_chunks[i: i + num_processes]) + + # Start first worker pgn_chunks_sublist = pgn_chunks_sublists[0] # For debugging only # process_chunks(cfg, pgn_path, pgn_chunks_sublist, elo_dict) - worker = Process(target=preprocess_thread, args=(queue, cfg, pgn_path, pgn_chunks_sublist, elo_dict)) + worker = Process( + target=preprocess_thread, + args=(queue, cfg, pgn_path, pgn_chunks_sublist, elo_dict), + ) worker.start() - + + # Process chunks and train num_chunk = 0 offset = 0 while True: if not queue.empty(): + # Start next worker if available if offset + 1 < len(pgn_chunks_sublists): pgn_chunks_sublist = pgn_chunks_sublists[offset + 1] - worker = Process(target=preprocess_thread, args=(queue, cfg, pgn_path, pgn_chunks_sublist, elo_dict)) + worker = Process( + target=preprocess_thread, + args=(queue, cfg, pgn_path, + pgn_chunks_sublist, elo_dict), + ) worker.start() offset += 1 + + # Get preprocessed data and train data, game_count, chunk_count = queue.get() - loss, loss_maia, loss_side_info, loss_value = train_chunks(cfg, data, model, optimizer, all_moves_dict, criterion_maia, criterion_side_info, criterion_value) + loss, loss_maia, loss_side_info, loss_value = train_chunks( + cfg, + data, + model, + optimizer, + all_moves_dict, + criterion_maia, + criterion_side_info, + criterion_value, + ) + + # Update counters and log num_chunk += chunk_count accumulated_samples += len(data) accumulated_games += game_count - print(f'[{num_chunk}/{len(pgn_chunks)}]', flush=True) - print(f'[# Positions]: {readable_num(accumulated_samples)}', flush=True) - print(f'[# Games]: {readable_num(accumulated_games)}', flush=True) - print(f'[# Loss]: {loss} | [# Loss MAIA]: {loss_maia} | [# Loss Side Info]: {loss_side_info} | [# Loss Value]: {loss_value}', flush=True) + print(f"[{num_chunk}/{len(pgn_chunks)}]", flush=True) + print( + f"[# Positions]: {readable_num(accumulated_samples)}", + flush=True, + ) + print( + f"[# Games]: {readable_num(accumulated_games)}", flush=True) + print( + f"[# Loss]: {loss} | [# Loss MAIA]: {loss_maia} | " + f"[# Loss Side Info]: {loss_side_info} | [# Loss Value]: {loss_value}", + flush=True, + ) if num_chunk == len(pgn_chunks): break + # Log completion and cleanup num_file += 1 - print(f'### ({num_file} / {len(pgn_paths)}) Took {readable_time(time.time() - start_time)} to train {pgn_path} with {len(pgn_chunks)} chunks.', flush=True) + print( + f"### ({num_file} / {len(pgn_paths)}) " + f"Took {readable_time(time.time() - start_time)} to train {pgn_path} " + f"with {len(pgn_chunks)} chunks.", + flush=True, + ) os.remove(pgn_path) - - torch.save({'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'accumulated_samples': accumulated_samples, - 'accumulated_games': accumulated_games}, f'{save_root}epoch_{epoch + 1}_{pgn_path[-11:]}.pt') + + # Save checkpoint + torch.save( + { + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "accumulated_samples": accumulated_samples, + "accumulated_games": accumulated_games, + }, + f"{save_root}epoch_{epoch + 1}_{pgn_path[-11:]}.pt", + ) diff --git a/maia2/utils.py b/maia2/utils.py index 7b64eb2..6b3db75 100644 --- a/maia2/utils.py +++ b/maia2/utils.py @@ -1,39 +1,126 @@ -import pdb -import chess -import pickle +"""Utility functions for MAIA2. + +Provides functions for chess position handling, Elo rating mapping, +model configuration, and data processing utilities. +""" + import os +import pickle import random -import numpy as np -import torch +import re import time -import requests -import tqdm +from typing import Dict, Final, Generator, List, Optional, Tuple, TypeVar, Union + +import chess +import numpy as np import pyzstd -import re +import torch import yaml +# Type aliases +BoardPosition = str +ChessMove = str +EloRating = int +TimeSeconds = float +FileOffset = int +EloRangeDict = Dict[str, int] +ConfigDict = Dict[str, Union[str, int, float, bool]] +MovesDict = Dict[ChessMove, int] +ReverseMovesDict = Dict[int, ChessMove] +Chunk = Tuple[FileOffset, FileOffset] +SideInfo = Tuple[torch.Tensor, torch.Tensor] + +# Constants +ELO_INTERVAL: Final[EloRating] = 100 +ELO_START: Final[EloRating] = 1100 +ELO_END: Final[EloRating] = 2000 +PIECE_TYPES: Final[List[chess.PieceType]] = [ + chess.PAWN, + chess.KNIGHT, + chess.BISHOP, + chess.ROOK, + chess.QUEEN, + chess.KING, +] + class Config: - - def __init__(self, config_dict): + """Dynamic configuration container for MAIA2.""" + + input_channels: int = 1 + elo_dim: int = 1 + dim_cnn: int = 1 + dim_vit: int = 1 + num_blocks_cnn: int = 1 + num_blocks_vit: int = 1 + vit_length: int = 1 + batch_size: Optional[int] + chunk_size: int = 1 + verbose: Optional[bool] + start_year: int = 1900 + end_year: int = 2027 + start_month: int = 1 + end_month: int = 12 + first_n_moves: int = 0 + clock_threshold: float = 0 + max_ply: Optional[int] + data_root: str = "~" + num_workers: int = 1 + side_info: Optional[int] + side_info_coefficient: Optional[float] + value: Optional[int] + value_coefficient: Optional[float] + seed: int = 123456789 + num_cpu_left: int = 1 + lr: float = 1e-3 + wd: float = 1e-2 + from_checkpoint: Optional[bool] + checkpoint_epoch: Optional[int] + checkpoint_year: Optional[str] + checkpoint_month: Optional[str] + max_epochs: int = 1 + queue_length: int = 1 + max_games_per_elo_range: int = 1 + + def __init__(self, config_dict: ConfigDict) -> None: + """Initialize from dictionary. + + Args: + config_dict: Configuration key-value pairs. + """ for key, value in config_dict.items(): setattr(self, key, value) -def parse_args(cfg_file_path): +def parse_args(cfg_file_path: str) -> Config: + """Parse YAML configuration file. - with open(cfg_file_path, 'r') as f: + Args: + cfg_file_path: Path to YAML config file. + + Returns: + Config object with settings. + + Raises: + OSError: If file cannot be read. + yaml.YAMLError: If YAML is malformed. + """ + with open(cfg_file_path, "r", encoding="utf-8") as f: cfg_dict = yaml.safe_load(f) - + cfg = Config(cfg_dict) return cfg -def seed_everything(seed: int): +def seed_everything(seed: int) -> None: + """Set random seeds for reproducibility. + Args: + seed: Seed value for all RNGs. + """ random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) + os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -42,8 +129,12 @@ def seed_everything(seed: int): torch.backends.cudnn.benchmark = False -def delete_file(filename): - +def delete_file(filename: str) -> None: + """Delete file if it exists. + + Args: + filename: Path to file. + """ if os.path.exists(filename): os.remove(filename) print(f"Data {filename} has been deleted.") @@ -51,20 +142,34 @@ def delete_file(filename): print(f"The file '{filename}' does not exist.") -def readable_num(num): - +def readable_num(num: int) -> str: + """Convert large number to readable format. + + Args: + num: Number to format. + + Returns: + Formatted string with suffix (K/M/B). + """ if num >= 1e9: # if parameters are in the billions - return f'{num / 1e9:.2f}B' + return f"{num / 1e9:.2f}B" elif num >= 1e6: # if parameters are in the millions - return f'{num / 1e6:.2f}M' + return f"{num / 1e6:.2f}M" elif num >= 1e3: # if parameters are in the thousands - return f'{num / 1e3:.2f}K' + return f"{num / 1e3:.2f}K" else: return str(num) -def readable_time(elapsed_time): +def readable_time(elapsed_time: TimeSeconds) -> str: + """Format elapsed time in readable format. + + Args: + elapsed_time: Duration in seconds. + Returns: + Formatted time string (e.g., "1h 30m 45.50s"). + """ hours, rem = divmod(elapsed_time, 3600) minutes, seconds = divmod(rem, 60) @@ -76,54 +181,89 @@ def readable_time(elapsed_time): return f"{seconds:.2f}s" -def count_parameters(model): - - total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - +def count_parameters(model: torch.nn.Module) -> str: + """Count trainable parameters in model. + + Args: + model: PyTorch model. + + Returns: + Formatted parameter count string. + """ + if not isinstance(model, torch.nn.Module): + raise TypeError(f"Expected PyTorch Module, got {type(model)}") + + total_params = sum(p.numel() + for p in model.parameters() if p.requires_grad) return readable_num(total_params) -def create_elo_dict(): - - inteval = 100 - start = 1100 - end = 2000 - - range_dict = {f"<{start}": 0} +def create_elo_dict() -> EloRangeDict: + """Create Elo rating ranges to category indices mapping. + + Returns: + Dict mapping Elo range strings to indices. + """ + range_dict: EloRangeDict = {f"<{ELO_START}": 0} range_index = 1 - for lower_bound in range(start, end - 1, inteval): - upper_bound = lower_bound + inteval + for lower_bound in range(ELO_START, ELO_END - 1, ELO_INTERVAL): + upper_bound = lower_bound + ELO_INTERVAL range_dict[f"{lower_bound}-{upper_bound - 1}"] = range_index range_index += 1 - range_dict[f">={end}"] = range_index - + range_dict[f">={ELO_END}"] = range_index + # print(range_dict, flush=True) - + return range_dict -def map_to_category(elo, elo_dict): +def map_to_category(elo: EloRating, elo_dict: EloRangeDict) -> int: + """Map Elo rating to category index. - inteval = 100 - start = 1100 - end = 2000 - - if elo < start: - return elo_dict[f"<{start}"] - elif elo >= end: - return elo_dict[f">={end}"] + Args: + elo: Player's Elo rating. + elo_dict: Elo ranges to indices mapping. + + Returns: + Category index for the rating. + + Raises: + TypeError: If elo is not integer. + ValueError: If elo cannot be categorized. + """ + if not isinstance(elo, int): + raise TypeError(f"Elo rating must be an integer, got {type(elo)}") + + if elo < ELO_START: + return elo_dict[f"<{ELO_START}"] + elif elo >= ELO_END: + return elo_dict[f">={ELO_END}"] else: - for lower_bound in range(start, end - 1, inteval): - upper_bound = lower_bound + inteval + for lower_bound in range(ELO_START, ELO_END - 1, ELO_INTERVAL): + upper_bound = lower_bound + ELO_INTERVAL if lower_bound <= elo < upper_bound: return elo_dict[f"{lower_bound}-{upper_bound - 1}"] + raise ValueError(f"Elo {elo} could not be categorized.") + -def get_side_info(board, move_uci, all_moves_dict): +def get_side_info( + board: chess.Board, move_uci: ChessMove, all_moves_dict: MovesDict +) -> SideInfo: + """Generate feature vectors for chess move. + + Args: + board: Current chess position. + move_uci: Move in UCI format. + all_moves_dict: UCI moves to indices. + + Returns: + Tuple of (legal_moves_mask, side_info_vector). + """ move = chess.Move.from_uci(move_uci) - + moving_piece = board.piece_at(move.from_square) captured_piece = board.piece_at(move.to_square) @@ -132,82 +272,115 @@ def get_side_info(board, move_uci, all_moves_dict): to_square_encoded = torch.zeros(64) to_square_encoded[move.to_square] = 1 - - if move_uci == 'e1g1': - rook_move = chess.Move.from_uci('h1f1') + + if move_uci == "e1g1": + rook_move = chess.Move.from_uci("h1f1") from_square_encoded[rook_move.from_square] = 1 to_square_encoded[rook_move.to_square] = 1 - - if move_uci == 'e1c1': - rook_move = chess.Move.from_uci('a1d1') + + if move_uci == "e1c1": + rook_move = chess.Move.from_uci("a1d1") from_square_encoded[rook_move.from_square] = 1 to_square_encoded[rook_move.to_square] = 1 board.push(move) is_check = board.is_check() board.pop() - + # Order: Pawn, Knight, Bishop, Rook, Queen, King side_info = torch.zeros(6 + 6 + 1) + assert moving_piece is not None side_info[moving_piece.piece_type - 1] = 1 - if move_uci in ['e1g1', 'e1c1']: + if move_uci in ["e1g1", "e1c1"]: side_info[3] = 1 if captured_piece: side_info[6 + captured_piece.piece_type - 1] = 1 if is_check: side_info[-1] = 1 - + legal_moves = torch.zeros(len(all_moves_dict)) - legal_moves_idx = torch.tensor([all_moves_dict[move.uci()] for move in board.legal_moves]) + legal_moves_idx = torch.tensor( + [all_moves_dict[move.uci()] for move in board.legal_moves] + ) legal_moves[legal_moves_idx] = 1 - - side_info = torch.cat([side_info, from_square_encoded, to_square_encoded, legal_moves], dim=0) - + + side_info = torch.cat( + [side_info, from_square_encoded, to_square_encoded, legal_moves], dim=0 + ) + return legal_moves, side_info -def extract_clock_time(comment): - - match = re.search(r'\[%clk (\d+):(\d+):(\d+)\]', comment) +def extract_clock_time(comment: str) -> Optional[int]: + """Extract remaining clock time from PGN comment. + + Args: + comment: PGN comment string. + + Returns: + Remaining time in seconds, or None if not found. + """ + match = re.search(r"\[%clk (\d+):(\d+):(\d+)\]", comment) if match: hours, minutes, seconds = map(int, match.groups()) return hours * 3600 + minutes * 60 + seconds return None - -def read_or_create_chunks(pgn_path, cfg): - cache_file = pgn_path.replace('.pgn', '_chunks.pkl') +def read_or_create_chunks(pgn_path: str, cfg: Config) -> List[Chunk]: + """Load or create file offset chunks for PGN. + + Args: + pgn_path: Path to PGN file. + cfg: Configuration with chunk_size. + + Returns: + List of (start_offset, end_offset) tuples. + + Raises: + OSError: If file access fails. + """ + cache_file = pgn_path.replace(".pgn", "_chunks.pkl") if os.path.exists(cache_file): print(f"Loading cached chunks from {cache_file}") - with open(cache_file, 'rb') as f: + with open(cache_file, "rb") as f: pgn_chunks = pickle.load(f) else: print(f"Cache not found. Creating chunks for {pgn_path}") start_time = time.time() pgn_chunks = get_chunks(pgn_path, cfg.chunk_size) - print(f'Chunking took {readable_time(time.time() - start_time)}', flush=True) - - with open(cache_file, 'wb') as f: + print( + f"Chunking took {readable_time(time.time() - start_time)}", flush=True) + + with open(cache_file, "wb") as f: pickle.dump(pgn_chunks, f) - + return pgn_chunks -def board_to_tensor(board): - - piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING] - num_piece_channels = 12 # 6 piece types * 2 colors - additional_channels = 6 # 1 for player's turn, 4 for castling rights, 1 for en passant - tensor = torch.zeros((num_piece_channels + additional_channels, 8, 8), dtype=torch.float32) +def board_to_tensor(board: chess.Board) -> torch.Tensor: + """Convert chess position to feature tensor. + + Args: + board: Chess position. + + Returns: + Tensor [18, 8, 8] with board representation. + """ + num_piece_channels: Final[int] = 12 # 6 piece types * 2 colors + # 1 for player's turn, 4 for castling rights, 1 for en passant + additional_channels: Final[int] = 6 + tensor = torch.zeros( + (num_piece_channels + additional_channels, 8, 8), dtype=torch.float32 + ) # Precompute indices for each piece type - piece_indices = {piece: i for i, piece in enumerate(piece_types)} + piece_indices = {piece: i for i, piece in enumerate(PIECE_TYPES)} # Fill tensor for each piece type - for piece_type in piece_types: - for color in [True, False]: # True is White, False is Black + for piece_type in PIECE_TYPES: + for color in [True, False]: # True=White, False=Black piece_map = board.pieces(piece_type, color) index = piece_indices[piece_type] + (0 if color else 6) for square in piece_map: @@ -220,10 +393,12 @@ def board_to_tensor(board): tensor[turn_channel, :, :] = 1.0 # Castling rights channels - castling_rights = [board.has_kingside_castling_rights(chess.WHITE), - board.has_queenside_castling_rights(chess.WHITE), - board.has_kingside_castling_rights(chess.BLACK), - board.has_queenside_castling_rights(chess.BLACK)] + castling_rights = [ + board.has_kingside_castling_rights(chess.WHITE), + board.has_queenside_castling_rights(chess.WHITE), + board.has_kingside_castling_rights(chess.BLACK), + board.has_queenside_castling_rights(chess.BLACK), + ] for i, has_right in enumerate(castling_rights): if has_right: tensor[num_piece_channels + 1 + i, :, :] = 1.0 @@ -236,47 +411,71 @@ def board_to_tensor(board): return tensor -def generate_pawn_promotions(): + +def generate_pawn_promotions() -> List[ChessMove]: + """Generate all possible pawn promotion moves. + + Returns: + List of UCI promotion moves. + """ # Define the promotion rows for both colors and the promotion pieces # promotion_rows = {'white': '7', 'black': '2'} - promotion_rows = {'white': '7'} - promotion_pieces = ['q', 'r', 'b', 'n'] - promotions = [] + promotion_rows = {"white": "7"} + promotion_pieces = ["q", "r", "b", "n"] + promotions: List[ChessMove] = [] # Iterate over each color for color, row in promotion_rows.items(): # Target rows for promotion (8 for white, 1 for black) - target_row = '8' if color == 'white' else '1' + target_row = "8" if color == "white" else "1" # Each file from 'a' to 'h' - for file in 'abcdefgh': + for file in "abcdefgh": # Direct move to promotion for piece in promotion_pieces: - promotions.append(f'{file}{row}{file}{target_row}{piece}') + promotions.append(f"{file}{row}{file}{target_row}{piece}") # Capturing moves to the left and right (if not on the edges of the board) - if file != 'a': - left_file = chr(ord(file) - 1) # File to the left + if file != "a": + left_file = chr(ord(file) - 1) for piece in promotion_pieces: - promotions.append(f'{file}{row}{left_file}{target_row}{piece}') + promotions.append( + f"{file}{row}{left_file}{target_row}{piece}") - if file != 'h': - right_file = chr(ord(file) + 1) # File to the right + # Capture right + if file != "h": + right_file = chr(ord(file) + 1) for piece in promotion_pieces: - promotions.append(f'{file}{row}{right_file}{target_row}{piece}') + promotions.append( + f"{file}{row}{right_file}{target_row}{piece}") return promotions -def mirror_square(square): - +def mirror_square(square: str) -> str: + """Mirror chess square vertically. + + Args: + square: Square in algebraic notation. + + Returns: + Mirrored square. + """ file = square[0] rank = str(9 - int(square[1])) - + return file + rank -def mirror_move(move_uci): +def mirror_move(move_uci: ChessMove) -> ChessMove: + """Mirror chess move vertically. + + Args: + move_uci: Move in UCI notation. + + Returns: + Mirrored move in UCI notation. + """ # Check if the move is a promotion (length of UCI string will be more than 4) is_promotion = len(move_uci) > 4 @@ -293,10 +492,22 @@ def mirror_move(move_uci): return mirrored_start + mirrored_end + promotion_piece -def get_chunks(pgn_path, chunk_size): +def get_chunks(pgn_path: str, chunk_size: int) -> List[Chunk]: + """Divide PGN file into chunks by game count. + + Args: + pgn_path: Path to PGN file. + chunk_size: Target games per chunk. + + Returns: + List of (start_offset, end_offset) tuples. - chunks = [] - with open(pgn_path, 'r', encoding='utf-8') as pgn_file: + Raises: + ValueError: If PGN format is invalid. + OSError: If file cannot be read. + """ + chunk_list: List[Chunk] = [] + with open(pgn_path, "r", encoding="utf-8") as pgn_file: while True: start_pos = pgn_file.tell() game_count = 0 @@ -314,45 +525,72 @@ def get_chunks(pgn_path, chunk_size): if line not in ["\n", ""]: raise ValueError end_pos = pgn_file.tell() - chunks.append((start_pos, end_pos)) + chunk_list.append((start_pos, end_pos)) if not line: break - return chunks + return chunk_list -def decompress_zst(file_path, decompressed_path): - """ Decompress a .zst file using pyzstd """ - with open(file_path, 'rb') as compressed_file, open(decompressed_path, 'wb') as decompressed_file: +def decompress_zst(file_path: str, decompressed_path: str) -> None: + """Decompress Zstandard (.zst) file. + + Args: + file_path: Path to .zst file. + decompressed_path: Output path for decompressed file. + + Raises: + OSError: If file access fails. + pyzstd.ZstdError: If decompression fails. + """ + with ( + open(file_path, "rb") as compressed_file, + open(decompressed_path, "wb") as decompressed_file, + ): pyzstd.decompress_stream(compressed_file, decompressed_file) -def get_all_possible_moves(): - - all_moves = [] +def get_all_possible_moves() -> List[ChessMove]: + """Generate all possible legal chess moves. + + Returns: + List of all moves in UCI notation. + """ + all_moves: List[chess.Move] = [] for rank in range(8): - for file in range(8): + for file in range(8): square = chess.square(file, rank) - + board = chess.Board(None) board.set_piece_at(square, chess.Piece(chess.QUEEN, chess.WHITE)) legal_moves = list(board.legal_moves) all_moves.extend(legal_moves) - + board = chess.Board(None) board.set_piece_at(square, chess.Piece(chess.KNIGHT, chess.WHITE)) legal_moves = list(board.legal_moves) all_moves.extend(legal_moves) - - all_moves = [all_moves[i].uci() for i in range(len(all_moves))] - + + all_moves_uci = [all_moves[i].uci() for i in range(len(all_moves))] + pawn_promotions = generate_pawn_promotions() - - return all_moves + pawn_promotions + return all_moves_uci + pawn_promotions + + +T = TypeVar("T") + + +def chunks(lst: List[T], n: int) -> Generator[List[T], None, None]: + """Split list into fixed-size chunks. + + Args: + lst: List to divide. + n: Chunk size. -def chunks(lst, n): - """Yield successive n-sized chunks from lst.""" + Yields: + Sublists of size n (or smaller for last chunk). + """ for i in range(0, len(lst), n): - yield lst[i:i + n] + yield lst[i: i + n] diff --git a/pyproject.toml b/pyproject.toml index 0ac599c..e5efdaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "gdown==5.2.0", "numpy==2.1.3", "pandas==2.2.3", - "pyyaml>=6.0.2", + "PyYAML>=6.0.2", "pyzstd==0.15.9", "requests==2.32.3", "torch==2.4.0", @@ -24,3 +24,8 @@ dependencies = [ [project.urls] Home = "https://github.com/CSSLab/maia2" + +[tool.mypy] +files = "maia2" +ignore_missing_imports = true +strict = true \ No newline at end of file