diff --git a/test/parser/expression.cpp b/test/parser/expression.cpp index ce7fe0d..d7db95f 100644 --- a/test/parser/expression.cpp +++ b/test/parser/expression.cpp @@ -1,5 +1,6 @@ #include "ast/expressions/identifier.hpp" #include "ast/expressions/integer.hpp" +#include "ast/expressions/prefix.hpp" #include "utils.hpp" #include @@ -11,24 +12,11 @@ TEST_SUITE("Parser: expression") { ) { setup("foobar;"); REQUIRE(program->statements.size() == 1); - ast::expression_stmt* expression_stmt; - REQUIRE_NOTHROW( - expression_stmt = - dynamic_cast(program->statements[0]) - ); - REQUIRE_MESSAGE( - expression_stmt != nullptr, - "Couldn't cast statement to an expression statement" - ); + ast::expression_stmt* expression_stmt = + cast(program->statements[0]); - ast::identifier* ident; - REQUIRE_NOTHROW( - ident = dynamic_cast(expression_stmt->expression) - ); - REQUIRE_MESSAGE( - ident != nullptr, - "Couldn't cast expression to an identifier" - ); + ast::identifier* ident = + cast(expression_stmt->expression); REQUIRE(ident->value == "foobar"); REQUIRE(ident->token_literal() == "foobar"); @@ -40,27 +28,61 @@ TEST_SUITE("Parser: expression") { ) { setup("5;"); REQUIRE(program->statements.size() == 1); - ast::expression_stmt* expression_stmt; - REQUIRE_NOTHROW( - expression_stmt = - dynamic_cast(program->statements[0]) - ); - REQUIRE_MESSAGE( - expression_stmt != nullptr, - "Couldn't cast statement to an expression statement" - ); - ast::integer_literal* int_lit; - REQUIRE_NOTHROW( - int_lit = - dynamic_cast(expression_stmt->expression) - ); - REQUIRE_MESSAGE( - int_lit != nullptr, - "Couldn't cast expression to an identifier" - ); + ast::expression_stmt* expression_stmt = + cast(program->statements[0]); + + ast::integer_literal* int_lit = + cast(expression_stmt->expression); REQUIRE(int_lit->value == 5); REQUIRE(int_lit->token_literal() == "5"); }; + + // TEST_CASE_FIXTURE( + // ParserFixture, + // "Simple expression statement with prefix before integer" + // ) { + // SUBCASE("Prefix: '!'") { + // setup("!5;"); + // + // REQUIRE(program->statements.size() == 1); + // ast::expression_stmt* expression_stmt = + // cast(program->statements[0]); + // + // ast::prefix_expr* prefix_expr; + // REQUIRE_NOTHROW( + // prefix_expr = + // dynamic_cast(expression_stmt->expression) + // ); + // REQUIRE_MESSAGE( + // prefix_expr != nullptr, + // "Couldn't cast expression to an identifier" + // ); + // + // REQUIRE(prefix_expr->value == 5); + // REQUIRE(prefix_expr->token_literal() == "5"); + // } + // SUBCASE("Prefix: '-'") { + // setup("-15;"); + // + // REQUIRE(program->statements.size() == 1); + // ast::expression_stmt* expression_stmt = + // get_expression_stmt(program->statements[0]); + // + // ast::integer_literal* int_lit; + // REQUIRE_NOTHROW( + // int_lit = dynamic_cast( + // expression_stmt->expression + // ) + // ); + // REQUIRE_MESSAGE( + // int_lit != nullptr, + // "Couldn't cast expression to an identifier" + // ); + // + // REQUIRE(int_lit->value == 5); + // REQUIRE(int_lit->token_literal() == "5"); + // } + // } } diff --git a/test/parser/let.cpp b/test/parser/let.cpp index c3f94b4..ca4c0c3 100644 --- a/test/parser/let.cpp +++ b/test/parser/let.cpp @@ -1,6 +1,7 @@ #include "ast/statements/let.hpp" #include "ast/ast.hpp" +#include "ast/errors/error.hpp" #include "lexer/lexer.hpp" #include "parser/parser.hpp" #include "utils.hpp" @@ -10,12 +11,7 @@ void test_let_statement(ast::statement* stmt, const std::string name) { REQUIRE(stmt->token_literal() == "let"); - ast::let_stmt* let_stmt; - REQUIRE_NOTHROW(let_stmt = dynamic_cast(stmt)); - REQUIRE_MESSAGE( - let_stmt != nullptr, - "Couldn't cast statement to a let statement" - ); + ast::let_stmt* let_stmt = cast(stmt); REQUIRE(let_stmt->name->value == name); REQUIRE(let_stmt->name->token_literal() == name); @@ -38,13 +34,8 @@ void test_failing_let_parsing( int i = 0; for (auto& e : p.errors) { - ast::error::expected_next* en; + ast::error::expected_next* en = cast(e); - REQUIRE_NOTHROW(en = dynamic_cast(e)); - REQUIRE_MESSAGE( - en != nullptr, - "Couldn't cast the error to an 'expected_next'" - ); REQUIRE(en->expected_type == expected_types[i++]); } diff --git a/test/parser/return.cpp b/test/parser/return.cpp index e904a17..3875a54 100644 --- a/test/parser/return.cpp +++ b/test/parser/return.cpp @@ -14,12 +14,7 @@ return 103213;\ for (const auto stmt : program->statements) { REQUIRE(stmt->token_literal() == "return"); - ast::return_stmt* let_stmt; - REQUIRE_NOTHROW(let_stmt = dynamic_cast(stmt)); - REQUIRE_MESSAGE( - let_stmt != nullptr, - "Couldn't cast statement to a return statement" - ); + ast::return_stmt* return_stmt = cast(stmt); } } } diff --git a/test/parser/utils.hpp b/test/parser/utils.hpp index bbabb20..540b4f8 100644 --- a/test/parser/utils.hpp +++ b/test/parser/utils.hpp @@ -1,9 +1,45 @@ +#include "ast/ast.hpp" #include "lexer/lexer.hpp" #include "parser/parser.hpp" +#include #include -void check_parser_errors(const std::vector& errors); +void check_parser_errors(const std::vector&); + +namespace { + template + T* cast_impl(Base* base) { + static_assert( + std::is_base_of_v, + "T must be derived from Base" + ); + + T* t; + REQUIRE_NOTHROW(t = dynamic_cast(base)); + REQUIRE_MESSAGE( + t != nullptr, + "Couldn't cast expression to a " * std::string(typeid(T).name()) + ); + return t; + } +} // namespace + +// Overloads for your known base types +template +T* cast(ast::expression* expr) { + return cast_impl(expr); +} + +template +T* cast(ast::statement* stmt) { + return cast_impl(stmt); +} + +template +T* cast(ast::error::error* err) { + return cast_impl(err); +} struct ParserFixture { std::stringstream input;