diff --git a/maia2/dataset.py b/maia2/dataset.py index 4f82abf..3569643 100644 --- a/maia2/dataset.py +++ b/maia2/dataset.py @@ -1,37 +1,82 @@ -import gdown +"""Dataset loading utilities for MAIA2. + +Provides functions to download and load chess datasets containing +positions, moves, and Elo ratings for training and testing. +""" + import os +from typing import Final + +import gdown # type: ignore import pandas as pd -def load_example_test_dataset(save_root = "./maia2_data"): - - url = "https://drive.google.com/uc?id=1fSu4Yp8uYj7xocbHAbjBP6DthsgiJy9X" - if os.path.exists(save_root) == False: +# Constants +DEFAULT_SAVE_ROOT: Final[str] = "./maia2_data" +TEST_DATASET_URL: Final[str] = ( + "https://drive.google.com/uc?id=1fSu4Yp8uYj7xocbHAbjBP6DthsgiJy9X" +) +TRAIN_DATASET_URL: Final[str] = ( + "https://drive.google.com/uc?id=1XBeuhB17z50mFK4tDvPG9rQRbxLSzNqB" +) + + +def load_example_test_dataset( + save_root: str = DEFAULT_SAVE_ROOT, +) -> pd.DataFrame: + """Download and load example test dataset. + + Args: + save_root: Directory to save dataset. + + Returns: + DataFrame with columns [board, move, active_elo, opponent_elo]. + Filtered to positions after move 10. + + Raises: + OSError: If download or directory creation fails. + pd.errors.EmptyDataError: If dataset is empty or corrupted. + """ + + if not os.path.exists(save_root): os.makedirs(save_root) + output_path = os.path.join(save_root, "example_test_dataset.csv") - + if os.path.exists(output_path): print("Example test dataset already downloaded.") else: - gdown.download(url, output_path, quiet=False) + gdown.download(TEST_DATASET_URL, output_path, quiet=False) print("Example test dataset downloaded.") - + data = pd.read_csv(output_path) - data = data[data.move_ply > 10][['board', 'move', 'active_elo', 'opponent_elo']] - - return data - -def load_example_train_dataset(save_root = "./maia2_data"): - - url = "https://drive.google.com/uc?id=1XBeuhB17z50mFK4tDvPG9rQRbxLSzNqB" - if os.path.exists(save_root) == False: + filtered_data = data[data.move_ply > 10][ + ["board", "move", "active_elo", "opponent_elo"] + ] + + return filtered_data + + +def load_example_train_dataset(save_root: str = DEFAULT_SAVE_ROOT) -> str: + """Download example training dataset. + + Args: + save_root: Directory to save dataset. + + Returns: + Path to downloaded training dataset CSV file. + + Raises: + OSError: If download or directory creation fails. + """ + if not os.path.exists(save_root): os.makedirs(save_root) + output_path = os.path.join(save_root, "example_train_dataset.csv") - + if os.path.exists(output_path): print("Example train dataset already downloaded.") else: - gdown.download(url, output_path, quiet=False) + gdown.download(TRAIN_DATASET_URL, output_path, quiet=False) print("Example train dataset downloaded.") - + return output_path - \ No newline at end of file diff --git a/maia2/inference.py b/maia2/inference.py index f4f2cfa..8a52888 100644 --- a/maia2/inference.py +++ b/maia2/inference.py @@ -1,59 +1,147 @@ -from .utils import * -from .main import * +"""Inference functions for MAIA2 model. -def preprocessing(fen, elo_self, elo_oppo, elo_dict, all_moves_dict): - - if fen.split(' ')[1] == 'w': +Provides preprocessing, dataset creation, and inference utilities +for running predictions on chess positions. +""" + +from typing import Dict, List, Tuple, cast + +import chess +import pandas as pd +import torch +import torch.utils.data +import tqdm + +from .main import MAIA2Model +from .utils import ( + BoardPosition, + ChessMove, + EloRangeDict, + EloRating, + MovesDict, + ReverseMovesDict, + board_to_tensor, + create_elo_dict, + get_all_possible_moves, + map_to_category, + mirror_move, +) + +TestDatasetItem = Tuple[BoardPosition, torch.Tensor, + EloRating, EloRating, torch.Tensor] +DictMoveProb = Dict[ChessMove, float] +PreparedDicts = Tuple[MovesDict, EloRangeDict, ReverseMovesDict] +DeprecatedPreparedDicts = List[(MovesDict | EloRangeDict | ReverseMovesDict)] + + +def preprocessing( + fen: BoardPosition, + elo_self: EloRating, + elo_oppo: EloRating, + elo_dict: EloRangeDict, + all_moves_dict: MovesDict, +) -> Tuple[torch.Tensor, EloRating, EloRating, torch.Tensor]: + """Preprocess FEN and Elo ratings into model tensors. + + Args: + fen: FEN string of chess position. + elo_self: Elo rating of active player. + elo_oppo: Elo rating of opponent. + elo_dict: Mapping of Elo ratings to categories. + all_moves_dict: Mapping of moves to indices. + + Returns: + Tuple of (board_tensor, elo_self_cat, elo_oppo_cat, legal_moves_mask). + """ + if fen.split(" ")[1] == "w": board = chess.Board(fen) - elif fen.split(' ')[1] == 'b': + elif fen.split(" ")[1] == "b": board = chess.Board(fen).mirror() else: raise ValueError(f"Invalid fen: {fen}") - + board_input = board_to_tensor(board) - + elo_self = map_to_category(elo_self, elo_dict) elo_oppo = map_to_category(elo_oppo, elo_dict) - - legal_moves = torch.zeros(len(all_moves_dict)) - legal_moves_idx = torch.tensor([all_moves_dict[move.uci()] for move in board.legal_moves]) + + legal_moves = torch.zeros(len(all_moves_dict), dtype=torch.float32) + legal_moves_idx = torch.tensor( + [all_moves_dict[move.uci()] for move in board.legal_moves] + ) legal_moves[legal_moves_idx] = 1 - + return board_input, elo_self, elo_oppo, legal_moves -class TestDataset(torch.utils.data.Dataset): - - def __init__(self, data, all_moves_dict, elo_dict): - +class TestDataset(torch.utils.data.Dataset[TestDatasetItem]): + """PyTorch Dataset for MAIA2 test data.""" + + def __init__( + self, + data: pd.DataFrame, + all_moves_dict: MovesDict, + elo_dict: EloRangeDict, + ): + """Initialize dataset. + + Args: + data: DataFrame with [fen, move, elo_self, elo_oppo]. + all_moves_dict: UCI moves to model indices. + elo_dict: Raw Elo to binned categories. + """ self.all_moves_dict = all_moves_dict self.data = data.values.tolist() self.elo_dict = elo_dict - - def __len__(self): - + + def __len__(self) -> int: + """Return number of samples.""" return len(self.data) - - def __getitem__(self, idx): - + + def __getitem__(self, idx: int) -> TestDatasetItem: + """Get preprocessed tensors for position. + + Args: + idx: Sample index. + + Returns: + Tuple of (fen, board_tensor, elo_self_cat, elo_oppo_cat, legal_moves_mask). + """ fen, _, elo_self, elo_oppo = self.data[idx] - board_input, elo_self, elo_oppo, legal_moves = preprocessing(fen, elo_self, elo_oppo, self.elo_dict, self.all_moves_dict) - + board_input, elo_self, elo_oppo, legal_moves = preprocessing( + fen, elo_self, elo_oppo, self.elo_dict, self.all_moves_dict + ) + return fen, board_input, elo_self, elo_oppo, legal_moves -def get_preds(model, dataloader, all_moves_dict_reversed): - - move_probs = [] - win_probs = [] - + +def get_preds( + model: MAIA2Model, + dataloader: torch.utils.data.DataLoader, + all_moves_dict_reversed: ReverseMovesDict, +) -> Tuple[List[DictMoveProb], List[float]]: + """Compute move and win probabilities for dataset. + + Args: + model: Trained MAIA2 model. + dataloader: DataLoader yielding test data. + all_moves_dict_reversed: Move indices to UCI strings. + + Returns: + Tuple of (move_probs_list, win_probs_list). + move_probs_list: List of dicts mapping UCI moves to probabilities. + win_probs_list: List of win probabilities. + """ + move_probs: List[DictMoveProb] = [] + win_probs: List[float] = [] + device = next(model.parameters()).device - + model.eval() + with torch.no_grad(): - for fens, boards, elos_self, elos_oppo, legal_moves in dataloader: - boards = boards.to(device) elos_self = elos_self.to(device) elos_oppo = elos_oppo.to(device) @@ -62,110 +150,169 @@ def get_preds(model, dataloader, all_moves_dict_reversed): logits_maia, _, logits_value = model(boards, elos_self, elos_oppo) logits_maia_legal = logits_maia * legal_moves probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() - logits_value = (logits_value / 2 + 0.5).clamp(0, 1).cpu().tolist() - - for i in range(len(fens)): - - fen = fens[i] + + for i, fen in enumerate(fens): black_flag = False - + # calculate win probability logit_value = logits_value[i] if fen.split(" ")[1] == "b": logit_value = 1 - logit_value black_flag = True win_probs.append(round(logit_value, 4)) - + # calculate move probabilities move_probs_each = {} - legal_move_indices = legal_moves[i].nonzero().flatten().cpu().numpy().tolist() + legal_move_indices = ( + legal_moves[i].nonzero().flatten().cpu().numpy().tolist() + ) legal_moves_mirrored = [] for move_idx in legal_move_indices: move = all_moves_dict_reversed[move_idx] if black_flag: move = mirror_move(move) legal_moves_mirrored.append(move) - - for j in range(len(legal_move_indices)): - move_probs_each[legal_moves_mirrored[j]] = round(probs[i][legal_move_indices[j]], 4) - - move_probs_each = dict(sorted(move_probs_each.items(), key=lambda item: item[1], reverse=True)) + + for j, legal_move_index in enumerate(legal_move_indices): + move_probs_each[legal_moves_mirrored[j]] = round( + probs[i][legal_move_index], 4 + ) + + move_probs_each = dict( + sorted( + move_probs_each.items(), key=lambda item: item[1], reverse=True + ) + ) move_probs.append(move_probs_each) - + return move_probs, win_probs -def inference_batch(data, model, verbose, batch_size, num_workers): +def inference_batch( + data: pd.DataFrame, + model: MAIA2Model, + verbose: bool, + batch_size: int, + num_workers: int, +) -> Tuple[pd.DataFrame, float]: + """Run inference on batch of chess positions. + + Args: + data: DataFrame with [fen, move, elo_self, elo_oppo]. + model: Trained MAIA2 model. + verbose: Show progress bar if True. + batch_size: Batch size for DataLoader. + num_workers: Number of DataLoader workers. + Returns: + Tuple of (updated_dataframe, accuracy). + updated_dataframe: Input data with added win_probs and move_probs columns. + accuracy: Move prediction accuracy. + """ all_moves = get_all_possible_moves() all_moves_dict = {move: i for i, move in enumerate(all_moves)} elo_dict = create_elo_dict() - all_moves_dict_reversed = {v: k for k, v in all_moves_dict.items()} + dataset = TestDataset(data, all_moves_dict, elo_dict) - dataloader = torch.utils.data.DataLoader(dataset, - batch_size=batch_size, - shuffle=False, - drop_last=False, - num_workers=num_workers) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers, + ) + if verbose: - dataloader = tqdm.tqdm(dataloader) - - move_probs, win_probs = get_preds(model, dataloader, all_moves_dict_reversed) - + dataloader = cast( + torch.utils.data.DataLoader[TestDatasetItem], tqdm.tqdm(dataloader) + ) + + move_probs, win_probs = get_preds( + model, dataloader, all_moves_dict_reversed) + data["win_probs"] = win_probs data["move_probs"] = move_probs - + acc = 0 for i in range(len(data)): - highest_prob_move = max(move_probs[i], key=move_probs[i].get) + highest_prob_move, _highest_prob = max( + move_probs[i].items(), key=lambda item: item[1] + ) + if highest_prob_move == data.iloc[i]["move"]: acc += 1 - acc = round(acc / len(data), 4) - - return data, acc + accuracy = round(acc / len(data), 4) + + return data, accuracy -def prepare(): +def prepare() -> DeprecatedPreparedDicts: + """Initialize dictionaries for model inference. + Returns: + List of [all_moves_dict, elo_dict, all_moves_dict_reversed]. + all_moves_dict: UCI moves to indices. + elo_dict: Raw Elo to categories. + all_moves_dict_reversed: Indices to UCI moves. + """ all_moves = get_all_possible_moves() all_moves_dict = {move: i for i, move in enumerate(all_moves)} elo_dict = create_elo_dict() - all_moves_dict_reversed = {v: k for k, v in all_moves_dict.items()} - + return [all_moves_dict, elo_dict, all_moves_dict_reversed] -def inference_each(model, prepared, fen, elo_self, elo_oppo): - - all_moves_dict, elo_dict, all_moves_dict_reversed = prepared - - board_input, elo_self, elo_oppo, legal_moves = preprocessing(fen, elo_self, elo_oppo, elo_dict, all_moves_dict) - +def inference_each( + model: MAIA2Model, + prepared: PreparedDicts | DeprecatedPreparedDicts, + fen: BoardPosition, + elo_self: EloRating, + elo_oppo: EloRating, +) -> Tuple[DictMoveProb, float]: + """Analyze single chess position with MAIA2. + + Args: + model: Trained MAIA2 model. + prepared: Tuple from prepare() with mapping dicts. + fen: FEN string of position. + elo_self: Elo of player to move. + elo_oppo: Elo of opponent. + + Returns: + Tuple of (move_probs, win_prob). + move_probs: Dict mapping UCI moves to probabilities (sorted). + win_prob: Win probability (0-1). + """ + all_moves_dict, elo_dict, all_moves_dict_reversed = cast( + PreparedDicts, prepared) + board_input, elo_self, elo_oppo, legal_moves = preprocessing( + fen, elo_self, elo_oppo, elo_dict, all_moves_dict + ) + device = next(model.parameters()).device - model.eval() - + board_input = board_input.unsqueeze(dim=0).to(device) - elo_self = torch.tensor([elo_self]).to(device) - elo_oppo = torch.tensor([elo_oppo]).to(device) + elo_self_tensor = torch.tensor([elo_self]).to(device) + elo_oppo_tensor = torch.tensor([elo_oppo]).to(device) legal_moves = legal_moves.unsqueeze(dim=0).to(device) - - logits_maia, _, logits_value = model(board_input, elo_self, elo_oppo) + + logits_maia, _, logits_value = model( + board_input, elo_self_tensor, elo_oppo_tensor) logits_maia_legal = logits_maia * legal_moves probs = logits_maia_legal.softmax(dim=-1).cpu().tolist() - logits_value = (logits_value / 2 + 0.5).clamp(0, 1).item() - + black_flag = False if fen.split(" ")[1] == "b": logits_value = 1 - logits_value black_flag = True win_prob = round(logits_value, 4) - - move_probs = {} + + move_probs: DictMoveProb = {} legal_move_indices = legal_moves.nonzero().flatten().cpu().numpy().tolist() legal_moves_mirrored = [] for move_idx in legal_move_indices: @@ -173,11 +320,13 @@ def inference_each(model, prepared, fen, elo_self, elo_oppo): if black_flag: move = mirror_move(move) legal_moves_mirrored.append(move) - - for j in range(len(legal_move_indices)): - move_probs[legal_moves_mirrored[j]] = round(probs[0][legal_move_indices[j]], 4) - - move_probs = dict(sorted(move_probs.items(), key=lambda item: item[1], reverse=True)) - - return move_probs, win_prob + for j, legal_move_index in enumerate(legal_move_indices): + move_probs[legal_moves_mirrored[j]] = round( + probs[0][legal_move_index], 4) + + move_probs = dict( + sorted(move_probs.items(), key=lambda item: item[1], reverse=True) + ) + + return move_probs, win_prob diff --git a/maia2/main.py b/maia2/main.py index 32b061d..5e21ec4 100644 --- a/maia2/main.py +++ b/maia2/main.py @@ -1,141 +1,239 @@ -import chess.pgn +# pylint: disable=too-many-lines +"""Main training and model implementation for MAIA2. + +Provides core functionality including data processing, model architecture +(ResNet + Transformer), training loop, and evaluation utilities. +""" + +import threading +from multiprocessing import Pool, Queue +from typing import Any, Dict, List, Optional, Tuple, cast + import chess -import pdb -from multiprocessing import Pool, cpu_count, Queue, Process +import chess.pgn +import pandas as pd import torch -import tqdm -from .utils import * -import torch.nn as nn import torch.nn.functional as F -from tqdm.contrib.concurrent import process_map -import os -import pandas as pd -import time +import tqdm from einops import rearrange +from torch import nn +from tqdm.contrib.concurrent import process_map - -def process_chunks(cfg, pgn_path, pgn_chunks, elo_dict): - +from .utils import ( + BoardPosition, + ChessMove, + Chunk, + Config, + EloRangeDict, + EloRating, + FileOffset, + MovesDict, + board_to_tensor, + extract_clock_time, + get_side_info, + map_to_category, + mirror_move, +) + +# Type aliases +ResultScore = int +EloPair = Tuple[EloRating, EloRating] +DictFrequency = Dict[EloPair, int] +TrainingPositionData = Tuple[ + BoardPosition, ChessMove, EloRating, EloRating, ResultScore +] +ProcessPosition = Tuple[List[TrainingPositionData], int, DictFrequency] +ProcessChunks = Tuple[List[ProcessPosition], int, int] +GameResult = Tuple[chess.pgn.Game, EloRating, EloRating, ResultScore] +ModelOutput = Tuple[torch.Tensor, torch.Tensor, torch.Tensor] +MAIA1DatasetItem = Tuple[torch.Tensor, int, + int, int, torch.Tensor, torch.Tensor] +MAIA2DatasetItem = Tuple[torch.Tensor, int, + int, int, torch.Tensor, torch.Tensor, int] + + +def process_chunks( + cfg: Config, + pgn_path: str, + pgn_chunks: List[Chunk], + elo_dict: EloRangeDict, +) -> ProcessChunks: + """Process PGN file chunks in parallel. + + Args: + cfg: Configuration with processing parameters. + pgn_path: Path to PGN file. + pgn_chunks: List of (start_pos, end_pos) byte positions. + elo_dict: Elo ratings to category indices. + + Returns: + Tuple of (processed_positions, valid_games_count, chunks_count). + """ # process_per_chunk((pgn_chunks[0][0], pgn_chunks[0][1], pgn_path, elo_dict, cfg)) - + if cfg.verbose: - results = process_map(process_per_chunk, [(start, end, pgn_path, elo_dict, cfg) for start, end in pgn_chunks], max_workers=len(pgn_chunks), chunksize=1) + results = process_map( + process_per_chunk, + [(start, end, pgn_path, elo_dict, cfg) + for start, end in pgn_chunks], + max_workers=len(pgn_chunks), + chunksize=1, + ) else: with Pool(processes=len(pgn_chunks)) as pool: - results = pool.map(process_per_chunk, [(start, end, pgn_path, elo_dict, cfg) for start, end in pgn_chunks]) - - ret = [] + results = pool.map( + process_per_chunk, + [(start, end, pgn_path, elo_dict, cfg) + for start, end in pgn_chunks], + ) + + ret: List[ProcessPosition] = [] count = 0 - list_of_dicts = [] + list_of_dicts: List[DictFrequency] = [] for result, game_count, frequency in results: ret.extend(result) count += game_count list_of_dicts.append(frequency) - - total_counts = {} + + total_counts: DictFrequency = {} for d in list_of_dicts: for key, value in d.items(): total_counts[key] = total_counts.get(key, 0) + value print(total_counts, flush=True) - + return ret, count, len(pgn_chunks) -def process_per_game(game, white_elo, black_elo, white_win, cfg): +def process_per_game( + game: chess.pgn.Game, + white_elo: EloRating, + black_elo: EloRating, + white_win: ResultScore, + cfg: Config, +) -> List[TrainingPositionData]: + """Extract training positions from single game. + + Args: + game: Chess game with move history. + white_elo: White's Elo category index. + black_elo: Black's Elo category index. + white_win: Result from white's perspective (+1/0/-1). + cfg: Configuration with first_n_moves, clock_threshold, max_ply. + + Returns: + List of (fen, move_uci, elo_self, elo_oppo, result) tuples. + """ + ret: List[TrainingPositionData] = [] - ret = [] - board = game.board() moves = list(game.mainline_moves()) - + for i, node in enumerate(game.mainline()): - move = moves[i] - + if i >= cfg.first_n_moves: - comment = node.comment clock_info = extract_clock_time(comment) - - if i % 2 == 0: + + if i % 2 == 0: # White to move board_input = board.fen() move_input = move.uci() elo_self = white_elo elo_oppo = black_elo active_win = white_win - - else: + else: # Black to move board_input = board.mirror().fen() move_input = mirror_move(move.uci()) elo_self = black_elo elo_oppo = white_elo - active_win = - white_win + active_win = -white_win + + if clock_info and clock_info > cfg.clock_threshold: + ret.append((board_input, move_input, + elo_self, elo_oppo, active_win)) - if clock_info > cfg.clock_threshold: - ret.append((board_input, move_input, elo_self, elo_oppo, active_win)) - board.push(move) if i == cfg.max_ply: break - + return ret -def game_filter(game): - - white_elo = game.headers.get("WhiteElo", "?") - black_elo = game.headers.get("BlackElo", "?") +def game_filter(game: chess.pgn.Game) -> Optional[GameResult]: + """Filter games based on metadata and format. + + Args: + game: Chess game with headers and moves. + + Returns: + Tuple of (game, white_elo, black_elo, white_win) if valid, else None. + Returns None if game fails any criteria. + """ + white_elo_str = game.headers.get("WhiteElo", "?") + black_elo_str = game.headers.get("BlackElo", "?") time_control = game.headers.get("TimeControl", "?") result = game.headers.get("Result", "?") event = game.headers.get("Event", "?") - - if white_elo == "?" or black_elo == "?" or time_control == "?" or result == "?" or event == "?": - return - - if 'Rated' not in event: - return - - if 'Rapid' not in event: - return - + + if ( + white_elo_str == "?" + or black_elo_str == "?" + or time_control == "?" + or result == "?" + or event == "?" + ): + return None + + if "Rated" not in event: + return None + + if "Rapid" not in event: + return None + for _, node in enumerate(game.mainline()): - if 'clk' not in node.comment: - return - - white_elo = int(white_elo) - black_elo = int(black_elo) + if "clk" not in node.comment: + return None + + white_elo = int(white_elo_str) + black_elo = int(black_elo_str) - if result == '1-0': + if result == "1-0": white_win = 1 - elif result == '0-1': + elif result == "0-1": white_win = -1 - elif result == '1/2-1/2': + elif result == "1/2-1/2": white_win = 0 else: - return - + return None + return game, white_elo, black_elo, white_win -def process_per_chunk(args): +def process_per_chunk( + args: Tuple[FileOffset, FileOffset, str, EloRangeDict, Config], +) -> ProcessPosition: + """Process chunk of games from PGN file. + + Args: + args: Tuple of (start_pos, end_pos, pgn_path, elo_dict, cfg). + Returns: + Tuple of (position_list, game_count, frequency_dict). + """ start_pos, end_pos, pgn_path, elo_dict, cfg = args - - ret = [] + + ret: List[TrainingPositionData] = [] game_count = 0 - - frequency = {} - - with open(pgn_path, 'r', encoding='utf-8') as pgn_file: - + frequency: DictFrequency = {} + + with open(pgn_path, "r", encoding="utf-8") as pgn_file: pgn_file.seek(start_pos) while pgn_file.tell() < end_pos: - game = chess.pgn.read_game(pgn_file) - + if game is None: break @@ -144,45 +242,68 @@ def process_per_chunk(args): game, white_elo, black_elo, white_win = filtered_game white_elo = map_to_category(white_elo, elo_dict) black_elo = map_to_category(black_elo, elo_dict) - + + # Ensure consistent Elo pair ordering if white_elo < black_elo: range_1, range_2 = black_elo, white_elo else: range_1, range_2 = white_elo, black_elo - + freq = frequency.get((range_1, range_2), 0) if freq >= cfg.max_games_per_elo_range: continue - - ret_per_game = process_per_game(game, white_elo, black_elo, white_win, cfg) + + ret_per_game = process_per_game( + game, white_elo, black_elo, white_win, cfg + ) ret.extend(ret_per_game) - if len(ret_per_game): - + if len(ret_per_game) > 0: if (range_1, range_2) in frequency: frequency[(range_1, range_2)] += 1 else: frequency[(range_1, range_2)] = 1 - + game_count += 1 - + return ret, game_count, frequency -class MAIA1Dataset(torch.utils.data.Dataset): - - def __init__(self, data, all_moves_dict, elo_dict, cfg): - +class MAIA1Dataset(torch.utils.data.Dataset[MAIA1DatasetItem]): + """Dataset for MAIA1 evaluation data.""" + + def __init__( + self, + data: pd.DataFrame, + all_moves_dict: MovesDict, + elo_dict: EloRangeDict, + cfg: Config, + ) -> None: + """Initialize dataset from DataFrame. + + Args: + data: DataFrame with [board, move, active_elo, opponent_elo, white_active]. + all_moves_dict: UCI moves to model indices. + elo_dict: Elo ratings to categories. + cfg: Configuration object. + """ self.all_moves_dict = all_moves_dict self.cfg = cfg self.data = data.values.tolist() self.elo_dict = elo_dict - - def __len__(self): - + + def __len__(self) -> int: + """Return number of positions.""" return len(self.data) - - def __getitem__(self, idx): - + + def __getitem__(self, idx: int) -> MAIA1DatasetItem: + """Get single training example. + + Args: + idx: Position index. + + Returns: + Tuple of (board, move_idx, elo_self, elo_oppo, legal_moves, side_info). + """ fen, move, elo_self, elo_oppo, white_active = self.data[idx] if white_active: @@ -190,60 +311,107 @@ def __getitem__(self, idx): else: board = chess.Board(fen).mirror() move = mirror_move(move) - + board_input = board_to_tensor(board) move_input = self.all_moves_dict[move] - + elo_self = map_to_category(elo_self, self.elo_dict) elo_oppo = map_to_category(elo_oppo, self.elo_dict) - - legal_moves, side_info = get_side_info(board, move, self.all_moves_dict) - + + legal_moves, side_info = get_side_info( + board, move, self.all_moves_dict) + return board_input, move_input, elo_self, elo_oppo, legal_moves, side_info -class MAIA2Dataset(torch.utils.data.Dataset): - - - def __init__(self, data, all_moves_dict, cfg): - +class MAIA2Dataset(torch.utils.data.Dataset[MAIA2DatasetItem]): + """Dataset for MAIA2 training data.""" + + def __init__( + self, + data: List[TrainingPositionData], + all_moves_dict: MovesDict, + cfg: Config, + ) -> None: + """Initialize dataset from processed games. + + Args: + data: List of (fen, move_uci, elo_self, elo_oppo, result) tuples. + all_moves_dict: UCI moves to model indices. + cfg: Configuration object. + """ self.all_moves_dict = all_moves_dict self.data = data self.cfg = cfg - - def __len__(self): - + + def __len__(self) -> int: + """Return number of positions.""" return len(self.data) - - def __getitem__(self, idx): - + + def __getitem__(self, idx: int) -> MAIA2DatasetItem: + """Get single training example. + + Args: + idx: Position index. + + Returns: + Tuple of (board, move_idx, elo_self, elo_oppo, legal_moves, side_info, result). + """ board_input, move_uci, elo_self, elo_oppo, active_win = self.data[idx] - + board = chess.Board(board_input) - board_input = board_to_tensor(board) - - legal_moves, side_info = get_side_info(board, move_uci, self.all_moves_dict) - + board_input_tensor = board_to_tensor(board) + + legal_moves, side_info = get_side_info( + board, move_uci, self.all_moves_dict) + move_input = self.all_moves_dict[move_uci] - - return board_input, move_input, elo_self, elo_oppo, legal_moves, side_info, active_win + + return ( + board_input_tensor, + move_input, + elo_self, + elo_oppo, + legal_moves, + side_info, + active_win, + ) class BasicBlock(torch.nn.Module): + """Basic residual block with dropout.""" + + def __init__(self, in_planes: int, planes: int, stride: int = 1) -> None: + """Initialize block. - def __init__(self, in_planes, planes, stride=1): + Args: + in_planes: Input channels. + planes: Output channels. + stride: Convolution stride. + """ super(BasicBlock, self).__init__() - + mid_planes = planes - - self.conv1 = torch.nn.Conv2d(in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + self.conv1 = torch.nn.Conv2d( + in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) self.bn1 = torch.nn.BatchNorm2d(mid_planes) - self.conv2 = torch.nn.Conv2d(mid_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.conv2 = torch.nn.Conv2d( + mid_planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) self.bn2 = torch.nn.BatchNorm2d(planes) self.dropout = nn.Dropout(p=0.5) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through block. + Args: + x: Input tensor [batch, channels, height, width]. + + Returns: + Output tensor with same shape. + """ out = self.conv1(x) out = self.bn1(out) out = F.relu(out) @@ -258,36 +426,80 @@ def forward(self, x): class ChessResNet(torch.nn.Module): - - def __init__(self, block, cfg): + """ResNet-based CNN for chess board processing.""" + + def __init__(self, block: type, cfg: Config) -> None: + """Initialize CNN. + + Args: + block: Residual block class. + cfg: Configuration with network parameters. + """ super(ChessResNet, self).__init__() - - self.conv1 = torch.nn.Conv2d(cfg.input_channels, cfg.dim_cnn, kernel_size=3, stride=1, padding=1, bias=False) + + self.conv1 = torch.nn.Conv2d( + cfg.input_channels, + cfg.dim_cnn, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) self.bn1 = torch.nn.BatchNorm2d(cfg.dim_cnn) self.layers = self._make_layer(block, cfg.dim_cnn, cfg.num_blocks_cnn) - self.conv_last = torch.nn.Conv2d(cfg.dim_cnn, cfg.vit_length, kernel_size=3, stride=1, padding=1, bias=False) + self.conv_last = torch.nn.Conv2d( + cfg.dim_cnn, cfg.vit_length, kernel_size=3, stride=1, padding=1, bias=False + ) self.bn_last = torch.nn.BatchNorm2d(cfg.vit_length) - def _make_layer(self, block, planes, num_blocks, stride=1): - + def _make_layer( + self, block: type, planes: int, num_blocks: int, stride: int = 1 + ) -> torch.nn.Sequential: + """Create layer of stacked residual blocks. + + Args: + block: Residual block class. + planes: Number of channels. + num_blocks: Number of blocks to stack. + stride: Convolution stride. + + Returns: + Sequential container of blocks. + """ layers = [] for _ in range(num_blocks): layers.append(block(planes, planes, stride)) - + return torch.nn.Sequential(*layers) - def forward(self, x): - + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input [batch, channels, 8, 8]. + + Returns: + Output [batch, vit_length, 8, 8]. + """ out = F.relu(self.bn1(self.conv1(x))) out = self.layers(out) out = self.conv_last(out) out = self.bn_last(out) - + return out class FeedForward(nn.Module): - def __init__(self, dim, hidden_dim, dropout = 0.): + """MLP with normalization and dropout.""" + + def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0) -> None: + """Initialize feed-forward network. + + Args: + dim: Input/output dimension. + hidden_dim: Hidden layer dimension. + dropout: Dropout probability. + """ super().__init__() self.net = nn.Sequential( nn.LayerNorm(dim), @@ -295,68 +507,145 @@ def __init__(self, dim, hidden_dim, dropout = 0.): nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), - nn.Dropout(dropout) + nn.Dropout(dropout), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input [batch, seq_len, dim]. + + Returns: + Output [batch, seq_len, dim]. + """ return self.net(x) class EloAwareAttention(nn.Module): - def __init__(self, dim, heads=8, dim_head=64, dropout=0., elo_dim=64): + """Multi-head attention with Elo conditioning.""" + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + elo_dim: int = 64, + ) -> None: + """Initialize attention layer. + + Args: + dim: Input dimension. + heads: Number of attention heads. + dim_head: Dimension per head. + dropout: Dropout probability. + elo_dim: Elo embedding dimension. + """ super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.norm = nn.LayerNorm(dim) - self.attend = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) - self.elo_query = nn.Linear(elo_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, dim), - nn.Dropout(dropout) - ) if project_out else nn.Identity() + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) - def forward(self, x, elo_emb): + def forward(self, x: torch.Tensor, elo_emb: torch.Tensor) -> torch.Tensor: + """Forward pass with Elo conditioning. + + Args: + x: Input sequence [batch, seq_len, dim]. + elo_emb: Elo embeddings [batch, elo_dim]. + + Returns: + Output [batch, seq_len, dim]. + """ x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) + q, k, v = map(lambda t: rearrange( + t, "b n (h d) -> b h n d", h=self.heads), qkv) + # Condition attention with Elo elo_effect = self.elo_query(elo_emb).view(x.size(0), self.heads, 1, -1) q = q + elo_effect + # Scaled dot-product attention dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale - attn = self.attend(dots) attn = self.dropout(attn) out = torch.matmul(attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., elo_dim=64): + """Transformer with Elo-aware attention.""" + + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + elo_dim: int = 64, + ) -> None: + """Initialize transformer. + + Args: + dim: Model dimension. + depth: Number of layers. + heads: Number of attention heads. + dim_head: Dimension per head. + mlp_dim: MLP hidden dimension. + dropout: Dropout probability. + elo_dim: Elo embedding dimension. + """ super().__init__() self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) self.elo_layers = nn.ModuleList([]) for _ in range(depth): - self.elo_layers.append(nn.ModuleList([ - EloAwareAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, elo_dim = elo_dim), - FeedForward(dim, mlp_dim, dropout = dropout) - ])) - - def forward(self, x, elo_emb): + self.elo_layers.append( + nn.ModuleList( + [ + EloAwareAttention( + dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + elo_dim=elo_dim, + ), + FeedForward(dim, mlp_dim, dropout=dropout), + ] + ) + ) + + def forward(self, x: torch.Tensor, elo_emb: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x: Input [batch, seq_len, dim]. + elo_emb: Elo embeddings [batch, elo_dim]. + + Returns: + Output [batch, seq_len, dim]. + """ for attn, ff in self.elo_layers: x = attn(x, elo_emb) + x x = ff(x) + x @@ -365,108 +654,177 @@ def forward(self, x, elo_emb): class MAIA2Model(torch.nn.Module): - - def __init__(self, output_dim, elo_dict, cfg): + """MAIA2 chess move prediction model. + + Hybrid CNN-Transformer with Elo-aware attention for move prediction. + """ + + def __init__(self, output_dim: int, elo_dict: EloRangeDict, cfg: Config) -> None: + """Initialize MAIA2 model. + + Args: + output_dim: Number of possible moves. + elo_dict: Elo ranges to indices. + cfg: Configuration with model parameters. + """ super(MAIA2Model, self).__init__() - + self.cfg = cfg self.chess_cnn = ChessResNet(BasicBlock, cfg) - + heads = 16 dim_head = 64 self.to_patch_embedding = nn.Sequential( nn.Linear(8 * 8, cfg.dim_vit), nn.LayerNorm(cfg.dim_vit), ) - self.transformer = Transformer(cfg.dim_vit, cfg.num_blocks_vit, heads, dim_head, mlp_dim=cfg.dim_vit, dropout = 0.1, elo_dim = cfg.elo_dim * 2) - self.pos_embedding = nn.Parameter(torch.randn(1, cfg.vit_length, cfg.dim_vit)) - + self.transformer = Transformer( + cfg.dim_vit, + cfg.num_blocks_vit, + heads, + dim_head, + mlp_dim=cfg.dim_vit, + dropout=0.1, + elo_dim=cfg.elo_dim * 2, + ) + self.pos_embedding = nn.Parameter( + torch.randn(1, cfg.vit_length, cfg.dim_vit)) + + # Output heads self.fc_1 = nn.Linear(cfg.dim_vit, output_dim) # self.fc_1_1 = nn.Linear(cfg.dim_vit, cfg.dim_vit) self.fc_2 = nn.Linear(cfg.dim_vit, output_dim + 6 + 6 + 1 + 64 + 64) # self.fc_2_1 = nn.Linear(cfg.dim_vit, cfg.dim_vit) self.fc_3 = nn.Linear(128, 1) self.fc_3_1 = nn.Linear(cfg.dim_vit, 128) - + self.elo_embedding = torch.nn.Embedding(len(elo_dict), cfg.elo_dim) - self.dropout = nn.Dropout(p=0.1) self.last_ln = nn.LayerNorm(cfg.dim_vit) + def forward( + self, boards: torch.Tensor, elos_self: torch.Tensor, elos_oppo: torch.Tensor + ) -> ModelOutput: + """Forward pass. + + Args: + boards: Board tensors [batch, channels, 8, 8]. + elos_self: Player Elo indices [batch]. + elos_oppo: Opponent Elo indices [batch]. - def forward(self, boards, elos_self, elos_oppo): - + Returns: + Tuple of (move_logits, side_info_logits, value_logits). + """ batch_size = boards.size(0) boards = boards.view(batch_size, self.cfg.input_channels, 8, 8) + + # Process board with CNN embs = self.chess_cnn(boards) embs = embs.view(batch_size, embs.size(1), 8 * 8) x = self.to_patch_embedding(embs) x += self.pos_embedding x = self.dropout(x) - + + # Combine Elo embeddings and process elos_emb_self = self.elo_embedding(elos_self) elos_emb_oppo = self.elo_embedding(elos_oppo) elos_emb = torch.cat((elos_emb_self, elos_emb_oppo), dim=1) x = self.transformer(x, elos_emb).mean(dim=1) - x = self.last_ln(x) + # Generate predictions logits_maia = self.fc_1(x) logits_side_info = self.fc_2(x) - logits_value = self.fc_3(self.dropout(torch.relu(self.fc_3_1(x)))).squeeze(dim=-1) - + logits_value = self.fc_3(self.dropout(torch.relu(self.fc_3_1(x)))).squeeze( + dim=-1 + ) + return logits_maia, logits_side_info, logits_value -def read_monthly_data_path(cfg): - - print('Training Data:', flush=True) - pgn_paths = [] - +def read_monthly_data_path(cfg: Config) -> List[str]: + """Get paths to monthly PGN files in date range. + + Args: + cfg: Configuration with start_year, end_year, start_month, end_month, data_root. + + Returns: + List of PGN file paths. + """ + print("Training Data:", flush=True) + pgn_paths: List[str] = [] + for year in range(cfg.start_year, cfg.end_year + 1): start_month = cfg.start_month if year == cfg.start_year else 1 end_month = cfg.end_month if year == cfg.end_year else 12 for month in range(start_month, end_month + 1): formatted_month = f"{month:02d}" - pgn_path = cfg.data_root + f"/lichess_db_standard_rated_{year}-{formatted_month}.pgn" + pgn_path = ( + cfg.data_root + + f"/lichess_db_standard_rated_{year}-{formatted_month}.pgn" + ) # skip 2019-12 if year == 2019 and month == 12: continue print(pgn_path, flush=True) pgn_paths.append(pgn_path) - + return pgn_paths -def evaluate(model, dataloader): - +def evaluate( + model: MAIA2Model, dataloader: torch.utils.data.DataLoader[MAIA1DatasetItem] +) -> Tuple[int, int]: + """Evaluate model accuracy on dataset. + + Args: + model: MAIA2 model. + dataloader: DataLoader with evaluation data. + + Returns: + Tuple of (correct_predictions, total_positions). + """ counter = 0 correct_move = 0 - + model.eval() with torch.no_grad(): - - for boards, labels, elos_self, elos_oppo, legal_moves, side_info in dataloader: - + for boards, labels, elos_self, elos_oppo, legal_moves, _side_info in dataloader: boards = boards.cuda() labels = labels.cuda() elos_self = elos_self.cuda() elos_oppo = elos_oppo.cuda() legal_moves = legal_moves.cuda() - logits_maia, logits_side_info, logits_value = model(boards, elos_self, elos_oppo) + logits_maia, _logits_side_info, _logits_value = model( + boards, elos_self, elos_oppo + ) logits_maia_legal = logits_maia * legal_moves preds = logits_maia_legal.argmax(dim=-1) correct_move += (preds == labels).sum().item() - + counter += len(labels) return correct_move, counter -def evaluate_MAIA1_data(model, all_moves_dict, elo_dict, cfg, tiny=False): - +def evaluate_MAIA1_data( # pylint: disable=invalid-name + model: MAIA2Model, + all_moves_dict: MovesDict, + elo_dict: EloRangeDict, + cfg: Config, + tiny: bool = False, +) -> None: + """Evaluate model on MAIA1 test dataset. + + Args: + model: MAIA2 model. + all_moves_dict: Moves to indices mapping. + elo_dict: Elo rating binning. + cfg: Configuration object. + tiny: Test only first Elo range if True. + """ elo_list = range(1000, 2600, 100) for i in elo_list: @@ -474,40 +832,85 @@ def evaluate_MAIA1_data(model, all_moves_dict, elo_dict, cfg, tiny=False): end = i + 100 file_path = f"../data/test/KDDTest_{start}-{end}.csv" data = pd.read_csv(file_path) - data = data[data.type == 'Rapid'][['board', 'move', 'active_elo', 'opponent_elo', 'white_active']] + data = data[data.type == "Rapid"][ + ["board", "move", "active_elo", "opponent_elo", "white_active"] + ] dataset = MAIA1Dataset(data, all_moves_dict, elo_dict, cfg) - dataloader = torch.utils.data.DataLoader(dataset, - batch_size=cfg.batch_size, - shuffle=False, - drop_last=False, - num_workers=cfg.num_workers) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=cfg.batch_size, + shuffle=False, + drop_last=False, + num_workers=cfg.num_workers, + ) if cfg.verbose: - dataloader = tqdm.tqdm(dataloader) - print(f'Testing Elo Range {start}-{end} with MAIA 1 data:', flush=True) + dataloader = cast( + torch.utils.data.DataLoader[MAIA1DatasetItem], tqdm.tqdm( + dataloader) + ) + print(f"Testing Elo Range {start}-{end} with MAIA 1 data:", flush=True) correct_move, counter = evaluate(model, dataloader) - print(f'Accuracy Move Prediction: {round(correct_move / counter, 4)}', flush=True) + print( + f"Accuracy Move Prediction: {round(correct_move / counter, 4)}", flush=True + ) if tiny: break -def train_chunks(cfg, data, model, optimizer, all_moves_dict, criterion_maia, criterion_side_info, criterion_value): - +def train_chunks( + cfg: Config, + data: List[TrainingPositionData], + model: torch.nn.DataParallel[MAIA2Model], + optimizer: torch.optim.Optimizer, + all_moves_dict: MovesDict, + criterion_maia: torch.nn.Module, + criterion_side_info: torch.nn.Module, + criterion_value: torch.nn.Module, +) -> Tuple[float, float, float, float]: + """Train model on batch of game chunks. + + Args: + cfg: Configuration with training parameters. + data: List of (fen, move, elos, result) tuples. + model: MAIA2 model. + optimizer: Optimizer. + all_moves_dict: Moves to indices. + criterion_maia: Move prediction loss. + criterion_side_info: Side info loss. + criterion_value: Value prediction loss. + + Returns: + Tuple of (total_loss, move_loss, side_info_loss, value_loss). + """ dataset_train = MAIA2Dataset(data, all_moves_dict, cfg) - dataloader_train = torch.utils.data.DataLoader(dataset_train, - batch_size=cfg.batch_size, - shuffle=True, - drop_last=False, - num_workers=cfg.num_workers) + dataloader_train = torch.utils.data.DataLoader( + dataset_train, + batch_size=cfg.batch_size, + shuffle=True, + drop_last=False, + num_workers=cfg.num_workers, + ) if cfg.verbose: - dataloader_train = tqdm.tqdm(dataloader_train) - - avg_loss = 0 - avg_loss_maia = 0 - avg_loss_side_info = 0 - avg_loss_value = 0 + dataloader_train = cast( + torch.utils.data.DataLoader[MAIA2DatasetItem], tqdm.tqdm( + dataloader_train) + ) + + avg_loss: float = 0 + avg_loss_maia: float = 0 + avg_loss_side_info: float = 0 + avg_loss_value: float = 0 step = 0 - for boards, labels, elos_self, elos_oppo, legal_moves, side_info, wdl in dataloader_train: + for ( + boards, + labels, + elos_self, + elos_oppo, + _legal_moves, + side_info, + wdl, + ) in dataloader_train: model.train() boards = boards.cuda() labels = labels.cuda() @@ -515,26 +918,30 @@ def train_chunks(cfg, data, model, optimizer, all_moves_dict, criterion_maia, cr elos_oppo = elos_oppo.cuda() side_info = side_info.cuda() wdl = wdl.float().cuda() - - logits_maia, logits_side_info, logits_value = model(boards, elos_self, elos_oppo) - - loss = 0 + + logits_maia, logits_side_info, logits_value = model( + boards, elos_self, elos_oppo + ) + loss_maia = criterion_maia(logits_maia, labels) - loss += loss_maia - + loss = 0 + loss_maia + if cfg.side_info: - - loss_side_info = criterion_side_info(logits_side_info, side_info) * cfg.side_info_coefficient + loss_side_info = ( + criterion_side_info(logits_side_info, side_info) + * cfg.side_info_coefficient + ) loss += loss_side_info - + if cfg.value: - loss_value = criterion_value(logits_value, wdl) * cfg.value_coefficient + loss_value = criterion_value( + logits_value, wdl) * cfg.value_coefficient loss += loss_value optimizer.zero_grad() loss.backward() optimizer.step() - + avg_loss += loss.item() avg_loss_maia += loss_maia.item() if cfg.side_info: @@ -542,18 +949,45 @@ def train_chunks(cfg, data, model, optimizer, all_moves_dict, criterion_maia, cr if cfg.value: avg_loss_value += loss_value.item() step += 1 - - return round(avg_loss / step, 3), round(avg_loss_maia / step, 3), round(avg_loss_side_info / step, 3), round(avg_loss_value / step, 3) - -def preprocess_thread(queue, cfg, pgn_path, pgn_chunks_sublist, elo_dict): - - data, game_count, chunk_count = process_chunks(cfg, pgn_path, pgn_chunks_sublist, elo_dict) + return ( + round(avg_loss / step, 3), + round(avg_loss_maia / step, 3), + round(avg_loss_side_info / step, 3), + round(avg_loss_value / step, 3), + ) + + +def preprocess_thread( + queue: Queue, + cfg: Config, + pgn_path: str, + pgn_chunks_sublist: List[Chunk], + elo_dict: EloRangeDict, +) -> None: + """Process PGN chunks in separate thread. + + Args: + queue: Queue for storing results. + cfg: Configuration object. + pgn_path: Path to PGN file. + pgn_chunks_sublist: List of chunk positions. + elo_dict: Elo rating binning. + """ + data, game_count, chunk_count = process_chunks( + cfg, pgn_path, pgn_chunks_sublist, elo_dict + ) queue.put([data, game_count, chunk_count]) del data -def worker_wrapper(semaphore, *args, **kwargs): +def worker_wrapper(semaphore: threading.Semaphore, *args: Any, **kwargs: Any) -> None: + """Thread worker with semaphore protection. + + Args: + semaphore: Semaphore for controlling access. + args: Positional arguments for preprocess_thread. + kwargs: Keyword arguments for preprocess_thread. + """ with semaphore: preprocess_thread(*args, **kwargs) - diff --git a/maia2/model.py b/maia2/model.py index 0a65e94..fb6d4ac 100644 --- a/maia2/model.py +++ b/maia2/model.py @@ -1,61 +1,93 @@ -import gdown +"""Model loading utilities for MAIA2. + +Provides functions to load pre-trained MAIA2 models +for blitz and rapid time controls. +""" + import os -from .main import MAIA2Model -from .utils import get_all_possible_moves, create_elo_dict, parse_args +import warnings +from typing import Final, Literal + +import gdown # type: ignore import torch from torch import nn -import warnings + +from .main import MAIA2Model +from .utils import create_elo_dict, get_all_possible_moves, parse_args + warnings.filterwarnings("ignore") -import pdb -def from_pretrained(type, device, save_root = "./maia2_models"): - - if os.path.exists(save_root) == False: +# Constants +DEFAULT_SAVE_ROOT: Final[str] = "./maia2_models" +CONFIG_URL: Final[str] = ( + "https://drive.google.com/uc?id=1GQTskYMVMubNwZH2Bi6AmevI15CS6gk0" +) +MODEL_BLITZ_URL: Final[str] = ( + "https://drive.google.com/uc?id=1X-Z4J3PX3MQFJoa8gRt3aL8CIH0PWoyt" +) +MODEL_RAPID_URL: Final[str] = ( + "https://drive.google.com/uc?id=1gbC1-c7c0EQOPPAVpGWubezeEW8grVwc" +) + + +def from_pretrained( + model_type: Literal["blitz", "rapid"], + device: Literal["gpu", "cpu"], + save_root: str = DEFAULT_SAVE_ROOT, +) -> MAIA2Model: + """Load pre-trained MAIA2 model. + + Args: + model_type: Type of model ("blitz" or "rapid"). + device: Device to load on ("gpu" or "cpu"). + save_root: Directory to save model files. + + Returns: + Loaded MAIA2 model. + + Raises: + ValueError: If model_type is invalid. + OSError: If download or directory creation fails. + RuntimeError: If model loading fails. + """ + if not os.path.exists(save_root): os.makedirs(save_root) - - if type == "blitz": - url = "https://drive.google.com/uc?id=1X-Z4J3PX3MQFJoa8gRt3aL8CIH0PWoyt" + + if model_type == "blitz": + url = MODEL_BLITZ_URL output_path = os.path.join(save_root, "blitz_model.pt") - - elif type == "rapid": - url = "https://drive.google.com/uc?id=1gbC1-c7c0EQOPPAVpGWubezeEW8grVwc" + + elif model_type == "rapid": + url = MODEL_RAPID_URL output_path = os.path.join(save_root, "rapid_model.pt") - + else: raise ValueError("Invalid model type. Choose between 'blitz' and 'rapid'.") if os.path.exists(output_path): - print(f"Model for {type} games already downloaded.") + print(f"Model for {model_type} games already downloaded.") else: - print(f"Downloading model for {type} games.") + print(f"Downloading model for {model_type} games.") gdown.download(url, output_path, quiet=False) - cfg_url = "https://drive.google.com/uc?id=1GQTskYMVMubNwZH2Bi6AmevI15CS6gk0" cfg_path = os.path.join(save_root, "config.yaml") if not os.path.exists(cfg_path): - gdown.download(cfg_url, cfg_path, quiet=False) + gdown.download(CONFIG_URL, cfg_path, quiet=False) cfg = parse_args(cfg_path) - all_moves = get_all_possible_moves() elo_dict = create_elo_dict() - model = MAIA2Model(len(all_moves), elo_dict, cfg) - model = nn.DataParallel(model) - - checkpoint = torch.load(output_path, map_location='cpu') - model.load_state_dict(checkpoint['model_state_dict']) - model = model.module - + maia2_model = MAIA2Model(len(all_moves), elo_dict, cfg) + model = nn.DataParallel(maia2_model) + + checkpoint = torch.load(output_path, map_location="cpu") + model.load_state_dict(checkpoint["model_state_dict"]) + model_module = model.module + if device == "gpu": - model = model.cuda() - - print(f"Model for {type} games loaded to {device}.") - - return model - - - - - - + model_module = model_module.cuda() + + print(f"Model for {model_type} games loaded to {device}.") + + return model_module diff --git a/maia2/requirements-dev.txt b/maia2/requirements-dev.txt new file mode 100644 index 0000000..df1f38b --- /dev/null +++ b/maia2/requirements-dev.txt @@ -0,0 +1,5 @@ +types-requests==2.32.4.20250913 +types-PyYAML==6.0.12.20250915 +types-tqdm==4.67.0.20250809 +pandas-stubs==2.3.2.250926 +types-pytz==2025.2.0.20250809 \ No newline at end of file diff --git a/maia2/requirements.txt b/maia2/requirements.txt index a94f2cc..64890c0 100644 --- a/maia2/requirements.txt +++ b/maia2/requirements.txt @@ -1,10 +1,10 @@ -chess==1.10.0 -einops==0.8.0 -gdown==5.2.0 -numpy==2.1.3 -pandas==2.2.3 -pyyaml==6.0.2 -pyzstd==0.15.9 -Requests==2.32.3 -torch==2.4.0 +chess==1.10.0 +einops==0.8.0 +gdown==5.2.0 +numpy==2.1.3 +pandas==2.2.3 +PyYAML==6.0.2 +pyzstd==0.15.9 +requests==2.32.3 +torch==2.4.0 tqdm==4.65.0 \ No newline at end of file diff --git a/maia2/train.py b/maia2/train.py index 570b4f8..f6e55bf 100644 --- a/maia2/train.py +++ b/maia2/train.py @@ -1,109 +1,203 @@ -import argparse +"""Training script for MAIA2. + +Handles complete training pipeline including data preprocessing, +model initialization, training loop, and checkpointing. +""" + import os -from multiprocessing import Process, Queue, cpu_count import time -from .utils import seed_everything, readable_time, readable_num, count_parameters -from .utils import get_all_possible_moves, create_elo_dict -from .utils import decompress_zst, read_or_create_chunks -from .main import MAIA2Model, preprocess_thread, train_chunks, read_monthly_data_path +from multiprocessing import Process, Queue, cpu_count +from typing import List + import torch import torch.nn as nn -import pdb +from .main import ( + MAIA2Model, + preprocess_thread, + read_monthly_data_path, + train_chunks, +) +from .utils import ( + Chunk, + Config, + count_parameters, + create_elo_dict, + decompress_zst, + get_all_possible_moves, + read_or_create_chunks, + readable_num, + readable_time, + seed_everything, +) -def run(cfg): - - print('Configurations:', flush=True) + +def run(cfg: Config) -> None: + """Execute complete MAIA2 training pipeline. + + Args: + cfg: Configuration object with training parameters. + Required attributes: seed, num_cpu_left, lr, batch_size, wd, + from_checkpoint, checkpoint_*, max_epochs, queue_length. + """ + # Print configuration + print("Configurations:", flush=True) for arg in vars(cfg): - print(f'\t{arg}: {getattr(cfg, arg)}', flush=True) + print(f"\t{arg}: {getattr(cfg, arg)}", flush=True) seed_everything(cfg.seed) + + # Set up multiprocessing num_processes = cpu_count() - cfg.num_cpu_left - save_root = f'../saves/{cfg.lr}_{cfg.batch_size}_{cfg.wd}/' + # Create save directory + save_root = f"../saves/{cfg.lr}_{cfg.batch_size}_{cfg.wd}/" if not os.path.exists(save_root): os.makedirs(save_root) + # Initialize dictionaries all_moves = get_all_possible_moves() all_moves_dict = {move: i for i, move in enumerate(all_moves)} elo_dict = create_elo_dict() - model = MAIA2Model(len(all_moves), elo_dict, cfg) + # Initialize model + maia2_model = MAIA2Model(len(all_moves), elo_dict, cfg) + + # Setup model for training + print(maia2_model, flush=True) + maia2_model = maia2_model.cuda() + model = nn.DataParallel(maia2_model) - print(model, flush=True) - model = model.cuda() - model = nn.DataParallel(model) + # Initialize loss functions criterion_maia = nn.CrossEntropyLoss() criterion_side_info = nn.BCEWithLogitsLoss() criterion_value = nn.MSELoss() - optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) - N_params = count_parameters(model) - print(f'Trainable Parameters: {N_params}', flush=True) + # Initialize optimizer + optimizer = torch.optim.AdamW( + model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) + n_params = count_parameters(model) + print(f"Trainable Parameters: {n_params}", flush=True) + # Initialize counters accumulated_samples = 0 accumulated_games = 0 + # Load checkpoint if requested if cfg.from_checkpoint: formatted_month = f"{cfg.checkpoint_month:02d}" - checkpoint = torch.load(save_root + f'epoch_{cfg.checkpoint_epoch}_{cfg.checkpoint_year}-{formatted_month}.pgn.pt') - model.load_state_dict(checkpoint['model_state_dict']) - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - accumulated_samples = checkpoint['accumulated_samples'] - accumulated_games = checkpoint['accumulated_games'] + checkpoint_path = ( + save_root + + f"epoch_{cfg.checkpoint_epoch}_{cfg.checkpoint_year}-{formatted_month}.pgn.pt" + ) + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + accumulated_samples = checkpoint["accumulated_samples"] + accumulated_games = checkpoint["accumulated_games"] + # Training loop for epoch in range(cfg.max_epochs): - - print(f'Epoch {epoch + 1}', flush=True) + print(f"Epoch {epoch + 1}", flush=True) pgn_paths = read_monthly_data_path(cfg) - + num_file = 0 for pgn_path in pgn_paths: - + # Decompress and prepare data start_time = time.time() - decompress_zst(pgn_path + '.zst', pgn_path) - print(f'Decompressing {pgn_path} took {readable_time(time.time() - start_time)}', flush=True) + decompress_zst(pgn_path + ".zst", pgn_path) + print( + f"Decompressing {pgn_path} took {readable_time(time.time() - start_time)}", + flush=True, + ) + # Read/create chunks pgn_chunks = read_or_create_chunks(pgn_path, cfg) - print(f'Training {pgn_path} with {len(pgn_chunks)} chunks.', flush=True) - - queue = Queue(maxsize=cfg.queue_length) - - pgn_chunks_sublists = [] + print( + f"Training {pgn_path} with {len(pgn_chunks)} chunks.", flush=True) + + # Setup multiprocessing queue + queue: Queue = Queue(maxsize=cfg.queue_length) + + # Split chunks for parallel processing + pgn_chunks_sublists: List[List[Chunk]] = [] for i in range(0, len(pgn_chunks), num_processes): - pgn_chunks_sublists.append(pgn_chunks[i:i + num_processes]) - + pgn_chunks_sublists.append(pgn_chunks[i: i + num_processes]) + + # Start first worker pgn_chunks_sublist = pgn_chunks_sublists[0] # For debugging only # process_chunks(cfg, pgn_path, pgn_chunks_sublist, elo_dict) - worker = Process(target=preprocess_thread, args=(queue, cfg, pgn_path, pgn_chunks_sublist, elo_dict)) + worker = Process( + target=preprocess_thread, + args=(queue, cfg, pgn_path, pgn_chunks_sublist, elo_dict), + ) worker.start() - + + # Process chunks and train num_chunk = 0 offset = 0 while True: if not queue.empty(): + # Start next worker if available if offset + 1 < len(pgn_chunks_sublists): pgn_chunks_sublist = pgn_chunks_sublists[offset + 1] - worker = Process(target=preprocess_thread, args=(queue, cfg, pgn_path, pgn_chunks_sublist, elo_dict)) + worker = Process( + target=preprocess_thread, + args=(queue, cfg, pgn_path, + pgn_chunks_sublist, elo_dict), + ) worker.start() offset += 1 + + # Get preprocessed data and train data, game_count, chunk_count = queue.get() - loss, loss_maia, loss_side_info, loss_value = train_chunks(cfg, data, model, optimizer, all_moves_dict, criterion_maia, criterion_side_info, criterion_value) + loss, loss_maia, loss_side_info, loss_value = train_chunks( + cfg, + data, + model, + optimizer, + all_moves_dict, + criterion_maia, + criterion_side_info, + criterion_value, + ) + + # Update counters and log num_chunk += chunk_count accumulated_samples += len(data) accumulated_games += game_count - print(f'[{num_chunk}/{len(pgn_chunks)}]', flush=True) - print(f'[# Positions]: {readable_num(accumulated_samples)}', flush=True) - print(f'[# Games]: {readable_num(accumulated_games)}', flush=True) - print(f'[# Loss]: {loss} | [# Loss MAIA]: {loss_maia} | [# Loss Side Info]: {loss_side_info} | [# Loss Value]: {loss_value}', flush=True) + print(f"[{num_chunk}/{len(pgn_chunks)}]", flush=True) + print( + f"[# Positions]: {readable_num(accumulated_samples)}", + flush=True, + ) + print( + f"[# Games]: {readable_num(accumulated_games)}", flush=True) + print( + f"[# Loss]: {loss} | [# Loss MAIA]: {loss_maia} | " + f"[# Loss Side Info]: {loss_side_info} | [# Loss Value]: {loss_value}", + flush=True, + ) if num_chunk == len(pgn_chunks): break + # Log completion and cleanup num_file += 1 - print(f'### ({num_file} / {len(pgn_paths)}) Took {readable_time(time.time() - start_time)} to train {pgn_path} with {len(pgn_chunks)} chunks.', flush=True) + print( + f"### ({num_file} / {len(pgn_paths)}) " + f"Took {readable_time(time.time() - start_time)} to train {pgn_path} " + f"with {len(pgn_chunks)} chunks.", + flush=True, + ) os.remove(pgn_path) - - torch.save({'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'accumulated_samples': accumulated_samples, - 'accumulated_games': accumulated_games}, f'{save_root}epoch_{epoch + 1}_{pgn_path[-11:]}.pt') + + # Save checkpoint + torch.save( + { + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "accumulated_samples": accumulated_samples, + "accumulated_games": accumulated_games, + }, + f"{save_root}epoch_{epoch + 1}_{pgn_path[-11:]}.pt", + ) diff --git a/maia2/utils.py b/maia2/utils.py index 7b64eb2..6b3db75 100644 --- a/maia2/utils.py +++ b/maia2/utils.py @@ -1,39 +1,126 @@ -import pdb -import chess -import pickle +"""Utility functions for MAIA2. + +Provides functions for chess position handling, Elo rating mapping, +model configuration, and data processing utilities. +""" + import os +import pickle import random -import numpy as np -import torch +import re import time -import requests -import tqdm +from typing import Dict, Final, Generator, List, Optional, Tuple, TypeVar, Union + +import chess +import numpy as np import pyzstd -import re +import torch import yaml +# Type aliases +BoardPosition = str +ChessMove = str +EloRating = int +TimeSeconds = float +FileOffset = int +EloRangeDict = Dict[str, int] +ConfigDict = Dict[str, Union[str, int, float, bool]] +MovesDict = Dict[ChessMove, int] +ReverseMovesDict = Dict[int, ChessMove] +Chunk = Tuple[FileOffset, FileOffset] +SideInfo = Tuple[torch.Tensor, torch.Tensor] + +# Constants +ELO_INTERVAL: Final[EloRating] = 100 +ELO_START: Final[EloRating] = 1100 +ELO_END: Final[EloRating] = 2000 +PIECE_TYPES: Final[List[chess.PieceType]] = [ + chess.PAWN, + chess.KNIGHT, + chess.BISHOP, + chess.ROOK, + chess.QUEEN, + chess.KING, +] + class Config: - - def __init__(self, config_dict): + """Dynamic configuration container for MAIA2.""" + + input_channels: int = 1 + elo_dim: int = 1 + dim_cnn: int = 1 + dim_vit: int = 1 + num_blocks_cnn: int = 1 + num_blocks_vit: int = 1 + vit_length: int = 1 + batch_size: Optional[int] + chunk_size: int = 1 + verbose: Optional[bool] + start_year: int = 1900 + end_year: int = 2027 + start_month: int = 1 + end_month: int = 12 + first_n_moves: int = 0 + clock_threshold: float = 0 + max_ply: Optional[int] + data_root: str = "~" + num_workers: int = 1 + side_info: Optional[int] + side_info_coefficient: Optional[float] + value: Optional[int] + value_coefficient: Optional[float] + seed: int = 123456789 + num_cpu_left: int = 1 + lr: float = 1e-3 + wd: float = 1e-2 + from_checkpoint: Optional[bool] + checkpoint_epoch: Optional[int] + checkpoint_year: Optional[str] + checkpoint_month: Optional[str] + max_epochs: int = 1 + queue_length: int = 1 + max_games_per_elo_range: int = 1 + + def __init__(self, config_dict: ConfigDict) -> None: + """Initialize from dictionary. + + Args: + config_dict: Configuration key-value pairs. + """ for key, value in config_dict.items(): setattr(self, key, value) -def parse_args(cfg_file_path): +def parse_args(cfg_file_path: str) -> Config: + """Parse YAML configuration file. - with open(cfg_file_path, 'r') as f: + Args: + cfg_file_path: Path to YAML config file. + + Returns: + Config object with settings. + + Raises: + OSError: If file cannot be read. + yaml.YAMLError: If YAML is malformed. + """ + with open(cfg_file_path, "r", encoding="utf-8") as f: cfg_dict = yaml.safe_load(f) - + cfg = Config(cfg_dict) return cfg -def seed_everything(seed: int): +def seed_everything(seed: int) -> None: + """Set random seeds for reproducibility. + Args: + seed: Seed value for all RNGs. + """ random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) + os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -42,8 +129,12 @@ def seed_everything(seed: int): torch.backends.cudnn.benchmark = False -def delete_file(filename): - +def delete_file(filename: str) -> None: + """Delete file if it exists. + + Args: + filename: Path to file. + """ if os.path.exists(filename): os.remove(filename) print(f"Data {filename} has been deleted.") @@ -51,20 +142,34 @@ def delete_file(filename): print(f"The file '{filename}' does not exist.") -def readable_num(num): - +def readable_num(num: int) -> str: + """Convert large number to readable format. + + Args: + num: Number to format. + + Returns: + Formatted string with suffix (K/M/B). + """ if num >= 1e9: # if parameters are in the billions - return f'{num / 1e9:.2f}B' + return f"{num / 1e9:.2f}B" elif num >= 1e6: # if parameters are in the millions - return f'{num / 1e6:.2f}M' + return f"{num / 1e6:.2f}M" elif num >= 1e3: # if parameters are in the thousands - return f'{num / 1e3:.2f}K' + return f"{num / 1e3:.2f}K" else: return str(num) -def readable_time(elapsed_time): +def readable_time(elapsed_time: TimeSeconds) -> str: + """Format elapsed time in readable format. + + Args: + elapsed_time: Duration in seconds. + Returns: + Formatted time string (e.g., "1h 30m 45.50s"). + """ hours, rem = divmod(elapsed_time, 3600) minutes, seconds = divmod(rem, 60) @@ -76,54 +181,89 @@ def readable_time(elapsed_time): return f"{seconds:.2f}s" -def count_parameters(model): - - total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - +def count_parameters(model: torch.nn.Module) -> str: + """Count trainable parameters in model. + + Args: + model: PyTorch model. + + Returns: + Formatted parameter count string. + """ + if not isinstance(model, torch.nn.Module): + raise TypeError(f"Expected PyTorch Module, got {type(model)}") + + total_params = sum(p.numel() + for p in model.parameters() if p.requires_grad) return readable_num(total_params) -def create_elo_dict(): - - inteval = 100 - start = 1100 - end = 2000 - - range_dict = {f"<{start}": 0} +def create_elo_dict() -> EloRangeDict: + """Create Elo rating ranges to category indices mapping. + + Returns: + Dict mapping Elo range strings to indices. + """ + range_dict: EloRangeDict = {f"<{ELO_START}": 0} range_index = 1 - for lower_bound in range(start, end - 1, inteval): - upper_bound = lower_bound + inteval + for lower_bound in range(ELO_START, ELO_END - 1, ELO_INTERVAL): + upper_bound = lower_bound + ELO_INTERVAL range_dict[f"{lower_bound}-{upper_bound - 1}"] = range_index range_index += 1 - range_dict[f">={end}"] = range_index - + range_dict[f">={ELO_END}"] = range_index + # print(range_dict, flush=True) - + return range_dict -def map_to_category(elo, elo_dict): +def map_to_category(elo: EloRating, elo_dict: EloRangeDict) -> int: + """Map Elo rating to category index. - inteval = 100 - start = 1100 - end = 2000 - - if elo < start: - return elo_dict[f"<{start}"] - elif elo >= end: - return elo_dict[f">={end}"] + Args: + elo: Player's Elo rating. + elo_dict: Elo ranges to indices mapping. + + Returns: + Category index for the rating. + + Raises: + TypeError: If elo is not integer. + ValueError: If elo cannot be categorized. + """ + if not isinstance(elo, int): + raise TypeError(f"Elo rating must be an integer, got {type(elo)}") + + if elo < ELO_START: + return elo_dict[f"<{ELO_START}"] + elif elo >= ELO_END: + return elo_dict[f">={ELO_END}"] else: - for lower_bound in range(start, end - 1, inteval): - upper_bound = lower_bound + inteval + for lower_bound in range(ELO_START, ELO_END - 1, ELO_INTERVAL): + upper_bound = lower_bound + ELO_INTERVAL if lower_bound <= elo < upper_bound: return elo_dict[f"{lower_bound}-{upper_bound - 1}"] + raise ValueError(f"Elo {elo} could not be categorized.") + -def get_side_info(board, move_uci, all_moves_dict): +def get_side_info( + board: chess.Board, move_uci: ChessMove, all_moves_dict: MovesDict +) -> SideInfo: + """Generate feature vectors for chess move. + + Args: + board: Current chess position. + move_uci: Move in UCI format. + all_moves_dict: UCI moves to indices. + + Returns: + Tuple of (legal_moves_mask, side_info_vector). + """ move = chess.Move.from_uci(move_uci) - + moving_piece = board.piece_at(move.from_square) captured_piece = board.piece_at(move.to_square) @@ -132,82 +272,115 @@ def get_side_info(board, move_uci, all_moves_dict): to_square_encoded = torch.zeros(64) to_square_encoded[move.to_square] = 1 - - if move_uci == 'e1g1': - rook_move = chess.Move.from_uci('h1f1') + + if move_uci == "e1g1": + rook_move = chess.Move.from_uci("h1f1") from_square_encoded[rook_move.from_square] = 1 to_square_encoded[rook_move.to_square] = 1 - - if move_uci == 'e1c1': - rook_move = chess.Move.from_uci('a1d1') + + if move_uci == "e1c1": + rook_move = chess.Move.from_uci("a1d1") from_square_encoded[rook_move.from_square] = 1 to_square_encoded[rook_move.to_square] = 1 board.push(move) is_check = board.is_check() board.pop() - + # Order: Pawn, Knight, Bishop, Rook, Queen, King side_info = torch.zeros(6 + 6 + 1) + assert moving_piece is not None side_info[moving_piece.piece_type - 1] = 1 - if move_uci in ['e1g1', 'e1c1']: + if move_uci in ["e1g1", "e1c1"]: side_info[3] = 1 if captured_piece: side_info[6 + captured_piece.piece_type - 1] = 1 if is_check: side_info[-1] = 1 - + legal_moves = torch.zeros(len(all_moves_dict)) - legal_moves_idx = torch.tensor([all_moves_dict[move.uci()] for move in board.legal_moves]) + legal_moves_idx = torch.tensor( + [all_moves_dict[move.uci()] for move in board.legal_moves] + ) legal_moves[legal_moves_idx] = 1 - - side_info = torch.cat([side_info, from_square_encoded, to_square_encoded, legal_moves], dim=0) - + + side_info = torch.cat( + [side_info, from_square_encoded, to_square_encoded, legal_moves], dim=0 + ) + return legal_moves, side_info -def extract_clock_time(comment): - - match = re.search(r'\[%clk (\d+):(\d+):(\d+)\]', comment) +def extract_clock_time(comment: str) -> Optional[int]: + """Extract remaining clock time from PGN comment. + + Args: + comment: PGN comment string. + + Returns: + Remaining time in seconds, or None if not found. + """ + match = re.search(r"\[%clk (\d+):(\d+):(\d+)\]", comment) if match: hours, minutes, seconds = map(int, match.groups()) return hours * 3600 + minutes * 60 + seconds return None - -def read_or_create_chunks(pgn_path, cfg): - cache_file = pgn_path.replace('.pgn', '_chunks.pkl') +def read_or_create_chunks(pgn_path: str, cfg: Config) -> List[Chunk]: + """Load or create file offset chunks for PGN. + + Args: + pgn_path: Path to PGN file. + cfg: Configuration with chunk_size. + + Returns: + List of (start_offset, end_offset) tuples. + + Raises: + OSError: If file access fails. + """ + cache_file = pgn_path.replace(".pgn", "_chunks.pkl") if os.path.exists(cache_file): print(f"Loading cached chunks from {cache_file}") - with open(cache_file, 'rb') as f: + with open(cache_file, "rb") as f: pgn_chunks = pickle.load(f) else: print(f"Cache not found. Creating chunks for {pgn_path}") start_time = time.time() pgn_chunks = get_chunks(pgn_path, cfg.chunk_size) - print(f'Chunking took {readable_time(time.time() - start_time)}', flush=True) - - with open(cache_file, 'wb') as f: + print( + f"Chunking took {readable_time(time.time() - start_time)}", flush=True) + + with open(cache_file, "wb") as f: pickle.dump(pgn_chunks, f) - + return pgn_chunks -def board_to_tensor(board): - - piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING] - num_piece_channels = 12 # 6 piece types * 2 colors - additional_channels = 6 # 1 for player's turn, 4 for castling rights, 1 for en passant - tensor = torch.zeros((num_piece_channels + additional_channels, 8, 8), dtype=torch.float32) +def board_to_tensor(board: chess.Board) -> torch.Tensor: + """Convert chess position to feature tensor. + + Args: + board: Chess position. + + Returns: + Tensor [18, 8, 8] with board representation. + """ + num_piece_channels: Final[int] = 12 # 6 piece types * 2 colors + # 1 for player's turn, 4 for castling rights, 1 for en passant + additional_channels: Final[int] = 6 + tensor = torch.zeros( + (num_piece_channels + additional_channels, 8, 8), dtype=torch.float32 + ) # Precompute indices for each piece type - piece_indices = {piece: i for i, piece in enumerate(piece_types)} + piece_indices = {piece: i for i, piece in enumerate(PIECE_TYPES)} # Fill tensor for each piece type - for piece_type in piece_types: - for color in [True, False]: # True is White, False is Black + for piece_type in PIECE_TYPES: + for color in [True, False]: # True=White, False=Black piece_map = board.pieces(piece_type, color) index = piece_indices[piece_type] + (0 if color else 6) for square in piece_map: @@ -220,10 +393,12 @@ def board_to_tensor(board): tensor[turn_channel, :, :] = 1.0 # Castling rights channels - castling_rights = [board.has_kingside_castling_rights(chess.WHITE), - board.has_queenside_castling_rights(chess.WHITE), - board.has_kingside_castling_rights(chess.BLACK), - board.has_queenside_castling_rights(chess.BLACK)] + castling_rights = [ + board.has_kingside_castling_rights(chess.WHITE), + board.has_queenside_castling_rights(chess.WHITE), + board.has_kingside_castling_rights(chess.BLACK), + board.has_queenside_castling_rights(chess.BLACK), + ] for i, has_right in enumerate(castling_rights): if has_right: tensor[num_piece_channels + 1 + i, :, :] = 1.0 @@ -236,47 +411,71 @@ def board_to_tensor(board): return tensor -def generate_pawn_promotions(): + +def generate_pawn_promotions() -> List[ChessMove]: + """Generate all possible pawn promotion moves. + + Returns: + List of UCI promotion moves. + """ # Define the promotion rows for both colors and the promotion pieces # promotion_rows = {'white': '7', 'black': '2'} - promotion_rows = {'white': '7'} - promotion_pieces = ['q', 'r', 'b', 'n'] - promotions = [] + promotion_rows = {"white": "7"} + promotion_pieces = ["q", "r", "b", "n"] + promotions: List[ChessMove] = [] # Iterate over each color for color, row in promotion_rows.items(): # Target rows for promotion (8 for white, 1 for black) - target_row = '8' if color == 'white' else '1' + target_row = "8" if color == "white" else "1" # Each file from 'a' to 'h' - for file in 'abcdefgh': + for file in "abcdefgh": # Direct move to promotion for piece in promotion_pieces: - promotions.append(f'{file}{row}{file}{target_row}{piece}') + promotions.append(f"{file}{row}{file}{target_row}{piece}") # Capturing moves to the left and right (if not on the edges of the board) - if file != 'a': - left_file = chr(ord(file) - 1) # File to the left + if file != "a": + left_file = chr(ord(file) - 1) for piece in promotion_pieces: - promotions.append(f'{file}{row}{left_file}{target_row}{piece}') + promotions.append( + f"{file}{row}{left_file}{target_row}{piece}") - if file != 'h': - right_file = chr(ord(file) + 1) # File to the right + # Capture right + if file != "h": + right_file = chr(ord(file) + 1) for piece in promotion_pieces: - promotions.append(f'{file}{row}{right_file}{target_row}{piece}') + promotions.append( + f"{file}{row}{right_file}{target_row}{piece}") return promotions -def mirror_square(square): - +def mirror_square(square: str) -> str: + """Mirror chess square vertically. + + Args: + square: Square in algebraic notation. + + Returns: + Mirrored square. + """ file = square[0] rank = str(9 - int(square[1])) - + return file + rank -def mirror_move(move_uci): +def mirror_move(move_uci: ChessMove) -> ChessMove: + """Mirror chess move vertically. + + Args: + move_uci: Move in UCI notation. + + Returns: + Mirrored move in UCI notation. + """ # Check if the move is a promotion (length of UCI string will be more than 4) is_promotion = len(move_uci) > 4 @@ -293,10 +492,22 @@ def mirror_move(move_uci): return mirrored_start + mirrored_end + promotion_piece -def get_chunks(pgn_path, chunk_size): +def get_chunks(pgn_path: str, chunk_size: int) -> List[Chunk]: + """Divide PGN file into chunks by game count. + + Args: + pgn_path: Path to PGN file. + chunk_size: Target games per chunk. + + Returns: + List of (start_offset, end_offset) tuples. - chunks = [] - with open(pgn_path, 'r', encoding='utf-8') as pgn_file: + Raises: + ValueError: If PGN format is invalid. + OSError: If file cannot be read. + """ + chunk_list: List[Chunk] = [] + with open(pgn_path, "r", encoding="utf-8") as pgn_file: while True: start_pos = pgn_file.tell() game_count = 0 @@ -314,45 +525,72 @@ def get_chunks(pgn_path, chunk_size): if line not in ["\n", ""]: raise ValueError end_pos = pgn_file.tell() - chunks.append((start_pos, end_pos)) + chunk_list.append((start_pos, end_pos)) if not line: break - return chunks + return chunk_list -def decompress_zst(file_path, decompressed_path): - """ Decompress a .zst file using pyzstd """ - with open(file_path, 'rb') as compressed_file, open(decompressed_path, 'wb') as decompressed_file: +def decompress_zst(file_path: str, decompressed_path: str) -> None: + """Decompress Zstandard (.zst) file. + + Args: + file_path: Path to .zst file. + decompressed_path: Output path for decompressed file. + + Raises: + OSError: If file access fails. + pyzstd.ZstdError: If decompression fails. + """ + with ( + open(file_path, "rb") as compressed_file, + open(decompressed_path, "wb") as decompressed_file, + ): pyzstd.decompress_stream(compressed_file, decompressed_file) -def get_all_possible_moves(): - - all_moves = [] +def get_all_possible_moves() -> List[ChessMove]: + """Generate all possible legal chess moves. + + Returns: + List of all moves in UCI notation. + """ + all_moves: List[chess.Move] = [] for rank in range(8): - for file in range(8): + for file in range(8): square = chess.square(file, rank) - + board = chess.Board(None) board.set_piece_at(square, chess.Piece(chess.QUEEN, chess.WHITE)) legal_moves = list(board.legal_moves) all_moves.extend(legal_moves) - + board = chess.Board(None) board.set_piece_at(square, chess.Piece(chess.KNIGHT, chess.WHITE)) legal_moves = list(board.legal_moves) all_moves.extend(legal_moves) - - all_moves = [all_moves[i].uci() for i in range(len(all_moves))] - + + all_moves_uci = [all_moves[i].uci() for i in range(len(all_moves))] + pawn_promotions = generate_pawn_promotions() - - return all_moves + pawn_promotions + return all_moves_uci + pawn_promotions + + +T = TypeVar("T") + + +def chunks(lst: List[T], n: int) -> Generator[List[T], None, None]: + """Split list into fixed-size chunks. + + Args: + lst: List to divide. + n: Chunk size. -def chunks(lst, n): - """Yield successive n-sized chunks from lst.""" + Yields: + Sublists of size n (or smaller for last chunk). + """ for i in range(0, len(lst), n): - yield lst[i:i + n] + yield lst[i: i + n] diff --git a/pyproject.toml b/pyproject.toml index 0ac599c..e5efdaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "gdown==5.2.0", "numpy==2.1.3", "pandas==2.2.3", - "pyyaml>=6.0.2", + "PyYAML>=6.0.2", "pyzstd==0.15.9", "requests==2.32.3", "torch==2.4.0", @@ -24,3 +24,8 @@ dependencies = [ [project.urls] Home = "https://github.com/CSSLab/maia2" + +[tool.mypy] +files = "maia2" +ignore_missing_imports = true +strict = true \ No newline at end of file