diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 59fda7ed..d49198c5 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -1,17 +1,26 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be +// INFO|TODO - Note that is file is Windows specific right now. Making it arch +// agnostic will be // taken up in beta release #include "ddbc_bindings.h" -#include "connection/connection.h" -#include "connection/connection_pool.h" +#include #include +#include +#include // NOLINT(build/c++17) #include // std::setw, std::setfill #include +#include +#include +#include +#include #include // std::forward -#include +#include + +#include "connection/connection.h" +#include "connection/connection_pool.h" //------------------------------------------------------------------------------------------------- // Macro definitions //------------------------------------------------------------------------------------------------- @@ -35,7 +44,8 @@ // Architecture-specific defines #ifndef ARCHITECTURE -#define ARCHITECTURE "win64" // Default to win64 if not defined during compilation +#define ARCHITECTURE \ + "win64" // Default to win64 if not defined during compilation #endif #define DAE_CHUNK_SIZE 8192 #define SQL_MAX_LOB_SIZE 8000 @@ -53,7 +63,7 @@ struct ParamInfo { SQLSMALLINT decimalDigits; SQLLEN strLenOrInd = 0; // Required for DAE bool isDAE = false; // Indicates if we need to stream - py::object dataPtr; + py::object dataPtr; }; // Mirrors the SQL_NUMERIC_STRUCT. But redefined to replace val char array @@ -62,17 +72,24 @@ struct ParamInfo { struct NumericData { SQLCHAR precision; SQLSCHAR scale; - SQLCHAR sign; // 1=pos, 0=neg - std::string val; // 123.45 -> 12345 - - NumericData() : precision(0), scale(0), sign(0), val(SQL_MAX_NUMERIC_LEN, '\0') {} - - NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, const std::string& valueBytes) - : precision(precision), scale(scale), sign(sign), val(SQL_MAX_NUMERIC_LEN, '\0') { + SQLCHAR sign; // 1=pos, 0=neg + std::string val; // 123.45 -> 12345 + + NumericData() + : precision(0), scale(0), sign(0), val(SQL_MAX_NUMERIC_LEN, '\0') {} + + NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, + const std::string& valueBytes) + : precision(precision), + scale(scale), + sign(sign), + val(SQL_MAX_NUMERIC_LEN, '\0') { if (valueBytes.size() > SQL_MAX_NUMERIC_LEN) { - throw std::runtime_error("NumericData valueBytes size exceeds SQL_MAX_NUMERIC_LEN (16)"); + throw std::runtime_error( + "NumericData valueBytes size exceeds SQL_MAX_NUMERIC_LEN (16)"); } - // Secure copy: bounds already validated, but using std::copy_n for safety + // Secure copy: bounds already validated, but using std::copy_n for + // safety if (valueBytes.size() > 0) { std::copy_n(valueBytes.data(), valueBytes.size(), &val[0]); } @@ -80,17 +97,16 @@ struct NumericData { }; // Struct to hold the DateTimeOffset structure -struct DateTimeOffset -{ - SQLSMALLINT year; - SQLUSMALLINT month; - SQLUSMALLINT day; - SQLUSMALLINT hour; - SQLUSMALLINT minute; - SQLUSMALLINT second; - SQLUINTEGER fraction; // Nanoseconds - SQLSMALLINT timezone_hour; // Offset hours from UTC - SQLSMALLINT timezone_minute; // Offset minutes from UTC +struct DateTimeOffset { + SQLSMALLINT year; + SQLUSMALLINT month; + SQLUSMALLINT day; + SQLUSMALLINT hour; + SQLUSMALLINT minute; + SQLUSMALLINT second; + SQLUINTEGER fraction; // Nanoseconds + SQLSMALLINT timezone_hour; // Offset hours from UTC + SQLSMALLINT timezone_minute; // Offset minutes from UTC }; // Struct to hold data buffers and indicators for each column @@ -182,65 +198,62 @@ SQLTablesFunc SQLTables_ptr = nullptr; SQLDescribeParamFunc SQLDescribeParam_ptr = nullptr; - -// Encoding String -static py::bytes EncodingString(const std::string& text, - const std::string& encoding, +// Encoding function with fallback strategy +static py::bytes EncodingString(const std::string& text, + const std::string& encoding, const std::string& errors = "strict") { try { py::gil_scoped_acquire gil; py::str unicode_str = py::str(text); - + // Direct encoding - let Python handle errors strictly py::bytes encoded = unicode_str.attr("encode")(encoding, errors); return encoded; - } catch (const py::error_already_set& e) { // Re-raise Python exceptions (UnicodeEncodeError, etc.) throw std::runtime_error("Encoding failed: " + std::string(e.what())); } } -// Decoding String static py::str DecodingString(const char* data, size_t length, - const std::string& encoding, + const std::string& encoding, const std::string& errors = "strict") { try { py::gil_scoped_acquire gil; py::bytes byte_data = py::bytes(data, length); - + // Direct decoding - let Python handle errors strictly py::str decoded = byte_data.attr("decode")(encoding, errors); return decoded; - } catch (const py::error_already_set& e) { // Re-raise Python exceptions (UnicodeDecodeError, etc.) throw std::runtime_error("Decoding failed: " + std::string(e.what())); } } -// Helper function to validate that an encoding string is a legitimate Python codec -// This prevents injection attacks while allowing all valid encodings +// Helper function to validate that an encoding string is a legitimate Python +// codec This prevents injection attacks while allowing all valid encodings static bool is_valid_encoding(const std::string& enc) { if (enc.empty() || enc.length() > 100) { // Reasonable length limit return false; } - - // Check for potentially dangerous characters that shouldn't be in codec names + + // Check for potentially dangerous characters that shouldn't be in codec + // names for (char c : enc) { if (!std::isalnum(c) && c != '-' && c != '_' && c != '.') { return false; // Reject suspicious characters } } - + // Verify it's a valid Python codec by attempting a test lookup try { py::gil_scoped_acquire gil; py::module_ codecs = py::module_::import("codecs"); - + // This will raise LookupError if the codec doesn't exist codecs.attr("lookup")(enc); - + return true; // Codec exists and is valid } catch (const py::error_already_set& e) { // Expected: LookupError for invalid codec names @@ -260,60 +273,68 @@ static bool is_valid_encoding(const std::string& enc) { // Helper function to validate error handling mode against an allowlist static bool is_valid_error_mode(const std::string& mode) { static const std::unordered_set allowed = { - "strict", - "ignore", - "replace", - "xmlcharrefreplace", - "backslashreplace" - }; + "strict", "ignore", "replace", "xmlcharrefreplace", "backslashreplace"}; return allowed.find(mode) != allowed.end(); } // Helper function to safely extract encoding settings from Python dict -static std::pair extract_encoding_settings(const py::dict& settings) { +static std::pair extract_encoding_settings( + const py::dict& settings) { try { std::string encoding = "utf-8"; // Default std::string errors = "strict"; // Default - + if (settings.contains("encoding") && !settings["encoding"].is_none()) { - std::string proposed_encoding = settings["encoding"].cast(); - + std::string proposed_encoding = + settings["encoding"].cast(); + // SECURITY: Validate encoding to prevent injection attacks - // Allows any valid Python codec (including SQL Server-supported encodings) + // Allows any valid Python codec (including SQL Server-supported + // encodings) if (is_valid_encoding(proposed_encoding)) { encoding = proposed_encoding; } else { - LOG("Invalid or unsafe encoding '{}' rejected, using default 'utf-8'", proposed_encoding); + LOG("Invalid or unsafe encoding '{}' rejected, using default " + "'utf-8'", + proposed_encoding); // Fall back to safe default encoding = "utf-8"; } } - + if (settings.contains("errors") && !settings["errors"].is_none()) { - std::string proposed_errors = settings["errors"].cast(); - + std::string proposed_errors = + settings["errors"].cast(); + // SECURITY: Validate error mode against allowlist if (is_valid_error_mode(proposed_errors)) { errors = proposed_errors; } else { - LOG("Invalid error mode '{}' rejected, using default 'strict'", proposed_errors); + LOG("Invalid error mode '{}' rejected, using default 'strict'", + proposed_errors); // Fall back to safe default errors = "strict"; } } - + return std::make_pair(encoding, errors); } catch (const py::error_already_set& e) { // Log Python exceptions (KeyError, TypeError, etc.) - LOG("Python exception while extracting encoding settings: {}. Using defaults (utf-8, strict)", e.what()); + LOG("Python exception while extracting encoding settings: {}. Using " + "defaults (utf-8, " + "strict)", + e.what()); return std::make_pair("utf-8", "strict"); } catch (const std::exception& e) { // Log C++ standard exceptions - LOG("Exception while extracting encoding settings: {}. Using defaults (utf-8, strict)", e.what()); + LOG("Exception while extracting encoding settings: {}. Using defaults " + "(utf-8, strict)", + e.what()); return std::make_pair("utf-8", "strict"); } catch (...) { // Last resort: unknown exception type - LOG("Unknown exception while extracting encoding settings. Using defaults (utf-8, strict)"); + LOG("Unknown exception while extracting encoding settings. Using " + "defaults (utf-8, strict)"); return std::make_pair("utf-8", "strict"); } } @@ -350,28 +371,33 @@ const char* GetSqlCTypeAsString(const SQLSMALLINT cType) { } } -std::string MakeParamMismatchErrorStr(const SQLSMALLINT cType, const int paramIndex) { +std::string MakeParamMismatchErrorStr(const SQLSMALLINT cType, + const int paramIndex) { std::string errorString = - "Parameter's object type does not match parameter's C type. paramIndex - " + + "Parameter's object type does not match parameter's C type. paramIndex " + "- " + std::to_string(paramIndex) + ", C type - " + GetSqlCTypeAsString(cType); return errorString; } -// This function allocates a buffer of ParamType, stores it as a void* in paramBuffers for -// book-keeping and then returns a ParamType* to the allocated memory. -// ctorArgs are the arguments to ParamType's constructor used while creating/allocating ParamType +// This function allocates a buffer of ParamType, stores it as a void* in +// paramBuffers for book-keeping and then returns a ParamType* to the allocated +// memory. ctorArgs are the arguments to ParamType's constructor used while +// creating/allocating ParamType template ParamType* AllocateParamBuffer(std::vector>& paramBuffers, CtorArgs&&... ctorArgs) { - paramBuffers.emplace_back(new ParamType(std::forward(ctorArgs)...), - std::default_delete()); + paramBuffers.emplace_back( + new ParamType(std::forward(ctorArgs)...), + std::default_delete()); return static_cast(paramBuffers.back().get()); } template -ParamType* AllocateParamBufferArray(std::vector>& paramBuffers, - size_t count) { - std::shared_ptr buffer(new ParamType[count], std::default_delete()); +ParamType* AllocateParamBufferArray( + std::vector>& paramBuffers, size_t count) { + std::shared_ptr buffer(new ParamType[count], + std::default_delete()); ParamType* raw = buffer.get(); paramBuffers.push_back(buffer); return raw; @@ -387,8 +413,8 @@ std::string DescribeChar(unsigned char ch) { } } -// Given a list of parameters and their ParamInfo, calls SQLBindParameter on each of them with -// appropriate arguments +// Given a list of parameters and their ParamInfo, calls SQLBindParameter on +// each of them with appropriate arguments SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, std::vector& paramInfos, std::vector>& paramBuffers, @@ -397,7 +423,8 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, for (int paramIndex = 0; paramIndex < params.size(); paramIndex++) { const auto& param = params[paramIndex]; ParamInfo& paramInfo = paramInfos[paramIndex]; - LOG("Binding parameter {} - C Type: {}, SQL Type: {}", paramIndex, paramInfo.paramCType, paramInfo.paramSQLType); + LOG("Binding parameter {} - C Type: {}, SQL Type: {}", paramIndex, + paramInfo.paramCType, paramInfo.paramSQLType); void* dataPtr = nullptr; SQLLEN bufferLength = 0; SQLLEN* strLenOrIndPtr = nullptr; @@ -406,31 +433,41 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, switch (paramInfo.paramCType) { case SQL_C_CHAR: { if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } - + std::string strValue; - - // Check if we have encoding settings and this is SQL_C_CHAR (not SQL_C_WCHAR) + + // Check if we have encoding settings and this is SQL_C_CHAR + // (not SQL_C_WCHAR) if (encoding_settings && !encoding_settings.is_none()) { try { - // SECURITY: Use extract_encoding_settings for full validation - // This validates encoding against allowlist and error mode - py::dict settings_dict = encoding_settings.cast(); - auto [encoding, errors] = extract_encoding_settings(settings_dict); - + // SECURITY: Use extract_encoding_settings for full + // validation This validates encoding against allowlist + // and error mode + py::dict settings_dict = + encoding_settings.cast(); + auto [encoding, errors] = + extract_encoding_settings(settings_dict); + // Validate ctype against allowlist if (settings_dict.contains("ctype")) { - SQLSMALLINT ctype = settings_dict["ctype"].cast(); - + SQLSMALLINT ctype = + settings_dict["ctype"].cast(); + // Only SQL_C_CHAR and SQL_C_WCHAR are allowed if (ctype != SQL_C_CHAR && ctype != SQL_C_WCHAR) { - LOG("Invalid ctype {} for parameter {}, using default", ctype, paramIndex); + LOG("Invalid ctype {} for parameter {}, using " + "default", + ctype, paramIndex); // Fall through to default behavior strValue = param.cast(); } else if (ctype == SQL_C_CHAR) { // Only use dynamic encoding for SQL_C_CHAR - py::bytes encoded_bytes = EncodingString(param.cast(), encoding, errors); + py::bytes encoded_bytes = + EncodingString(param.cast(), + encoding, errors); strValue = encoded_bytes.cast(); } else { // SQL_C_WCHAR - use default behavior @@ -441,7 +478,10 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, strValue = param.cast(); } } catch (const std::exception& e) { - LOG("Encoding settings processing failed for parameter {}: {}. Using default.", paramIndex, e.what()); + LOG("Encoding settings processing failed for parameter " + "{}: {}. Using " + "default.", + paramIndex, e.what()); // Fall back to safe default behavior strValue = param.cast(); } @@ -451,49 +491,70 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, } // Allocate buffer and copy string data - size_t bufferSize = strValue.length() + 1; // +1 for null terminator - char* buffer = AllocateParamBufferArray(paramBuffers, bufferSize); - + size_t bufferSize = + strValue.length() + 1; // +1 for null terminator + char* buffer = + AllocateParamBufferArray(paramBuffers, bufferSize); + if (!buffer) { - ThrowStdException("Failed to allocate buffer for SQL_C_CHAR parameter at index " + std::to_string(paramIndex)); + ThrowStdException( + "Failed to allocate buffer for SQL_C_CHAR parameter at " + "index " + + std::to_string(paramIndex)); } - - // SECURITY: Validate size before copying to prevent buffer overflow + + // SECURITY: Validate size before copying to prevent buffer + // overflow size_t copyLength = strValue.length(); if (copyLength >= bufferSize) { - ThrowStdException("Buffer overflow prevented: string length exceeds allocated buffer at index " + std::to_string(paramIndex)); + ThrowStdException( + "Buffer overflow prevented: string length exceeds " + "allocated buffer at " + "index " + + std::to_string(paramIndex)); } - - // Use secure copy with bounds checking - #ifdef _WIN32 - // Windows: Use memcpy_s for secure copy - errno_t err = memcpy_s(buffer, bufferSize, strValue.data(), copyLength); - if (err != 0) { - ThrowStdException("Secure memory copy failed with error code " + std::to_string(err) + " at index " + std::to_string(paramIndex)); - } - #else - // POSIX: Use std::copy_n with explicit bounds checking - if (copyLength > 0) { - std::copy_n(strValue.data(), copyLength, buffer); - } - #endif - + +// Use secure copy with bounds checking +#ifdef _WIN32 + // Windows: Use memcpy_s for secure copy + errno_t err = + memcpy_s(buffer, bufferSize, strValue.data(), copyLength); + if (err != 0) { + ThrowStdException( + "Secure memory copy failed with error code " + + std::to_string(err) + " at index " + + std::to_string(paramIndex)); + } +#else + // POSIX: Use std::copy_n with explicit bounds checking + if (copyLength > 0) { + std::copy_n(strValue.data(), copyLength, buffer); + } +#endif + buffer[copyLength] = '\0'; // Ensure null termination - + paramInfo.strLenOrInd = copyLength; - - LOG("Binding SQL_C_CHAR parameter at index {} with encoded length {}", paramIndex, strValue.length()); + + LOG("Binding SQL_C_CHAR parameter at index {} with encoded " + "length {}", + paramIndex, strValue.length()); break; } case SQL_C_BINARY: { - if (!py::isinstance(param) && !py::isinstance(param) && + if (!py::isinstance(param) && + !py::isinstance(param) && !py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } if (paramInfo.isDAE) { // Deferred execution for VARBINARY(MAX) - LOG("Parameter[{}] is marked for DAE streaming (VARBINARY(MAX))", paramIndex); - dataPtr = const_cast(reinterpret_cast(¶mInfos[paramIndex])); + LOG("Parameter[{}] is marked for DAE streaming " + "(VARBINARY(MAX))", + paramIndex); + dataPtr = const_cast( + reinterpret_cast(¶mInfos[paramIndex])); strLenOrIndPtr = AllocateParamBuffer(paramBuffers); *strLenOrIndPtr = SQL_LEN_DATA_AT_EXEC(0); bufferLength = 0; @@ -504,11 +565,15 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, binData = param.cast(); } else { // bytearray - binData = std::string(reinterpret_cast(PyByteArray_AsString(param.ptr())), - PyByteArray_Size(param.ptr())); + binData = + std::string(reinterpret_cast( + PyByteArray_AsString(param.ptr())), + PyByteArray_Size(param.ptr())); } - std::string* binBuffer = AllocateParamBuffer(paramBuffers, binData); - dataPtr = const_cast(static_cast(binBuffer->data())); + std::string* binBuffer = + AllocateParamBuffer(paramBuffers, binData); + dataPtr = const_cast( + static_cast(binBuffer->data())); bufferLength = static_cast(binBuffer->size()); strLenOrIndPtr = AllocateParamBuffer(paramBuffers); *strLenOrIndPtr = bufferLength; @@ -516,75 +581,80 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, break; } case SQL_C_WCHAR: { - if (!py::isinstance(param) && !py::isinstance(param) && + if (!py::isinstance(param) && + !py::isinstance(param) && !py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } if (paramInfo.isDAE) { // deferred execution - LOG("Parameter[{}] is marked for DAE streaming", paramIndex); - dataPtr = const_cast(reinterpret_cast(¶mInfos[paramIndex])); + LOG("Parameter[{}] is marked for DAE streaming", + paramIndex); + dataPtr = const_cast( + reinterpret_cast(¶mInfos[paramIndex])); strLenOrIndPtr = AllocateParamBuffer(paramBuffers); *strLenOrIndPtr = SQL_LEN_DATA_AT_EXEC(0); bufferLength = 0; } else { // Normal small-string case - std::wstring* strParam = - AllocateParamBuffer(paramBuffers, param.cast()); - LOG("SQL_C_WCHAR Parameter[{}]: Length={}, isDAE={}", paramIndex, strParam->size(), paramInfo.isDAE); + std::wstring* strParam = AllocateParamBuffer( + paramBuffers, param.cast()); + LOG("SQL_C_WCHAR Parameter[{}]: Length={}, isDAE={}", + paramIndex, strParam->size(), paramInfo.isDAE); std::vector* sqlwcharBuffer = - AllocateParamBuffer>(paramBuffers, WStringToSQLWCHAR(*strParam)); + AllocateParamBuffer>( + paramBuffers, WStringToSQLWCHAR(*strParam)); dataPtr = sqlwcharBuffer->data(); bufferLength = sqlwcharBuffer->size() * sizeof(SQLWCHAR); strLenOrIndPtr = AllocateParamBuffer(paramBuffers); *strLenOrIndPtr = SQL_NTS; - } break; } case SQL_C_BIT: { if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } - dataPtr = - static_cast(AllocateParamBuffer(paramBuffers, param.cast())); + dataPtr = static_cast(AllocateParamBuffer( + paramBuffers, param.cast())); break; } case SQL_C_DEFAULT: { if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } - SQLSMALLINT sqlType = paramInfo.paramSQLType; - SQLULEN columnSize = paramInfo.columnSize; + SQLSMALLINT sqlType = paramInfo.paramSQLType; + SQLULEN columnSize = paramInfo.columnSize; SQLSMALLINT decimalDigits = paramInfo.decimalDigits; if (sqlType == SQL_UNKNOWN_TYPE) { SQLSMALLINT describedType; - SQLULEN describedSize; + SQLULEN describedSize; SQLSMALLINT describedDigits; SQLSMALLINT nullable; RETCODE rc = SQLDescribeParam_ptr( - hStmt, - static_cast(paramIndex + 1), - &describedType, - &describedSize, - &describedDigits, - &nullable - ); + hStmt, static_cast(paramIndex + 1), + &describedType, &describedSize, &describedDigits, + &nullable); if (!SQL_SUCCEEDED(rc)) { - LOG("SQLDescribeParam failed for parameter {} with error code {}", paramIndex, rc); + LOG("SQLDescribeParam failed for parameter {} with " + "error code {}", + paramIndex, rc); return rc; } - sqlType = describedType; - columnSize = describedSize; + sqlType = describedType; + columnSize = describedSize; decimalDigits = describedDigits; } dataPtr = nullptr; strLenOrIndPtr = AllocateParamBuffer(paramBuffers); *strLenOrIndPtr = SQL_NULL_DATA; bufferLength = 0; - paramInfo.paramSQLType = sqlType; - paramInfo.columnSize = columnSize; - paramInfo.decimalDigits = decimalDigits; + paramInfo.paramSQLType = sqlType; + paramInfo.columnSize = columnSize; + paramInfo.decimalDigits = decimalDigits; break; } case SQL_C_STINYINT: @@ -592,143 +662,202 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, case SQL_C_SSHORT: case SQL_C_SHORT: { if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } int value = param.cast(); // Range validation for signed 16-bit integer - if (value < std::numeric_limits::min() || value > std::numeric_limits::max()) { - ThrowStdException("Signed short integer parameter out of range at paramIndex " + std::to_string(paramIndex)); + if (value < std::numeric_limits::min() || + value > std::numeric_limits::max()) { + ThrowStdException( + "Signed short integer parameter out of range at " + "paramIndex " + + std::to_string(paramIndex)); } - dataPtr = - static_cast(AllocateParamBuffer(paramBuffers, param.cast())); + dataPtr = static_cast( + AllocateParamBuffer(paramBuffers, param.cast())); break; } case SQL_C_UTINYINT: case SQL_C_USHORT: { if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } unsigned int value = param.cast(); - if (value > std::numeric_limits::max()) { - ThrowStdException("Unsigned short integer parameter out of range at paramIndex " + std::to_string(paramIndex)); + if (value > std::numeric_limits::max()) { + ThrowStdException( + "Unsigned short integer parameter out of range at " + "paramIndex " + + std::to_string(paramIndex)); } - dataPtr = static_cast( - AllocateParamBuffer(paramBuffers, param.cast())); + dataPtr = static_cast(AllocateParamBuffer( + paramBuffers, param.cast())); break; } case SQL_C_SBIGINT: case SQL_C_SLONG: case SQL_C_LONG: { if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } int64_t value = param.cast(); // Range validation for signed 64-bit integer - if (value < std::numeric_limits::min() || value > std::numeric_limits::max()) { - ThrowStdException("Signed 64-bit integer parameter out of range at paramIndex " + std::to_string(paramIndex)); + if (value < std::numeric_limits::min() || + value > std::numeric_limits::max()) { + ThrowStdException( + "Signed 64-bit integer parameter out of range at " + "paramIndex " + + std::to_string(paramIndex)); } - dataPtr = static_cast( - AllocateParamBuffer(paramBuffers, param.cast())); + dataPtr = static_cast(AllocateParamBuffer( + paramBuffers, param.cast())); break; } case SQL_C_UBIGINT: case SQL_C_ULONG: { if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } uint64_t value = param.cast(); // Range validation for unsigned 64-bit integer if (value > std::numeric_limits::max()) { - ThrowStdException("Unsigned 64-bit integer parameter out of range at paramIndex " + std::to_string(paramIndex)); + ThrowStdException( + "Unsigned 64-bit integer parameter out of range at " + "paramIndex " + + std::to_string(paramIndex)); } - dataPtr = static_cast( - AllocateParamBuffer(paramBuffers, param.cast())); + dataPtr = static_cast(AllocateParamBuffer( + paramBuffers, param.cast())); break; } case SQL_C_FLOAT: { if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } - dataPtr = static_cast( - AllocateParamBuffer(paramBuffers, param.cast())); + dataPtr = static_cast(AllocateParamBuffer( + paramBuffers, param.cast())); break; } case SQL_C_DOUBLE: { if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } - dataPtr = static_cast( - AllocateParamBuffer(paramBuffers, param.cast())); + dataPtr = static_cast(AllocateParamBuffer( + paramBuffers, param.cast())); break; } case SQL_C_TYPE_DATE: { - py::object dateType = py::module_::import("datetime").attr("date"); + py::object dateType = + py::module_::import("datetime").attr("date"); if (!py::isinstance(param, dateType)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } int year = param.attr("year").cast(); if (year < 1753 || year > 9999) { - ThrowStdException("Date out of range for SQL Server (1753-9999) at paramIndex " + std::to_string(paramIndex)); + ThrowStdException( + "Date out of range for SQL Server (1753-9999) at " + "paramIndex " + + std::to_string(paramIndex)); } - // TODO: can be moved to python by registering SQL_DATE_STRUCT in pybind - SQL_DATE_STRUCT* sqlDatePtr = AllocateParamBuffer(paramBuffers); - sqlDatePtr->year = static_cast(param.attr("year").cast()); - sqlDatePtr->month = static_cast(param.attr("month").cast()); - sqlDatePtr->day = static_cast(param.attr("day").cast()); + // TODO: can be moved to python by registering SQL_DATE_STRUCT + // in pybind + SQL_DATE_STRUCT* sqlDatePtr = + AllocateParamBuffer(paramBuffers); + sqlDatePtr->year = + static_cast(param.attr("year").cast()); + sqlDatePtr->month = + static_cast(param.attr("month").cast()); + sqlDatePtr->day = + static_cast(param.attr("day").cast()); dataPtr = static_cast(sqlDatePtr); break; } case SQL_C_TYPE_TIME: { - py::object timeType = py::module_::import("datetime").attr("time"); + py::object timeType = + py::module_::import("datetime").attr("time"); if (!py::isinstance(param, timeType)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } - // TODO: can be moved to python by registering SQL_TIME_STRUCT in pybind - SQL_TIME_STRUCT* sqlTimePtr = AllocateParamBuffer(paramBuffers); - sqlTimePtr->hour = static_cast(param.attr("hour").cast()); - sqlTimePtr->minute = static_cast(param.attr("minute").cast()); - sqlTimePtr->second = static_cast(param.attr("second").cast()); + // TODO: can be moved to python by registering SQL_TIME_STRUCT + // in pybind + SQL_TIME_STRUCT* sqlTimePtr = + AllocateParamBuffer(paramBuffers); + sqlTimePtr->hour = + static_cast(param.attr("hour").cast()); + sqlTimePtr->minute = + static_cast(param.attr("minute").cast()); + sqlTimePtr->second = + static_cast(param.attr("second").cast()); dataPtr = static_cast(sqlTimePtr); break; } case SQL_C_SS_TIMESTAMPOFFSET: { - py::object datetimeType = py::module_::import("datetime").attr("datetime"); + py::object datetimeType = + py::module_::import("datetime").attr("datetime"); if (!py::isinstance(param, datetimeType)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } // Checking if the object has a timezone py::object tzinfo = param.attr("tzinfo"); if (tzinfo.is_none()) { - ThrowStdException("Datetime object must have tzinfo for SQL_C_SS_TIMESTAMPOFFSET at paramIndex " + std::to_string(paramIndex)); + ThrowStdException( + "Datetime object must have tzinfo for " + "SQL_C_SS_TIMESTAMPOFFSET at paramIndex " + + std::to_string(paramIndex)); } - DateTimeOffset* dtoPtr = AllocateParamBuffer(paramBuffers); - - dtoPtr->year = static_cast(param.attr("year").cast()); - dtoPtr->month = static_cast(param.attr("month").cast()); - dtoPtr->day = static_cast(param.attr("day").cast()); - dtoPtr->hour = static_cast(param.attr("hour").cast()); - dtoPtr->minute = static_cast(param.attr("minute").cast()); - dtoPtr->second = static_cast(param.attr("second").cast()); + DateTimeOffset* dtoPtr = + AllocateParamBuffer(paramBuffers); + + dtoPtr->year = + static_cast(param.attr("year").cast()); + dtoPtr->month = + static_cast(param.attr("month").cast()); + dtoPtr->day = + static_cast(param.attr("day").cast()); + dtoPtr->hour = + static_cast(param.attr("hour").cast()); + dtoPtr->minute = + static_cast(param.attr("minute").cast()); + dtoPtr->second = + static_cast(param.attr("second").cast()); // SQL server supports in ns, but python datetime supports in µs - dtoPtr->fraction = static_cast(param.attr("microsecond").cast() * 1000); + dtoPtr->fraction = static_cast( + param.attr("microsecond").cast() * 1000); py::object utcoffset = tzinfo.attr("utcoffset")(param); if (utcoffset.is_none()) { - ThrowStdException("Datetime object's tzinfo.utcoffset() returned None at paramIndex " + std::to_string(paramIndex)); + ThrowStdException( + "Datetime object's tzinfo.utcoffset() returned None at " + "paramIndex " + + std::to_string(paramIndex)); } - int total_seconds = static_cast(utcoffset.attr("total_seconds")().cast()); + int total_seconds = static_cast( + utcoffset.attr("total_seconds")().cast()); const int MAX_OFFSET = 14 * 3600; const int MIN_OFFSET = -14 * 3600; if (total_seconds > MAX_OFFSET || total_seconds < MIN_OFFSET) { - ThrowStdException("Datetimeoffset tz offset out of SQL Server range (-14h to +14h) at paramIndex " + std::to_string(paramIndex)); + ThrowStdException( + "Datetimeoffset tz offset out of SQL Server range " + "(-14h to +14h) at paramIndex " + + std::to_string(paramIndex)); } std::div_t div_result = std::div(total_seconds, 3600); - dtoPtr->timezone_hour = static_cast(div_result.quot); - dtoPtr->timezone_minute = static_cast(div(div_result.rem, 60).quot); - + dtoPtr->timezone_hour = + static_cast(div_result.quot); + dtoPtr->timezone_minute = + static_cast(div(div_result.rem, 60).quot); + dataPtr = static_cast(dtoPtr); bufferLength = sizeof(DateTimeOffset); strLenOrIndPtr = AllocateParamBuffer(paramBuffers); @@ -736,62 +865,84 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, break; } case SQL_C_TYPE_TIMESTAMP: { - py::object datetimeType = py::module_::import("datetime").attr("datetime"); + py::object datetimeType = + py::module_::import("datetime").attr("datetime"); if (!py::isinstance(param, datetimeType)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } SQL_TIMESTAMP_STRUCT* sqlTimestampPtr = AllocateParamBuffer(paramBuffers); - sqlTimestampPtr->year = static_cast(param.attr("year").cast()); - sqlTimestampPtr->month = static_cast(param.attr("month").cast()); - sqlTimestampPtr->day = static_cast(param.attr("day").cast()); - sqlTimestampPtr->hour = static_cast(param.attr("hour").cast()); - sqlTimestampPtr->minute = static_cast(param.attr("minute").cast()); - sqlTimestampPtr->second = static_cast(param.attr("second").cast()); + sqlTimestampPtr->year = + static_cast(param.attr("year").cast()); + sqlTimestampPtr->month = + static_cast(param.attr("month").cast()); + sqlTimestampPtr->day = + static_cast(param.attr("day").cast()); + sqlTimestampPtr->hour = + static_cast(param.attr("hour").cast()); + sqlTimestampPtr->minute = + static_cast(param.attr("minute").cast()); + sqlTimestampPtr->second = + static_cast(param.attr("second").cast()); // SQL server supports in ns, but python datetime supports in µs sqlTimestampPtr->fraction = static_cast( - param.attr("microsecond").cast() * 1000); // Convert µs to ns + param.attr("microsecond").cast() * + 1000); // Convert µs to ns dataPtr = static_cast(sqlTimestampPtr); break; } case SQL_C_NUMERIC: { if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } NumericData decimalParam = param.cast(); - LOG("Received numeric parameter: precision - {}, scale- {}, sign - {}, value - {}", - decimalParam.precision, decimalParam.scale, decimalParam.sign, - decimalParam.val); + LOG("Received numeric parameter: precision - {}, scale- {}, " + "sign - {}, value - {}", + decimalParam.precision, decimalParam.scale, + decimalParam.sign, decimalParam.val); SQL_NUMERIC_STRUCT* decimalPtr = AllocateParamBuffer(paramBuffers); decimalPtr->precision = decimalParam.precision; decimalPtr->scale = decimalParam.scale; decimalPtr->sign = decimalParam.sign; // Convert the integer decimalParam.val to char array - std::memset(static_cast(decimalPtr->val), 0, sizeof(decimalPtr->val)); - size_t copyLen = std::min(decimalParam.val.size(), sizeof(decimalPtr->val)); + std::memset(static_cast(decimalPtr->val), 0, + sizeof(decimalPtr->val)); + size_t copyLen = + std::min(decimalParam.val.size(), sizeof(decimalPtr->val)); // Secure copy: bounds already validated with std::min if (copyLen > 0) { - std::copy_n(decimalParam.val.data(), copyLen, decimalPtr->val); + std::copy_n(decimalParam.val.data(), copyLen, + decimalPtr->val); } dataPtr = static_cast(decimalPtr); break; } case SQL_C_GUID: { if (!py::isinstance(param)) { - ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + paramInfo.paramCType, paramIndex)); } py::bytes uuid_bytes = param.cast(); - const unsigned char* uuid_data = reinterpret_cast(PyBytes_AS_STRING(uuid_bytes.ptr())); + const unsigned char* uuid_data = + reinterpret_cast( + PyBytes_AS_STRING(uuid_bytes.ptr())); if (PyBytes_GET_SIZE(uuid_bytes.ptr()) != 16) { - LOG("Invalid UUID parameter at index {}: expected 16 bytes, got {} bytes, type {}", paramIndex, PyBytes_GET_SIZE(uuid_bytes.ptr()), paramInfo.paramCType); - ThrowStdException("UUID binary data must be exactly 16 bytes long."); + LOG("Invalid UUID parameter at index {}: expected 16 " + "bytes, got {} bytes, type {}", + paramIndex, PyBytes_GET_SIZE(uuid_bytes.ptr()), + paramInfo.paramCType); + ThrowStdException( + "UUID binary data must be exactly 16 bytes long."); } - SQLGUID* guid_data_ptr = AllocateParamBuffer(paramBuffers); + SQLGUID* guid_data_ptr = + AllocateParamBuffer(paramBuffers); guid_data_ptr->Data1 = (static_cast(uuid_data[3]) << 24) | (static_cast(uuid_data[2]) << 16) | - (static_cast(uuid_data[1]) << 8) | + (static_cast(uuid_data[1]) << 8) | (static_cast(uuid_data[0])); guid_data_ptr->Data2 = (static_cast(uuid_data[5]) << 8) | @@ -809,55 +960,68 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, } default: { std::ostringstream errorString; - errorString << "Unsupported parameter type - " << paramInfo.paramCType - << " for parameter - " << paramIndex; + errorString << "Unsupported parameter type - " + << paramInfo.paramCType << " for parameter - " + << paramIndex; ThrowStdException(errorString.str()); } } - assert(SQLBindParameter_ptr && SQLGetStmtAttr_ptr && SQLSetDescField_ptr); + assert(SQLBindParameter_ptr && SQLGetStmtAttr_ptr && + SQLSetDescField_ptr); RETCODE rc = SQLBindParameter_ptr( hStmt, - static_cast(paramIndex + 1), /* 1-based indexing */ + static_cast(paramIndex + 1), /* 1-based indexing */ static_cast(paramInfo.inputOutputType), static_cast(paramInfo.paramCType), - static_cast(paramInfo.paramSQLType), paramInfo.columnSize, - paramInfo.decimalDigits, dataPtr, bufferLength, strLenOrIndPtr); + static_cast(paramInfo.paramSQLType), + paramInfo.columnSize, paramInfo.decimalDigits, dataPtr, + bufferLength, strLenOrIndPtr); if (!SQL_SUCCEEDED(rc)) { LOG("Error when binding parameter - {}", paramIndex); return rc; } - // Special handling for Numeric type - - // https://learn.microsoft.com/en-us/sql/odbc/reference/appendixes/retrieve-numeric-data-sql-numeric-struct-kb222831?view=sql-server-ver16#sql_c_numeric-overview + // Special handling for Numeric type - + // https://learn.microsoft.com/en-us/sql/odbc/reference/appendixes/retrieve-numeric-data-sql-numeric-struct-kb222831?view=sql-server-ver16#sql_c_numeric-overview if (paramInfo.paramCType == SQL_C_NUMERIC) { SQLHDESC hDesc = nullptr; - rc = SQLGetStmtAttr_ptr(hStmt, SQL_ATTR_APP_PARAM_DESC, &hDesc, 0, NULL); - if(!SQL_SUCCEEDED(rc)) { + rc = SQLGetStmtAttr_ptr(hStmt, SQL_ATTR_APP_PARAM_DESC, &hDesc, 0, + NULL); + if (!SQL_SUCCEEDED(rc)) { LOG("Error when getting statement attribute - {}", paramIndex); return rc; } - rc = SQLSetDescField_ptr(hDesc, 1, SQL_DESC_TYPE, (SQLPOINTER) SQL_C_NUMERIC, 0); - if(!SQL_SUCCEEDED(rc)) { - LOG("Error when setting descriptor field SQL_DESC_TYPE - {}", paramIndex); + rc = SQLSetDescField_ptr(hDesc, 1, SQL_DESC_TYPE, + (SQLPOINTER)SQL_C_NUMERIC, 0); + if (!SQL_SUCCEEDED(rc)) { + LOG("Error when setting descriptor field SQL_DESC_TYPE - {}", + paramIndex); return rc; } - SQL_NUMERIC_STRUCT* numericPtr = reinterpret_cast(dataPtr); + SQL_NUMERIC_STRUCT* numericPtr = + reinterpret_cast(dataPtr); rc = SQLSetDescField_ptr(hDesc, 1, SQL_DESC_PRECISION, - (SQLPOINTER) numericPtr->precision, 0); - if(!SQL_SUCCEEDED(rc)) { - LOG("Error when setting descriptor field SQL_DESC_PRECISION - {}", paramIndex); + (SQLPOINTER)numericPtr->precision, 0); + if (!SQL_SUCCEEDED(rc)) { + LOG("Error when setting descriptor field SQL_DESC_PRECISION - " + "{}", + paramIndex); return rc; } rc = SQLSetDescField_ptr(hDesc, 1, SQL_DESC_SCALE, - (SQLPOINTER) numericPtr->scale, 0); - if(!SQL_SUCCEEDED(rc)) { - LOG("Error when setting descriptor field SQL_DESC_SCALE - {}", paramIndex); + (SQLPOINTER)numericPtr->scale, 0); + if (!SQL_SUCCEEDED(rc)) { + LOG("Error when setting descriptor field SQL_DESC_SCALE - {}", + paramIndex); return rc; } - rc = SQLSetDescField_ptr(hDesc, 1, SQL_DESC_DATA_PTR, (SQLPOINTER) numericPtr, 0); - if(!SQL_SUCCEEDED(rc)) { - LOG("Error when setting descriptor field SQL_DESC_DATA_PTR - {}", paramIndex); + rc = SQLSetDescField_ptr(hDesc, 1, SQL_DESC_DATA_PTR, + (SQLPOINTER)numericPtr, 0); + if (!SQL_SUCCEEDED(rc)) { + LOG("Error when setting descriptor field SQL_DESC_DATA_PTR - " + "{}", + paramIndex); return rc; } } @@ -866,12 +1030,13 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, return SQL_SUCCESS; } -// This is temporary hack to avoid crash when SQLDescribeCol returns 0 as columnSize -// for NVARCHAR(MAX) & similar types. Variable length data needs more nuanced handling. +// This is temporary hack to avoid crash when SQLDescribeCol returns 0 as +// columnSize for NVARCHAR(MAX) & similar types. Variable length data needs more +// nuanced handling. // TODO: Fix this in beta -// This function sets the buffer allocated to fetch NVARCHAR(MAX) & similar types to -// 4096 chars. So we'll retrieve data upto 4096. Anything greater then that will throw -// error +// This function sets the buffer allocated to fetch NVARCHAR(MAX) & similar +// types to 4096 chars. So we'll retrieve data upto 4096. Anything greater then +// that will throw error void HandleZeroColumnSizeAtFetch(SQLULEN& columnSize) { if (columnSize == 0) { columnSize = 4096; @@ -885,23 +1050,26 @@ void HandleZeroColumnSizeAtFetch(SQLULEN& columnSize) { static bool is_python_finalizing() { try { if (Py_IsInitialized() == 0) { - return true; // Python is already shut down + return true; // Python is already shut down } - + py::gil_scoped_acquire gil; py::object sys_module = py::module_::import("sys"); if (!sys_module.is_none()) { - // Check if the attribute exists before accessing it (for Python version compatibility) + // Check if the attribute exists before accessing it (for Python + // version compatibility) if (py::hasattr(sys_module, "_is_finalizing")) { py::object finalizing_func = sys_module.attr("_is_finalizing"); - if (!finalizing_func.is_none() && finalizing_func().cast()) { - return true; // Python is finalizing + if (!finalizing_func.is_none() && + finalizing_func().cast()) { + return true; // Python is finalizing } } } return false; } catch (...) { - std::cerr << "Error occurred while checking Python finalization state." << std::endl; + std::cerr << "Error occurred while checking Python finalization state." + << std::endl; // Be conservative - don't assume shutdown on any exception // Only return true if we're absolutely certain Python is shutting down return false; @@ -913,21 +1081,24 @@ template void LOG(const std::string& formatString, Args&&... args) { // Check if Python is shutting down to avoid crash during cleanup if (is_python_finalizing()) { - return; // Python is shutting down or finalizing, don't log + return; // Python is shutting down or finalizing, don't log } - + try { py::gil_scoped_acquire gil; // <---- this ensures safe Python API usage - py::object logger = py::module_::import("mssql_python.logging_config").attr("get_logger")(); + py::object logger = py::module_::import("mssql_python.logging_config") + .attr("get_logger")(); if (py::isinstance(logger)) return; try { - std::string ddbcFormatString = "[DDBC Bindings log] " + formatString; + std::string ddbcFormatString = + "[DDBC Bindings log] " + formatString; if constexpr (sizeof...(args) == 0) { logger.attr("debug")(py::str(ddbcFormatString)); } else { - py::str message = py::str(ddbcFormatString).format(std::forward(args)...); + py::str message = py::str(ddbcFormatString) + .format(std::forward(args)...); logger.attr("debug")(message); } } catch (const std::exception& e) { @@ -935,17 +1106,19 @@ void LOG(const std::string& formatString, Args&&... args) { } } catch (const py::error_already_set& e) { // Python is shutting down or in an inconsistent state, silently ignore - (void)e; // Suppress unused variable warning + (void)e; // Suppress unused variable warning return; } catch (const std::exception& e) { // Any other error, ignore to prevent crash during cleanup - (void)e; // Suppress unused variable warning + (void)e; // Suppress unused variable warning return; } } // TODO: Add more nuanced exception classes -void ThrowStdException(const std::string& message) { throw std::runtime_error(message); } +void ThrowStdException(const std::string& message) { + throw std::runtime_error(message); +} std::string GetLastErrorMessage(); // TODO: Move this to Python @@ -953,11 +1126,12 @@ std::string GetModuleDirectory() { py::object module = py::module::import("mssql_python"); py::object module_path = module.attr("__file__"); std::string module_file = module_path.cast(); - + #ifdef _WIN32 // Windows-specific path handling char path[MAX_PATH]; - errno_t err = strncpy_s(path, MAX_PATH, module_file.c_str(), module_file.length()); + errno_t err = + strncpy_s(path, MAX_PATH, module_file.c_str(), module_file.length()); if (err != 0) { LOG("strncpy_s failed with error code: {}", err); return {}; @@ -979,13 +1153,14 @@ std::string GetModuleDirectory() { // Platform-agnostic function to load the driver dynamic library DriverHandle LoadDriverLibrary(const std::string& driverPath) { LOG("Loading driver from path: {}", driverPath); - + #ifdef _WIN32 // Windows: Convert string to wide string for LoadLibraryW std::wstring widePath(driverPath.begin(), driverPath.end()); HMODULE handle = LoadLibraryW(widePath.c_str()); if (!handle) { - LOG("Failed to load library: {}. Error: {}", driverPath, GetLastErrorMessage()); + LOG("Failed to load library: {}. Error: {}", driverPath, + GetLastErrorMessage()); ThrowStdException("Failed to load library: " + driverPath); } return handle; @@ -1006,15 +1181,12 @@ std::string GetLastErrorMessage() { DWORD error = GetLastError(); char* messageBuffer = nullptr; size_t size = FormatMessageA( - FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, - error, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPSTR)&messageBuffer, - 0, - NULL - ); - std::string errorMessage = messageBuffer ? std::string(messageBuffer, size) : "Unknown error"; + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, error, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPSTR)&messageBuffer, 0, NULL); + std::string errorMessage = + messageBuffer ? std::string(messageBuffer, size) : "Unknown error"; LocalFree(messageBuffer); return "Error code: " + std::to_string(error) + " - " + errorMessage; #else @@ -1024,20 +1196,20 @@ std::string GetLastErrorMessage() { #endif } - /* * Resolve ODBC driver path in C++ to avoid circular import issues on Alpine. * * Background: - * On Alpine Linux, calling into Python during module initialization (via pybind11) - * causes a circular import due to musl's stricter dynamic loader behavior. + * On Alpine Linux, calling into Python during module initialization (via + * pybind11) causes a circular import due to musl's stricter dynamic loader + * behavior. * - * Specifically, importing Python helpers from C++ triggered a re-import of the - * partially-initialized native module, which works on glibc (Ubuntu/macOS) but + * Specifically, importing Python helpers from C++ triggered a re-import of the + * partially-initialized native module, which works on glibc (Ubuntu/macOS) but * fails on musl-based systems like Alpine. * - * By moving driver path resolution entirely into C++, we avoid any Python-layer - * dependencies during critical initialization, ensuring compatibility across + * By moving driver path resolution entirely into C++, we avoid any Python-layer + * dependencies during critical initialization, ensuring compatibility across * all supported platforms. */ std::string GetDriverPathCpp(const std::string& moduleDir) { @@ -1047,45 +1219,51 @@ std::string GetDriverPathCpp(const std::string& moduleDir) { std::string platform; std::string arch; - // Detect architecture - #if defined(__aarch64__) || defined(_M_ARM64) - arch = "arm64"; - #elif defined(__x86_64__) || defined(_M_X64) || defined(_M_AMD64) - arch = "x86_64"; // maps to "x64" on Windows - #else - throw std::runtime_error("Unsupported architecture"); - #endif - - // Detect platform and set path - #ifdef __linux__ - if (fs::exists("/etc/alpine-release")) { - platform = "alpine"; - } else if (fs::exists("/etc/redhat-release") || fs::exists("/etc/centos-release")) { - platform = "rhel"; - } else if (fs::exists("/etc/SuSE-release") || fs::exists("/etc/SUSE-brand")) { - platform = "suse"; - } else { - platform = "debian_ubuntu"; // Default to debian_ubuntu for other distros - } +// Detect architecture +#if defined(__aarch64__) || defined(_M_ARM64) + arch = "arm64"; +#elif defined(__x86_64__) || defined(_M_X64) || defined(_M_AMD64) + arch = "x86_64"; // maps to "x64" on Windows +#else + throw std::runtime_error("Unsupported architecture"); +#endif + +// Detect platform and set path +#ifdef __linux__ + if (fs::exists("/etc/alpine-release")) { + platform = "alpine"; + } else if (fs::exists("/etc/redhat-release") || + fs::exists("/etc/centos-release")) { + platform = "rhel"; + } else if (fs::exists("/etc/SuSE-release") || + fs::exists("/etc/SUSE-brand")) { + platform = "suse"; + } else { + platform = + "debian_ubuntu"; // Default to debian_ubuntu for other distros + } - fs::path driverPath = basePath / "libs" / "linux" / platform / arch / "lib" / "libmsodbcsql-18.5.so.1.1"; - return driverPath.string(); + fs::path driverPath = basePath / "libs" / "linux" / platform / arch / + "lib" / "libmsodbcsql-18.5.so.1.1"; + return driverPath.string(); - #elif defined(__APPLE__) - platform = "macos"; - fs::path driverPath = basePath / "libs" / platform / arch / "lib" / "libmsodbcsql.18.dylib"; - return driverPath.string(); +#elif defined(__APPLE__) + platform = "macos"; + fs::path driverPath = + basePath / "libs" / platform / arch / "lib" / "libmsodbcsql.18.dylib"; + return driverPath.string(); - #elif defined(_WIN32) - platform = "windows"; - // Normalize x86_64 to x64 for Windows naming - if (arch == "x86_64") arch = "x64"; - fs::path driverPath = basePath / "libs" / platform / arch / "msodbcsql18.dll"; - return driverPath.string(); +#elif defined(_WIN32) + platform = "windows"; + // Normalize x86_64 to x64 for Windows naming + if (arch == "x86_64") arch = "x64"; + fs::path driverPath = + basePath / "libs" / platform / arch / "msodbcsql18.dll"; + return driverPath.string(); - #else - throw std::runtime_error("Unsupported platform"); - #endif +#else + throw std::runtime_error("Unsupported platform"); +#endif } DriverHandle LoadDriverOrThrowException() { @@ -1098,36 +1276,43 @@ DriverHandle LoadDriverOrThrowException() { LOG("Architecture: {}", archStr); // Use only C++ function for driver path resolution - // Not using Python function since it causes circular import issues on Alpine Linux - // and other platforms with strict module loading rules. + // Not using Python function since it causes circular import issues on + // Alpine Linux and other platforms with strict module loading rules. std::string driverPathStr = GetDriverPathCpp(moduleDir); - + fs::path driverPath(driverPathStr); - + LOG("Driver path determined: {}", driverPath.string()); - #ifdef _WIN32 - // On Windows, optionally load mssql-auth.dll if it exists - std::string archDir = - (archStr == "win64" || archStr == "amd64" || archStr == "x64") ? "x64" : - (archStr == "arm64") ? "arm64" : - "x86"; - - fs::path dllDir = fs::path(moduleDir) / "libs" / "windows" / archDir; - fs::path authDllPath = dllDir / "mssql-auth.dll"; - if (fs::exists(authDllPath)) { - HMODULE hAuth = LoadLibraryW(std::wstring(authDllPath.native().begin(), authDllPath.native().end()).c_str()); - if (hAuth) { - LOG("mssql-auth.dll loaded: {}", authDllPath.string()); - } else { - LOG("Failed to load mssql-auth.dll: {}", GetLastErrorMessage()); - ThrowStdException("Failed to load mssql-auth.dll. Please ensure it is present in the expected directory."); - } +#ifdef _WIN32 + // On Windows, optionally load mssql-auth.dll if it exists + std::string archDir = + (archStr == "win64" || archStr == "amd64" || archStr == "x64") ? "x64" + : (archStr == "arm64") ? "arm64" + : "x86"; + + fs::path dllDir = fs::path(moduleDir) / "libs" / "windows" / archDir; + fs::path authDllPath = dllDir / "mssql-auth.dll"; + if (fs::exists(authDllPath)) { + HMODULE hAuth = LoadLibraryW(std::wstring(authDllPath.native().begin(), + authDllPath.native().end()) + .c_str()); + if (hAuth) { + LOG("mssql-auth.dll loaded: {}", authDllPath.string()); } else { - LOG("Note: mssql-auth.dll not found. This is OK if Entra ID is not in use."); - ThrowStdException("mssql-auth.dll not found. If you are using Entra ID, please ensure it is present."); + LOG("Failed to load mssql-auth.dll: {}", GetLastErrorMessage()); + ThrowStdException( + "Failed to load mssql-auth.dll. Please ensure it is present in " + "the expected directory."); } - #endif + } else { + LOG("Note: mssql-auth.dll not found. This is OK if Entra ID is not in " + "use."); + ThrowStdException( + "mssql-auth.dll not found. If you are using Entra ID, please " + "ensure it is present."); + } +#endif if (!fs::exists(driverPath)) { ThrowStdException("ODBC driver not found at: " + driverPath.string()); @@ -1136,55 +1321,86 @@ DriverHandle LoadDriverOrThrowException() { DriverHandle handle = LoadDriverLibrary(driverPath.string()); if (!handle) { LOG("Failed to load driver: {}", GetLastErrorMessage()); - ThrowStdException("Failed to load the driver. Please read the documentation (https://github.com/microsoft/mssql-python#installation) to install the required dependencies."); + ThrowStdException( + "Failed to load the driver. Please read the documentation " + "(https://github.com/microsoft/mssql-python#installation) to " + "install the required dependencies."); } LOG("Driver library successfully loaded."); // Load function pointers using helper - SQLAllocHandle_ptr = GetFunctionPointer(handle, "SQLAllocHandle"); - SQLSetEnvAttr_ptr = GetFunctionPointer(handle, "SQLSetEnvAttr"); - SQLSetConnectAttr_ptr = GetFunctionPointer(handle, "SQLSetConnectAttrW"); - SQLSetStmtAttr_ptr = GetFunctionPointer(handle, "SQLSetStmtAttrW"); - SQLGetConnectAttr_ptr = GetFunctionPointer(handle, "SQLGetConnectAttrW"); - - SQLDriverConnect_ptr = GetFunctionPointer(handle, "SQLDriverConnectW"); - SQLExecDirect_ptr = GetFunctionPointer(handle, "SQLExecDirectW"); + SQLAllocHandle_ptr = + GetFunctionPointer(handle, "SQLAllocHandle"); + SQLSetEnvAttr_ptr = + GetFunctionPointer(handle, "SQLSetEnvAttr"); + SQLSetConnectAttr_ptr = + GetFunctionPointer(handle, "SQLSetConnectAttrW"); + SQLSetStmtAttr_ptr = + GetFunctionPointer(handle, "SQLSetStmtAttrW"); + SQLGetConnectAttr_ptr = + GetFunctionPointer(handle, "SQLGetConnectAttrW"); + + SQLDriverConnect_ptr = + GetFunctionPointer(handle, "SQLDriverConnectW"); + SQLExecDirect_ptr = + GetFunctionPointer(handle, "SQLExecDirectW"); SQLPrepare_ptr = GetFunctionPointer(handle, "SQLPrepareW"); - SQLBindParameter_ptr = GetFunctionPointer(handle, "SQLBindParameter"); + SQLBindParameter_ptr = + GetFunctionPointer(handle, "SQLBindParameter"); SQLExecute_ptr = GetFunctionPointer(handle, "SQLExecute"); - SQLRowCount_ptr = GetFunctionPointer(handle, "SQLRowCount"); - SQLGetStmtAttr_ptr = GetFunctionPointer(handle, "SQLGetStmtAttrW"); - SQLSetDescField_ptr = GetFunctionPointer(handle, "SQLSetDescFieldW"); + SQLRowCount_ptr = + GetFunctionPointer(handle, "SQLRowCount"); + SQLGetStmtAttr_ptr = + GetFunctionPointer(handle, "SQLGetStmtAttrW"); + SQLSetDescField_ptr = + GetFunctionPointer(handle, "SQLSetDescFieldW"); SQLFetch_ptr = GetFunctionPointer(handle, "SQLFetch"); - SQLFetchScroll_ptr = GetFunctionPointer(handle, "SQLFetchScroll"); + SQLFetchScroll_ptr = + GetFunctionPointer(handle, "SQLFetchScroll"); SQLGetData_ptr = GetFunctionPointer(handle, "SQLGetData"); - SQLNumResultCols_ptr = GetFunctionPointer(handle, "SQLNumResultCols"); + SQLNumResultCols_ptr = + GetFunctionPointer(handle, "SQLNumResultCols"); SQLBindCol_ptr = GetFunctionPointer(handle, "SQLBindCol"); - SQLDescribeCol_ptr = GetFunctionPointer(handle, "SQLDescribeColW"); - SQLMoreResults_ptr = GetFunctionPointer(handle, "SQLMoreResults"); - SQLColAttribute_ptr = GetFunctionPointer(handle, "SQLColAttributeW"); - SQLGetTypeInfo_ptr = GetFunctionPointer(handle, "SQLGetTypeInfoW"); - SQLProcedures_ptr = GetFunctionPointer(handle, "SQLProceduresW"); - SQLForeignKeys_ptr = GetFunctionPointer(handle, "SQLForeignKeysW"); - SQLPrimaryKeys_ptr = GetFunctionPointer(handle, "SQLPrimaryKeysW"); - SQLSpecialColumns_ptr = GetFunctionPointer(handle, "SQLSpecialColumnsW"); - SQLStatistics_ptr = GetFunctionPointer(handle, "SQLStatisticsW"); + SQLDescribeCol_ptr = + GetFunctionPointer(handle, "SQLDescribeColW"); + SQLMoreResults_ptr = + GetFunctionPointer(handle, "SQLMoreResults"); + SQLColAttribute_ptr = + GetFunctionPointer(handle, "SQLColAttributeW"); + SQLGetTypeInfo_ptr = + GetFunctionPointer(handle, "SQLGetTypeInfoW"); + SQLProcedures_ptr = + GetFunctionPointer(handle, "SQLProceduresW"); + SQLForeignKeys_ptr = + GetFunctionPointer(handle, "SQLForeignKeysW"); + SQLPrimaryKeys_ptr = + GetFunctionPointer(handle, "SQLPrimaryKeysW"); + SQLSpecialColumns_ptr = + GetFunctionPointer(handle, "SQLSpecialColumnsW"); + SQLStatistics_ptr = + GetFunctionPointer(handle, "SQLStatisticsW"); SQLColumns_ptr = GetFunctionPointer(handle, "SQLColumnsW"); SQLGetInfo_ptr = GetFunctionPointer(handle, "SQLGetInfoW"); SQLEndTran_ptr = GetFunctionPointer(handle, "SQLEndTran"); - SQLDisconnect_ptr = GetFunctionPointer(handle, "SQLDisconnect"); - SQLFreeHandle_ptr = GetFunctionPointer(handle, "SQLFreeHandle"); - SQLFreeStmt_ptr = GetFunctionPointer(handle, "SQLFreeStmt"); - - SQLGetDiagRec_ptr = GetFunctionPointer(handle, "SQLGetDiagRecW"); - - SQLParamData_ptr = GetFunctionPointer(handle, "SQLParamData"); + SQLDisconnect_ptr = + GetFunctionPointer(handle, "SQLDisconnect"); + SQLFreeHandle_ptr = + GetFunctionPointer(handle, "SQLFreeHandle"); + SQLFreeStmt_ptr = + GetFunctionPointer(handle, "SQLFreeStmt"); + + SQLGetDiagRec_ptr = + GetFunctionPointer(handle, "SQLGetDiagRecW"); + + SQLParamData_ptr = + GetFunctionPointer(handle, "SQLParamData"); SQLPutData_ptr = GetFunctionPointer(handle, "SQLPutData"); SQLTables_ptr = GetFunctionPointer(handle, "SQLTablesW"); - SQLDescribeParam_ptr = GetFunctionPointer(handle, "SQLDescribeParam"); + SQLDescribeParam_ptr = + GetFunctionPointer(handle, "SQLDescribeParam"); bool success = SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr && @@ -1195,21 +1411,21 @@ DriverHandle LoadDriverOrThrowException() { SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && - SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLGetInfo_ptr && SQLParamData_ptr && - SQLPutData_ptr && SQLTables_ptr && - SQLDescribeParam_ptr && - SQLGetTypeInfo_ptr && SQLProcedures_ptr && SQLForeignKeys_ptr && - SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && SQLStatistics_ptr && - SQLColumns_ptr; + SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLGetInfo_ptr && + SQLParamData_ptr && SQLPutData_ptr && SQLTables_ptr && + SQLDescribeParam_ptr && SQLGetTypeInfo_ptr && SQLProcedures_ptr && + SQLForeignKeys_ptr && SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && + SQLStatistics_ptr && SQLColumns_ptr; if (!success) { - ThrowStdException("Failed to load required function pointers from driver."); + ThrowStdException( + "Failed to load required function pointers from driver."); } LOG("All driver function pointers successfully loaded."); return handle; } -// DriverLoader definition +// DriverLoader definition DriverLoader::DriverLoader() : m_driverLoaded(false) {} DriverLoader& DriverLoader::getInstance() { @@ -1234,13 +1450,9 @@ SqlHandle::~SqlHandle() { } } -SQLHANDLE SqlHandle::get() const { - return _handle; -} +SQLHANDLE SqlHandle::get() const { return _handle; } -SQLSMALLINT SqlHandle::type() const { - return _type; -} +SQLSMALLINT SqlHandle::type() const { return _type; } /* * IMPORTANT: Never log in destructors - it causes segfaults. @@ -1253,28 +1465,31 @@ void SqlHandle::free() { if (_handle && SQLFreeHandle_ptr) { // Check if Python is shutting down using centralized helper function bool pythonShuttingDown = is_python_finalizing(); - - // CRITICAL FIX: During Python shutdown, don't free STMT handles as their parent DBC may already be freed - // This prevents segfault when handles are freed in wrong order during interpreter shutdown - // Type 3 = SQL_HANDLE_STMT, Type 2 = SQL_HANDLE_DBC, Type 1 = SQL_HANDLE_ENV + + // CRITICAL FIX: During Python shutdown, don't free STMT handles as + // their parent DBC may already be freed This prevents segfault when + // handles are freed in wrong order during interpreter shutdown Type 3 = + // SQL_HANDLE_STMT, Type 2 = SQL_HANDLE_DBC, Type 1 = SQL_HANDLE_ENV if (pythonShuttingDown && _type == 3) { - _handle = nullptr; // Mark as freed to prevent double-free attempts + _handle = nullptr; // Mark as freed to prevent double-free attempts return; } - + // Always clean up ODBC resources, regardless of Python state SQLFreeHandle_ptr(_type, _handle); _handle = nullptr; - + // Only log if Python is not shutting down (to avoid segfault) if (!pythonShuttingDown) { - // Don't log during destruction - even in normal cases it can be problematic - // If logging is needed, use explicit close() methods instead + // Don't log during destruction - even in normal cases it can be + // problematic If logging is needed, use explicit close() methods + // instead } } } -SQLRETURN SQLGetTypeInfo_Wrapper(SqlHandlePtr StatementHandle, SQLSMALLINT DataType) { +SQLRETURN SQLGetTypeInfo_Wrapper(SqlHandlePtr StatementHandle, + SQLSMALLINT DataType) { if (!SQLGetTypeInfo_ptr) { ThrowStdException("SQLGetTypeInfo function not loaded"); } @@ -1282,62 +1497,85 @@ SQLRETURN SQLGetTypeInfo_Wrapper(SqlHandlePtr StatementHandle, SQLSMALLINT DataT return SQLGetTypeInfo_ptr(StatementHandle->get(), DataType); } -SQLRETURN SQLProcedures_wrap(SqlHandlePtr StatementHandle, - const py::object& catalogObj, - const py::object& schemaObj, - const py::object& procedureObj) { +SQLRETURN SQLProcedures_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const py::object& procedureObj) { if (!SQLProcedures_ptr) { ThrowStdException("SQLProcedures function not loaded"); } - std::wstring catalog = py::isinstance(catalogObj) ? L"" : catalogObj.cast(); - std::wstring schema = py::isinstance(schemaObj) ? L"" : schemaObj.cast(); - std::wstring procedure = py::isinstance(procedureObj) ? L"" : procedureObj.cast(); + std::wstring catalog = py::isinstance(catalogObj) + ? L"" + : catalogObj.cast(); + std::wstring schema = py::isinstance(schemaObj) + ? L"" + : schemaObj.cast(); + std::wstring procedure = py::isinstance(procedureObj) + ? L"" + : procedureObj.cast(); #if defined(__APPLE__) || defined(__linux__) // Unix implementation std::vector catalogBuf = WStringToSQLWCHAR(catalog); std::vector schemaBuf = WStringToSQLWCHAR(schema); std::vector procedureBuf = WStringToSQLWCHAR(procedure); - - return SQLProcedures_ptr( - StatementHandle->get(), - catalog.empty() ? nullptr : catalogBuf.data(), - catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : schemaBuf.data(), - schema.empty() ? 0 : SQL_NTS, - procedure.empty() ? nullptr : procedureBuf.data(), - procedure.empty() ? 0 : SQL_NTS); + + return SQLProcedures_ptr(StatementHandle->get(), + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + procedure.empty() ? nullptr : procedureBuf.data(), + procedure.empty() ? 0 : SQL_NTS); #else // Windows implementation return SQLProcedures_ptr( StatementHandle->get(), - catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? nullptr + : reinterpret_cast( + const_cast(catalog.c_str())), catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() + ? nullptr + : reinterpret_cast(const_cast(schema.c_str())), schema.empty() ? 0 : SQL_NTS, - procedure.empty() ? nullptr : (SQLWCHAR*)procedure.c_str(), + procedure.empty() ? nullptr + : reinterpret_cast( + const_cast(procedure.c_str())), procedure.empty() ? 0 : SQL_NTS); #endif } -SQLRETURN SQLForeignKeys_wrap(SqlHandlePtr StatementHandle, - const py::object& pkCatalogObj, - const py::object& pkSchemaObj, - const py::object& pkTableObj, - const py::object& fkCatalogObj, - const py::object& fkSchemaObj, - const py::object& fkTableObj) { +SQLRETURN SQLForeignKeys_wrap(SqlHandlePtr StatementHandle, + const py::object& pkCatalogObj, + const py::object& pkSchemaObj, + const py::object& pkTableObj, + const py::object& fkCatalogObj, + const py::object& fkSchemaObj, + const py::object& fkTableObj) { if (!SQLForeignKeys_ptr) { ThrowStdException("SQLForeignKeys function not loaded"); } - std::wstring pkCatalog = py::isinstance(pkCatalogObj) ? L"" : pkCatalogObj.cast(); - std::wstring pkSchema = py::isinstance(pkSchemaObj) ? L"" : pkSchemaObj.cast(); - std::wstring pkTable = py::isinstance(pkTableObj) ? L"" : pkTableObj.cast(); - std::wstring fkCatalog = py::isinstance(fkCatalogObj) ? L"" : fkCatalogObj.cast(); - std::wstring fkSchema = py::isinstance(fkSchemaObj) ? L"" : fkSchemaObj.cast(); - std::wstring fkTable = py::isinstance(fkTableObj) ? L"" : fkTableObj.cast(); + std::wstring pkCatalog = py::isinstance(pkCatalogObj) + ? L"" + : pkCatalogObj.cast(); + std::wstring pkSchema = py::isinstance(pkSchemaObj) + ? L"" + : pkSchemaObj.cast(); + std::wstring pkTable = py::isinstance(pkTableObj) + ? L"" + : pkTableObj.cast(); + std::wstring fkCatalog = py::isinstance(fkCatalogObj) + ? L"" + : fkCatalogObj.cast(); + std::wstring fkSchema = py::isinstance(fkSchemaObj) + ? L"" + : fkSchemaObj.cast(); + std::wstring fkTable = py::isinstance(fkTableObj) + ? L"" + : fkTableObj.cast(); #if defined(__APPLE__) || defined(__linux__) // Unix implementation @@ -1347,125 +1585,143 @@ SQLRETURN SQLForeignKeys_wrap(SqlHandlePtr StatementHandle, std::vector fkCatalogBuf = WStringToSQLWCHAR(fkCatalog); std::vector fkSchemaBuf = WStringToSQLWCHAR(fkSchema); std::vector fkTableBuf = WStringToSQLWCHAR(fkTable); - - return SQLForeignKeys_ptr( - StatementHandle->get(), - pkCatalog.empty() ? nullptr : pkCatalogBuf.data(), - pkCatalog.empty() ? 0 : SQL_NTS, - pkSchema.empty() ? nullptr : pkSchemaBuf.data(), - pkSchema.empty() ? 0 : SQL_NTS, - pkTable.empty() ? nullptr : pkTableBuf.data(), - pkTable.empty() ? 0 : SQL_NTS, - fkCatalog.empty() ? nullptr : fkCatalogBuf.data(), - fkCatalog.empty() ? 0 : SQL_NTS, - fkSchema.empty() ? nullptr : fkSchemaBuf.data(), - fkSchema.empty() ? 0 : SQL_NTS, - fkTable.empty() ? nullptr : fkTableBuf.data(), - fkTable.empty() ? 0 : SQL_NTS); + + return SQLForeignKeys_ptr(StatementHandle->get(), + pkCatalog.empty() ? nullptr : pkCatalogBuf.data(), + pkCatalog.empty() ? 0 : SQL_NTS, + pkSchema.empty() ? nullptr : pkSchemaBuf.data(), + pkSchema.empty() ? 0 : SQL_NTS, + pkTable.empty() ? nullptr : pkTableBuf.data(), + pkTable.empty() ? 0 : SQL_NTS, + fkCatalog.empty() ? nullptr : fkCatalogBuf.data(), + fkCatalog.empty() ? 0 : SQL_NTS, + fkSchema.empty() ? nullptr : fkSchemaBuf.data(), + fkSchema.empty() ? 0 : SQL_NTS, + fkTable.empty() ? nullptr : fkTableBuf.data(), + fkTable.empty() ? 0 : SQL_NTS); #else // Windows implementation return SQLForeignKeys_ptr( StatementHandle->get(), - pkCatalog.empty() ? nullptr : (SQLWCHAR*)pkCatalog.c_str(), + pkCatalog.empty() ? nullptr + : reinterpret_cast( + const_cast(pkCatalog.c_str())), pkCatalog.empty() ? 0 : SQL_NTS, - pkSchema.empty() ? nullptr : (SQLWCHAR*)pkSchema.c_str(), + pkSchema.empty() ? nullptr + : reinterpret_cast( + const_cast(pkSchema.c_str())), pkSchema.empty() ? 0 : SQL_NTS, - pkTable.empty() ? nullptr : (SQLWCHAR*)pkTable.c_str(), + pkTable.empty() ? nullptr + : reinterpret_cast( + const_cast(pkTable.c_str())), pkTable.empty() ? 0 : SQL_NTS, - fkCatalog.empty() ? nullptr : (SQLWCHAR*)fkCatalog.c_str(), + fkCatalog.empty() ? nullptr + : reinterpret_cast( + const_cast(fkCatalog.c_str())), fkCatalog.empty() ? 0 : SQL_NTS, - fkSchema.empty() ? nullptr : (SQLWCHAR*)fkSchema.c_str(), + fkSchema.empty() ? nullptr + : reinterpret_cast( + const_cast(fkSchema.c_str())), fkSchema.empty() ? 0 : SQL_NTS, - fkTable.empty() ? nullptr : (SQLWCHAR*)fkTable.c_str(), + fkTable.empty() ? nullptr + : reinterpret_cast( + const_cast(fkTable.c_str())), fkTable.empty() ? 0 : SQL_NTS); #endif } -SQLRETURN SQLPrimaryKeys_wrap(SqlHandlePtr StatementHandle, - const py::object& catalogObj, - const py::object& schemaObj, - const std::wstring& table) { +SQLRETURN SQLPrimaryKeys_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const std::wstring& table) { if (!SQLPrimaryKeys_ptr) { ThrowStdException("SQLPrimaryKeys function not loaded"); } // Convert py::object to std::wstring, treating None as empty string - std::wstring catalog = catalogObj.is_none() ? L"" : catalogObj.cast(); - std::wstring schema = schemaObj.is_none() ? L"" : schemaObj.cast(); + std::wstring catalog = + catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schema = + schemaObj.is_none() ? L"" : schemaObj.cast(); #if defined(__APPLE__) || defined(__linux__) // Unix implementation std::vector catalogBuf = WStringToSQLWCHAR(catalog); std::vector schemaBuf = WStringToSQLWCHAR(schema); std::vector tableBuf = WStringToSQLWCHAR(table); - + return SQLPrimaryKeys_ptr( - StatementHandle->get(), - catalog.empty() ? nullptr : catalogBuf.data(), + StatementHandle->get(), catalog.empty() ? nullptr : catalogBuf.data(), catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : schemaBuf.data(), - schema.empty() ? 0 : SQL_NTS, - table.empty() ? nullptr : tableBuf.data(), + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, table.empty() ? nullptr : tableBuf.data(), table.empty() ? 0 : SQL_NTS); #else // Windows implementation return SQLPrimaryKeys_ptr( StatementHandle->get(), - catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? nullptr + : reinterpret_cast( + const_cast(catalog.c_str())), catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() + ? nullptr + : reinterpret_cast(const_cast(schema.c_str())), schema.empty() ? 0 : SQL_NTS, - table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), + table.empty() + ? nullptr + : reinterpret_cast(const_cast(table.c_str())), table.empty() ? 0 : SQL_NTS); #endif } -SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle, - const py::object& catalogObj, - const py::object& schemaObj, - const std::wstring& table, - SQLUSMALLINT unique, - SQLUSMALLINT reserved) { +SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const std::wstring& table, SQLUSMALLINT unique, + SQLUSMALLINT reserved) { if (!SQLStatistics_ptr) { ThrowStdException("SQLStatistics function not loaded"); } - // Convert py::object to std::wstring, treating None as empty string - std::wstring catalog = catalogObj.is_none() ? L"" : catalogObj.cast(); - std::wstring schema = schemaObj.is_none() ? L"" : schemaObj.cast(); + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalog = + catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schema = + schemaObj.is_none() ? L"" : schemaObj.cast(); #if defined(__APPLE__) || defined(__linux__) // Unix implementation std::vector catalogBuf = WStringToSQLWCHAR(catalog); std::vector schemaBuf = WStringToSQLWCHAR(schema); std::vector tableBuf = WStringToSQLWCHAR(table); - + return SQLStatistics_ptr( - StatementHandle->get(), - catalog.empty() ? nullptr : catalogBuf.data(), + StatementHandle->get(), catalog.empty() ? nullptr : catalogBuf.data(), catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : schemaBuf.data(), - schema.empty() ? 0 : SQL_NTS, - table.empty() ? nullptr : tableBuf.data(), - table.empty() ? 0 : SQL_NTS, - unique, - reserved); + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, table.empty() ? nullptr : tableBuf.data(), + table.empty() ? 0 : SQL_NTS, unique, reserved); #else // Windows implementation return SQLStatistics_ptr( StatementHandle->get(), - catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? nullptr + : reinterpret_cast( + const_cast(catalog.c_str())), catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() + ? nullptr + : reinterpret_cast(const_cast(schema.c_str())), schema.empty() ? 0 : SQL_NTS, - table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), - table.empty() ? 0 : SQL_NTS, - unique, - reserved); + table.empty() + ? nullptr + : reinterpret_cast(const_cast(table.c_str())), + table.empty() ? 0 : SQL_NTS, unique, reserved); #endif } -SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, +SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, const py::object& catalogObj, const py::object& schemaObj, const py::object& tableObj, @@ -1475,10 +1731,14 @@ SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, } // Convert py::object to std::wstring, treating None as empty string - std::wstring catalogStr = catalogObj.is_none() ? L"" : catalogObj.cast(); - std::wstring schemaStr = schemaObj.is_none() ? L"" : schemaObj.cast(); - std::wstring tableStr = tableObj.is_none() ? L"" : tableObj.cast(); - std::wstring columnStr = columnObj.is_none() ? L"" : columnObj.cast(); + std::wstring catalogStr = + catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schemaStr = + schemaObj.is_none() ? L"" : schemaObj.cast(); + std::wstring tableStr = + tableObj.is_none() ? L"" : tableObj.cast(); + std::wstring columnStr = + columnObj.is_none() ? L"" : columnObj.cast(); #if defined(__APPLE__) || defined(__linux__) // Unix implementation @@ -1486,39 +1746,47 @@ SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, std::vector schemaBuf = WStringToSQLWCHAR(schemaStr); std::vector tableBuf = WStringToSQLWCHAR(tableStr); std::vector columnBuf = WStringToSQLWCHAR(columnStr); - - return SQLColumns_ptr( - StatementHandle->get(), - catalogStr.empty() ? nullptr : catalogBuf.data(), - catalogStr.empty() ? 0 : SQL_NTS, - schemaStr.empty() ? nullptr : schemaBuf.data(), - schemaStr.empty() ? 0 : SQL_NTS, - tableStr.empty() ? nullptr : tableBuf.data(), - tableStr.empty() ? 0 : SQL_NTS, - columnStr.empty() ? nullptr : columnBuf.data(), - columnStr.empty() ? 0 : SQL_NTS); + + return SQLColumns_ptr(StatementHandle->get(), + catalogStr.empty() ? nullptr : catalogBuf.data(), + catalogStr.empty() ? 0 : SQL_NTS, + schemaStr.empty() ? nullptr : schemaBuf.data(), + schemaStr.empty() ? 0 : SQL_NTS, + tableStr.empty() ? nullptr : tableBuf.data(), + tableStr.empty() ? 0 : SQL_NTS, + columnStr.empty() ? nullptr : columnBuf.data(), + columnStr.empty() ? 0 : SQL_NTS); #else // Windows implementation return SQLColumns_ptr( StatementHandle->get(), - catalogStr.empty() ? nullptr : (SQLWCHAR*)catalogStr.c_str(), + catalogStr.empty() ? nullptr + : reinterpret_cast( + const_cast(catalogStr.c_str())), catalogStr.empty() ? 0 : SQL_NTS, - schemaStr.empty() ? nullptr : (SQLWCHAR*)schemaStr.c_str(), + schemaStr.empty() ? nullptr + : reinterpret_cast( + const_cast(schemaStr.c_str())), schemaStr.empty() ? 0 : SQL_NTS, - tableStr.empty() ? nullptr : (SQLWCHAR*)tableStr.c_str(), + tableStr.empty() ? nullptr + : reinterpret_cast( + const_cast(tableStr.c_str())), tableStr.empty() ? 0 : SQL_NTS, - columnStr.empty() ? nullptr : (SQLWCHAR*)columnStr.c_str(), + columnStr.empty() ? nullptr + : reinterpret_cast( + const_cast(columnStr.c_str())), columnStr.empty() ? 0 : SQL_NTS); #endif } // Helper function to check for driver errors -ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { - LOG("Checking errors for retcode - {}" , retcode); +ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, + SQLRETURN retcode) { + LOG("Checking errors for retcode - {}", retcode); ErrorInfo errorInfo; if (retcode == SQL_INVALID_HANDLE) { LOG("Invalid handle received"); - errorInfo.ddbcErrorMsg = std::wstring( L"Invalid handle!"); + errorInfo.ddbcErrorMsg = std::wstring(L"Invalid handle!"); return errorInfo; } assert(handle != 0); @@ -1534,8 +1802,8 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET SQLSMALLINT messageLen; SQLRETURN diagReturn = - SQLGetDiagRec_ptr(handleType, rawHandle, 1, sqlState, - &nativeError, message, SQL_MAX_MESSAGE_LENGTH, &messageLen); + SQLGetDiagRec_ptr(handleType, rawHandle, 1, sqlState, &nativeError, + message, SQL_MAX_MESSAGE_LENGTH, &messageLen); if (SQL_SUCCEEDED(diagReturn)) { #if defined(_WIN32) @@ -1543,7 +1811,8 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET errorInfo.sqlState = std::wstring(sqlState); errorInfo.ddbcErrorMsg = std::wstring(message); #else - // On macOS/Linux, need to convert SQLWCHAR (usually unsigned short) to wchar_t + // On macOS/Linux, need to convert SQLWCHAR (usually unsigned short) + // to wchar_t errorInfo.sqlState = SQLWCHARToWString(sqlState); errorInfo.ddbcErrorMsg = SQLWCHARToWString(message, messageLen); #endif @@ -1558,67 +1827,69 @@ py::list SQLGetAllDiagRecords(SqlHandlePtr handle) { LOG("Function pointer not initialized. Loading the driver."); DriverLoader::getInstance().loadDriver(); } - + py::list records; SQLHANDLE rawHandle = handle->get(); SQLSMALLINT handleType = handle->type(); - + // Iterate through all available diagnostic records - for (SQLSMALLINT recNumber = 1; ; recNumber++) { + for (SQLSMALLINT recNumber = 1;; recNumber++) { SQLWCHAR sqlState[6] = {0}; SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; SQLINTEGER nativeError = 0; SQLSMALLINT messageLen = 0; - + SQLRETURN diagReturn = SQLGetDiagRec_ptr( - handleType, rawHandle, recNumber, sqlState, &nativeError, - message, SQL_MAX_MESSAGE_LENGTH, &messageLen); - - if (diagReturn == SQL_NO_DATA || !SQL_SUCCEEDED(diagReturn)) - break; - + handleType, rawHandle, recNumber, sqlState, &nativeError, message, + SQL_MAX_MESSAGE_LENGTH, &messageLen); + + if (diagReturn == SQL_NO_DATA || !SQL_SUCCEEDED(diagReturn)) break; + #if defined(_WIN32) // On Windows, create a formatted UTF-8 string for state+error - + // Convert SQLWCHAR sqlState to UTF-8 - int stateSize = WideCharToMultiByte(CP_UTF8, 0, sqlState, -1, NULL, 0, NULL, NULL); + int stateSize = + WideCharToMultiByte(CP_UTF8, 0, sqlState, -1, NULL, 0, NULL, NULL); std::vector stateBuffer(stateSize); - WideCharToMultiByte(CP_UTF8, 0, sqlState, -1, stateBuffer.data(), stateSize, NULL, NULL); - + WideCharToMultiByte(CP_UTF8, 0, sqlState, -1, stateBuffer.data(), + stateSize, NULL, NULL); + // Format the state with error code - std::string stateWithError = "[" + std::string(stateBuffer.data()) + "] (" + std::to_string(nativeError) + ")"; - + std::string stateWithError = "[" + std::string(stateBuffer.data()) + + "] (" + std::to_string(nativeError) + ")"; + // Convert wide string message to UTF-8 - int msgSize = WideCharToMultiByte(CP_UTF8, 0, message, -1, NULL, 0, NULL, NULL); + int msgSize = + WideCharToMultiByte(CP_UTF8, 0, message, -1, NULL, 0, NULL, NULL); std::vector msgBuffer(msgSize); - WideCharToMultiByte(CP_UTF8, 0, message, -1, msgBuffer.data(), msgSize, NULL, NULL); - + WideCharToMultiByte(CP_UTF8, 0, message, -1, msgBuffer.data(), msgSize, + NULL, NULL); + // Create the tuple with converted strings - records.append(py::make_tuple( - py::str(stateWithError), - py::str(msgBuffer.data()) - )); + records.append( + py::make_tuple(py::str(stateWithError), py::str(msgBuffer.data()))); #else // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 std::string stateStr = WideToUTF8(SQLWCHARToWString(sqlState)); std::string msgStr = WideToUTF8(SQLWCHARToWString(message, messageLen)); - + // Format the state string - std::string stateWithError = "[" + stateStr + "] (" + std::to_string(nativeError) + ")"; - + std::string stateWithError = + "[" + stateStr + "] (" + std::to_string(nativeError) + ")"; + // Create the tuple with converted strings - records.append(py::make_tuple( - py::str(stateWithError), - py::str(msgStr) - )); + records.append( + py::make_tuple(py::str(stateWithError), py::str(msgStr))); #endif } - + return records; } // Wrap SQLExecDirect -SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Query) { +SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, + const std::wstring& Query) { LOG("Execute SQL query directly - {}", Query.c_str()); if (!SQLExecDirect_ptr) { LOG("Function pointer not initialized. Loading the driver."); @@ -1627,14 +1898,10 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q // Ensure statement is scrollable BEFORE executing if (SQLSetStmtAttr_ptr && StatementHandle && StatementHandle->get()) { - SQLSetStmtAttr_ptr(StatementHandle->get(), - SQL_ATTR_CURSOR_TYPE, - (SQLPOINTER)SQL_CURSOR_STATIC, - 0); - SQLSetStmtAttr_ptr(StatementHandle->get(), - SQL_ATTR_CONCURRENCY, - (SQLPOINTER)SQL_CONCUR_READ_ONLY, - 0); + SQLSetStmtAttr_ptr(StatementHandle->get(), SQL_ATTR_CURSOR_TYPE, + (SQLPOINTER)SQL_CURSOR_STATIC, 0); + SQLSetStmtAttr_ptr(StatementHandle->get(), SQL_ATTR_CONCURRENCY, + (SQLPOINTER)SQL_CONCUR_READ_ONLY, 0); } SQLWCHAR* queryPtr; @@ -1644,7 +1911,8 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q #else queryPtr = const_cast(Query.c_str()); #endif - SQLRETURN ret = SQLExecDirect_ptr(StatementHandle->get(), queryPtr, SQL_NTS); + SQLRETURN ret = + SQLExecDirect_ptr(StatementHandle->get(), queryPtr, SQL_NTS); if (!SQL_SUCCEEDED(ret)) { LOG("Failed to execute query directly"); } @@ -1652,12 +1920,10 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q } // Wrapper for SQLTables -SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle, +SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle, const std::wstring& catalog, - const std::wstring& schema, - const std::wstring& table, + const std::wstring& schema, const std::wstring& table, const std::wstring& tableType) { - if (!SQLTables_ptr) { LOG("Function pointer not initialized. Loading the driver."); DriverLoader::getInstance().loadDriver(); @@ -1719,13 +1985,9 @@ SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle, } #endif - SQLRETURN ret = SQLTables_ptr( - StatementHandle->get(), - catalogPtr, catalogLen, - schemaPtr, schemaLen, - tablePtr, tableLen, - tableTypePtr, tableTypeLen - ); + SQLRETURN ret = SQLTables_ptr(StatementHandle->get(), catalogPtr, + catalogLen, schemaPtr, schemaLen, tablePtr, + tableLen, tableTypePtr, tableTypeLen); if (!SQL_SUCCEEDED(ret)) { LOG("SQLTables failed with return code: {}", ret); @@ -1736,24 +1998,28 @@ SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle, return ret; } -// Executes the provided query. If the query is parametrized, it prepares the statement and -// binds the parameters. Otherwise, it executes the query directly. -// 'usePrepare' parameter can be used to disable the prepare step for queries that might already -// be prepared in a previous call. +// Executes the provided query. If the query is parametrized, it prepares the +// statement and binds the parameters. Otherwise, it executes the query +// directly. 'usePrepare' parameter can be used to disable the prepare step for +// queries that might already be prepared in a previous call. SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, const std::wstring& query /* TODO: Use SQLTCHAR? */, - const py::list& params, std::vector& paramInfos, - py::list& isStmtPrepared, const bool usePrepare = true, + const py::list& params, + std::vector& paramInfos, + py::list& isStmtPrepared, + const bool usePrepare = true, const py::object& encoding_settings = py::none()) { LOG("Execute SQL Query - {}", query.c_str()); if (!SQLPrepare_ptr) { LOG("Function pointer not initialized. Loading the driver."); DriverLoader::getInstance().loadDriver(); // Load the driver } - assert(SQLPrepare_ptr && SQLBindParameter_ptr && SQLExecute_ptr && SQLExecDirect_ptr); + assert(SQLPrepare_ptr && SQLBindParameter_ptr && SQLExecute_ptr && + SQLExecDirect_ptr); if (params.size() != paramInfos.size()) { - // TODO: This should be a special internal exception, that python wont relay to users as is + // TODO: This should be a special internal exception, that python wont + // relay to users as is ThrowStdException("Number of parameters and paramInfos do not match"); } @@ -1765,14 +2031,10 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, // Ensure statement is scrollable BEFORE executing if (SQLSetStmtAttr_ptr && hStmt) { - SQLSetStmtAttr_ptr(hStmt, - SQL_ATTR_CURSOR_TYPE, - (SQLPOINTER)SQL_CURSOR_STATIC, - 0); - SQLSetStmtAttr_ptr(hStmt, - SQL_ATTR_CONCURRENCY, - (SQLPOINTER)SQL_CONCUR_READ_ONLY, - 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_CURSOR_TYPE, + (SQLPOINTER)SQL_CURSOR_STATIC, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_CONCURRENCY, + (SQLPOINTER)SQL_CONCUR_READ_ONLY, 0); } SQLWCHAR* queryPtr; @@ -1783,9 +2045,9 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, queryPtr = const_cast(query.c_str()); #endif if (params.size() == 0) { - // Execute statement directly if the statement is not parametrized. This is the - // fastest way to submit a SQL statement for one-time execution according to - // DDBC documentation - + // Execute statement directly if the statement is not parametrized. This + // is the fastest way to submit a SQL statement for one-time execution + // according to DDBC documentation - // https://learn.microsoft.com/en-us/sql/odbc/reference/syntax/sqlexecdirect-function?view=sql-server-ver16 rc = SQLExecDirect_ptr(hStmt, queryPtr, SQL_NTS); if (!SQL_SUCCEEDED(rc) && rc != SQL_NO_DATA) { @@ -1793,9 +2055,10 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, } return rc; } else { - // isStmtPrepared is a list instead of a bool coz bools in Python are immutable. - // Hence, we can't pass around bools by reference & modify them. Therefore, isStmtPrepared - // must be a list with exactly one bool element + // isStmtPrepared is a list instead of a bool coz bools in Python are + // immutable. Hence, we can't pass around bools by reference & modify + // them. Therefore, isStmtPrepared must be a list with exactly one bool + // element assert(isStmtPrepared.size() == 1); if (usePrepare) { rc = SQLPrepare_ptr(hStmt, queryPtr, SQL_NTS); @@ -1805,7 +2068,8 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, } isStmtPrepared[0] = py::cast(true); } else { - // Make sure the statement has been prepared earlier if we're not preparing now + // Make sure the statement has been prepared earlier if we're not + // preparing now bool isStmtPreparedAsBool = isStmtPrepared[0].cast(); if (!isStmtPreparedAsBool) { // TODO: Print the query @@ -1816,7 +2080,8 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, // This vector manages the heap memory allocated for parameter buffers. // It must be in scope until SQLExecute is done. std::vector> paramBuffers; - rc = BindParameters(hStmt, params, paramInfos, paramBuffers, encoding_settings); + rc = BindParameters(hStmt, params, paramInfos, paramBuffers, + encoding_settings); if (!SQL_SUCCEEDED(rc)) { return rc; } @@ -1824,18 +2089,21 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, rc = SQLExecute_ptr(hStmt); if (rc == SQL_NEED_DATA) { LOG("Beginning SQLParamData/SQLPutData loop for DAE."); - SQLPOINTER paramToken = nullptr; - while ((rc = SQLParamData_ptr(hStmt, ¶mToken)) == SQL_NEED_DATA) { + SQLPOINTER paramToken = nullptr; + while ((rc = SQLParamData_ptr(hStmt, ¶mToken)) == + SQL_NEED_DATA) { // Finding the paramInfo that matches the returned token const ParamInfo* matchedInfo = nullptr; for (auto& info : paramInfos) { - if (reinterpret_cast(const_cast(&info)) == paramToken) { + if (reinterpret_cast( + const_cast(&info)) == paramToken) { matchedInfo = &info; break; } } if (!matchedInfo) { - ThrowStdException("Unrecognized paramToken returned by SQLParamData"); + ThrowStdException( + "Unrecognized paramToken returned by SQLParamData"); } const py::object& pyObj = matchedInfo->dataPtr; if (pyObj.is_none()) { @@ -1858,14 +2126,22 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, size_t offset = 0; size_t chunkChars = DAE_CHUNK_SIZE / sizeof(SQLWCHAR); while (offset < totalChars) { - size_t len = std::min(chunkChars, totalChars - offset); + size_t len = + std::min(chunkChars, totalChars - offset); size_t lenBytes = len * sizeof(SQLWCHAR); - if (lenBytes > static_cast(std::numeric_limits::max())) { - ThrowStdException("Chunk size exceeds maximum allowed by SQLLEN"); + if (lenBytes > + static_cast( + std::numeric_limits::max())) { + ThrowStdException( + "Chunk size exceeds maximum allowed by " + "SQLLEN"); } - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast(lenBytes)); + rc = SQLPutData_ptr(hStmt, + (SQLPOINTER)(dataPtr + offset), + static_cast(lenBytes)); if (!SQL_SUCCEEDED(rc)) { - LOG("SQLPutData failed at offset {} of {}", offset, totalChars); + LOG("SQLPutData failed at offset {} of {}", + offset, totalChars); return rc; } offset += len; @@ -1877,11 +2153,15 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, size_t offset = 0; size_t chunkBytes = DAE_CHUNK_SIZE; while (offset < totalBytes) { - size_t len = std::min(chunkBytes, totalBytes - offset); + size_t len = + std::min(chunkBytes, totalBytes - offset); - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast(len)); + rc = SQLPutData_ptr(hStmt, + (SQLPOINTER)(dataPtr + offset), + static_cast(len)); if (!SQL_SUCCEEDED(rc)) { - LOG("SQLPutData failed at offset {} of {}", offset, totalBytes); + LOG("SQLPutData failed at offset {} of {}", + offset, totalBytes); return rc; } offset += len; @@ -1889,17 +2169,22 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, } else { ThrowStdException("Unsupported C type for str in DAE"); } - } else if (py::isinstance(pyObj) || py::isinstance(pyObj)) { + } else if (py::isinstance(pyObj) || + py::isinstance(pyObj)) { py::bytes b = pyObj.cast(); std::string s = b; const char* dataPtr = s.data(); size_t totalBytes = s.size(); const size_t chunkSize = DAE_CHUNK_SIZE; - for (size_t offset = 0; offset < totalBytes; offset += chunkSize) { + for (size_t offset = 0; offset < totalBytes; + offset += chunkSize) { size_t len = std::min(chunkSize, totalBytes - offset); - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast(len)); + rc = SQLPutData_ptr(hStmt, + (SQLPOINTER)(dataPtr + offset), + static_cast(len)); if (!SQL_SUCCEEDED(rc)) { - LOG("SQLPutData failed at offset {} of {}", offset, totalBytes); + LOG("SQLPutData failed at offset {} of {}", offset, + totalBytes); return rc; } } @@ -1918,40 +2203,48 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, return rc; } - // Unbind the bound buffers for all parameters coz the buffers' memory will - // be freed when this function exits (parambuffers goes out of scope) + // Unbind the bound buffers for all parameters coz the buffers' memory + // will be freed when this function exits (parambuffers goes out of + // scope) rc = SQLFreeStmt_ptr(hStmt, SQL_RESET_PARAMS); return rc; } } -SQLRETURN BindParameterArray(SQLHANDLE hStmt, - const py::list& columnwise_params, +SQLRETURN BindParameterArray(SQLHANDLE hStmt, const py::list& columnwise_params, const std::vector& paramInfos, size_t paramSetSize, std::vector>& paramBuffers, const py::object& encoding_settings) { - LOG("Starting column-wise parameter array binding. paramSetSize: {}, paramCount: {}", paramSetSize, columnwise_params.size()); + LOG("Starting column-wise parameter array binding. paramSetSize: {}, " + "paramCount: {}", + paramSetSize, columnwise_params.size()); std::vector> tempBuffers; try { - for (int paramIndex = 0; paramIndex < columnwise_params.size(); ++paramIndex) { - const py::list& columnValues = columnwise_params[paramIndex].cast(); + for (int paramIndex = 0; paramIndex < columnwise_params.size(); + ++paramIndex) { + const py::list& columnValues = + columnwise_params[paramIndex].cast(); const ParamInfo& info = paramInfos[paramIndex]; if (columnValues.size() != paramSetSize) { - ThrowStdException("Column " + std::to_string(paramIndex) + " has mismatched size."); + ThrowStdException("Column " + std::to_string(paramIndex) + + " has mismatched size."); } void* dataPtr = nullptr; SQLLEN* strLenOrIndArray = nullptr; SQLLEN bufferLength = 0; switch (info.paramCType) { case SQL_C_LONG: { - int* dataArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + int* dataArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { if (!strLenOrIndArray) - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + strLenOrIndArray = + AllocateParamBufferArray( + tempBuffers, paramSetSize); dataArray[i] = 0; strLenOrIndArray[i] = SQL_NULL_DATA; } else { @@ -1963,11 +2256,14 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, break; } case SQL_C_DOUBLE: { - double* dataArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + double* dataArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { if (!strLenOrIndArray) - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + strLenOrIndArray = + AllocateParamBufferArray( + tempBuffers, paramSetSize); dataArray[i] = 0; strLenOrIndArray[i] = SQL_NULL_DATA; } else { @@ -1979,50 +2275,88 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, break; } case SQL_C_WCHAR: { - SQLWCHAR* wcharArray = AllocateParamBufferArray(tempBuffers, paramSetSize * (info.columnSize + 1)); - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + SQLWCHAR* wcharArray = AllocateParamBufferArray( + tempBuffers, paramSetSize * (info.columnSize + 1)); + strLenOrIndArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { strLenOrIndArray[i] = SQL_NULL_DATA; - std::memset(wcharArray + i * (info.columnSize + 1), 0, (info.columnSize + 1) * sizeof(SQLWCHAR)); + std::memset( + wcharArray + i * (info.columnSize + 1), 0, + (info.columnSize + 1) * sizeof(SQLWCHAR)); } else { - std::wstring wstr = columnValues[i].cast(); + std::wstring wstr = + columnValues[i].cast(); #if defined(__APPLE__) || defined(__linux__) - // Convert to UTF-16 first, then check the actual UTF-16 length + // Convert to UTF-16 first, then check the actual + // UTF-16 length auto utf16Buf = WStringToSQLWCHAR(wstr); - // Check UTF-16 length (excluding null terminator) against column size - if (utf16Buf.size() > 0 && (utf16Buf.size() - 1) > info.columnSize) { + // Check UTF-16 length (excluding null terminator) + // against column size + if (utf16Buf.size() > 0 && + (utf16Buf.size() - 1) > info.columnSize) { std::string offending = WideToUTF8(wstr); - ThrowStdException("Input string UTF-16 length exceeds allowed column size at parameter index " + std::to_string(paramIndex) + - ". UTF-16 length: " + std::to_string(utf16Buf.size() - 1) + ", Column size: " + std::to_string(info.columnSize)); + ThrowStdException( + "Input string UTF-16 length exceeds " + "allowed column size at parameter index " + + std::to_string(paramIndex) + + ". UTF-16 length: " + + std::to_string(utf16Buf.size() - 1) + + ", Column size: " + + std::to_string(info.columnSize)); } - // Secure copy: use validated bounds for defense-in-depth - size_t copyBytes = utf16Buf.size() * sizeof(SQLWCHAR); - size_t bufferBytes = (info.columnSize + 1) * sizeof(SQLWCHAR); - SQLWCHAR* destPtr = wcharArray + i * (info.columnSize + 1); - + // Secure copy: use validated bounds for + // defense-in-depth + size_t copyBytes = + utf16Buf.size() * sizeof(SQLWCHAR); + size_t bufferBytes = + (info.columnSize + 1) * sizeof(SQLWCHAR); + SQLWCHAR* destPtr = + wcharArray + i * (info.columnSize + 1); + if (copyBytes > bufferBytes) { - ThrowStdException("Buffer overflow prevented in WCHAR array binding at parameter index " + std::to_string(paramIndex) + + ThrowStdException( + "Buffer overflow prevented in WCHAR array " + "binding at parameter " + "index " + + std::to_string(paramIndex) + ", array element " + std::to_string(i)); } if (copyBytes > 0) { - std::copy_n(reinterpret_cast(utf16Buf.data()), copyBytes, reinterpret_cast(destPtr)); + std::copy_n(reinterpret_cast( + utf16Buf.data()), + copyBytes, + reinterpret_cast(destPtr)); } #else - // On Windows, wchar_t is already UTF-16, so the original check is sufficient + // On Windows, wchar_t is already UTF-16, so the + // original check is sufficient if (wstr.length() > info.columnSize) { std::string offending = WideToUTF8(wstr); - ThrowStdException("Input string exceeds allowed column size at parameter index " + std::to_string(paramIndex)); + ThrowStdException( + "Input string exceeds allowed column size " + "at parameter index " + + std::to_string(paramIndex)); } // Secure copy with bounds checking - size_t copyBytes = (wstr.length() + 1) * sizeof(SQLWCHAR); - size_t bufferBytes = (info.columnSize + 1) * sizeof(SQLWCHAR); - SQLWCHAR* destPtr = wcharArray + i * (info.columnSize + 1); - - errno_t err = memcpy_s(destPtr, bufferBytes, wstr.c_str(), copyBytes); + size_t copyBytes = + (wstr.length() + 1) * sizeof(SQLWCHAR); + size_t bufferBytes = + (info.columnSize + 1) * sizeof(SQLWCHAR); + SQLWCHAR* destPtr = + wcharArray + i * (info.columnSize + 1); + + errno_t err = memcpy_s(destPtr, bufferBytes, + wstr.c_str(), copyBytes); if (err != 0) { - ThrowStdException("Secure memory copy failed in WCHAR array binding at parameter index " + std::to_string(paramIndex) + - ", array element " + std::to_string(i) + ", error code: " + std::to_string(err)); + ThrowStdException( + "Secure memory copy failed in WCHAR array " + "binding at parameter " + "index " + + std::to_string(paramIndex) + + ", array element " + std::to_string(i) + + ", error code: " + std::to_string(err)); } #endif strLenOrIndArray[i] = SQL_NTS; @@ -2034,17 +2368,23 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, } case SQL_C_TINYINT: case SQL_C_UTINYINT: { - unsigned char* dataArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + unsigned char* dataArray = + AllocateParamBufferArray(tempBuffers, + paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { if (!strLenOrIndArray) - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + strLenOrIndArray = + AllocateParamBufferArray( + tempBuffers, paramSetSize); dataArray[i] = 0; strLenOrIndArray[i] = SQL_NULL_DATA; } else { int intVal = columnValues[i].cast(); if (intVal < 0 || intVal > 255) { - ThrowStdException("UTINYINT value out of range at rowIndex " + std::to_string(i)); + ThrowStdException( + "UTINYINT value out of range at rowIndex " + + std::to_string(i)); } dataArray[i] = static_cast(intVal); if (strLenOrIndArray) strLenOrIndArray[i] = 0; @@ -2055,95 +2395,127 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, break; } case SQL_C_SHORT: { - short* dataArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + int16_t* dataArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { if (!strLenOrIndArray) - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + strLenOrIndArray = + AllocateParamBufferArray( + tempBuffers, paramSetSize); dataArray[i] = 0; strLenOrIndArray[i] = SQL_NULL_DATA; } else { int intVal = columnValues[i].cast(); - if (intVal < std::numeric_limits::min() || - intVal > std::numeric_limits::max()) { - ThrowStdException("SHORT value out of range at rowIndex " + std::to_string(i)); + if (intVal < std::numeric_limits::min() || + intVal > std::numeric_limits::max()) { + ThrowStdException( + "SHORT value out of range at rowIndex " + + std::to_string(i)); } - dataArray[i] = static_cast(intVal); + dataArray[i] = static_cast(intVal); if (strLenOrIndArray) strLenOrIndArray[i] = 0; } } dataPtr = dataArray; - bufferLength = sizeof(short); + bufferLength = sizeof(int16_t); break; } case SQL_C_CHAR: case SQL_C_BINARY: { - char* charArray = AllocateParamBufferArray(tempBuffers, paramSetSize * (info.columnSize + 1)); - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + char* charArray = AllocateParamBufferArray( + tempBuffers, paramSetSize * (info.columnSize + 1)); + strLenOrIndArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { strLenOrIndArray[i] = SQL_NULL_DATA; - std::memset(charArray + i * (info.columnSize + 1), 0, info.columnSize + 1); + std::memset(charArray + i * (info.columnSize + 1), + 0, info.columnSize + 1); } else { std::string str; - - // Apply dynamic encoding only for SQL_C_CHAR (not SQL_C_BINARY) - if (info.paramCType == SQL_C_CHAR && encoding_settings && - !encoding_settings.is_none() && - encoding_settings.contains("ctype") && + + // Apply dynamic encoding only for SQL_C_CHAR + // (not SQL_C_BINARY) + if (info.paramCType == SQL_C_CHAR && + encoding_settings && + !encoding_settings.is_none() && + encoding_settings.contains("ctype") && encoding_settings.contains("encoding")) { - - SQLSMALLINT ctype = encoding_settings["ctype"].cast(); - + SQLSMALLINT ctype = encoding_settings["ctype"] + .cast(); if (ctype == SQL_C_CHAR) { try { - py::dict settings_dict = encoding_settings.cast(); - auto [encoding, errors] = extract_encoding_settings(settings_dict); - + py::dict settings_dict = + encoding_settings.cast(); + auto [encoding, errors] = + extract_encoding_settings( + settings_dict); // Use our safe encoding function - py::bytes encoded_bytes = EncodingString(columnValues[i].cast(), encoding, errors); + py::bytes encoded_bytes = + EncodingString( + columnValues[i] + .cast(), + encoding, errors); str = encoded_bytes.cast(); - } catch (const std::exception& e) { - ThrowStdException("Failed to encode parameter array element " + std::to_string(i) + ": " + e.what()); + ThrowStdException( + "Failed to encode " + "parameter array element " + + std::to_string(i) + ": " + + e.what()); } } else { // Default behavior str = columnValues[i].cast(); } } else { - // No encoding settings or SQL_C_BINARY - use default behavior + // No encoding settings or SQL_C_BINARY - use + // default behavior str = columnValues[i].cast(); } - if (str.size() > info.columnSize) { - ThrowStdException("Input exceeds column size at index " + std::to_string(i)); + ThrowStdException( + "Input exceeds column size at index " + + std::to_string(i)); } - + // SECURITY: Use secure copy with bounds checking size_t destOffset = i * (info.columnSize + 1); size_t destBufferSize = info.columnSize + 1; size_t copyLength = str.size(); - + // Validate bounds to prevent buffer overflow if (copyLength >= destBufferSize) { - ThrowStdException("Buffer overflow prevented at parameter array index " + std::to_string(i)); + ThrowStdException( + "Buffer overflow prevented at parameter " + "array index " + + std::to_string(i)); } - - #ifdef _WIN32 - // Windows: Use memcpy_s for secure copy - errno_t err = memcpy_s(charArray + destOffset, destBufferSize, str.data(), copyLength); - if (err != 0) { - ThrowStdException("Secure memory copy failed with error code " + std::to_string(err) + " at array index " + std::to_string(i)); - } - #else - // POSIX: Use std::copy_n with explicit bounds checking - if (copyLength > 0) { - std::copy_n(str.data(), copyLength, charArray + destOffset); - } - #endif - - strLenOrIndArray[i] = static_cast(copyLength); + +#ifdef _WIN32 + // Windows: Use memcpy_s for secure copy + errno_t err = + memcpy_s(charArray + destOffset, destBufferSize, + str.data(), copyLength); + if (err != 0) { + ThrowStdException( + "Secure memory copy failed with error " + "code " + + std::to_string(err) + " at array index " + + std::to_string(i)); + } +#else + // POSIX: Use std::copy_n with explicit bounds + // checking + if (copyLength > 0) { + std::copy_n(str.data(), copyLength, + charArray + destOffset); + } +#endif + + strLenOrIndArray[i] = + static_cast(copyLength); } } dataPtr = charArray; @@ -2151,8 +2523,10 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, break; } case SQL_C_BIT: { - char* boolArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + char* boolArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); + strLenOrIndArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { boolArray[i] = 0; @@ -2168,27 +2542,31 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, } case SQL_C_STINYINT: case SQL_C_USHORT: { - unsigned short* dataArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + uint16_t* dataArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); + strLenOrIndArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { strLenOrIndArray[i] = SQL_NULL_DATA; dataArray[i] = 0; } else { - dataArray[i] = columnValues[i].cast(); + dataArray[i] = columnValues[i].cast(); strLenOrIndArray[i] = 0; } } dataPtr = dataArray; - bufferLength = sizeof(unsigned short); + bufferLength = sizeof(uint16_t); break; } case SQL_C_SBIGINT: case SQL_C_SLONG: case SQL_C_UBIGINT: case SQL_C_ULONG: { - int64_t* dataArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + int64_t* dataArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); + strLenOrIndArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { strLenOrIndArray[i] = SQL_NULL_DATA; @@ -2203,8 +2581,10 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, break; } case SQL_C_FLOAT: { - float* dataArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + float* dataArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); + strLenOrIndArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { strLenOrIndArray[i] = SQL_NULL_DATA; @@ -2219,17 +2599,24 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, break; } case SQL_C_TYPE_DATE: { - SQL_DATE_STRUCT* dateArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + SQL_DATE_STRUCT* dateArray = + AllocateParamBufferArray(tempBuffers, + paramSetSize); + strLenOrIndArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { strLenOrIndArray[i] = SQL_NULL_DATA; - std::memset(&dateArray[i], 0, sizeof(SQL_DATE_STRUCT)); + std::memset(&dateArray[i], 0, + sizeof(SQL_DATE_STRUCT)); } else { py::object dateObj = columnValues[i]; - dateArray[i].year = dateObj.attr("year").cast(); - dateArray[i].month = dateObj.attr("month").cast(); - dateArray[i].day = dateObj.attr("day").cast(); + dateArray[i].year = + dateObj.attr("year").cast(); + dateArray[i].month = + dateObj.attr("month").cast(); + dateArray[i].day = + dateObj.attr("day").cast(); strLenOrIndArray[i] = 0; } } @@ -2238,17 +2625,24 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, break; } case SQL_C_TYPE_TIME: { - SQL_TIME_STRUCT* timeArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + SQL_TIME_STRUCT* timeArray = + AllocateParamBufferArray(tempBuffers, + paramSetSize); + strLenOrIndArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { strLenOrIndArray[i] = SQL_NULL_DATA; - std::memset(&timeArray[i], 0, sizeof(SQL_TIME_STRUCT)); + std::memset(&timeArray[i], 0, + sizeof(SQL_TIME_STRUCT)); } else { py::object timeObj = columnValues[i]; - timeArray[i].hour = timeObj.attr("hour").cast(); - timeArray[i].minute = timeObj.attr("minute").cast(); - timeArray[i].second = timeObj.attr("second").cast(); + timeArray[i].hour = + timeObj.attr("hour").cast(); + timeArray[i].minute = + timeObj.attr("minute").cast(); + timeArray[i].second = + timeObj.attr("second").cast(); strLenOrIndArray[i] = 0; } } @@ -2257,21 +2651,33 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, break; } case SQL_C_TYPE_TIMESTAMP: { - SQL_TIMESTAMP_STRUCT* tsArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + SQL_TIMESTAMP_STRUCT* tsArray = + AllocateParamBufferArray( + tempBuffers, paramSetSize); + strLenOrIndArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { if (columnValues[i].is_none()) { strLenOrIndArray[i] = SQL_NULL_DATA; - std::memset(&tsArray[i], 0, sizeof(SQL_TIMESTAMP_STRUCT)); + std::memset(&tsArray[i], 0, + sizeof(SQL_TIMESTAMP_STRUCT)); } else { py::object dtObj = columnValues[i]; - tsArray[i].year = dtObj.attr("year").cast(); - tsArray[i].month = dtObj.attr("month").cast(); - tsArray[i].day = dtObj.attr("day").cast(); - tsArray[i].hour = dtObj.attr("hour").cast(); - tsArray[i].minute = dtObj.attr("minute").cast(); - tsArray[i].second = dtObj.attr("second").cast(); - tsArray[i].fraction = static_cast(dtObj.attr("microsecond").cast() * 1000); // µs to ns + tsArray[i].year = + dtObj.attr("year").cast(); + tsArray[i].month = + dtObj.attr("month").cast(); + tsArray[i].day = + dtObj.attr("day").cast(); + tsArray[i].hour = + dtObj.attr("hour").cast(); + tsArray[i].minute = + dtObj.attr("minute").cast(); + tsArray[i].second = + dtObj.attr("second").cast(); + tsArray[i].fraction = static_cast( + dtObj.attr("microsecond").cast() * + 1000); // µs to ns strLenOrIndArray[i] = 0; } } @@ -2280,44 +2686,69 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, break; } case SQL_C_SS_TIMESTAMPOFFSET: { - DateTimeOffset* dtoArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + DateTimeOffset* dtoArray = + AllocateParamBufferArray(tempBuffers, + paramSetSize); + strLenOrIndArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); - py::object datetimeType = py::module_::import("datetime").attr("datetime"); + py::object datetimeType = + py::module_::import("datetime").attr("datetime"); for (size_t i = 0; i < paramSetSize; ++i) { const py::handle& param = columnValues[i]; if (param.is_none()) { - std::memset(&dtoArray[i], 0, sizeof(DateTimeOffset)); + std::memset(&dtoArray[i], 0, + sizeof(DateTimeOffset)); strLenOrIndArray[i] = SQL_NULL_DATA; } else { if (!py::isinstance(param, datetimeType)) { - ThrowStdException(MakeParamMismatchErrorStr(info.paramCType, paramIndex)); + ThrowStdException(MakeParamMismatchErrorStr( + info.paramCType, paramIndex)); } py::object tzinfo = param.attr("tzinfo"); if (tzinfo.is_none()) { - ThrowStdException("Datetime object must have tzinfo for SQL_C_SS_TIMESTAMPOFFSET at paramIndex " + + ThrowStdException( + "Datetime object must have " + "tzinfo for SQL_C_SS_TIMESTAMPOFFSET at " + "paramIndex " + std::to_string(paramIndex)); } - // Populate the C++ struct directly from the Python datetime object. - dtoArray[i].year = static_cast(param.attr("year").cast()); - dtoArray[i].month = static_cast(param.attr("month").cast()); - dtoArray[i].day = static_cast(param.attr("day").cast()); - dtoArray[i].hour = static_cast(param.attr("hour").cast()); - dtoArray[i].minute = static_cast(param.attr("minute").cast()); - dtoArray[i].second = static_cast(param.attr("second").cast()); - // SQL server supports in ns, but python datetime supports in µs - dtoArray[i].fraction = static_cast(param.attr("microsecond").cast() * 1000); + // Populate the C++ struct directly from the Python + // datetime object. + dtoArray[i].year = static_cast( + param.attr("year").cast()); + dtoArray[i].month = static_cast( + param.attr("month").cast()); + dtoArray[i].day = static_cast( + param.attr("day").cast()); + dtoArray[i].hour = static_cast( + param.attr("hour").cast()); + dtoArray[i].minute = static_cast( + param.attr("minute").cast()); + dtoArray[i].second = static_cast( + param.attr("second").cast()); + // SQL server supports in ns, but python datetime + // supports in µs + dtoArray[i].fraction = static_cast( + param.attr("microsecond").cast() * 1000); // Compute and preserve the original UTC offset. - py::object utcoffset = tzinfo.attr("utcoffset")(param); - int total_seconds = static_cast(utcoffset.attr("total_seconds")().cast()); - std::div_t div_result = std::div(total_seconds, 3600); - dtoArray[i].timezone_hour = static_cast(div_result.quot); - dtoArray[i].timezone_minute = static_cast(div(div_result.rem, 60).quot); + py::object utcoffset = + tzinfo.attr("utcoffset")(param); + int total_seconds = static_cast( + utcoffset.attr("total_seconds")() + .cast()); + std::div_t div_result = + std::div(total_seconds, 3600); + dtoArray[i].timezone_hour = + static_cast(div_result.quot); + dtoArray[i].timezone_minute = + static_cast( + div(div_result.rem, 60).quot); strLenOrIndArray[i] = sizeof(DateTimeOffset); } @@ -2327,30 +2758,39 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, break; } case SQL_C_NUMERIC: { - SQL_NUMERIC_STRUCT* numericArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + SQL_NUMERIC_STRUCT* numericArray = + AllocateParamBufferArray( + tempBuffers, paramSetSize); + strLenOrIndArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); for (size_t i = 0; i < paramSetSize; ++i) { const py::handle& element = columnValues[i]; if (element.is_none()) { strLenOrIndArray[i] = SQL_NULL_DATA; - std::memset(&numericArray[i], 0, sizeof(SQL_NUMERIC_STRUCT)); + std::memset(&numericArray[i], 0, + sizeof(SQL_NUMERIC_STRUCT)); continue; } if (!py::isinstance(element)) { - throw std::runtime_error(MakeParamMismatchErrorStr(info.paramCType, paramIndex)); + throw std::runtime_error(MakeParamMismatchErrorStr( + info.paramCType, paramIndex)); } NumericData decimalParam = element.cast(); - LOG("Received numeric parameter at [%zu]: precision=%d, scale=%d, sign=%d, val=%s", - i, decimalParam.precision, decimalParam.scale, decimalParam.sign, decimalParam.val.c_str()); + LOG("Received numeric parameter at [%zu]: " + "precision=%d, scale=%d, sign=%d, val=%s", + i, decimalParam.precision, decimalParam.scale, + decimalParam.sign, decimalParam.val.c_str()); SQL_NUMERIC_STRUCT& target = numericArray[i]; std::memset(&target, 0, sizeof(SQL_NUMERIC_STRUCT)); target.precision = decimalParam.precision; target.scale = decimalParam.scale; target.sign = decimalParam.sign; - size_t copyLen = std::min(decimalParam.val.size(), sizeof(target.val)); + size_t copyLen = std::min(decimalParam.val.size(), + sizeof(target.val)); // Secure copy: bounds already validated with std::min if (copyLen > 0) { - std::copy_n(decimalParam.val.data(), copyLen, target.val); + std::copy_n(decimalParam.val.data(), copyLen, + target.val); } strLenOrIndArray[i] = sizeof(SQL_NUMERIC_STRUCT); } @@ -2359,13 +2799,17 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, break; } case SQL_C_GUID: { - SQLGUID* guidArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); + SQLGUID* guidArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); + strLenOrIndArray = AllocateParamBufferArray( + tempBuffers, paramSetSize); // Get cached UUID class from module-level helper - // This avoids static object destruction issues during Python finalization - py::object uuid_class = py::module_::import("mssql_python.ddbc_bindings").attr("_get_uuid_class")(); - + // This avoids static object destruction issues during + // Python finalization + py::object uuid_class = + py::module_::import("mssql_python.ddbc_bindings") + .attr("_get_uuid_class")(); for (size_t i = 0; i < paramSetSize; ++i) { const py::handle& element = columnValues[i]; std::array uuid_bytes; @@ -2373,33 +2817,44 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, std::memset(&guidArray[i], 0, sizeof(SQLGUID)); strLenOrIndArray[i] = SQL_NULL_DATA; continue; - } - else if (py::isinstance(element)) { + } else if (py::isinstance(element)) { py::bytes b = element.cast(); if (PyBytes_GET_SIZE(b.ptr()) != 16) { - ThrowStdException("UUID binary data must be exactly 16 bytes long."); + ThrowStdException( + "UUID binary data must be exactly " + "16 bytes long."); } - // Secure copy: Fixed 16-byte copy, size validated above - std::copy_n(reinterpret_cast(PyBytes_AS_STRING(b.ptr())), 16, uuid_bytes.data()); - } - else if (py::isinstance(element, uuid_class)) { - py::bytes b = element.attr("bytes_le").cast(); - // Secure copy: Fixed 16-byte copy from UUID bytes_le attribute - std::copy_n(reinterpret_cast(PyBytes_AS_STRING(b.ptr())), 16, uuid_bytes.data()); - } - else { - ThrowStdException(MakeParamMismatchErrorStr(info.paramCType, paramIndex)); + // Secure copy: Fixed 16-byte copy, size validated + // above + std::copy_n(reinterpret_cast( + PyBytes_AS_STRING(b.ptr())), + 16, uuid_bytes.data()); + } else if (py::isinstance(element, uuid_class)) { + py::bytes b = + element.attr("bytes_le").cast(); + // Secure copy: Fixed 16-byte copy from UUID + // bytes_le attribute + std::copy_n(reinterpret_cast( + PyBytes_AS_STRING(b.ptr())), + 16, uuid_bytes.data()); + } else { + ThrowStdException(MakeParamMismatchErrorStr( + info.paramCType, paramIndex)); } - guidArray[i].Data1 = (static_cast(uuid_bytes[3]) << 24) | - (static_cast(uuid_bytes[2]) << 16) | - (static_cast(uuid_bytes[1]) << 8) | - (static_cast(uuid_bytes[0])); - guidArray[i].Data2 = (static_cast(uuid_bytes[5]) << 8) | - (static_cast(uuid_bytes[4])); - guidArray[i].Data3 = (static_cast(uuid_bytes[7]) << 8) | - (static_cast(uuid_bytes[6])); + guidArray[i].Data1 = + (static_cast(uuid_bytes[3]) << 24) | + (static_cast(uuid_bytes[2]) << 16) | + (static_cast(uuid_bytes[1]) << 8) | + (static_cast(uuid_bytes[0])); + guidArray[i].Data2 = + (static_cast(uuid_bytes[5]) << 8) | + (static_cast(uuid_bytes[4])); + guidArray[i].Data3 = + (static_cast(uuid_bytes[7]) << 8) | + (static_cast(uuid_bytes[6])); // Secure copy: Fixed 8-byte copy for GUID Data4 field - std::copy_n(uuid_bytes.data() + 8, 8, guidArray[i].Data4); + std::copy_n(uuid_bytes.data() + 8, 8, + guidArray[i].Data4); strLenOrIndArray[i] = sizeof(SQLGUID); } dataPtr = guidArray; @@ -2407,21 +2862,17 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, break; } default: { - ThrowStdException("BindParameterArray: Unsupported C type: " + std::to_string(info.paramCType)); + ThrowStdException( + "BindParameterArray: Unsupported C type: " + + std::to_string(info.paramCType)); } } RETCODE rc = SQLBindParameter_ptr( - hStmt, - static_cast(paramIndex + 1), + hStmt, static_cast(paramIndex + 1), static_cast(info.inputOutputType), static_cast(info.paramCType), - static_cast(info.paramSQLType), - info.columnSize, - info.decimalDigits, - dataPtr, - bufferLength, - strLenOrIndArray - ); + static_cast(info.paramSQLType), info.columnSize, + info.decimalDigits, dataPtr, bufferLength, strLenOrIndArray); if (!SQL_SUCCEEDED(rc)) { LOG("Failed to bind array param {}", paramIndex); return rc; @@ -2431,17 +2882,16 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, LOG("Exception occurred during parameter array binding. Cleaning up."); throw; } - paramBuffers.insert(paramBuffers.end(), tempBuffers.begin(), tempBuffers.end()); + paramBuffers.insert(paramBuffers.end(), tempBuffers.begin(), + tempBuffers.end()); LOG("Finished column-wise parameter array binding."); return SQL_SUCCESS; } -SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, - const std::wstring& query, - const py::list& columnwise_params, - const std::vector& paramInfos, - size_t paramSetSize, - const py::object& encoding_settings = py::none()) { +SQLRETURN SQLExecuteMany_wrap( + const SqlHandlePtr statementHandle, const std::wstring& query, + const py::list& columnwise_params, const std::vector& paramInfos, + size_t paramSetSize, const py::object& encoding_settings = py::none()) { SQLHANDLE hStmt = statementHandle->get(); SQLWCHAR* queryPtr; @@ -2463,10 +2913,12 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, } if (!hasDAE) { std::vector> paramBuffers; - rc = BindParameterArray(hStmt, columnwise_params, paramInfos, paramSetSize, paramBuffers, encoding_settings); + rc = BindParameterArray(hStmt, columnwise_params, paramInfos, + paramSetSize, paramBuffers, encoding_settings); if (!SQL_SUCCEEDED(rc)) return rc; - rc = SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_PARAMSET_SIZE, (SQLPOINTER)paramSetSize, 0); + rc = SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_PARAMSET_SIZE, + (SQLPOINTER)paramSetSize, 0); if (!SQL_SUCCEEDED(rc)) return rc; rc = SQLExecute_ptr(hStmt); @@ -2477,7 +2929,9 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, py::list rowParams = columnwise_params[rowIndex]; std::vector> paramBuffers; - rc = BindParameters(hStmt, rowParams, const_cast&>(paramInfos), paramBuffers, encoding_settings); + rc = BindParameters(hStmt, rowParams, + const_cast&>(paramInfos), + paramBuffers, encoding_settings); if (!SQL_SUCCEEDED(rc)) return rc; rc = SQLExecute_ptr(hStmt); @@ -2492,11 +2946,14 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, if (py::isinstance(*py_obj_ptr)) { std::string data = py_obj_ptr->cast(); SQLLEN data_len = static_cast(data.size()); - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)data.c_str(), data_len); - } else if (py::isinstance(*py_obj_ptr) || py::isinstance(*py_obj_ptr)) { + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)data.c_str(), + data_len); + } else if (py::isinstance(*py_obj_ptr) || + py::isinstance(*py_obj_ptr)) { std::string data = py_obj_ptr->cast(); SQLLEN data_len = static_cast(data.size()); - rc = SQLPutData_ptr(hStmt, (SQLPOINTER)data.c_str(), data_len); + rc = SQLPutData_ptr(hStmt, (SQLPOINTER)data.c_str(), + data_len); } else { LOG("Unsupported DAE parameter type in row {}", rowIndex); return SQL_ERROR; @@ -2509,7 +2966,6 @@ SQLRETURN SQLExecuteMany_wrap(const SqlHandlePtr statementHandle, } } - // Wrap SQLNumResultCols SQLSMALLINT SQLNumResultCols_wrap(SqlHandlePtr statementHandle) { LOG("Get number of columns in result set"); @@ -2525,7 +2981,8 @@ SQLSMALLINT SQLNumResultCols_wrap(SqlHandlePtr statementHandle) { } // Wrap SQLDescribeCol -SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, py::list& ColumnMetadata) { +SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, + py::list& ColumnMetadata) { LOG("Get column description"); if (!SQLDescribeCol_ptr) { LOG("Function pointer not initialized. Loading the driver."); @@ -2549,20 +3006,22 @@ SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, py::list& ColumnMeta SQLSMALLINT Nullable; retcode = SQLDescribeCol_ptr(StatementHandle->get(), i, ColumnName, - sizeof(ColumnName) / sizeof(SQLWCHAR), &NameLength, &DataType, - &ColumnSize, &DecimalDigits, &Nullable); + sizeof(ColumnName) / sizeof(SQLWCHAR), + &NameLength, &DataType, &ColumnSize, + &DecimalDigits, &Nullable); if (SQL_SUCCEEDED(retcode)) { // Append a named py::dict to ColumnMetadata // TODO: Should we define a struct for this task instead of dict? #if defined(__APPLE__) || defined(__linux__) - ColumnMetadata.append(py::dict("ColumnName"_a = SQLWCHARToWString(ColumnName, SQL_NTS), + ColumnMetadata.append(py::dict( + "ColumnName"_a = SQLWCHARToWString(ColumnName, SQL_NTS), #else - ColumnMetadata.append(py::dict("ColumnName"_a = std::wstring(ColumnName), + ColumnMetadata.append(py::dict( + "ColumnName"_a = std::wstring(ColumnName), #endif - "DataType"_a = DataType, "ColumnSize"_a = ColumnSize, - "DecimalDigits"_a = DecimalDigits, - "Nullable"_a = Nullable)); + "DataType"_a = DataType, "ColumnSize"_a = ColumnSize, + "DecimalDigits"_a = DecimalDigits, "Nullable"_a = Nullable)); } else { return retcode; } @@ -2570,51 +3029,52 @@ SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, py::list& ColumnMeta return SQL_SUCCESS; } -SQLRETURN SQLSpecialColumns_wrap(SqlHandlePtr StatementHandle, - SQLSMALLINT identifierType, - const py::object& catalogObj, - const py::object& schemaObj, - const std::wstring& table, - SQLSMALLINT scope, - SQLSMALLINT nullable) { +SQLRETURN SQLSpecialColumns_wrap(SqlHandlePtr StatementHandle, + SQLSMALLINT identifierType, + const py::object& catalogObj, + const py::object& schemaObj, + const std::wstring& table, SQLSMALLINT scope, + SQLSMALLINT nullable) { if (!SQLSpecialColumns_ptr) { ThrowStdException("SQLSpecialColumns function not loaded"); } // Convert py::object to std::wstring, treating None as empty string - std::wstring catalog = catalogObj.is_none() ? L"" : catalogObj.cast(); - std::wstring schema = schemaObj.is_none() ? L"" : schemaObj.cast(); + std::wstring catalog = + catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schema = + schemaObj.is_none() ? L"" : schemaObj.cast(); #if defined(__APPLE__) || defined(__linux__) // Unix implementation std::vector catalogBuf = WStringToSQLWCHAR(catalog); std::vector schemaBuf = WStringToSQLWCHAR(schema); std::vector tableBuf = WStringToSQLWCHAR(table); - - return SQLSpecialColumns_ptr( - StatementHandle->get(), - identifierType, - catalog.empty() ? nullptr : catalogBuf.data(), - catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : schemaBuf.data(), - schema.empty() ? 0 : SQL_NTS, - table.empty() ? nullptr : tableBuf.data(), - table.empty() ? 0 : SQL_NTS, - scope, - nullable); + + return SQLSpecialColumns_ptr(StatementHandle->get(), identifierType, + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : tableBuf.data(), + table.empty() ? 0 : SQL_NTS, scope, nullable); #else // Windows implementation return SQLSpecialColumns_ptr( - StatementHandle->get(), - identifierType, - catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + StatementHandle->get(), identifierType, + catalog.empty() + ? nullptr + : const_cast( + reinterpret_cast(catalog.c_str())), catalog.empty() ? 0 : SQL_NTS, - schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? nullptr + : const_cast( + reinterpret_cast(schema.c_str())), schema.empty() ? 0 : SQL_NTS, - table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), - table.empty() ? 0 : SQL_NTS, - scope, - nullable); + table.empty() ? nullptr + : const_cast( + reinterpret_cast(table.c_str())), + table.empty() ? 0 : SQL_NTS, scope, nullable); #endif } @@ -2629,12 +3089,9 @@ SQLRETURN SQLFetch_wrap(SqlHandlePtr StatementHandle) { return SQLFetch_ptr(StatementHandle->get()); } -static py::object FetchLobColumnData(SQLHSTMT hStmt, - SQLUSMALLINT colIndex, - SQLSMALLINT cType, - bool isWideChar, - bool isBinary, - const std::string& char_encoding = "utf-8") { +static py::object FetchLobColumnData( + SQLHSTMT hStmt, SQLUSMALLINT colIndex, SQLSMALLINT cType, bool isWideChar, + bool isBinary, const std::string& char_encoding = "utf-8") { std::vector buffer; SQLRETURN ret = SQL_SUCCESS_WITH_INFO; int loopCount = 0; @@ -2643,18 +3100,14 @@ static py::object FetchLobColumnData(SQLHSTMT hStmt, ++loopCount; std::vector chunk(DAE_CHUNK_SIZE, 0); SQLLEN actualRead = 0; - ret = SQLGetData_ptr(hStmt, - colIndex, - cType, - chunk.data(), - DAE_CHUNK_SIZE, - &actualRead); - - if (ret == SQL_ERROR || !SQL_SUCCEEDED(ret) && ret != SQL_SUCCESS_WITH_INFO) { + ret = SQLGetData_ptr(hStmt, colIndex, cType, chunk.data(), + DAE_CHUNK_SIZE, &actualRead); + + if (ret == SQL_ERROR || + (!SQL_SUCCEEDED(ret) && ret != SQL_SUCCESS_WITH_INFO)) { std::ostringstream oss; oss << "Error fetching LOB for column " << colIndex - << ", cType=" << cType - << ", loop=" << loopCount + << ", cType=" << cType << ", loop=" << loopCount << ", SQLGetData return=" << ret; LOG(oss.str()); ThrowStdException(oss.str()); @@ -2689,20 +3142,23 @@ static py::object FetchLobColumnData(SQLHSTMT hStmt, // Wide characters size_t wcharSize = sizeof(SQLWCHAR); if (bytesRead >= wcharSize) { - auto sqlwBuf = reinterpret_cast(chunk.data()); + auto sqlwBuf = + reinterpret_cast(chunk.data()); size_t wcharCount = bytesRead / wcharSize; while (wcharCount > 0 && sqlwBuf[wcharCount - 1] == 0) { --wcharCount; bytesRead -= wcharSize; } if (bytesRead < DAE_CHUNK_SIZE) { - LOG("Loop {}: Trimmed null terminator (wide)", loopCount); + LOG("Loop {}: Trimmed null terminator (wide)", + loopCount); } } } } if (bytesRead > 0) { - buffer.insert(buffer.end(), chunk.begin(), chunk.begin() + bytesRead); + buffer.insert(buffer.end(), chunk.begin(), + chunk.begin() + bytesRead); LOG("Loop {}: Appended {} bytes", loopCount, bytesRead); } if (ret == SQL_SUCCESS) { @@ -2720,13 +3176,15 @@ static py::object FetchLobColumnData(SQLHSTMT hStmt, } if (isWideChar) { #if defined(_WIN32) - std::wstring wstr(reinterpret_cast(buffer.data()), buffer.size() / sizeof(wchar_t)); + std::wstring wstr(reinterpret_cast(buffer.data()), + buffer.size() / sizeof(wchar_t)); std::string utf8str = WideToUTF8(wstr); return py::str(utf8str); #else // Linux/macOS handling size_t wcharCount = buffer.size() / sizeof(SQLWCHAR); - const SQLWCHAR* sqlwBuf = reinterpret_cast(buffer.data()); + const SQLWCHAR* sqlwBuf = + reinterpret_cast(buffer.data()); std::wstring wstr = SQLWCHARToWString(sqlwBuf, wcharCount); std::string utf8str = WideToUTF8(wstr); return py::str(utf8str); @@ -2736,35 +3194,39 @@ static py::object FetchLobColumnData(SQLHSTMT hStmt, LOG("FetchLobColumnData: Returning binary of {} bytes", buffer.size()); return py::bytes(buffer.data(), buffer.size()); } - + // SQL_C_CHAR handling with dynamic encoding if (cType == SQL_C_CHAR && !char_encoding.empty()) { try { - py::str decoded_str = DecodingString( - buffer.data(), - buffer.size(), - char_encoding, - "strict" - ); - LOG("FetchLobColumnData: Applied dynamic decoding for LOB using encoding '{}'", char_encoding); + py::str decoded_str = DecodingString(buffer.data(), buffer.size(), + char_encoding, "strict"); + LOG("FetchLobColumnData: Applied dynamic decoding for LOB " + "using encoding '{}'", + char_encoding); return decoded_str; } catch (const std::exception& e) { - LOG("FetchLobColumnData: Dynamic decoding failed: {}. Using fallback.", e.what()); + LOG("FetchLobColumnData: Dynamic decoding failed: {}. " + "Using fallback.", + e.what()); // Fallback to original logic } } - + // Fallback: original behavior for SQL_C_CHAR std::string str(buffer.data(), buffer.size()); - LOG("FetchLobColumnData: Returning narrow string of length {}", str.length()); + LOG("FetchLobColumnData: Returning narrow string of length {}", + str.length()); return py::str(str); } // Helper function to retrieve column data -SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, py::list& row, +SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, + py::list& row, const std::string& char_encoding = "utf-8", const std::string& wchar_encoding = "utf-16le") { - UNREFERENCED_PARAMETER(wchar_encoding); // SQL_WCHAR behavior unchanged, keeping parameter for API consistency + UNREFERENCED_PARAMETER(wchar_encoding); // SQL_WCHAR behavior unchanged, + // keeping parameter for API + // consistency LOG("Get data from columns"); if (!SQLGetData_ptr) { LOG("Function pointer not initialized. Loading the driver."); @@ -2781,10 +3243,13 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p SQLSMALLINT decimalDigits; SQLSMALLINT nullable; - ret = SQLDescribeCol_ptr(hStmt, i, columnName, sizeof(columnName) / sizeof(SQLWCHAR), - &columnNameLen, &dataType, &columnSize, &decimalDigits, &nullable); + ret = SQLDescribeCol_ptr( + hStmt, i, columnName, sizeof(columnName) / sizeof(SQLWCHAR), + &columnNameLen, &dataType, &columnSize, &decimalDigits, &nullable); if (!SQL_SUCCEEDED(ret)) { - LOG("Error retrieving data for column - {}, SQLDescribeCol return code - {}", i, ret); + LOG("Error retrieving data for column - {}, SQLDescribeCol " + "return code - {}", + i, ret); row.append(py::none()); continue; } @@ -2793,15 +3258,19 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { - if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > SQL_MAX_LOB_SIZE) { + if (columnSize == SQL_NO_TOTAL || columnSize == 0 || + columnSize > SQL_MAX_LOB_SIZE) { LOG("Streaming LOB for column {}", i); - row.append(FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false, char_encoding)); + row.append(FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, + false, char_encoding)); } else { - uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; + uint64_t fetchBufferSize = + columnSize + 1 /* null-termination */; std::vector dataBuffer(fetchBufferSize); SQLLEN dataLen; - ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, dataBuffer.data(), dataBuffer.size(), - &dataLen); + ret = + SQLGetData_ptr(hStmt, i, SQL_C_CHAR, dataBuffer.data(), + dataBuffer.size(), &dataLen); if (SQL_SUCCEEDED(ret)) { // columnSize is in chars, dataLen is in bytes if (dataLen > 0) { @@ -2809,28 +3278,38 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p if (numCharsInData < dataBuffer.size()) { // Use dynamic decoding for SQL_CHAR types try { - py::str decoded_str = DecodingString( - reinterpret_cast(dataBuffer.data()), - numCharsInData, - char_encoding, - "strict" - ); + py::str decoded_str = + DecodingString(reinterpret_cast( + dataBuffer.data()), + numCharsInData, + char_encoding, "strict"); row.append(decoded_str); - LOG("Applied dynamic decoding for CHAR column {} using encoding '{}'", i, char_encoding); + LOG("Applied dynamic decoding for CHAR " + "column {} using encoding '{}'", + i, char_encoding); } catch (const std::exception& e) { - LOG("Dynamic decoding failed for column {}: {}. Using fallback.", i, e.what()); + LOG("Dynamic decoding failed for column " + "{}: {}. Using fallback.", + i, e.what()); // Fallback to platform-specific handling - #if defined(__APPLE__) || defined(__linux__) - std::string fullStr(reinterpret_cast(dataBuffer.data())); +#if defined(__APPLE__) || defined(__linux__) + std::string fullStr(reinterpret_cast( + dataBuffer.data())); row.append(fullStr); - #else - row.append(std::string(reinterpret_cast(dataBuffer.data()))); - #endif +#else + row.append( + std::string(reinterpret_cast( + dataBuffer.data()))); +#endif } } else { // Buffer too small, fallback to streaming - LOG("CHAR column {} data truncated, using streaming LOB", i); - row.append(FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false, char_encoding)); + LOG("CHAR column {} data truncated, " + "using streaming LOB", + i); + row.append(FetchLobColumnData( + hStmt, i, SQL_C_CHAR, false, false, + char_encoding)); } } else if (dataLen == SQL_NULL_DATA) { LOG("Column {} is NULL (CHAR)", i); @@ -2838,28 +3317,38 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p } else if (dataLen == 0) { row.append(py::str("")); } else if (dataLen == SQL_NO_TOTAL) { - LOG("SQLGetData couldn't determine the length of the data. " - "Returning NULL value instead. Column ID - {}, Data Type - {}", i, dataType); + LOG("SQLGetData couldn't determine the length of " + "the " + "data. Returning NULL value instead. Column ID " + "- {}, " + "Data Type - {}", + i, dataType); row.append(py::none()); } else if (dataLen < 0) { - LOG("SQLGetData returned an unexpected negative data length. " - "Raising exception. Column ID - {}, Data Type - {}, Data Length - {}", + LOG("SQLGetData returned an unexpected negative " + "data " + "length. Raising exception. Column ID - {}, " + "Data Type - {}, Data Length - {}", i, dataType, dataLen); - ThrowStdException("SQLGetData returned an unexpected negative data length"); + ThrowStdException( + "SQLGetData returned an unexpected " + "negative data length"); } } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", + LOG("Error retrieving data for column - {}, data type " + "- " + "{}, SQLGetData return code - {}. Returning NULL " + "value instead", i, dataType, ret); row.append(py::none()); } } break; } - case SQL_SS_XML: - { + case SQL_SS_XML: { LOG("Streaming XML for column {}", i); - row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false)); + row.append( + FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false)); break; } case SQL_WCHAR: @@ -2867,30 +3356,45 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p case SQL_WLONGVARCHAR: { if (columnSize == SQL_NO_TOTAL || columnSize > 4000) { LOG("Streaming LOB for column {} (NVARCHAR)", i); - row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false)); + row.append( + FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false)); } else { - uint64_t fetchBufferSize = (columnSize + 1) * sizeof(SQLWCHAR); // +1 for null terminator + uint64_t fetchBufferSize = + (columnSize + 1) * + sizeof(SQLWCHAR); // +1 for null terminator std::vector dataBuffer(columnSize + 1); SQLLEN dataLen; - ret = SQLGetData_ptr(hStmt, i, SQL_C_WCHAR, dataBuffer.data(), fetchBufferSize, &dataLen); + ret = + SQLGetData_ptr(hStmt, i, SQL_C_WCHAR, dataBuffer.data(), + fetchBufferSize, &dataLen); if (SQL_SUCCEEDED(ret)) { if (dataLen > 0) { - uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); + uint64_t numCharsInData = + dataLen / sizeof(SQLWCHAR); if (numCharsInData < dataBuffer.size()) { #if defined(__APPLE__) || defined(__linux__) - const SQLWCHAR* sqlwBuf = reinterpret_cast(dataBuffer.data()); - std::wstring wstr = SQLWCHARToWString(sqlwBuf, numCharsInData); + const SQLWCHAR* sqlwBuf = + reinterpret_cast( + dataBuffer.data()); + std::wstring wstr = + SQLWCHARToWString(sqlwBuf, numCharsInData); std::string utf8str = WideToUTF8(wstr); row.append(py::str(utf8str)); #else - std::wstring wstr(reinterpret_cast(dataBuffer.data())); + std::wstring wstr(reinterpret_cast( + dataBuffer.data())); row.append(py::cast(wstr)); #endif - LOG("Appended NVARCHAR string of length {} to result row", numCharsInData); - } else { + LOG("Appended NVARCHAR string of length {} " + "to result row", + numCharsInData); + } else { // Buffer too small, fallback to streaming - LOG("NVARCHAR column {} data truncated, using streaming LOB", i); - row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false)); + LOG("NVARCHAR column {} data truncated, " + "using streaming LOB", + i); + row.append(FetchLobColumnData( + hStmt, i, SQL_C_WCHAR, true, false)); } } else if (dataLen == SQL_NULL_DATA) { LOG("Column {} is NULL (CHAR)", i); @@ -2898,16 +3402,25 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p } else if (dataLen == 0) { row.append(py::str("")); } else if (dataLen == SQL_NO_TOTAL) { - LOG("SQLGetData couldn't determine the length of the NVARCHAR data. Returning NULL. Column ID - {}", i); + LOG("SQLGetData couldn't determine the length of " + "the NVARCHAR data. Returning NULL. " + "Column ID - {}", + i); row.append(py::none()); } else if (dataLen < 0) { - LOG("SQLGetData returned an unexpected negative data length. " - "Raising exception. Column ID - {}, Data Type - {}, Data Length - {}", + LOG("SQLGetData returned an unexpected negative " + "data " + "length. Raising exception. Column ID - {}, " + "Data Type - {}, Data Length - {}", i, dataType, dataLen); - ThrowStdException("SQLGetData returned an unexpected negative data length"); + ThrowStdException( + "SQLGetData returned an unexpected " + "negative data length"); } } else { - LOG("Error retrieving data for column {} (NVARCHAR), SQLGetData return code {}", i, ret); + LOG("Error retrieving data for column {} (NVARCHAR), " + "SQLGetData return code {}", + i, ret); row.append(py::none()); } } @@ -2925,12 +3438,14 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p } case SQL_SMALLINT: { SQLSMALLINT smallIntValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_SHORT, &smallIntValue, 0, NULL); + ret = SQLGetData_ptr(hStmt, i, SQL_C_SHORT, &smallIntValue, 0, + NULL); if (SQL_SUCCEEDED(ret)) { row.append(static_cast(smallIntValue)); } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", + LOG("Error retrieving data for column - {}, " + "data type - {}, SQLGetData return code - {}. " + "Returning NULL value instead", i, dataType, ret); row.append(py::none()); } @@ -2938,12 +3453,14 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p } case SQL_REAL: { SQLREAL realValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_FLOAT, &realValue, 0, NULL); + ret = + SQLGetData_ptr(hStmt, i, SQL_C_FLOAT, &realValue, 0, NULL); if (SQL_SUCCEEDED(ret)) { row.append(realValue); } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", + LOG("Error retrieving data for column - {}, " + "data type - {}, SQLGetData return code - {}. " + "Returning NULL value instead", i, dataType, ret); row.append(py::none()); } @@ -2954,39 +3471,50 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p SQLCHAR numericStr[MAX_DIGITS_IN_NUMERIC] = {0}; SQLLEN indicator = 0; - ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, numericStr, sizeof(numericStr), &indicator); + ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, numericStr, + sizeof(numericStr), &indicator); if (SQL_SUCCEEDED(ret)) { try { - // Validate 'indicator' to avoid buffer overflow and fallback to a safe - // null-terminated read when length is unknown or out-of-range. - const char* cnum = reinterpret_cast(numericStr); + // Validate 'indicator' to avoid buffer overflow and + // fallback to a safe null-terminated read when length + // is unknown or out-of-range. + const char* cnum = + reinterpret_cast(numericStr); size_t bufSize = sizeof(numericStr); size_t safeLen = 0; - if (indicator > 0 && indicator <= static_cast(bufSize)) { - // indicator appears valid and within the buffer size + if (indicator > 0 && + indicator <= static_cast(bufSize)) { + // indicator appears valid and within the buffer + // size safeLen = static_cast(indicator); } else { - // indicator is unknown, zero, negative, or too large; determine length - // by searching for a terminating null (safe bounded scan) + // indicator is unknown, zero, negative, or too + // large; determine length by searching for a + // terminating null (safe bounded scan) for (size_t j = 0; j < bufSize; ++j) { if (cnum[j] == '\0') { safeLen = j; break; } } - // if no null found, use the full buffer size as a conservative fallback - if (safeLen == 0 && bufSize > 0 && cnum[0] != '\0') { + // if no null found, use the full buffer size as a + // conservative fallback + if (safeLen == 0 && bufSize > 0 && + cnum[0] != '\0') { safeLen = bufSize; } } - // Use the validated length to construct the string for Decimal + // Use the validated length to construct the string for + // Decimal std::string numStr(cnum, safeLen); // Create Python Decimal object - py::object decimalObj = py::module_::import("decimal").attr("Decimal")(numStr); + py::object decimalObj = + py::module_::import("decimal").attr("Decimal")( + numStr); // Add to row row.append(decimalObj); @@ -2995,9 +3523,9 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p LOG("Error converting to decimal: {}", e.what()); row.append(py::none()); } - } - else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " + } else { + LOG("Error retrieving data for column - {}, data type - " + "{}, SQLGetData return " "code - {}. Returning NULL value instead", i, dataType, ret); row.append(py::none()); @@ -3008,11 +3536,13 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p case SQL_DOUBLE: case SQL_FLOAT: { SQLDOUBLE doubleValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_DOUBLE, &doubleValue, 0, NULL); + ret = SQLGetData_ptr(hStmt, i, SQL_C_DOUBLE, &doubleValue, 0, + NULL); if (SQL_SUCCEEDED(ret)) { row.append(doubleValue); } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " + LOG("Error retrieving data for column - {}, data type - " + "{}, SQLGetData return " "code - {}. Returning NULL value instead", i, dataType, ret); row.append(py::none()); @@ -3021,11 +3551,13 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p } case SQL_BIGINT: { SQLBIGINT bigintValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_SBIGINT, &bigintValue, 0, NULL); + ret = SQLGetData_ptr(hStmt, i, SQL_C_SBIGINT, &bigintValue, 0, + NULL); if (SQL_SUCCEEDED(ret)) { - row.append(static_cast(bigintValue)); + row.append(static_cast(bigintValue)); } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " + LOG("Error retrieving data for column - {}, data type - " + "{}, SQLGetData return " "code - {}. Returning NULL value instead", i, dataType, ret); row.append(py::none()); @@ -3034,18 +3566,16 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p } case SQL_TYPE_DATE: { SQL_DATE_STRUCT dateValue; - ret = - SQLGetData_ptr(hStmt, i, SQL_C_TYPE_DATE, &dateValue, sizeof(dateValue), NULL); + ret = SQLGetData_ptr(hStmt, i, SQL_C_TYPE_DATE, &dateValue, + sizeof(dateValue), NULL); if (SQL_SUCCEEDED(ret)) { - row.append( - py::module_::import("datetime").attr("date")( - dateValue.year, - dateValue.month, - dateValue.day - ) - ); + row.append(py::module_::import("datetime") + .attr("date")(dateValue.year, + dateValue.month, + dateValue.day)); } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " + LOG("Error retrieving data for column - {}, data type - " + "{}, SQLGetData return " "code - {}. Returning NULL value instead", i, dataType, ret); row.append(py::none()); @@ -3056,18 +3586,16 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p case SQL_TYPE_TIME: case SQL_SS_TIME2: { SQL_TIME_STRUCT timeValue; - ret = - SQLGetData_ptr(hStmt, i, SQL_C_TYPE_TIME, &timeValue, sizeof(timeValue), NULL); + ret = SQLGetData_ptr(hStmt, i, SQL_C_TYPE_TIME, &timeValue, + sizeof(timeValue), NULL); if (SQL_SUCCEEDED(ret)) { - row.append( - py::module_::import("datetime").attr("time")( - timeValue.hour, - timeValue.minute, - timeValue.second - ) - ); + row.append(py::module_::import("datetime") + .attr("time")(timeValue.hour, + timeValue.minute, + timeValue.second)); } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " + LOG("Error retrieving data for column - {}, data type - " + "{}, SQLGetData return " "code - {}. Returning NULL value instead", i, dataType, ret); row.append(py::none()); @@ -3078,22 +3606,21 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { SQL_TIMESTAMP_STRUCT timestampValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_TYPE_TIMESTAMP, ×tampValue, - sizeof(timestampValue), NULL); + ret = SQLGetData_ptr(hStmt, i, SQL_C_TYPE_TIMESTAMP, + ×tampValue, sizeof(timestampValue), + NULL); if (SQL_SUCCEEDED(ret)) { row.append( - py::module_::import("datetime").attr("datetime")( - timestampValue.year, - timestampValue.month, - timestampValue.day, - timestampValue.hour, - timestampValue.minute, - timestampValue.second, - timestampValue.fraction / 1000 // Convert back ns to µs - ) - ); + py::module_::import("datetime") + .attr("datetime")( + timestampValue.year, timestampValue.month, + timestampValue.day, timestampValue.hour, + timestampValue.minute, timestampValue.second, + timestampValue.fraction / + 1000)); // Convert back ns to µs } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " + LOG("Error retrieving data for column - {}, data type - " + "{}, SQLGetData return " "code - {}. Returning NULL value instead", i, dataType, ret); row.append(py::none()); @@ -3103,48 +3630,39 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p case SQL_SS_TIMESTAMPOFFSET: { DateTimeOffset dtoValue; SQLLEN indicator; - ret = SQLGetData_ptr( - hStmt, - i, SQL_C_SS_TIMESTAMPOFFSET, - &dtoValue, - sizeof(dtoValue), - &indicator - ); + ret = SQLGetData_ptr(hStmt, i, SQL_C_SS_TIMESTAMPOFFSET, + &dtoValue, sizeof(dtoValue), &indicator); if (SQL_SUCCEEDED(ret) && indicator != SQL_NULL_DATA) { - LOG("[Fetch] Retrieved DTO: {}-{}-{} {}:{}:{}, fraction(ns)={}, tz_hour={}, tz_minute={}", + LOG("[Fetch] Retrieved DTO: {}-{}-{} {}:{}:{}, " + "fraction(ns)={}, tz_hour={}, tz_minute={}", dtoValue.year, dtoValue.month, dtoValue.day, dtoValue.hour, dtoValue.minute, dtoValue.second, - dtoValue.fraction, - dtoValue.timezone_hour, dtoValue.timezone_minute - ); + dtoValue.fraction, dtoValue.timezone_hour, + dtoValue.timezone_minute); - int totalMinutes = dtoValue.timezone_hour * 60 + dtoValue.timezone_minute; + int totalMinutes = + dtoValue.timezone_hour * 60 + dtoValue.timezone_minute; // Validating offset if (totalMinutes < -24 * 60 || totalMinutes > 24 * 60) { std::ostringstream oss; - oss << "Invalid timezone offset from SQL_SS_TIMESTAMPOFFSET_STRUCT: " + oss << "Invalid timezone offset from " + "SQL_SS_TIMESTAMPOFFSET_STRUCT: " << totalMinutes << " minutes for column " << i; ThrowStdException(oss.str()); } // Convert fraction from ns to µs int microseconds = dtoValue.fraction / 1000; py::object datetime = py::module_::import("datetime"); - py::object tzinfo = datetime.attr("timezone")( - datetime.attr("timedelta")(py::arg("minutes") = totalMinutes) - ); + py::object tzinfo = datetime.attr("timezone")(datetime.attr( + "timedelta")(py::arg("minutes") = totalMinutes)); py::object py_dt = datetime.attr("datetime")( - dtoValue.year, - dtoValue.month, - dtoValue.day, - dtoValue.hour, - dtoValue.minute, - dtoValue.second, - microseconds, - tzinfo - ); + dtoValue.year, dtoValue.month, dtoValue.day, + dtoValue.hour, dtoValue.minute, dtoValue.second, + microseconds, tzinfo); row.append(py_dt); } else { - LOG("Error fetching DATETIMEOFFSET for column {}, ret={}", i, ret); + LOG("Error fetching DATETIMEOFFSET for column {}, ret={}", + i, ret); row.append(py::none()); } break; @@ -3152,23 +3670,34 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: { - // Use streaming for large VARBINARY (columnSize unknown or > 8000) - if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > 8000) { + // Use streaming for large VARBINARY (columnSize unknown or + // > 8000) + if (columnSize == SQL_NO_TOTAL || columnSize == 0 || + columnSize > 8000) { LOG("Streaming LOB for column {} (VARBINARY)", i); - row.append(FetchLobColumnData(hStmt, i, SQL_C_BINARY, false, true)); + row.append(FetchLobColumnData(hStmt, i, SQL_C_BINARY, false, + true)); } else { // Small VARBINARY, fetch directly std::vector dataBuffer(columnSize); SQLLEN dataLen; - ret = SQLGetData_ptr(hStmt, i, SQL_C_BINARY, dataBuffer.data(), columnSize, &dataLen); + ret = + SQLGetData_ptr(hStmt, i, SQL_C_BINARY, + dataBuffer.data(), columnSize, &dataLen); if (SQL_SUCCEEDED(ret)) { if (dataLen > 0) { if (static_cast(dataLen) <= columnSize) { - row.append(py::bytes(reinterpret_cast(dataBuffer.data()), dataLen)); + row.append( + py::bytes(reinterpret_cast( + dataBuffer.data()), + dataLen)); } else { - LOG("VARBINARY column {} data truncated, using streaming LOB", i); - row.append(FetchLobColumnData(hStmt, i, SQL_C_BINARY, false, true)); + LOG("VARBINARY column {} data truncated, " + "using streaming LOB", + i); + row.append(FetchLobColumnData( + hStmt, i, SQL_C_BINARY, false, true)); } } else if (dataLen == SQL_NULL_DATA) { row.append(py::none()); @@ -3176,13 +3705,17 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p row.append(py::bytes("")); } else { std::ostringstream oss; - oss << "Unexpected negative length (" << dataLen << ") returned by SQLGetData. ColumnID=" - << i << ", dataType=" << dataType << ", bufferSize=" << columnSize; + oss << "Unexpected negative length (" << dataLen + << ") returned by SQLGetData. ColumnID=" << i + << ", dataType=" << dataType + << ", bufferSize=" << columnSize; LOG("Error: {}", oss.str()); ThrowStdException(oss.str()); } } else { - LOG("Error retrieving VARBINARY data for column {}. SQLGetData rc = {}", i, ret); + LOG("Error retrieving VARBINARY data for column {}. " + "SQLGetData rc = {}", + i, ret); row.append(py::none()); } } @@ -3190,12 +3723,14 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p } case SQL_TINYINT: { SQLCHAR tinyIntValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_TINYINT, &tinyIntValue, 0, NULL); + ret = SQLGetData_ptr(hStmt, i, SQL_C_TINYINT, &tinyIntValue, 0, + NULL); if (SQL_SUCCEEDED(ret)) { row.append(static_cast(tinyIntValue)); } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", + LOG("Error retrieving data for column - {}, data type - " + "{}, SQLGetData return code - {}. Returning NULL " + "value instead", i, dataType, ret); row.append(py::none()); } @@ -3207,8 +3742,9 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p if (SQL_SUCCEEDED(ret)) { row.append(static_cast(bitValue)); } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", + LOG("Error retrieving data for column - {}, data type - " + "{}, SQLGetData return code - {}. Returning NULL " + "value instead", i, dataType, ret); row.append(py::none()); } @@ -3218,30 +3754,43 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p case SQL_GUID: { SQLGUID guidValue; SQLLEN indicator; - ret = SQLGetData_ptr(hStmt, i, SQL_C_GUID, &guidValue, sizeof(guidValue), &indicator); + ret = SQLGetData_ptr(hStmt, i, SQL_C_GUID, &guidValue, + sizeof(guidValue), &indicator); if (SQL_SUCCEEDED(ret) && indicator != SQL_NULL_DATA) { std::vector guid_bytes(16); - guid_bytes[0] = ((char*)&guidValue.Data1)[3]; - guid_bytes[1] = ((char*)&guidValue.Data1)[2]; - guid_bytes[2] = ((char*)&guidValue.Data1)[1]; - guid_bytes[3] = ((char*)&guidValue.Data1)[0]; - guid_bytes[4] = ((char*)&guidValue.Data2)[1]; - guid_bytes[5] = ((char*)&guidValue.Data2)[0]; - guid_bytes[6] = ((char*)&guidValue.Data3)[1]; - guid_bytes[7] = ((char*)&guidValue.Data3)[0]; + guid_bytes[0] = + reinterpret_cast(&guidValue.Data1)[3]; + guid_bytes[1] = + reinterpret_cast(&guidValue.Data1)[2]; + guid_bytes[2] = + reinterpret_cast(&guidValue.Data1)[1]; + guid_bytes[3] = + reinterpret_cast(&guidValue.Data1)[0]; + guid_bytes[4] = + reinterpret_cast(&guidValue.Data2)[1]; + guid_bytes[5] = + reinterpret_cast(&guidValue.Data2)[0]; + guid_bytes[6] = + reinterpret_cast(&guidValue.Data3)[1]; + guid_bytes[7] = + reinterpret_cast(&guidValue.Data3)[0]; // Secure copy: Fixed 8-byte copy for GUID Data4 field - std::copy_n(guidValue.Data4, sizeof(guidValue.Data4), &guid_bytes[8]); + std::copy_n(guidValue.Data4, sizeof(guidValue.Data4), + &guid_bytes[8]); - py::bytes py_guid_bytes(guid_bytes.data(), guid_bytes.size()); + py::bytes py_guid_bytes(guid_bytes.data(), + guid_bytes.size()); py::object uuid_module = py::module_::import("uuid"); - py::object uuid_obj = uuid_module.attr("UUID")(py::arg("bytes")=py_guid_bytes); + py::object uuid_obj = uuid_module.attr("UUID")( + py::arg("bytes") = py_guid_bytes); row.append(uuid_obj); } else if (indicator == SQL_NULL_DATA) { row.append(py::none()); } else { - LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " - "code - {}. Returning NULL value instead", + LOG("Error retrieving data for column - {}, data type - " + "{}, SQLGetData return code - {}. Returning NULL " + "value instead", i, dataType, ret); row.append(py::none()); } @@ -3250,8 +3799,9 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p #endif default: std::ostringstream errorString; - errorString << "Unsupported data type for column - " << columnName << ", Type - " - << dataType << ", column ID - " << i; + errorString << "Unsupported data type for column - " + << columnName << ", Type - " << dataType + << ", column ID - " << i; LOG(errorString.str()); ThrowStdException(errorString.str()); break; @@ -3260,36 +3810,41 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p return ret; } -SQLRETURN SQLFetchScroll_wrap(SqlHandlePtr StatementHandle, SQLSMALLINT FetchOrientation, SQLLEN FetchOffset, py::list& row_data) { - LOG("Fetching with scroll: orientation={}, offset={}", FetchOrientation, FetchOffset); +SQLRETURN SQLFetchScroll_wrap(SqlHandlePtr StatementHandle, + SQLSMALLINT FetchOrientation, SQLLEN FetchOffset, + py::list& row_data) { + LOG("Fetching with scroll: orientation={}, offset={}", FetchOrientation, + FetchOffset); if (!SQLFetchScroll_ptr) { LOG("Function pointer not initialized. Loading the driver."); DriverLoader::getInstance().loadDriver(); // Load the driver } - // Unbind any columns from previous fetch operations to avoid memory corruption + // Unbind any columns from previous fetch operations to avoid memory + // corruption SQLFreeStmt_ptr(StatementHandle->get(), SQL_UNBIND); - + // Perform scroll operation - SQLRETURN ret = SQLFetchScroll_ptr(StatementHandle->get(), FetchOrientation, FetchOffset); - + SQLRETURN ret = SQLFetchScroll_ptr(StatementHandle->get(), FetchOrientation, + FetchOffset); + // If successful and caller wants data, retrieve it if (SQL_SUCCEEDED(ret) && row_data.size() == 0) { // Get column count SQLSMALLINT colCount = SQLNumResultCols_wrap(StatementHandle); - + // Get the data in a consistent way with other fetch methods ret = SQLGetData_wrap(StatementHandle, colCount, row_data); } - + return ret; } - // For column in the result set, binds a buffer to retrieve column data // TODO: Move to anonymous namespace, since it is not used outside this file -SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames, - SQLUSMALLINT numCols, int fetchSize) { +SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, + py::list& columnNames, SQLUSMALLINT numCols, + int fetchSize) { SQLRETURN ret = SQL_SUCCESS; // Bind columns based on their data types for (SQLUSMALLINT col = 1; col <= numCols; col++) { @@ -3301,20 +3856,25 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { - // TODO: handle variable length data correctly. This logic wont suffice + // TODO: handle variable length data correctly. This logic + // wont suffice HandleZeroColumnSizeAtFetch(columnSize); uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; - // TODO: For LONGVARCHAR/BINARY types, columnSize is returned as 2GB-1 by - // SQLDescribeCol. So fetchBufferSize = 2GB. fetchSize=1 if columnSize>1GB. - // So we'll allocate a vector of size 2GB. If a query fetches multiple (say N) - // LONG... columns, we will have allocated multiple (N) 2GB sized vectors. This - // will make driver very slow. And if the N is high enough, we could hit the OS - // limit for heap memory that we can allocate, & hence get a std::bad_alloc. The - // process could also be killed by OS for consuming too much memory. - // Hence this will be revisited in beta to not allocate 2GB+ memory, - // & use streaming instead - buffers.charBuffers[col - 1].resize(fetchSize * fetchBufferSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_CHAR, buffers.charBuffers[col - 1].data(), + // TODO: For LONGVARCHAR/BINARY types, columnSize is returned + // as 2GB-1 by SQLDescribeCol. So fetchBufferSize = 2GB. + // fetchSize=1 if columnSize>1GB. So we'll allocate a vector + // of size 2GB. If a query fetches multiple (say N) LONG... + // columns, we will have allocated multiple (N) 2GB sized + // vectors. This will make driver very slow. And if the N is + // high enough, we could hit the OS limit for heap memory that + // we can allocate, & hence get a std::bad_alloc. The process + // could also be killed by OS for consuming too much memory. + // Hence this will be revisited in beta to not allocate 2GB+ + // memory, & use streaming instead + buffers.charBuffers[col - 1].resize(fetchSize * + fetchBufferSize); + ret = SQLBindCol_ptr(hStmt, col, SQL_C_CHAR, + buffers.charBuffers[col - 1].data(), fetchBufferSize * sizeof(SQLCHAR), buffers.indicators[col - 1].data()); break; @@ -3322,118 +3882,143 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column case SQL_WCHAR: case SQL_WVARCHAR: case SQL_WLONGVARCHAR: { - // TODO: handle variable length data correctly. This logic wont suffice + // TODO: handle variable length data correctly. This logic + // wont suffice HandleZeroColumnSizeAtFetch(columnSize); uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; - buffers.wcharBuffers[col - 1].resize(fetchSize * fetchBufferSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_WCHAR, buffers.wcharBuffers[col - 1].data(), + buffers.wcharBuffers[col - 1].resize(fetchSize * + fetchBufferSize); + ret = SQLBindCol_ptr(hStmt, col, SQL_C_WCHAR, + buffers.wcharBuffers[col - 1].data(), fetchBufferSize * sizeof(SQLWCHAR), buffers.indicators[col - 1].data()); break; } case SQL_INTEGER: buffers.intBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_SLONG, buffers.intBuffers[col - 1].data(), - sizeof(SQLINTEGER), buffers.indicators[col - 1].data()); + ret = SQLBindCol_ptr( + hStmt, col, SQL_C_SLONG, buffers.intBuffers[col - 1].data(), + sizeof(SQLINTEGER), buffers.indicators[col - 1].data()); break; case SQL_SMALLINT: buffers.smallIntBuffers[col - 1].resize(fetchSize); ret = SQLBindCol_ptr(hStmt, col, SQL_C_SSHORT, - buffers.smallIntBuffers[col - 1].data(), sizeof(SQLSMALLINT), + buffers.smallIntBuffers[col - 1].data(), + sizeof(SQLSMALLINT), buffers.indicators[col - 1].data()); break; case SQL_TINYINT: buffers.charBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_TINYINT, buffers.charBuffers[col - 1].data(), - sizeof(SQLCHAR), buffers.indicators[col - 1].data()); + ret = SQLBindCol_ptr(hStmt, col, SQL_C_TINYINT, + buffers.charBuffers[col - 1].data(), + sizeof(SQLCHAR), + buffers.indicators[col - 1].data()); break; case SQL_BIT: buffers.charBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_BIT, buffers.charBuffers[col - 1].data(), - sizeof(SQLCHAR), buffers.indicators[col - 1].data()); + ret = SQLBindCol_ptr( + hStmt, col, SQL_C_BIT, buffers.charBuffers[col - 1].data(), + sizeof(SQLCHAR), buffers.indicators[col - 1].data()); break; case SQL_REAL: buffers.realBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_FLOAT, buffers.realBuffers[col - 1].data(), - sizeof(SQLREAL), buffers.indicators[col - 1].data()); + ret = SQLBindCol_ptr(hStmt, col, SQL_C_FLOAT, + buffers.realBuffers[col - 1].data(), + sizeof(SQLREAL), + buffers.indicators[col - 1].data()); break; case SQL_DECIMAL: case SQL_NUMERIC: - buffers.charBuffers[col - 1].resize(fetchSize * MAX_DIGITS_IN_NUMERIC); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_CHAR, buffers.charBuffers[col - 1].data(), + buffers.charBuffers[col - 1].resize(fetchSize * + MAX_DIGITS_IN_NUMERIC); + ret = SQLBindCol_ptr(hStmt, col, SQL_C_CHAR, + buffers.charBuffers[col - 1].data(), MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), buffers.indicators[col - 1].data()); break; case SQL_DOUBLE: case SQL_FLOAT: buffers.doubleBuffers[col - 1].resize(fetchSize); - ret = - SQLBindCol_ptr(hStmt, col, SQL_C_DOUBLE, buffers.doubleBuffers[col - 1].data(), - sizeof(SQLDOUBLE), buffers.indicators[col - 1].data()); + ret = SQLBindCol_ptr(hStmt, col, SQL_C_DOUBLE, + buffers.doubleBuffers[col - 1].data(), + sizeof(SQLDOUBLE), + buffers.indicators[col - 1].data()); break; case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: buffers.timestampBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr( - hStmt, col, SQL_C_TYPE_TIMESTAMP, buffers.timestampBuffers[col - 1].data(), - sizeof(SQL_TIMESTAMP_STRUCT), buffers.indicators[col - 1].data()); + ret = SQLBindCol_ptr(hStmt, col, SQL_C_TYPE_TIMESTAMP, + buffers.timestampBuffers[col - 1].data(), + sizeof(SQL_TIMESTAMP_STRUCT), + buffers.indicators[col - 1].data()); break; case SQL_BIGINT: buffers.bigIntBuffers[col - 1].resize(fetchSize); - ret = - SQLBindCol_ptr(hStmt, col, SQL_C_SBIGINT, buffers.bigIntBuffers[col - 1].data(), - sizeof(SQLBIGINT), buffers.indicators[col - 1].data()); + ret = SQLBindCol_ptr(hStmt, col, SQL_C_SBIGINT, + buffers.bigIntBuffers[col - 1].data(), + sizeof(SQLBIGINT), + buffers.indicators[col - 1].data()); break; case SQL_TYPE_DATE: buffers.dateBuffers[col - 1].resize(fetchSize); - ret = - SQLBindCol_ptr(hStmt, col, SQL_C_TYPE_DATE, buffers.dateBuffers[col - 1].data(), - sizeof(SQL_DATE_STRUCT), buffers.indicators[col - 1].data()); + ret = SQLBindCol_ptr(hStmt, col, SQL_C_TYPE_DATE, + buffers.dateBuffers[col - 1].data(), + sizeof(SQL_DATE_STRUCT), + buffers.indicators[col - 1].data()); break; case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: buffers.timeBuffers[col - 1].resize(fetchSize); - ret = - SQLBindCol_ptr(hStmt, col, SQL_C_TYPE_TIME, buffers.timeBuffers[col - 1].data(), - sizeof(SQL_TIME_STRUCT), buffers.indicators[col - 1].data()); + ret = SQLBindCol_ptr(hStmt, col, SQL_C_TYPE_TIME, + buffers.timeBuffers[col - 1].data(), + sizeof(SQL_TIME_STRUCT), + buffers.indicators[col - 1].data()); break; case SQL_GUID: buffers.guidBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_GUID, buffers.guidBuffers[col - 1].data(), - sizeof(SQLGUID), buffers.indicators[col - 1].data()); + ret = SQLBindCol_ptr( + hStmt, col, SQL_C_GUID, buffers.guidBuffers[col - 1].data(), + sizeof(SQLGUID), buffers.indicators[col - 1].data()); break; case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: - // TODO: handle variable length data correctly. This logic wont suffice + // TODO: handle variable length data correctly. This logic + // wont suffice HandleZeroColumnSizeAtFetch(columnSize); buffers.charBuffers[col - 1].resize(fetchSize * columnSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_BINARY, buffers.charBuffers[col - 1].data(), - columnSize, buffers.indicators[col - 1].data()); + ret = SQLBindCol_ptr(hStmt, col, SQL_C_BINARY, + buffers.charBuffers[col - 1].data(), + columnSize, + buffers.indicators[col - 1].data()); break; case SQL_SS_TIMESTAMPOFFSET: buffers.datetimeoffsetBuffers[col - 1].resize(fetchSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_SS_TIMESTAMPOFFSET, - buffers.datetimeoffsetBuffers[col - 1].data(), - sizeof(DateTimeOffset) * fetchSize, - buffers.indicators[col - 1].data()); + ret = SQLBindCol_ptr( + hStmt, col, SQL_C_SS_TIMESTAMPOFFSET, + buffers.datetimeoffsetBuffers[col - 1].data(), + sizeof(DateTimeOffset) * fetchSize, + buffers.indicators[col - 1].data()); break; default: - std::wstring columnName = columnMeta["ColumnName"].cast(); + std::wstring columnName = + columnMeta["ColumnName"].cast(); std::ostringstream errorString; - errorString << "Unsupported data type for column - " << columnName.c_str() - << ", Type - " << dataType << ", column ID - " << col; + errorString << "Unsupported data type for column - " + << columnName.c_str() << ", Type - " << dataType + << ", column ID - " << col; LOG(errorString.str()); ThrowStdException(errorString.str()); break; } if (!SQL_SUCCEEDED(ret)) { - std::wstring columnName = columnMeta["ColumnName"].cast(); + std::wstring columnName = + columnMeta["ColumnName"].cast(); std::ostringstream errorString; - errorString << "Failed to bind column - " << columnName.c_str() << ", Type - " - << dataType << ", column ID - " << col; + errorString << "Failed to bind column - " << columnName.c_str() + << ", Type - " << dataType << ", column ID - " << col; LOG(errorString.str()); ThrowStdException(errorString.str()); return ret; @@ -3444,12 +4029,15 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column // Fetch rows in batches // TODO: Move to anonymous namespace, since it is not used outside this file -SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames, - py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched, +SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, + py::list& columnNames, py::list& rows, + SQLUSMALLINT numCols, SQLULEN& numRowsFetched, const std::vector& lobColumns, const std::string& char_encoding = "utf-8", const std::string& wchar_encoding = "utf-16le") { - UNREFERENCED_PARAMETER(wchar_encoding); // SQL_WCHAR behavior unchanged, keeping parameter for API consistency + UNREFERENCED_PARAMETER(wchar_encoding); // SQL_WCHAR behavior unchanged, + // keeping parameter for API + // consistency LOG("Fetching data in batches"); SQLRETURN ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0); if (ret == SQL_NO_DATA) { @@ -3460,8 +4048,8 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum LOG("Error while fetching rows in batches"); return ret; } - // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. It'll be populated by - // SQLFetchScroll + // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. It'll be + // populated by SQLFetchScroll for (SQLULEN i = 0; i < numRowsFetched; i++) { py::list row; for (SQLUSMALLINT col = 1; col <= numCols; col++) { @@ -3473,46 +4061,66 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum row.append(py::none()); continue; } - // TODO: variable length data needs special handling, this logic wont suffice - // This value indicates that the driver cannot determine the length of the data + // TODO: variable length data needs special handling, this logic + // wont suffice + // This value indicates that the driver cannot determine the + // length of the data if (dataLen == SQL_NO_TOTAL) { - LOG("Cannot determine the length of the data. Returning NULL value instead." - "Column ID - {}", col); + LOG("Cannot determine the length of the data. Returning " + "NULL value instead. Column ID - {}", + col); row.append(py::none()); continue; } else if (dataLen == SQL_NULL_DATA) { - LOG("Column data is NULL. Appending None to the result row. Column ID - {}", col); + LOG("Column data is NULL. Appending None to the result " + "row. Column ID - {}", + col); row.append(py::none()); continue; } else if (dataLen == 0) { // Handle zero-length (non-NULL) data - if (dataType == SQL_CHAR || dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR) { + if (dataType == SQL_CHAR || dataType == SQL_VARCHAR || + dataType == SQL_LONGVARCHAR) { // Apply dynamic encoding for SQL_CHAR types if (!char_encoding.empty()) { try { - py::str decoded_str = DecodingString("", 0, char_encoding, "strict"); + py::str decoded_str = + DecodingString("", 0, char_encoding, "strict"); row.append(decoded_str); } catch (const std::exception& e) { - LOG("Decoding failed for empty SQL_CHAR data: {}", e.what()); + LOG("Decoding failed for empty SQL_CHAR data: {}", + e.what()); row.append(std::string("")); } } else { row.append(std::string("")); } - } else if (dataType == SQL_WCHAR || dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR) { + } else if (dataType == SQL_WCHAR || dataType == SQL_WVARCHAR || + dataType == SQL_WLONGVARCHAR) { row.append(std::wstring(L"")); - } else if (dataType == SQL_BINARY || dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) { + } else if (dataType == SQL_BINARY || + dataType == SQL_VARBINARY || + dataType == SQL_LONGVARBINARY) { row.append(py::bytes("")); } else { - // For other datatypes, 0 length is unexpected. Log & append None - LOG("Column data length is 0 for non-string/binary datatype. Appending None to the result row. Column ID - {}", col); + // For other datatypes, 0 length is unexpected. Log & + // append None + LOG("Column data length is 0 for non-string/binary " + "datatype. Appending None to the result row. Column " + "ID - {}", + col); row.append(py::none()); } continue; } else if (dataLen < 0) { - // Negative value is unexpected, log column index, SQL type & raise exception - LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); - ThrowStdException("Unexpected negative data length, check logs for details"); + // Negative value is unexpected, log column index, SQL type & + // raise exception + LOG("Unexpected negative data length. Column ID - {}, " + "SQL Type - {}, Data Length - {}", + col, dataType, dataLen); + ThrowStdException( + "Unexpected negative data length, check " + "logs for details"); } assert(dataLen > 0 && "Data length must be > 0"); @@ -3520,60 +4128,83 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + SQLULEN columnSize = + columnMeta["ColumnSize"].cast(); HandleZeroColumnSizeAtFetch(columnSize); - uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; + uint64_t fetchBufferSize = + columnSize + 1 /*null-terminator*/; uint64_t numCharsInData = dataLen / sizeof(SQLCHAR); - bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); - // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' + bool isLob = std::find(lobColumns.begin(), lobColumns.end(), + col) != lobColumns.end(); + // fetchBufferSize includes null-terminator, numCharsInData + // doesn't. Hence '<' if (!isLob && numCharsInData < fetchBufferSize) { // Apply dynamic decoding for SQL_CHAR types try { py::str decoded_str = DecodingString( - reinterpret_cast(&buffers.charBuffers[col - 1][i * fetchBufferSize]), - numCharsInData, - char_encoding, - "strict" - ); + reinterpret_cast( + &buffers.charBuffers[col - 1] + [i * fetchBufferSize]), + numCharsInData, char_encoding, "strict"); row.append(decoded_str); - LOG("Applied dynamic decoding for batch CHAR column {} using encoding '{}'", col, char_encoding); + LOG("Applied dynamic decoding for batch CHAR " + "column {} using encoding '{}'", + col, char_encoding); } catch (const std::exception& e) { - LOG("Dynamic decoding failed for batch column {}: {}. Using fallback.", col, e.what()); + LOG("Dynamic decoding failed for batch column " + "{}: {}. Using fallback.", + col, e.what()); // Fallback to original logic row.append(std::string( - reinterpret_cast(&buffers.charBuffers[col - 1][i * fetchBufferSize]), + reinterpret_cast( + &buffers.charBuffers[col - 1] + [i * fetchBufferSize]), numCharsInData)); } } else { - row.append(FetchLobColumnData(hStmt, col, SQL_C_CHAR, false, false, char_encoding)); + row.append(FetchLobColumnData(hStmt, col, SQL_C_CHAR, + false, false, + char_encoding)); } break; } case SQL_WCHAR: case SQL_WVARCHAR: case SQL_WLONGVARCHAR: { - // TODO: variable length data needs special handling, this logic wont suffice - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + // TODO: variable length data needs special handling, this + // logic wont suffice + SQLULEN columnSize = + columnMeta["ColumnSize"].cast(); HandleZeroColumnSizeAtFetch(columnSize); - uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; - uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); - bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); - // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' + uint64_t fetchBufferSize = + columnSize + 1 /*null-terminator*/; + uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); + bool isLob = std::find(lobColumns.begin(), lobColumns.end(), + col) != lobColumns.end(); + // fetchBufferSize includes null-terminator, numCharsInData + // doesn't. Hence '<' if (!isLob && numCharsInData < fetchBufferSize) { // SQLFetch will nullterminate the data #if defined(__APPLE__) || defined(__linux__) - // Use unix-specific conversion to handle the wchar_t/SQLWCHAR size difference - SQLWCHAR* wcharData = &buffers.wcharBuffers[col - 1][i * fetchBufferSize]; - std::wstring wstr = SQLWCHARToWString(wcharData, numCharsInData); + // Use unix-specific conversion to handle the + // wchar_t/SQLWCHAR size difference + SQLWCHAR* wcharData = + &buffers.wcharBuffers[col - 1][i * fetchBufferSize]; + std::wstring wstr = + SQLWCHARToWString(wcharData, numCharsInData); row.append(wstr); #else - // On Windows, wchar_t and SQLWCHAR are both 2 bytes, so direct cast works + // On Windows, wchar_t and SQLWCHAR are both 2 bytes, so + // direct cast works row.append(std::wstring( - reinterpret_cast(&buffers.wcharBuffers[col - 1][i * fetchBufferSize]), + reinterpret_cast( + &buffers.wcharBuffers[col - 1] + [i * fetchBufferSize]), numCharsInData)); #endif } else { - row.append(FetchLobColumnData(hStmt, col, SQL_C_WCHAR, true, false)); + row.append(FetchLobColumnData(hStmt, col, SQL_C_WCHAR, + true, false)); } break; } @@ -3590,7 +4221,8 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum break; } case SQL_BIT: { - row.append(static_cast(buffers.charBuffers[col - 1][i])); + row.append( + static_cast(buffers.charBuffers[col - 1][i])); break; } case SQL_REAL: { @@ -3600,26 +4232,33 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum case SQL_DECIMAL: case SQL_NUMERIC: { try { - // Convert the string to use the current decimal separator - std::string numStr(reinterpret_cast( - &buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]), + // Convert the string to use the current decimal + // separator + std::string numStr( + reinterpret_cast( + &buffers + .charBuffers[col - 1] + [i * MAX_DIGITS_IN_NUMERIC]), buffers.indicators[col - 1][i]); - + // Get the current separator in a thread-safe way std::string separator = GetDecimalSeparator(); - + if (separator != ".") { - // Replace the driver's decimal point with our configured separator + // Replace the driver's decimal point with our + // configured separator size_t pos = numStr.find('.'); if (pos != std::string::npos) { numStr.replace(pos, 1, separator); } } - + // Convert to Python decimal - row.append(py::module_::import("decimal").attr("Decimal")(numStr)); + row.append(py::module_::import("decimal").attr( + "Decimal")(numStr)); } catch (const py::error_already_set& e) { - // Handle the exception, e.g., log the error and append py::none() + // Handle the exception, e.g., log the error and append + // py::none() LOG("Error converting to decimal: {}", e.what()); row.append(py::none()); } @@ -3633,14 +4272,17 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { - row.append(py::module_::import("datetime") - .attr("datetime")(buffers.timestampBuffers[col - 1][i].year, - buffers.timestampBuffers[col - 1][i].month, - buffers.timestampBuffers[col - 1][i].day, - buffers.timestampBuffers[col - 1][i].hour, - buffers.timestampBuffers[col - 1][i].minute, - buffers.timestampBuffers[col - 1][i].second, - buffers.timestampBuffers[col - 1][i].fraction / 1000 /* Convert back ns to µs */)); + row.append( + py::module_::import("datetime") + .attr("datetime")( + buffers.timestampBuffers[col - 1][i].year, + buffers.timestampBuffers[col - 1][i].month, + buffers.timestampBuffers[col - 1][i].day, + buffers.timestampBuffers[col - 1][i].hour, + buffers.timestampBuffers[col - 1][i].minute, + buffers.timestampBuffers[col - 1][i].second, + buffers.timestampBuffers[col - 1][i].fraction / + 1000 /* Convert back ns to µs */)); break; } case SQL_BIGINT: { @@ -3648,41 +4290,40 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum break; } case SQL_TYPE_DATE: { - row.append(py::module_::import("datetime") - .attr("date")(buffers.dateBuffers[col - 1][i].year, - buffers.dateBuffers[col - 1][i].month, - buffers.dateBuffers[col - 1][i].day)); + row.append( + py::module_::import("datetime") + .attr("date")(buffers.dateBuffers[col - 1][i].year, + buffers.dateBuffers[col - 1][i].month, + buffers.dateBuffers[col - 1][i].day)); break; } case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: { row.append(py::module_::import("datetime") - .attr("time")(buffers.timeBuffers[col - 1][i].hour, - buffers.timeBuffers[col - 1][i].minute, - buffers.timeBuffers[col - 1][i].second)); + .attr("time")( + buffers.timeBuffers[col - 1][i].hour, + buffers.timeBuffers[col - 1][i].minute, + buffers.timeBuffers[col - 1][i].second)); break; } case SQL_SS_TIMESTAMPOFFSET: { SQLULEN rowIdx = i; - const DateTimeOffset& dtoValue = buffers.datetimeoffsetBuffers[col - 1][rowIdx]; + const DateTimeOffset& dtoValue = + buffers.datetimeoffsetBuffers[col - 1][rowIdx]; SQLLEN indicator = buffers.indicators[col - 1][rowIdx]; if (indicator != SQL_NULL_DATA) { - int totalMinutes = dtoValue.timezone_hour * 60 + dtoValue.timezone_minute; + int totalMinutes = dtoValue.timezone_hour * 60 + + dtoValue.timezone_minute; py::object datetime = py::module_::import("datetime"); py::object tzinfo = datetime.attr("timezone")( - datetime.attr("timedelta")(py::arg("minutes") = totalMinutes) - ); + datetime.attr("timedelta")(py::arg("minutes") = + totalMinutes)); py::object py_dt = datetime.attr("datetime")( - dtoValue.year, - dtoValue.month, - dtoValue.day, - dtoValue.hour, - dtoValue.minute, - dtoValue.second, + dtoValue.year, dtoValue.month, dtoValue.day, + dtoValue.hour, dtoValue.minute, dtoValue.second, dtoValue.fraction / 1000, // ns → µs - tzinfo - ); + tzinfo); row.append(py_dt); } else { row.append(py::none()); @@ -3697,44 +4338,60 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum } SQLGUID* guidValue = &buffers.guidBuffers[col - 1][i]; uint8_t reordered[16]; - reordered[0] = ((char*)&guidValue->Data1)[3]; - reordered[1] = ((char*)&guidValue->Data1)[2]; - reordered[2] = ((char*)&guidValue->Data1)[1]; - reordered[3] = ((char*)&guidValue->Data1)[0]; - reordered[4] = ((char*)&guidValue->Data2)[1]; - reordered[5] = ((char*)&guidValue->Data2)[0]; - reordered[6] = ((char*)&guidValue->Data3)[1]; - reordered[7] = ((char*)&guidValue->Data3)[0]; + reordered[0] = + reinterpret_cast(&guidValue->Data1)[3]; + reordered[1] = + reinterpret_cast(&guidValue->Data1)[2]; + reordered[2] = + reinterpret_cast(&guidValue->Data1)[1]; + reordered[3] = + reinterpret_cast(&guidValue->Data1)[0]; + reordered[4] = + reinterpret_cast(&guidValue->Data2)[1]; + reordered[5] = + reinterpret_cast(&guidValue->Data2)[0]; + reordered[6] = + reinterpret_cast(&guidValue->Data3)[1]; + reordered[7] = + reinterpret_cast(&guidValue->Data3)[0]; // Secure copy: Fixed 8-byte copy for GUID Data4 field std::copy_n(guidValue->Data4, 8, reordered + 8); - py::bytes py_guid_bytes(reinterpret_cast(reordered), 16); + py::bytes py_guid_bytes(reinterpret_cast(reordered), + 16); py::dict kwargs; kwargs["bytes"] = py_guid_bytes; - py::object uuid_obj = py::module_::import("uuid").attr("UUID")(**kwargs); + py::object uuid_obj = + py::module_::import("uuid").attr("UUID")(**kwargs); row.append(uuid_obj); break; } case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: { - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + SQLULEN columnSize = + columnMeta["ColumnSize"].cast(); HandleZeroColumnSizeAtFetch(columnSize); - bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); + bool isLob = std::find(lobColumns.begin(), lobColumns.end(), + col) != lobColumns.end(); if (!isLob && static_cast(dataLen) <= columnSize) { - row.append(py::bytes(reinterpret_cast( - &buffers.charBuffers[col - 1][i * columnSize]), - dataLen)); + row.append(py::bytes( + reinterpret_cast( + &buffers.charBuffers[col - 1][i * columnSize]), + dataLen)); } else { - row.append(FetchLobColumnData(hStmt, col, SQL_C_BINARY, false, true)); + row.append(FetchLobColumnData(hStmt, col, SQL_C_BINARY, + false, true)); } break; } default: { - std::wstring columnName = columnMeta["ColumnName"].cast(); + std::wstring columnName = + columnMeta["ColumnName"].cast(); std::ostringstream errorString; - errorString << "Unsupported data type for column - " << columnName.c_str() - << ", Type - " << dataType << ", column ID - " << col; + errorString << "Unsupported data type for column - " + << columnName.c_str() << ", Type - " << dataType + << ", column ID - " << col; LOG(errorString.str()); ThrowStdException(errorString.str()); break; @@ -3746,8 +4403,8 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum return ret; } -// Given a list of columns that are a part of single row in the result set, calculates -// the max size of the row +// Given a list of columns that are a part of single row in the result set, +// calculates the max size of the row // TODO: Move to anonymous namespace, since it is not used outside this file size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) { size_t rowSize = 0; @@ -3819,10 +4476,12 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) { rowSize += sizeof(DateTimeOffset); break; default: - std::wstring columnName = columnMeta["ColumnName"].cast(); + std::wstring columnName = + columnMeta["ColumnName"].cast(); std::ostringstream errorString; - errorString << "Unsupported data type for column - " << columnName.c_str() - << ", Type - " << dataType << ", column ID - " << col; + errorString << "Unsupported data type for column - " + << columnName.c_str() << ", Type - " << dataType + << ", column ID - " << col; LOG(errorString.str()); ThrowStdException(errorString.str()); break; @@ -3833,22 +4492,29 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) { // FetchMany_wrap - Fetches multiple rows of data from the result set. // -// @param StatementHandle: Handle to the statement from which data is to be fetched. -// @param rows: A Python list that will be populated with the fetched rows of data. +// @param StatementHandle: Handle to the statement from which data is to be +// fetched. +// @param rows: A Python list that will be populated with the fetched rows of +// data. // @param fetchSize: The number of rows to fetch. Default value is 1. // // @return SQLRETURN: SQL_SUCCESS if data is fetched successfully, -// SQL_NO_DATA if there are no more rows to fetch, -// throws a runtime error if there is an error fetching data. +// SQL_NO_DATA if there are no more rows to fetch, throws +// a runtime error if there is an error fetching data. // -// This function assumes that the statement handle (hStmt) is already allocated and a query has been -// executed. It fetches the specified number of rows from the result set and populates the provided -// Python list with the row data. If there are no more rows to fetch, it returns SQL_NO_DATA. If an -// error occurs during fetching, it throws a runtime error. -SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetchSize = 1, +// This function assumes that the statement handle (hStmt) is already allocated +// and a query has been executed. It fetches the specified number of rows from +// the result set and populates the provided Python list with the row data. If +// there are no more rows to fetch, it returns SQL_NO_DATA. If an error occurs +// during fetching, it throws a runtime error. +SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, + int fetchSize = 1, const std::string& char_encoding = "utf-8", const std::string& wchar_encoding = "utf-16le") { - UNREFERENCED_PARAMETER(wchar_encoding); // SQL_WCHAR behavior unchanged, keeping parameter for API consistency + UNREFERENCED_PARAMETER( + wchar_encoding); // SQL_WCHAR behavior + // unchanged, + // keeping parameter for API consistency SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); // Retrieve column count @@ -3868,11 +4534,13 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch SQLSMALLINT dataType = colMeta["DataType"].cast(); SQLULEN columnSize = colMeta["ColumnSize"].cast(); - if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || + if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || - dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && - (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { - lobColumns.push_back(i + 1); // 1-based + dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || + dataType == SQL_SS_XML) && + (columnSize == 0 || columnSize == SQL_NO_TOTAL || + columnSize > SQL_MAX_LOB_SIZE)) { + lobColumns.push_back(i + 1); // 1-based } } @@ -3885,7 +4553,9 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch if (!SQL_SUCCEEDED(ret)) return ret; py::list row; - SQLGetData_wrap(StatementHandle, numCols, row, char_encoding, wchar_encoding); // <-- streams LOBs correctly + // streams LOBs correctly + SQLGetData_wrap(StatementHandle, numCols, row, char_encoding, + wchar_encoding); rows.append(row); } return SQL_SUCCESS; @@ -3902,10 +4572,13 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch } SQLULEN numRowsFetched; - SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, + (SQLPOINTER)(intptr_t)fetchSize, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); - ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns, char_encoding, wchar_encoding); + ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, + numRowsFetched, lobColumns, char_encoding, + wchar_encoding); if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) { LOG("Error when fetching data"); return ret; @@ -3920,21 +4593,27 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch // FetchAll_wrap - Fetches all rows of data from the result set. // -// @param StatementHandle: Handle to the statement from which data is to be fetched. -// @param rows: A Python list that will be populated with the fetched rows of data. +// @param StatementHandle: Handle to the statement from which data is to be +// fetched. +// @param rows: A Python list that will be populated with the fetched rows of +// data. // // @return SQLRETURN: SQL_SUCCESS if data is fetched successfully, // SQL_NO_DATA if there are no more rows to fetch, -// throws a runtime error if there is an error fetching data. +// throws a runtime error if there is an error fetching +// data. // -// This function assumes that the statement handle (hStmt) is already allocated and a query has been -// executed. It fetches all rows from the result set and populates the provided Python list with the -// row data. If there are no more rows to fetch, it returns SQL_NO_DATA. If an error occurs during -// fetching, it throws a runtime error. +// This function assumes that the statement handle (hStmt) is already allocated +// and a query has been executed. It fetches all rows from the result set and +// populates the provided Python list with the row data. If there are no more +// rows to fetch, it returns SQL_NO_DATA. If an error occurs during fetching, +// it throws a runtime error. SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, const std::string& char_encoding = "utf-8", const std::string& wchar_encoding = "utf-16le") { - UNREFERENCED_PARAMETER(wchar_encoding); // SQL_WCHAR behavior unchanged, keeping parameter for API consistency + UNREFERENCED_PARAMETER(wchar_encoding); // SQL_WCHAR behavior unchanged, + // keeping parameter for API + // consistency SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); // Retrieve column count @@ -3963,24 +4642,28 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, // TODO: Find why NVARCHAR(MAX) returns columnsize 0 // TODO: What if a row has 2 cols, an int & NVARCHAR(MAX)? // totalRowSize will be 4+0 = 4. It wont take NVARCHAR(MAX) - // into account. So, we will end up fetching 1000 rows at a time. + // into account. So, we will end up fetching 1000 rows at + // time. numRowsInMemLimit = 1; // fetchsize will be 10 } - // TODO: Revisit this logic. Eventhough we're fetching fetchSize rows at a time, - // fetchall will keep all rows in memory anyway. So what are we gaining by fetching - // fetchSize rows at a time? - // Also, say the table has only 10 rows, each row size if 100 bytes. Here, we'll have - // fetchSize = 1000, so we'll allocate memory for 1000 rows inside SQLBindCol_wrap, while - // actually only need to retrieve 10 rows + // TODO: Revisit this logic. Eventhough we're fetching fetchSize rows at a + // time, fetchall will keep all rows in memory anyway. So what are we + // gaining by fetching fetchSize rows at a time? + // Also, say the table has only 10 rows, each row size if 100 bytes. + // Here, we'll have fetchSize = 1000, so we'll allocate memory for 1000 + // rows inside SQLBindCol_wrap, while actually only need to retrieve 10 + // rows int fetchSize; if (numRowsInMemLimit == 0) { - // If the row size is larger than the memory limit, fetch one row at a time + // If the row size is larger than the memory limit, fetch one row + // at a time fetchSize = 1; } else if (numRowsInMemLimit > 0 && numRowsInMemLimit <= 100) { // If between 1-100 rows fit in memoryLimit, fetch 10 rows at a time fetchSize = 10; } else if (numRowsInMemLimit > 100 && numRowsInMemLimit <= 1000) { - // If between 100-1000 rows fit in memoryLimit, fetch 100 rows at a time + // If between 100-1000 rows fit in memoryLimit, fetch 100 rows at a + // time fetchSize = 100; } else { fetchSize = 1000; @@ -3993,11 +4676,13 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, SQLSMALLINT dataType = colMeta["DataType"].cast(); SQLULEN columnSize = colMeta["ColumnSize"].cast(); - if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || + if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || - dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && - (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { - lobColumns.push_back(i + 1); // 1-based + dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || + dataType == SQL_SS_XML) && + (columnSize == 0 || columnSize == SQL_NO_TOTAL || + columnSize > SQL_MAX_LOB_SIZE)) { + lobColumns.push_back(i + 1); // 1-based } } @@ -4010,7 +4695,8 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, if (!SQL_SUCCEEDED(ret)) return ret; py::list row; - SQLGetData_wrap(StatementHandle, numCols, row, char_encoding, wchar_encoding); // <-- streams LOBs correctly + SQLGetData_wrap(StatementHandle, numCols, row, char_encoding, + wchar_encoding); // streams LOBs correctly rows.append(row); } return SQL_SUCCESS; @@ -4026,17 +4712,20 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, } SQLULEN numRowsFetched; - SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, + (SQLPOINTER)(intptr_t)fetchSize, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); while (ret != SQL_NO_DATA) { - ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns, char_encoding, wchar_encoding); + ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, + numRowsFetched, lobColumns, char_encoding, + wchar_encoding); if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) { LOG("Error when fetching data"); return ret; } } - + // Reset attributes before returning to avoid using stack pointers later SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); @@ -4046,21 +4735,26 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, // FetchOne_wrap - Fetches a single row of data from the result set. // -// @param StatementHandle: Handle to the statement from which data is to be fetched. +// @param StatementHandle: Handle to the statement from which data is to be +// fetched. // @param row: A Python list that will be populated with the fetched row data. // -// @return SQLRETURN: SQL_SUCCESS or SQL_SUCCESS_WITH_INFO if data is fetched successfully, -// SQL_NO_DATA if there are no more rows to fetch, -// throws a runtime error if there is an error fetching data. +// @return SQLRETURN: SQL_SUCCESS or SQL_SUCCESS_WITH_INFO if data is fetched +// successfully, SQL_NO_DATA if there are no more rows to +// fetch, throws a runtime error if there is an error +// fetching data. // -// This function assumes that the statement handle (hStmt) is already allocated and a query has been -// executed. It fetches the next row of data from the result set and populates the provided Python -// list with the row data. If there are no more rows to fetch, it returns SQL_NO_DATA. If an error +// This function assumes that the statement handle (hStmt) is already allocated +// and a query has been executed. It fetches the next row of data from the +// result set and populates the provided Python list with the row data. If +// there are no more rows to fetch, it returns SQL_NO_DATA. If an error // occurs during fetching, it throws a runtime error. SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row, const std::string& char_encoding = "utf-8", const std::string& wchar_encoding = "utf-16le") { - UNREFERENCED_PARAMETER(wchar_encoding); // SQL_WCHAR behavior unchanged, keeping parameter for API consistency + UNREFERENCED_PARAMETER(wchar_encoding); // SQL_WCHAR behavior unchanged, + // keeping parameter for API + // consistency SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); @@ -4069,7 +4763,8 @@ SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row, if (SQL_SUCCEEDED(ret)) { // Retrieve column count SQLSMALLINT colCount = SQLNumResultCols_wrap(StatementHandle); - ret = SQLGetData_wrap(StatementHandle, colCount, row, char_encoding, wchar_encoding); + ret = SQLGetData_wrap(StatementHandle, colCount, row, char_encoding, + wchar_encoding); } else if (ret != SQL_NO_DATA) { LOG("Error when fetching data"); } @@ -4137,7 +4832,9 @@ void DDBCSetDecimalSeparator(const std::string& separator) { // Architecture-specific defines #ifndef ARCHITECTURE -#define ARCHITECTURE "win64" // Default to win64 if not defined during compilation +#define ARCHITECTURE \ + "win64" // Default to win64 if not defined during + // compilation #endif // Functions/data to be exposed to Python as a part of ddbc_bindings module @@ -4149,10 +4846,11 @@ PYBIND11_MODULE(ddbc_bindings, m) { // Expose architecture-specific constants m.attr("ARCHITECTURE") = ARCHITECTURE; - + // Expose the C++ functions to Python m.def("ThrowStdException", &ThrowStdException); - m.def("GetDriverPathCpp", &GetDriverPathCpp, "Get the path to the ODBC driver"); + m.def("GetDriverPathCpp", &GetDriverPathCpp, + "Get the path to the ODBC driver"); // Define parameter info class py::class_(m, "ParamInfo") @@ -4179,127 +4877,150 @@ PYBIND11_MODULE(ddbc_bindings, m) { py::class_(m, "ErrorInfo") .def_readwrite("sqlState", &ErrorInfo::sqlState) .def_readwrite("ddbcErrorMsg", &ErrorInfo::ddbcErrorMsg); - + py::class_(m, "SqlHandle") .def("free", &SqlHandle::free, "Free the handle"); - + py::class_(m, "Connection") - .def(py::init(), py::arg("conn_str"), py::arg("use_pool"), py::arg("attrs_before") = py::dict()) + .def(py::init(), + py::arg("conn_str"), py::arg("use_pool"), + py::arg("attrs_before") = py::dict()) .def("close", &ConnectionHandle::close, "Close the connection") - .def("commit", &ConnectionHandle::commit, "Commit the current transaction") - .def("rollback", &ConnectionHandle::rollback, "Rollback the current transaction") + .def("commit", &ConnectionHandle::commit, + "Commit the current transaction") + .def("rollback", &ConnectionHandle::rollback, + "Rollback the current transaction") .def("set_autocommit", &ConnectionHandle::setAutocommit) .def("get_autocommit", &ConnectionHandle::getAutocommit) - .def("set_attr", &ConnectionHandle::setAttr, py::arg("attribute"), py::arg("value"), "Set connection attribute") + .def("set_attr", &ConnectionHandle::setAttr, py::arg("attribute"), + py::arg("value"), "Set connection attribute") .def("alloc_statement_handle", &ConnectionHandle::allocStatementHandle) .def("get_info", &ConnectionHandle::getInfo, py::arg("info_type")); - m.def("enable_pooling", &enable_pooling, "Enable global connection pooling"); - m.def("close_pooling", []() {ConnectionPoolManager::getInstance().closePools();}); - m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, "Execute a SQL query directly"); - m.def("DDBCSQLExecute", &SQLExecute_wrap, "Prepare and execute T-SQL statements"); - m.def("SQLExecuteMany", &SQLExecuteMany_wrap, "Execute statement with multiple parameter sets"); + m.def("enable_pooling", &enable_pooling, + "Enable global connection pooling"); + m.def("close_pooling", + []() { ConnectionPoolManager::getInstance().closePools(); }); + m.def("DDBCSQLExecDirect", &SQLExecDirect_wrap, + "Execute a SQL query directly"); + m.def("DDBCSQLExecute", &SQLExecute_wrap, + "Prepare and execute T-SQL statements"); + m.def("SQLExecuteMany", &SQLExecuteMany_wrap, + "Execute statement with multiple parameter sets"); m.def("DDBCSQLRowCount", &SQLRowCount_wrap, "Get the number of rows affected by the last statement"); - m.def("DDBCSQLFetch", &SQLFetch_wrap, "Fetch the next row from the result set"); + m.def("DDBCSQLFetch", &SQLFetch_wrap, + "Fetch the next row from the result set"); m.def("DDBCSQLNumResultCols", &SQLNumResultCols_wrap, "Get the number of columns in the result set"); m.def("DDBCSQLDescribeCol", &SQLDescribeCol_wrap, "Get information about a column in the result set"); - m.def("DDBCSQLGetData", &SQLGetData_wrap, "Retrieve data from the result set"); - m.def("DDBCSQLMoreResults", &SQLMoreResults_wrap, "Check for more results in the result set"); - m.def("DDBCSQLFetchOne", &FetchOne_wrap, "Fetch one row from the result set"); - m.def("DDBCSQLFetchMany", &FetchMany_wrap, py::arg("StatementHandle"), py::arg("rows"), - py::arg("fetchSize") = 1, - py::arg("char_encoding") = "utf-8", py::arg("wchar_encoding") = "utf-16le", + m.def("DDBCSQLGetData", &SQLGetData_wrap, + "Retrieve data from the result set"); + m.def("DDBCSQLMoreResults", &SQLMoreResults_wrap, + "Check for more results in the result set"); + m.def("DDBCSQLFetchOne", &FetchOne_wrap, + "Fetch one row from the result set"); + m.def("DDBCSQLFetchMany", &FetchMany_wrap, py::arg("StatementHandle"), + py::arg("rows"), py::arg("fetchSize") = 1, + py::arg("char_encoding") = "utf-8", + py::arg("wchar_encoding") = "utf-16le", "Fetch many rows from the result set"); - m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); + m.def("DDBCSQLFetchAll", &FetchAll_wrap, + "Fetch all rows from the result set"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords, - "Get all diagnostic records for a handle", - py::arg("handle")); - m.def("DDBCSQLTables", &SQLTables_wrap, + "Get all diagnostic records for a handle", py::arg("handle")); + m.def("DDBCSQLTables", &SQLTables_wrap, "Get table information using ODBC SQLTables", - py::arg("StatementHandle"), py::arg("catalog") = std::wstring(), - py::arg("schema") = std::wstring(), py::arg("table") = std::wstring(), + py::arg("StatementHandle"), py::arg("catalog") = std::wstring(), + py::arg("schema") = std::wstring(), py::arg("table") = std::wstring(), py::arg("tableType") = std::wstring()); m.def("DDBCSQLFetchScroll", &SQLFetchScroll_wrap, - "Scroll to a specific position in the result set and optionally fetch data"); - m.def("DDBCSetDecimalSeparator", &DDBCSetDecimalSeparator, "Set the decimal separator character"); - m.def("DDBCSQLSetStmtAttr", [](SqlHandlePtr stmt, SQLINTEGER attr, SQLPOINTER value) { - return SQLSetStmtAttr_ptr(stmt->get(), attr, value, 0); - }, "Set statement attributes"); - m.def("DDBCSQLGetTypeInfo", &SQLGetTypeInfo_Wrapper, "Returns information about the data types that are supported by the data source", - py::arg("StatementHandle"), py::arg("DataType")); - m.def("DDBCSQLProcedures", [](SqlHandlePtr StatementHandle, - const py::object& catalog, - const py::object& schema, - const py::object& procedure) { - return SQLProcedures_wrap(StatementHandle, catalog, schema, procedure); - }); - - m.def("DDBCSQLForeignKeys", [](SqlHandlePtr StatementHandle, - const py::object& pkCatalog, - const py::object& pkSchema, - const py::object& pkTable, - const py::object& fkCatalog, - const py::object& fkSchema, - const py::object& fkTable) { - return SQLForeignKeys_wrap(StatementHandle, - pkCatalog, pkSchema, pkTable, - fkCatalog, fkSchema, fkTable); - }); - m.def("DDBCSQLPrimaryKeys", [](SqlHandlePtr StatementHandle, - const py::object& catalog, - const py::object& schema, - const std::wstring& table) { - return SQLPrimaryKeys_wrap(StatementHandle, catalog, schema, table); - }); - m.def("DDBCSQLSpecialColumns", [](SqlHandlePtr StatementHandle, - SQLSMALLINT identifierType, - const py::object& catalog, - const py::object& schema, - const std::wstring& table, - SQLSMALLINT scope, - SQLSMALLINT nullable) { - return SQLSpecialColumns_wrap(StatementHandle, - identifierType, catalog, schema, table, - scope, nullable); - }); - m.def("DDBCSQLStatistics", [](SqlHandlePtr StatementHandle, - const py::object& catalog, - const py::object& schema, - const std::wstring& table, - SQLUSMALLINT unique, - SQLUSMALLINT reserved) { - return SQLStatistics_wrap(StatementHandle, catalog, schema, table, unique, reserved); - }); - m.def("DDBCSQLColumns", [](SqlHandlePtr StatementHandle, - const py::object& catalog, - const py::object& schema, - const py::object& table, - const py::object& column) { - return SQLColumns_wrap(StatementHandle, catalog, schema, table, column); - }); - + "Scroll to a specific position in the result set and optionally " + "fetch data"); + m.def("DDBCSetDecimalSeparator", &DDBCSetDecimalSeparator, + "Set the decimal separator character"); + m.def( + "DDBCSQLSetStmtAttr", + [](SqlHandlePtr stmt, SQLINTEGER attr, SQLPOINTER value) { + return SQLSetStmtAttr_ptr(stmt->get(), attr, value, 0); + }, + "Set statement attributes"); + m.def("DDBCSQLGetTypeInfo", &SQLGetTypeInfo_Wrapper, + "Returns information about the data types that are supported by " + "the data source", + py::arg("StatementHandle"), py::arg("DataType")); + m.def("DDBCSQLProcedures", + [](SqlHandlePtr StatementHandle, const py::object& catalog, + const py::object& schema, const py::object& procedure) { + return SQLProcedures_wrap(StatementHandle, catalog, schema, + procedure); + }); + + m.def("DDBCSQLForeignKeys", + [](SqlHandlePtr StatementHandle, const py::object& pkCatalog, + const py::object& pkSchema, const py::object& pkTable, + const py::object& fkCatalog, const py::object& fkSchema, + const py::object& fkTable) { + return SQLForeignKeys_wrap(StatementHandle, pkCatalog, pkSchema, + pkTable, fkCatalog, fkSchema, fkTable); + }); + m.def("DDBCSQLPrimaryKeys", + [](SqlHandlePtr StatementHandle, const py::object& catalog, + const py::object& schema, const std::wstring& table) { + return SQLPrimaryKeys_wrap(StatementHandle, catalog, schema, + table); + }); + m.def( + "DDBCSQLSpecialColumns", + [](SqlHandlePtr StatementHandle, SQLSMALLINT identifierType, + const py::object& catalog, const py::object& schema, + const std::wstring& table, SQLSMALLINT scope, SQLSMALLINT nullable) { + return SQLSpecialColumns_wrap(StatementHandle, identifierType, + catalog, schema, table, scope, + nullable); + }); + m.def("DDBCSQLStatistics", + [](SqlHandlePtr StatementHandle, const py::object& catalog, + const py::object& schema, const std::wstring& table, + SQLUSMALLINT unique, SQLUSMALLINT reserved) { + return SQLStatistics_wrap(StatementHandle, catalog, schema, table, + unique, reserved); + }); + m.def("DDBCSQLColumns", + [](SqlHandlePtr StatementHandle, const py::object& catalog, + const py::object& schema, const py::object& table, + const py::object& column) { + return SQLColumns_wrap(StatementHandle, catalog, schema, table, + column); + }); // Module-level UUID class cache - // This caches the uuid.UUID class at module initialization time and keeps it alive - // for the entire module lifetime, avoiding static destructor issues during Python finalization - m.def("_get_uuid_class", []() -> py::object { - static py::object uuid_class = py::module_::import("uuid").attr("UUID"); - return uuid_class; - }, "Internal helper to get cached UUID class"); + // This caches the uuid.UUID class at module initialization + // time and keeps it alive + // for the entire module lifetime, avoiding static + // destructor issues during Python finalization + m.def( + "_get_uuid_class", + []() -> py::object { + static py::object uuid_class = + py::module_::import("uuid").attr("UUID"); + return uuid_class; + }, + "Internal helper to get cached UUID class"); // Add a version attribute m.attr("__version__") = "1.0.0"; - + try { // Try loading the ODBC driver when the module is imported LOG("Loading ODBC driver"); DriverLoader::getInstance().loadDriver(); // Load the driver } catch (const std::exception& e) { - // Log the error but don't throw - let the error happen when functions are called - LOG("Failed to load ODBC driver during module initialization: {}", e.what()); + // Log the error but don't throw - + // let the error happen when functions are called + LOG("Failed to load ODBC driver during module initialization: {}", + e.what()); } }