0
0
mirror of https://github.com/rbock/sqlpp11.git synced 2024-11-15 20:31:16 +08:00

Address conversion warnings and fix UAF bug in test

This commit is contained in:
Roland Bock 2022-01-14 07:39:11 +01:00
parent ccc75eafc7
commit b8aed2af55
6 changed files with 80 additions and 61 deletions

View File

@ -164,7 +164,7 @@ namespace sqlpp
} }
*is_null = _handle->result.isNull(_handle->count, index); *is_null = _handle->result.isNull(_handle->count, index);
*value = _handle->result.getValue<bool>(_handle->count, index); *value = _handle->result.getBoolValue(_handle->count, index);
} }
inline void bind_result_t::_bind_floating_point_result(size_t _index, double* value, bool* is_null) inline void bind_result_t::_bind_floating_point_result(size_t _index, double* value, bool* is_null)
@ -176,7 +176,7 @@ namespace sqlpp
} }
*is_null = _handle->result.isNull(_handle->count, index); *is_null = _handle->result.isNull(_handle->count, index);
*value = _handle->result.getValue<double>(_handle->count, index); *value = _handle->result.getDoubleValue(_handle->count, index);
} }
inline void bind_result_t::_bind_integral_result(size_t _index, int64_t* value, bool* is_null) inline void bind_result_t::_bind_integral_result(size_t _index, int64_t* value, bool* is_null)
@ -188,7 +188,7 @@ namespace sqlpp
} }
*is_null = _handle->result.isNull(_handle->count, index); *is_null = _handle->result.isNull(_handle->count, index);
*value = _handle->result.getValue<unsigned long long>(_handle->count, index); *value = _handle->result.getInt64Value(_handle->count, index);
} }
inline void bind_result_t::_bind_unsigned_integral_result(size_t _index, uint64_t* value, bool* is_null) inline void bind_result_t::_bind_unsigned_integral_result(size_t _index, uint64_t* value, bool* is_null)
@ -200,7 +200,7 @@ namespace sqlpp
} }
*is_null = _handle->result.isNull(_handle->count, index); *is_null = _handle->result.isNull(_handle->count, index);
*value = _handle->result.getValue<unsigned long long>(_handle->count, index); *value = _handle->result.getUInt64Value(_handle->count, index);
} }
inline void bind_result_t::_bind_text_result(size_t _index, const char** value, size_t* len) inline void bind_result_t::_bind_text_result(size_t _index, const char** value, size_t* len)
@ -218,8 +218,8 @@ namespace sqlpp
} }
else else
{ {
*value = _handle->result.getValue<const char*>(_handle->count, index); *value = _handle->result.getCharPtrValue(_handle->count, index);
*len = _handle->result.length(_handle->count, index); *len = static_cast<size_t>(_handle->result.length(_handle->count, index));
} }
} }
@ -311,7 +311,7 @@ namespace sqlpp
if (!(*is_null)) if (!(*is_null))
{ {
const auto date_string = _handle->result.getValue<const char*>(_handle->count, index); const auto date_string = _handle->result.getCharPtrValue(_handle->count, index);
if (_handle->debug()) if (_handle->debug())
{ {
@ -350,7 +350,7 @@ namespace sqlpp
if (!(*is_null)) if (!(*is_null))
{ {
const auto date_string = _handle->result.getValue(_handle->count, index); const auto date_string = _handle->result.getCharPtrValue(_handle->count, index);
if (_handle->debug()) if (_handle->debug())
{ {
@ -386,8 +386,8 @@ namespace sqlpp
if (std::strlen(time_string) <= 9) if (std::strlen(time_string) <= 9)
return; return;
auto us_string = time_string + 9; // hh:mm:ss. auto us_string = time_string + 9; // hh:mm:ss.
unsigned usec = 0; int usec = 0;
for (int i = 0; i < 6; ++i) for (size_t i = 0u; i < 6u; ++i)
{ {
if (std::isdigit(us_string[0])) if (std::isdigit(us_string[0]))
{ {
@ -417,8 +417,8 @@ namespace sqlpp
} }
else else
{ {
*value = _handle->result.getValue<const uint8_t*>(_handle->count, index); *value = _handle->result.getBlobValue(_handle->count, index);
*len = _handle->result.length(_handle->count, index); *len = static_cast<size_t>(_handle->result.length(_handle->count, index));
} }
} }

View File

@ -532,7 +532,7 @@ namespace sqlpp
throw sqlpp::exception("PostgreSQL error: could not read default_transaction_isolation"); throw sqlpp::exception("PostgreSQL error: could not read default_transaction_isolation");
} }
auto in = res->result.getValue<std::string>(0, 0); auto in = res->result.getStringValue(0, 0);
if (in == "read committed") if (in == "read committed")
{ {
return isolation_level::read_committed; return isolation_level::read_committed;

View File

@ -57,10 +57,10 @@ namespace sqlpp
{ {
detail::connection_handle& connection; detail::connection_handle& connection;
Result result; Result result;
bool valid{false}; bool valid = false;
uint32_t count{0}; int count = 0;
uint32_t totalCount = {0}; int totalCount = 0;
uint32_t fields = {0}; int fields = 0;
// ctor // ctor
statement_handle_t(detail::connection_handle& _connection); statement_handle_t(detail::connection_handle& _connection);

View File

@ -68,21 +68,71 @@ namespace sqlpp
void operator=(PGresult* res); void operator=(PGresult* res);
operator bool() const; operator bool() const;
template <typename T = const char*> inline int64_t getInt64Value(int record, int field) const
inline T getValue(int record, int field) const
{ {
static_assert(std::is_arithmetic<T>::value, "Value must be numeric type");
checkIndex(record, field); checkIndex(record, field);
T t(0); auto t = int64_t{};
auto txt = std::string(getPqValue(m_result, record, field)); const auto txt = std::string(getPqValue(m_result, record, field));
if(txt != "") if(txt != "")
{ {
t = std::stold(txt); t = std::stoll(txt);
} }
return t; return t;
} }
inline uint64_t getUInt64Value(int record, int field) const
{
checkIndex(record, field);
auto t = uint64_t{};
const auto txt = std::string(getPqValue(m_result, record, field));
if(txt != "")
{
t = std::stoull(txt);
}
return t;
}
inline double getDoubleValue(int record, int field) const
{
checkIndex(record, field);
auto t = double{};
auto txt = std::string(getPqValue(m_result, record, field));
if(txt != "")
{
t = std::stod(txt);
}
return t;
}
inline const char* getCharPtrValue(int record, int field) const
{
return const_cast<const char*>(getPqValue(m_result, record, field));
}
inline std::string getStringValue(int record, int field) const
{
return {getCharPtrValue(record, field)};
}
inline const uint8_t* getBlobValue(int record, int field) const
{
return reinterpret_cast<const uint8_t*>(getPqValue(m_result, record, field));
}
inline bool getBoolValue(int record, int field) const
{
checkIndex(record, field);
auto val = getPqValue(m_result, record, field);
if (*val == 't')
return true;
else if (*val == 'f')
return false;
return const_cast<const char*>(val);
}
const std::string& query() const const std::string& query() const
{ {
return m_query; return m_query;
@ -109,36 +159,6 @@ namespace sqlpp
std::string m_query; std::string m_query;
}; };
template <>
inline const char* Result::getValue<const char*>(int record, int field) const
{
return const_cast<const char*>(getPqValue(m_result, record, field));
}
template <>
inline std::string Result::getValue<std::string>(int record, int field) const
{
return {getValue<const char*>(record, field)};
}
template <>
inline bool Result::getValue<bool>(int record, int field) const
{
checkIndex(record, field);
auto val = getPqValue(m_result, record, field);
if (*val == 't')
return true;
else if (*val == 'f')
return false;
return const_cast<const char*>(val);
}
template <>
inline const uint8_t* Result::getValue<const uint8_t*>(int record, int field) const
{
return reinterpret_cast<const uint8_t*>(getValue<const char*>(record, field));
}
inline Result::Result() : m_result(nullptr) inline Result::Result() : m_result(nullptr)
{ {

View File

@ -75,7 +75,7 @@ namespace sqlpp
case 'F': case 'F':
return c + 10 - 'A'; return c + 10 - 'A';
} }
throw sqlpp::exception(std::string("Unexpected hex char: ") += c); throw sqlpp::exception(std::string("Unexpected hex char: ") + static_cast<char>(c));
} }
inline void hex_assign(std::vector<unsigned char>& value, const uint8_t* blob, size_t len) inline void hex_assign(std::vector<unsigned char>& value, const uint8_t* blob, size_t len)
@ -85,7 +85,7 @@ namespace sqlpp
size_t blob_index = 2; size_t blob_index = 2;
while (blob_index < len) while (blob_index < len)
{ {
value[val_index] = (unhex(blob[blob_index]) << 4) + unhex(blob[blob_index + 1]); value[val_index] = static_cast<unsigned char>(unhex(blob[blob_index]) << 4) + unhex(blob[blob_index + 1]);
++val_index; ++val_index;
blob_index += 2; blob_index += 2;
} }

View File

@ -33,8 +33,6 @@
#include "TabFoo.h" #include "TabFoo.h"
#include "make_test_connection.h" #include "make_test_connection.h"
SQLPP_ALIAS_PROVIDER(left);
namespace sql = sqlpp::postgresql; namespace sql = sqlpp::postgresql;
model::TabFoo tab = {}; model::TabFoo tab = {};
@ -126,13 +124,14 @@ int Select(int, char*[])
// remove // remove
db(remove_from(tab).where(tab.alpha == tab.alpha + 3)); db(remove_from(tab).where(tab.alpha == tab.alpha + 3));
auto result = db(select(all_of(tab)).from(tab).unconditionally()); auto result1 = db(select(all_of(tab)).from(tab).unconditionally());
std::cerr << "Accessing a field directly from the result (using the current row): " << result.begin()->alpha std::cerr << "Accessing a field directly from the result (using the current row): " << result1.begin()->alpha
<< std::endl; << std::endl;
std::cerr << "Can do that again, no problem: " << result.begin()->alpha << std::endl; std::cerr << "Can do that again, no problem: " << result1.begin()->alpha << std::endl;
auto tx = start_transaction(db); auto tx = start_transaction(db);
if (const auto& row = *db(select(all_of(tab), select(max(tab.alpha)).from(tab)).from(tab).unconditionally()).begin()) auto result2 = db(select(all_of(tab), select(max(tab.alpha)).from(tab)).from(tab).unconditionally());
if (const auto& row = *result2.begin())
{ {
auto a = row.alpha; auto a = row.alpha;
auto m = row.max; auto m = row.max;