diff --git a/include/sqlpp11/cte.h b/include/sqlpp11/cte.h index 5fa9c177..76aa337e 100644 --- a/include/sqlpp11/cte.h +++ b/include/sqlpp11/cte.h @@ -27,7 +27,9 @@ #ifndef SQLPP_CTE_H #define SQLPP_CTE_H -#include +#include +#include +#include #include #include #include @@ -38,7 +40,7 @@ namespace sqlpp { - template + template struct cte_t; template @@ -49,7 +51,7 @@ namespace sqlpp using _traits = make_traits, tag::must_not_insert, tag::must_not_update, - tag_if::value> + tag_if::value> >; }; @@ -62,14 +64,14 @@ namespace sqlpp template struct make_cte_impl> { - using type = cte_t...>; + using type = cte_t; }; template using make_cte_t = typename make_cte_impl>::type; - template - struct cte_t: public member_t>... + template + struct cte_t: public member_t, column_t>>... { using _traits = make_traits; // FIXME: is table? really? struct _recursive_traits @@ -83,12 +85,55 @@ namespace sqlpp using _parameters = parameters_of; using _tags = detail::type_set<>; }; - - // FIXME: need a union_distinct and union_all here - // unions can depend on the cte itself In that case the cte is recursive. - using _alias_t = typename AliasProvider::_alias_t; + using _column_tuple_t = std::tuple>...>; + + template + using _check = logic::all_t::value...>; + + template + auto union_distinct(Rhs rhs) const + -> typename std::conditional<_check::value, cte_t, FieldSpecs...>, bad_statement>::type + { + static_assert(is_statement_t::value, "argument of union call has to be a statement"); + static_assert(has_policy_t::value, "argument of union call has to be a select"); + static_assert(has_result_row_t::value, "argument of a union has to be a (complete) select statement"); + + using _result_row_t = result_row_t; + static_assert(std::is_same<_result_row_t, get_result_row_t>::value, "both select statements in a union have to have the same result columns (type and name)"); + + return _union_impl(_check{}, rhs); + } + + template + auto union_all(Rhs rhs) const + -> typename std::conditional<_check::value, cte_t, FieldSpecs...>, bad_statement>::type + { + static_assert(is_statement_t::value, "argument of union call has to be a statement"); + static_assert(has_policy_t::value, "argument of union call has to be a select"); + static_assert(has_result_row_t::value, "argument of a union has to be a (complete) select statement"); + + using _result_row_t = result_row_t; + static_assert(std::is_same<_result_row_t, get_result_row_t>::value, "both select statements in a union have to have the same result columns (type and name)"); + + return _union_impl(_check{}, rhs); + } + + private: + template + auto _union_impl(const std::false_type&, Rhs rhs) const + -> bad_statement; + + template + auto _union_impl(const std::true_type&, Rhs rhs) const + -> cte_t, FieldSpecs...> + { + return union_data_t{_statement, rhs}; + } + + public: + cte_t(Statement statement): _statement(statement){} cte_t(const cte_t&) = default; cte_t(cte_t&&) = default; @@ -143,6 +188,7 @@ namespace sqlpp { static_assert(required_tables_of::size::value == 0, "common table expression must not use unknown tables"); static_assert(not detail::is_element_of>::value, "common table expression must not self-reference in the first part, use union_all/union_distinct for recursion"); + static_assert(is_static_result_row_t>::value, "ctes must not have dynamically added columns"); return { statement }; } diff --git a/include/sqlpp11/union.h b/include/sqlpp11/union.h index b45d1930..066e73e3 100644 --- a/include/sqlpp11/union.h +++ b/include/sqlpp11/union.h @@ -28,6 +28,7 @@ #define SQLPP_UNION_H #include +#include #include #include #include @@ -159,9 +160,9 @@ namespace sqlpp static_assert(has_policy_t::value, "argument of union call has to be a select"); static_assert(has_result_row_t::value, "argument of a union has to be a complete select statement"); static_assert(has_result_row_t>::value, "left hand side argument of a union has to be a complete select statement or union"); - static_assert(std::is_same>, get_result_row_t>::value, "both arguments in a union have to have the same result columns (type and name)"); using _result_row_t = get_result_row_t; + static_assert(std::is_same>, _result_row_t>::value, "both arguments in a union have to have the same result columns (type and name)"); static_assert(is_static_result_row_t<_result_row_t>::value, "unions must not have dynamically added columns"); return _union_impl(_check, Rhs>{}, rhs); @@ -175,9 +176,9 @@ namespace sqlpp static_assert(has_policy_t::value, "argument of union call has to be a select"); static_assert(has_result_row_t::value, "argument of a union has to be a (complete) select statement"); static_assert(has_result_row_t>::value, "left hand side argument of a union has to be a (complete) select statement"); - static_assert(std::is_same>, get_result_row_t>::value, "both select statements in a union have to have the same result columns (type and name)"); using _result_row_t = get_result_row_t; + static_assert(std::is_same>, _result_row_t>::value, "both arguments in a union have to have the same result columns (type and name)"); static_assert(is_static_result_row_t<_result_row_t>::value, "unions must not have dynamically added columns"); return _union_impl(_check, Rhs>{}, rhs); diff --git a/tests/WithTest.cpp b/tests/WithTest.cpp index 32b27ed4..f6bd4006 100644 --- a/tests/WithTest.cpp +++ b/tests/WithTest.cpp @@ -40,5 +40,8 @@ int main() db(with(x)(select(x.alpha).from(x).where(true))); + auto y0 = cte(sqlpp::y).as(select(all_of(t)).from(t)); + auto y = y0.union_all(select(all_of(y0)).from(y0).where(false)); + return 0; }