diff --git a/cpp/src/main.cpp b/cpp/src/main.cpp index 9364384..a7c1777 100644 --- a/cpp/src/main.cpp +++ b/cpp/src/main.cpp @@ -11,15 +11,16 @@ #include int main(int argc, char* argv[]) { - // std::string pos = - // "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"; + std::string pos = + "r2qkb1r/2p1pppp/p1n1b3/1p6/B2P4/2P1P3/P4PPP/R1BQK1NR w KQkq - 0 9 "; // std::string pos = // "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 4 3 // "; // pos for ai timing< - std::string pos = - "r3k2r/p1ppqpb1/Bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPB1PPP/R3K2R b KQkq - 0 3 "; + // std::string pos = + // "r3k2r/p1ppqpb1/Bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPB1PPP/R3K2R b KQkq - + // 0 3 "; Board b = Board::setup_fen_position(pos); @@ -28,9 +29,9 @@ int main(int argc, char* argv[]) { // ai::v2_alpha_beta p2(false, std::chrono::milliseconds(20000)); ai::v3_AB_ordering p2(false, std::chrono::milliseconds(20000)); - NoOpView gui; - AIvsAIController manual(b, gui, p1, p2); - // HumanVsAIController manual(b, gui, p2); + GUI gui; + // AIvsAIController manual(b, gui, p1, p2); + HumanVsAIController manual(b, gui, p2); Controller& controller = manual; diff --git a/cpp/src/model/ais/ai.hpp b/cpp/src/model/ais/ai.hpp index 4e2af5b..40ae39a 100644 --- a/cpp/src/model/ais/ai.hpp +++ b/cpp/src/model/ais/ai.hpp @@ -57,7 +57,7 @@ namespace ai { class v3_AB_ordering : public AI { // looks two moves ahead, with alpha-beta pruning, with move ordering - int _search(const Board&, int, int, int); + virtual int _search(const Board&, int, int, int); public: v3_AB_ordering(bool w, std::chrono::milliseconds tt): AI(w, tt) {} @@ -65,4 +65,15 @@ namespace ai { Move _search(const Board&) override; int eval(const Board&) override; }; + + class v4_search_captures : public v3_AB_ordering { + // same as v3, but looking at only at captures when leaf is reached, + // until no captures are left + int _search(const Board&, int, int, int) override; + int _search_captures(const Board&, int, int); + + public: + v4_search_captures(bool w, std::chrono::milliseconds tt) + : v3_AB_ordering(w, tt) {} + }; } // namespace ai diff --git a/cpp/src/model/ais/v4_search_captures.cpp b/cpp/src/model/ais/v4_search_captures.cpp new file mode 100644 index 0000000..d36dc9f --- /dev/null +++ b/cpp/src/model/ais/v4_search_captures.cpp @@ -0,0 +1,61 @@ +#include "../pieces/piece.hpp" +#include "../utils/utils.hpp" +#include "ai.hpp" + +#include + +#define MULTITHREADED 1 + + +static int position_counter; + +int ai::v4_search_captures::_search( + const Board& b, int depth, int alpha, int beta +) { + if (depth == 0 || stop_computation) + return _search_captures(b, alpha, beta); + + if (b.no_legal_moves()) { + if (b.is_check()) + return -INFINITY; + return 0; + } + + std::vector moves = b.all_legal_moves(); + std::sort(moves.begin(), moves.end(), [&](Move& m1, Move& m2) { + return m1.score_guess(b) > m2.score_guess(b); + }); + + Move best_move; + for (const Move& move : moves) { + Board tmp_board = b.make_move(move); + int tmp_eval = -_search(tmp_board, depth - 1, -beta, -alpha); + if (tmp_eval >= beta) + return beta; + alpha = std::max(alpha, tmp_eval); + } + return alpha; +} + +int ai::v4_search_captures::_search_captures( + const Board& b, int alpha, int beta +) { + int evaluation = eval(b); + if (evaluation >= beta) + return beta; + alpha = std::max(evaluation, alpha); + + std::vector moves = b.all_capturing_moves(); + std::sort(moves.begin(), moves.end(), [&](Move& m1, Move& m2) { + return m1.score_guess(b) > m2.score_guess(b); + }); + + for (const Move& move : moves) { + Board tmp_board = b.make_move(move); + int tmp_eval = -_search_captures(tmp_board, -beta, -alpha); + if (tmp_eval >= beta) + return beta; + alpha = std::max(alpha, tmp_eval); + } + return alpha; +} diff --git a/cpp/src/model/board/board.cpp b/cpp/src/model/board/board.cpp index e15bd30..a8733b7 100644 --- a/cpp/src/model/board/board.cpp +++ b/cpp/src/model/board/board.cpp @@ -410,3 +410,14 @@ std::vector Board::all_legal_moves() const { } return ret; } + +std::vector Board::all_capturing_moves() const { + std::vector moves = all_legal_moves(); + std::vector ret; + ret.reserve(moves.size()); + for (const Move& move : moves) + if (piece_at(move.target_square) != Piece::None) + ret.push_back(move); + + return ret; +} diff --git a/cpp/src/model/board/board.hpp b/cpp/src/model/board/board.hpp index 274d4c4..f4ffa59 100644 --- a/cpp/src/model/board/board.hpp +++ b/cpp/src/model/board/board.hpp @@ -38,6 +38,7 @@ struct Board { }; std::vector all_legal_moves() const; + std::vector all_capturing_moves() const; bool is_checkmate() const;