fixed GH #2830: Fix wrong buffer size in client handshake when re-using a SecureSocket

This commit is contained in:
Günter Obiltschnig
2020-01-22 13:11:37 +01:00
parent 0865fcf039
commit 3300467543

View File

@@ -245,7 +245,7 @@ void SecureSocketImpl::close()
serverDisconnect(&_hCreds, &_hContext); serverDisconnect(&_hCreds, &_hContext);
else else
clientDisconnect(&_hCreds, &_hContext); clientDisconnect(&_hCreds, &_hContext);
_pSocket->close(); _pSocket->close();
cleanup(); cleanup();
} }
@@ -276,7 +276,7 @@ void SecureSocketImpl::verifyPeerCertificate()
{ {
if (_peerHostName.empty()) if (_peerHostName.empty())
_peerHostName = _pSocket->peerAddress().host().toString(); _peerHostName = _pSocket->peerAddress().host().toString();
verifyPeerCertificate(_peerHostName); verifyPeerCertificate(_peerHostName);
} }
@@ -362,7 +362,7 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags)
SecBuffer* pExtraBuffer = 0; SecBuffer* pExtraBuffer = 0;
std::memcpy(_sendBuffer.begin() + _streamSizes.cbHeader, pBuffer + dataSent, dataSize); std::memcpy(_sendBuffer.begin() + _streamSizes.cbHeader, pBuffer + dataSent, dataSize);
msg.setSecBufferStreamHeader(0, _sendBuffer.begin(), _streamSizes.cbHeader); msg.setSecBufferStreamHeader(0, _sendBuffer.begin(), _streamSizes.cbHeader);
msg.setSecBufferData(1, _sendBuffer.begin() + _streamSizes.cbHeader, dataSize); msg.setSecBufferData(1, _sendBuffer.begin() + _streamSizes.cbHeader, dataSize);
msg.setSecBufferStreamTrailer(2, _sendBuffer.begin() + _streamSizes.cbHeader + dataSize, _streamSizes.cbTrailer); msg.setSecBufferStreamTrailer(2, _sendBuffer.begin() + _streamSizes.cbHeader + dataSize, _streamSizes.cbTrailer);
@@ -378,7 +378,7 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags)
int sent = sendRawBytes(_sendBuffer.begin(), outBufferLen, flags); int sent = sendRawBytes(_sendBuffer.begin(), outBufferLen, flags);
if (_pSocket->getBlocking() && sent == -1) if (_pSocket->getBlocking() && sent == -1)
{ {
if (dataSent == 0) if (dataSent == 0)
return -1; return -1;
else else
return dataSent; return dataSent;
@@ -452,10 +452,10 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags)
if (needData) if (needData)
{ {
int numBytes = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, _ioBufferSize - _recvBufferOffset); int numBytes = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, _ioBufferSize - _recvBufferOffset);
if (numBytes == -1) if (numBytes == -1)
return -1; return -1;
else if (numBytes == 0) else if (numBytes == 0)
break; break;
else else
_recvBufferOffset += numBytes; _recvBufferOffset += numBytes;
@@ -484,7 +484,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags)
{ {
// bytesDecoded contains everything including overflow data // bytesDecoded contains everything including overflow data
rc = bytesDecoded; rc = bytesDecoded;
if (rc > length) if (rc > length)
rc = length; rc = length;
return rc; return rc;
} }
@@ -604,7 +604,7 @@ SECURITY_STATUS SecureSocketImpl::decodeBufferFull(BYTE* pBuffer, DWORD bufSize,
overflowBuffer.resize(bufSize); overflowBuffer.resize(bufSize);
if (outLength > 0) if (outLength > 0)
{ {
// make pOutBuffer full // make pOutBuffer full
std::memcpy(pOutBuffer, pDataBuffer->pvBuffer, outLength); std::memcpy(pOutBuffer, pDataBuffer->pvBuffer, outLength);
// no longer valid to write to pOutBuffer // no longer valid to write to pOutBuffer
pOutBuffer = 0; pOutBuffer = 0;
@@ -643,7 +643,7 @@ SECURITY_STATUS SecureSocketImpl::decodeBufferFull(BYTE* pBuffer, DWORD bufSize,
} }
} }
if (securityStatus == SEC_I_RENEGOTIATE) if (securityStatus == SEC_I_RENEGOTIATE)
{ {
_needData = false; _needData = false;
securityStatus = performClientHandshakeLoop(); securityStatus = performClientHandshakeLoop();
@@ -652,7 +652,7 @@ SECURITY_STATUS SecureSocketImpl::decodeBufferFull(BYTE* pBuffer, DWORD bufSize,
} }
} }
while (securityStatus == SEC_E_OK && pBuffer); while (securityStatus == SEC_E_OK && pBuffer);
if (overflowOffset > 0) if (overflowOffset > 0)
{ {
_overflowBuffer.resize(overflowOffset); _overflowBuffer.resize(overflowOffset);
@@ -733,7 +733,7 @@ void SecureSocketImpl::clientConnectVerify()
try try
{ {
SECURITY_STATUS securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (PVOID) &_pPeerCertificate); SECURITY_STATUS securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (PVOID) &_pPeerCertificate);
if (securityStatus != SEC_E_OK) if (securityStatus != SEC_E_OK)
throw SSLException("Failed to obtain peer certificate", Utility::formatError(securityStatus)); throw SSLException("Failed to obtain peer certificate", Utility::formatError(securityStatus));
clientVerifyCertificate(_peerHostName); clientVerifyCertificate(_peerHostName);
@@ -791,7 +791,7 @@ void SecureSocketImpl::performInitialClientHandshake()
0, 0,
0, 0,
0, 0,
&_hContext, &_hContext,
&_outSecBuffer, &_outSecBuffer,
&contextAttributes, &contextAttributes,
&ts); &ts);
@@ -808,21 +808,21 @@ void SecureSocketImpl::performInitialClientHandshake()
throw SSLException("Handshake failed", Utility::formatError(_securityStatus)); throw SSLException("Handshake failed", Utility::formatError(_securityStatus));
} }
} }
// incomplete credentials: more calls to InitializeSecurityContext needed // incomplete credentials: more calls to InitializeSecurityContext needed
// send the token // send the token
sendInitialTokenOutBuffer(); sendInitialTokenOutBuffer();
if (_securityStatus == SEC_E_OK) if (_securityStatus == SEC_E_OK)
{ {
// The security context was successfully initialized. // The security context was successfully initialized.
// There is no need for another InitializeSecurityContext (Schannel) call. // There is no need for another InitializeSecurityContext (Schannel) call.
_state = ST_DONE; _state = ST_DONE;
return; return;
} }
//SEC_I_CONTINUE_NEEDED was returned: //SEC_I_CONTINUE_NEEDED was returned:
// Wait for a return token. The returned token is then passed in // Wait for a return token. The returned token is then passed in
// another call to InitializeSecurityContext (Schannel). The output token can be empty. // another call to InitializeSecurityContext (Schannel). The output token can be empty.
_extraSecBuffer.pvBuffer = 0; _extraSecBuffer.pvBuffer = 0;
@@ -853,7 +853,7 @@ SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop()
while (_securityStatus == SEC_I_CONTINUE_NEEDED || _securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) while (_securityStatus == SEC_I_CONTINUE_NEEDED || _securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_INCOMPLETE_CREDENTIALS)
{ {
performClientHandshakeLoopCondReceive(); performClientHandshakeLoopCondReceive();
if (_securityStatus == SEC_E_OK) if (_securityStatus == SEC_E_OK)
{ {
performClientHandshakeLoopOK(); performClientHandshakeLoopOK();
@@ -912,7 +912,7 @@ void SecureSocketImpl::performClientHandshakeLoopError()
void SecureSocketImpl::performClientHandshakeSendOutBuffer() void SecureSocketImpl::performClientHandshakeSendOutBuffer()
{ {
if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer)
{ {
int numBytes = sendRawBytes(static_cast<const void*>(_outSecBuffer[0].pvBuffer), _outSecBuffer[0].cbBuffer); int numBytes = sendRawBytes(static_cast<const void*>(_outSecBuffer[0].pvBuffer), _outSecBuffer[0].cbBuffer);
if (numBytes != _outSecBuffer[0].cbBuffer) if (numBytes != _outSecBuffer[0].cbBuffer)
@@ -966,14 +966,16 @@ void SecureSocketImpl::performClientHandshakeLoopReceive()
void SecureSocketImpl::performClientHandshakeLoopCondReceive() void SecureSocketImpl::performClientHandshakeLoopCondReceive()
{ {
poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED); poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED);
performClientHandshakeLoopInit(); performClientHandshakeLoopInit();
if (_needData) if (_needData)
{ {
if (_recvBuffer.capacity() != IO_BUFFER_SIZE)
_recvBuffer.setCapacity(IO_BUFFER_SIZE);
performClientHandshakeLoopReceive(); performClientHandshakeLoopReceive();
} }
else _needData = true; else _needData = true;
_inSecBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset); _inSecBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset);
// inbuffer 1 should be empty // inbuffer 1 should be empty
_inSecBuffer.setSecBufferEmpty(1); _inSecBuffer.setSecBufferEmpty(1);
@@ -1046,7 +1048,7 @@ void SecureSocketImpl::performServerHandshake()
serverHandshakeLoop(&_hContext, &_hCreds, _clientAuthRequired, true, true); serverHandshakeLoop(&_hContext, &_hCreds, _clientAuthRequired, true, true);
SECURITY_STATUS securityStatus; SECURITY_STATUS securityStatus;
if (_clientAuthRequired) if (_clientAuthRequired)
{ {
poco_assert_dbg (!_pPeerCertificate); poco_assert_dbg (!_pPeerCertificate);
securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &_pPeerCertificate); securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &_pPeerCertificate);
@@ -1084,7 +1086,7 @@ bool SecureSocketImpl::serverHandshakeLoop(PCtxtHandle phContext, PCredHandle ph
while (securityStatus == SEC_I_CONTINUE_NEEDED || securityStatus == SEC_E_INCOMPLETE_MESSAGE || securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) while (securityStatus == SEC_I_CONTINUE_NEEDED || securityStatus == SEC_E_INCOMPLETE_MESSAGE || securityStatus == SEC_I_INCOMPLETE_CREDENTIALS)
{ {
if (securityStatus == SEC_E_INCOMPLETE_MESSAGE) if (securityStatus == SEC_E_INCOMPLETE_MESSAGE)
{ {
if (doRead) if (doRead)
{ {
@@ -1094,7 +1096,7 @@ bool SecureSocketImpl::serverHandshakeLoop(PCtxtHandle phContext, PCredHandle ph
throw SSLException("Failed to receive data in handshake"); throw SSLException("Failed to receive data in handshake");
else else
_recvBufferOffset += n; _recvBufferOffset += n;
} }
else doRead = true; else doRead = true;
} }
@@ -1132,8 +1134,8 @@ bool SecureSocketImpl::serverHandshakeLoop(PCtxtHandle phContext, PCredHandle ph
{ {
std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - inBuffer[1].cbBuffer), inBuffer[1].cbBuffer); std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - inBuffer[1].cbBuffer), inBuffer[1].cbBuffer);
_recvBufferOffset = inBuffer[1].cbBuffer; _recvBufferOffset = inBuffer[1].cbBuffer;
} }
else else
{ {
_recvBufferOffset = 0; _recvBufferOffset = 0;
} }
@@ -1151,7 +1153,7 @@ bool SecureSocketImpl::serverHandshakeLoop(PCtxtHandle phContext, PCredHandle ph
std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - inBuffer[1].cbBuffer), inBuffer[1].cbBuffer); std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - inBuffer[1].cbBuffer), inBuffer[1].cbBuffer);
_recvBufferOffset = inBuffer[1].cbBuffer; _recvBufferOffset = inBuffer[1].cbBuffer;
} }
else else
{ {
_recvBufferOffset = 0; _recvBufferOffset = 0;
} }
@@ -1164,12 +1166,12 @@ bool SecureSocketImpl::serverHandshakeLoop(PCtxtHandle phContext, PCredHandle ph
void SecureSocketImpl::clientVerifyCertificate(const std::string& hostName) void SecureSocketImpl::clientVerifyCertificate(const std::string& hostName)
{ {
if (_pContext->verificationMode() == Context::VERIFY_NONE) return; if (_pContext->verificationMode() == Context::VERIFY_NONE) return;
if (!_pPeerCertificate) throw SSLException("No Server certificate"); if (!_pPeerCertificate) throw SSLException("No Server certificate");
if (hostName.empty()) throw SSLException("Server name not set"); if (hostName.empty()) throw SSLException("Server name not set");
X509Certificate cert(_pPeerCertificate, true); X509Certificate cert(_pPeerCertificate, true);
if (!cert.verify(hostName)) if (!cert.verify(hostName))
{ {
VerificationErrorArgs args(cert, 0, SEC_E_CERT_EXPIRED, "The certificate host names do not match the server host name"); VerificationErrorArgs args(cert, 0, SEC_E_CERT_EXPIRED, "The certificate host names do not match the server host name");
@@ -1178,7 +1180,7 @@ void SecureSocketImpl::clientVerifyCertificate(const std::string& hostName)
throw InvalidCertificateException("Host name verification failed"); throw InvalidCertificateException("Host name verification failed");
} }
verifyCertificateChainClient(_pPeerCertificate); verifyCertificateChainClient(_pPeerCertificate);
} }
@@ -1204,7 +1206,7 @@ void SecureSocketImpl::verifyCertificateChainClient(PCCERT_CONTEXT pServerCert)
throw SSLException("Cannot get certificate chain", GetLastError()); throw SSLException("Cannot get certificate chain", GetLastError());
} }
HTTPSPolicyCallbackData polHttps; HTTPSPolicyCallbackData polHttps;
std::memset(&polHttps, 0, sizeof(HTTPSPolicyCallbackData)); std::memset(&polHttps, 0, sizeof(HTTPSPolicyCallbackData));
polHttps.cbStruct = sizeof(HTTPSPolicyCallbackData); polHttps.cbStruct = sizeof(HTTPSPolicyCallbackData);
polHttps.dwAuthType = AUTHTYPE_SERVER; polHttps.dwAuthType = AUTHTYPE_SERVER;
@@ -1309,14 +1311,14 @@ void SecureSocketImpl::verifyCertificateChainClient(PCCERT_CONTEXT pServerCert)
void SecureSocketImpl::serverVerifyCertificate() void SecureSocketImpl::serverVerifyCertificate()
{ {
if (_pContext->verificationMode() < Context::VERIFY_STRICT) return; if (_pContext->verificationMode() < Context::VERIFY_STRICT) return;
// we are now in Strict mode // we are now in Strict mode
if (!_pPeerCertificate) throw SSLException("No client certificate"); if (!_pPeerCertificate) throw SSLException("No client certificate");
DWORD status = SEC_E_OK; DWORD status = SEC_E_OK;
X509Certificate cert(_pPeerCertificate, true); X509Certificate cert(_pPeerCertificate, true);
PCCERT_CHAIN_CONTEXT pChainContext = NULL; PCCERT_CHAIN_CONTEXT pChainContext = NULL;
CERT_CHAIN_PARA chainPara; CERT_CHAIN_PARA chainPara;
std::memset(&chainPara, 0, sizeof(chainPara)); std::memset(&chainPara, 0, sizeof(chainPara));
@@ -1330,7 +1332,7 @@ void SecureSocketImpl::serverVerifyCertificate()
&chainPara, &chainPara,
CERT_CHAIN_REVOCATION_CHECK_CHAIN, CERT_CHAIN_REVOCATION_CHECK_CHAIN,
NULL, NULL,
&pChainContext)) &pChainContext))
{ {
throw SSLException("Cannot get certificate chain", GetLastError()); throw SSLException("Cannot get certificate chain", GetLastError());
} }
@@ -1351,37 +1353,37 @@ void SecureSocketImpl::serverVerifyCertificate()
std::memset(&policyStatus, 0, sizeof(policyStatus)); std::memset(&policyStatus, 0, sizeof(policyStatus));
policyStatus.cbSize = sizeof(policyStatus); policyStatus.cbSize = sizeof(policyStatus);
if (!CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL, pChainContext, &policyPara, &policyStatus)) if (!CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL, pChainContext, &policyPara, &policyStatus))
{ {
VerificationErrorArgs args(cert, 0, GetLastError(), "Failed to verify certificate chain"); VerificationErrorArgs args(cert, 0, GetLastError(), "Failed to verify certificate chain");
SSLManager::instance().ServerVerificationError(this, args); SSLManager::instance().ServerVerificationError(this, args);
CertFreeCertificateChain(pChainContext); CertFreeCertificateChain(pChainContext);
if (!args.getIgnoreError()) if (!args.getIgnoreError())
throw SSLException("Cannot verify certificate chain"); throw SSLException("Cannot verify certificate chain");
else else
return; return;
} }
else if (policyStatus.dwError) else if (policyStatus.dwError)
{ {
VerificationErrorArgs args(cert, policyStatus.lElementIndex, status, Utility::formatError(policyStatus.dwError)); VerificationErrorArgs args(cert, policyStatus.lElementIndex, status, Utility::formatError(policyStatus.dwError));
SSLManager::instance().ServerVerificationError(this, args); SSLManager::instance().ServerVerificationError(this, args);
CertFreeCertificateChain(pChainContext); CertFreeCertificateChain(pChainContext);
if (!args.getIgnoreError()) if (!args.getIgnoreError())
throw SSLException("Failed to verify certificate chain"); throw SSLException("Failed to verify certificate chain");
else else
return; return;
} }
#if !defined(_WIN32_WCE) #if !defined(_WIN32_WCE)
// perform revocation checking // perform revocation checking
for (DWORD i = 0; i < pChainContext->cChain; i++) for (DWORD i = 0; i < pChainContext->cChain; i++)
{ {
std::vector<PCCERT_CONTEXT> certs; std::vector<PCCERT_CONTEXT> certs;
for (DWORD k = 0; k < pChainContext->rgpChain[i]->cElement; k++) for (DWORD k = 0; k < pChainContext->rgpChain[i]->cElement; k++)
{ {
certs.push_back(pChainContext->rgpChain[i]->rgpElement[k]->pCertContext); certs.push_back(pChainContext->rgpChain[i]->rgpElement[k]->pCertContext);
} }
CERT_REVOCATION_STATUS revStat; CERT_REVOCATION_STATUS revStat;
revStat.cbSize = sizeof(CERT_REVOCATION_STATUS); revStat.cbSize = sizeof(CERT_REVOCATION_STATUS);
@@ -1405,7 +1407,7 @@ void SecureSocketImpl::serverVerifyCertificate()
} }
} }
#endif #endif
if (pChainContext) if (pChainContext)
{ {
CertFreeCertificateChain(pChainContext); CertFreeCertificateChain(pChainContext);
} }
@@ -1467,7 +1469,7 @@ LONG SecureSocketImpl::serverDisconnect(PCredHandle phCreds, CtxtHandle* phConte
} }
AutoSecBufferDesc<1> tokBuffer(&_securityFunctions, false); AutoSecBufferDesc<1> tokBuffer(&_securityFunctions, false);
DWORD tokenType = SCHANNEL_SHUTDOWN; DWORD tokenType = SCHANNEL_SHUTDOWN;
tokBuffer.setSecBufferToken(0, &tokenType, sizeof(tokenType)); tokBuffer.setSecBufferToken(0, &tokenType, sizeof(tokenType));
DWORD status = _securityFunctions.ApplyControlToken(phContext, &tokBuffer); DWORD status = _securityFunctions.ApplyControlToken(phContext, &tokBuffer);
@@ -1475,9 +1477,9 @@ LONG SecureSocketImpl::serverDisconnect(PCredHandle phCreds, CtxtHandle* phConte
if (FAILED(status)) return status; if (FAILED(status)) return status;
DWORD sspiFlags = ASC_REQ_SEQUENCE_DETECT DWORD sspiFlags = ASC_REQ_SEQUENCE_DETECT
| ASC_REQ_REPLAY_DETECT | ASC_REQ_REPLAY_DETECT
| ASC_REQ_CONFIDENTIALITY | ASC_REQ_CONFIDENTIALITY
| ASC_REQ_EXTENDED_ERROR | ASC_REQ_EXTENDED_ERROR
| ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_ALLOCATE_MEMORY
| ASC_REQ_STREAM; | ASC_REQ_STREAM;