0
0
mirror of https://github.com/rbock/sqlpp11.git synced 2024-11-15 20:31:16 +08:00

Address conversion warnings and fix UAF bug in test

This commit is contained in:
Roland Bock 2022-01-14 07:39:11 +01:00
parent ccc75eafc7
commit b8aed2af55
6 changed files with 80 additions and 61 deletions

View File

@ -164,7 +164,7 @@ namespace sqlpp
}
*is_null = _handle->result.isNull(_handle->count, index);
*value = _handle->result.getValue<bool>(_handle->count, index);
*value = _handle->result.getBoolValue(_handle->count, index);
}
inline void bind_result_t::_bind_floating_point_result(size_t _index, double* value, bool* is_null)
@ -176,7 +176,7 @@ namespace sqlpp
}
*is_null = _handle->result.isNull(_handle->count, index);
*value = _handle->result.getValue<double>(_handle->count, index);
*value = _handle->result.getDoubleValue(_handle->count, index);
}
inline void bind_result_t::_bind_integral_result(size_t _index, int64_t* value, bool* is_null)
@ -188,7 +188,7 @@ namespace sqlpp
}
*is_null = _handle->result.isNull(_handle->count, index);
*value = _handle->result.getValue<unsigned long long>(_handle->count, index);
*value = _handle->result.getInt64Value(_handle->count, index);
}
inline void bind_result_t::_bind_unsigned_integral_result(size_t _index, uint64_t* value, bool* is_null)
@ -200,7 +200,7 @@ namespace sqlpp
}
*is_null = _handle->result.isNull(_handle->count, index);
*value = _handle->result.getValue<unsigned long long>(_handle->count, index);
*value = _handle->result.getUInt64Value(_handle->count, index);
}
inline void bind_result_t::_bind_text_result(size_t _index, const char** value, size_t* len)
@ -218,8 +218,8 @@ namespace sqlpp
}
else
{
*value = _handle->result.getValue<const char*>(_handle->count, index);
*len = _handle->result.length(_handle->count, index);
*value = _handle->result.getCharPtrValue(_handle->count, index);
*len = static_cast<size_t>(_handle->result.length(_handle->count, index));
}
}
@ -311,7 +311,7 @@ namespace sqlpp
if (!(*is_null))
{
const auto date_string = _handle->result.getValue<const char*>(_handle->count, index);
const auto date_string = _handle->result.getCharPtrValue(_handle->count, index);
if (_handle->debug())
{
@ -350,7 +350,7 @@ namespace sqlpp
if (!(*is_null))
{
const auto date_string = _handle->result.getValue(_handle->count, index);
const auto date_string = _handle->result.getCharPtrValue(_handle->count, index);
if (_handle->debug())
{
@ -386,8 +386,8 @@ namespace sqlpp
if (std::strlen(time_string) <= 9)
return;
auto us_string = time_string + 9; // hh:mm:ss.
unsigned usec = 0;
for (int i = 0; i < 6; ++i)
int usec = 0;
for (size_t i = 0u; i < 6u; ++i)
{
if (std::isdigit(us_string[0]))
{
@ -417,8 +417,8 @@ namespace sqlpp
}
else
{
*value = _handle->result.getValue<const uint8_t*>(_handle->count, index);
*len = _handle->result.length(_handle->count, index);
*value = _handle->result.getBlobValue(_handle->count, index);
*len = static_cast<size_t>(_handle->result.length(_handle->count, index));
}
}

View File

@ -532,7 +532,7 @@ namespace sqlpp
throw sqlpp::exception("PostgreSQL error: could not read default_transaction_isolation");
}
auto in = res->result.getValue<std::string>(0, 0);
auto in = res->result.getStringValue(0, 0);
if (in == "read committed")
{
return isolation_level::read_committed;

View File

@ -57,10 +57,10 @@ namespace sqlpp
{
detail::connection_handle& connection;
Result result;
bool valid{false};
uint32_t count{0};
uint32_t totalCount = {0};
uint32_t fields = {0};
bool valid = false;
int count = 0;
int totalCount = 0;
int fields = 0;
// ctor
statement_handle_t(detail::connection_handle& _connection);

View File

@ -68,21 +68,71 @@ namespace sqlpp
void operator=(PGresult* res);
operator bool() const;
template <typename T = const char*>
inline T getValue(int record, int field) const
inline int64_t getInt64Value(int record, int field) const
{
static_assert(std::is_arithmetic<T>::value, "Value must be numeric type");
checkIndex(record, field);
T t(0);
auto txt = std::string(getPqValue(m_result, record, field));
auto t = int64_t{};
const auto txt = std::string(getPqValue(m_result, record, field));
if(txt != "")
{
t = std::stold(txt);
t = std::stoll(txt);
}
return t;
}
inline uint64_t getUInt64Value(int record, int field) const
{
checkIndex(record, field);
auto t = uint64_t{};
const auto txt = std::string(getPqValue(m_result, record, field));
if(txt != "")
{
t = std::stoull(txt);
}
return t;
}
inline double getDoubleValue(int record, int field) const
{
checkIndex(record, field);
auto t = double{};
auto txt = std::string(getPqValue(m_result, record, field));
if(txt != "")
{
t = std::stod(txt);
}
return t;
}
inline const char* getCharPtrValue(int record, int field) const
{
return const_cast<const char*>(getPqValue(m_result, record, field));
}
inline std::string getStringValue(int record, int field) const
{
return {getCharPtrValue(record, field)};
}
inline const uint8_t* getBlobValue(int record, int field) const
{
return reinterpret_cast<const uint8_t*>(getPqValue(m_result, record, field));
}
inline bool getBoolValue(int record, int field) const
{
checkIndex(record, field);
auto val = getPqValue(m_result, record, field);
if (*val == 't')
return true;
else if (*val == 'f')
return false;
return const_cast<const char*>(val);
}
const std::string& query() const
{
return m_query;
@ -109,36 +159,6 @@ namespace sqlpp
std::string m_query;
};
template <>
inline const char* Result::getValue<const char*>(int record, int field) const
{
return const_cast<const char*>(getPqValue(m_result, record, field));
}
template <>
inline std::string Result::getValue<std::string>(int record, int field) const
{
return {getValue<const char*>(record, field)};
}
template <>
inline bool Result::getValue<bool>(int record, int field) const
{
checkIndex(record, field);
auto val = getPqValue(m_result, record, field);
if (*val == 't')
return true;
else if (*val == 'f')
return false;
return const_cast<const char*>(val);
}
template <>
inline const uint8_t* Result::getValue<const uint8_t*>(int record, int field) const
{
return reinterpret_cast<const uint8_t*>(getValue<const char*>(record, field));
}
inline Result::Result() : m_result(nullptr)
{

View File

@ -75,7 +75,7 @@ namespace sqlpp
case 'F':
return c + 10 - 'A';
}
throw sqlpp::exception(std::string("Unexpected hex char: ") += c);
throw sqlpp::exception(std::string("Unexpected hex char: ") + static_cast<char>(c));
}
inline void hex_assign(std::vector<unsigned char>& value, const uint8_t* blob, size_t len)
@ -85,7 +85,7 @@ namespace sqlpp
size_t blob_index = 2;
while (blob_index < len)
{
value[val_index] = (unhex(blob[blob_index]) << 4) + unhex(blob[blob_index + 1]);
value[val_index] = static_cast<unsigned char>(unhex(blob[blob_index]) << 4) + unhex(blob[blob_index + 1]);
++val_index;
blob_index += 2;
}

View File

@ -33,8 +33,6 @@
#include "TabFoo.h"
#include "make_test_connection.h"
SQLPP_ALIAS_PROVIDER(left);
namespace sql = sqlpp::postgresql;
model::TabFoo tab = {};
@ -126,13 +124,14 @@ int Select(int, char*[])
// remove
db(remove_from(tab).where(tab.alpha == tab.alpha + 3));
auto result = db(select(all_of(tab)).from(tab).unconditionally());
std::cerr << "Accessing a field directly from the result (using the current row): " << result.begin()->alpha
auto result1 = db(select(all_of(tab)).from(tab).unconditionally());
std::cerr << "Accessing a field directly from the result (using the current row): " << result1.begin()->alpha
<< std::endl;
std::cerr << "Can do that again, no problem: " << result.begin()->alpha << std::endl;
std::cerr << "Can do that again, no problem: " << result1.begin()->alpha << std::endl;
auto tx = start_transaction(db);
if (const auto& row = *db(select(all_of(tab), select(max(tab.alpha)).from(tab)).from(tab).unconditionally()).begin())
auto result2 = db(select(all_of(tab), select(max(tab.alpha)).from(tab)).from(tab).unconditionally());
if (const auto& row = *result2.begin())
{
auto a = row.alpha;
auto m = row.max;