diff --git a/cpp/src/stickfosh.cpp b/cpp/src/stickfosh.cpp index 713cdf4..d126a7f 100644 --- a/cpp/src/stickfosh.cpp +++ b/cpp/src/stickfosh.cpp @@ -2,6 +2,7 @@ #include "board.hpp" #include "move.hpp" +#include "threadpool.hpp" #include #include @@ -102,7 +103,9 @@ static std::map> pos2expected{ static std::stringstream res; -int move_generation_test(Board& b, int depth, int max_depth) { +int move_generation_test( + const Board& b, int depth, int max_depth, ThreadPool& pool +) { if (depth == max_depth) { res.str(""); res.clear(); @@ -118,15 +121,34 @@ int move_generation_test(Board& b, int depth, int max_depth) { return moves.size(); int num_pos = 0; - for (const Move& move : moves) { - Board tmp_board = b.make_move(move); - // std::cout << ">" << move << std::endl; - int n = move_generation_test(tmp_board, depth - 1, max_depth); - // std::cout << "<" << move << std::endl; - if (depth == max_depth) - res << move << ": " << n << std::endl; - num_pos += n; + + if (depth == max_depth) { + // Parallel execution at the top level + std::vector> futures; + for (const Move& move : moves) { + Board tmp_board = b.make_move(move); + futures.push_back(pool.enqueue( + move_generation_test, + tmp_board, + depth - 1, + max_depth, + std::ref(pool) + )); + } + + for (auto& future : futures) + num_pos += future.get(); // Retrieve the result of each task + } else { + // Regular sequential execution + for (const Move& move : moves) { + Board tmp_board = b.make_move(move); + int n = move_generation_test(tmp_board, depth - 1, max_depth, pool); + if (depth == max_depth) + res << move << ": " << n << std::endl; + num_pos += n; + } } + return num_pos; } @@ -139,11 +161,12 @@ void perft(std::string pos) { std::cout << pos << std::endl; std::map expected = pos2expected[pos]; Board b = Board::setup_fen_position(pos); + ThreadPool pool(std::thread::hardware_concurrency()); for (const auto& [depth, expected_n_moves] : expected) { std::cout << "Depth: " << depth << " " << std::flush; auto start = std::chrono::steady_clock::now(); - int moves = move_generation_test(b, depth, depth); + int moves = move_generation_test(b, depth, depth, pool); auto end = std::chrono::steady_clock::now(); auto elapsed = std::chrono::duration_cast(end - start) diff --git a/cpp/src/threadpool.hpp b/cpp/src/threadpool.hpp new file mode 100644 index 0000000..3985602 --- /dev/null +++ b/cpp/src/threadpool.hpp @@ -0,0 +1,73 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + ThreadPool(size_t numThreads) { + for (size_t i = 0; i < numThreads; ++i) { + workers.emplace_back([this] { + while (true) { + std::function task; + { + std::unique_lock lock(queueMutex); + condition.wait(lock, [this] { + return stop || !tasks.empty(); + }); + if (stop && tasks.empty()) + return; + task = std::move(tasks.front()); + tasks.pop(); + } + task(); + } + }); + } + } + + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::invoke_result::type; + + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...) + ); + + std::future res = task->get_future(); + { + std::unique_lock lock(queueMutex); + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; + } + + void waitAll() { + std::unique_lock lock(queueMutex); + condition.wait(lock, [this] { return tasks.empty(); }); + } + + ~ThreadPool() { + { + std::unique_lock lock(queueMutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) + worker.join(); + } + + private: + std::vector workers; + std::queue> tasks; + std::mutex queueMutex; + std::condition_variable condition; + bool stop = false; +};