/*********************************************************************
 *
 *           Copyright (c) 2021 by Visuality Systems, Ltd.
 *
 *********************************************************************
 * FILE NAME     : $Workfile:$
 * ID            : $Header:$
 * REVISION      : $Revision:$
 *--------------------------------------------------------------------
 * DESCRIPTION   : SMB2 Transform Header command handler
 *--------------------------------------------------------------------
 * MODULE        : Server
 * DEPENDENCIES  :
 ********************************************************************/
#include "cmsmb2.h"
#include "csdataba.h"
#include "cmcrypt.h"

#if defined(UD_NQ_INCLUDECIFSSERVER) && defined(UD_NQ_INCLUDESMB3)

NQ_BOOL cs2TransformHeaderEncrypt(CSUser *user, NQ_BYTE *response, NQ_COUNT length)
{
    CSUser *pUser = NULL;
    CMBufferReader reader;
    CMBufferWriter writer;
    CMSmb2Header smb2Header;
    CMSmb2TransformHeader tranHeader;
    CSSession *pSession = NULL;
    NQ_BOOL result = FALSE;

    LOGFB(CM_TRC_LEVEL_FUNC_TOOL, "user:%p response:%p length:%d", user, response, length);

    syMemset(&tranHeader, 0, sizeof(CMSmb2TransformHeader));

    if (NULL != user)
    {
        pUser = user;
        tranHeader.sid.low = (NQ_UINT32)uidToSessionId(pUser->uid);
    }
    else
    {
        cmBufferReaderInit(&reader, response + SMB2_TRANSFORMHEADER_SIZE, length);
        cmSmb2HeaderRead(&smb2Header, &reader);
        pUser = csGetUserByUid((CSUid)sessionIdToUid(smb2Header.sid.low));
        if (NULL == pUser)
        {
            LOGERR(CM_TRC_LEVEL_ERROR, "No User found");
            goto Exit;
        }
        tranHeader.sid = smb2Header.sid;
    }

    pSession = csGetSessionById(pUser->session);
    if (NULL == pSession)
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "No session found");
        goto Exit;
    }

    tranHeader.originalMsgSize = length;
    tranHeader.encryptionArgorithm = (pSession->dialect >= CS_DIALECT_SMB311) ? SMB2_USE_NEGOTIATED_CIPHER : SMB2_ENCRYPTION_AES128_CCM; /* starting 3.1.1 this field is called flags */
    /* use received nonce - copy using GCM size. maybe one extra byte copy */
    syMemcpy(tranHeader.nonce, pUser->encryptNonce, SMB2_AES128_GCM_NONCE_SIZE);
    cmBufferWriterInit(&writer, response, SMB2_TRANSFORMHEADER_SIZE);
    cmSmb2TransformHeaderWrite(&tranHeader, &writer);

    /* encrypted part: payload and not SMB header, authenticated part: all (SMBheader + payload) - protocolID (4 bytes) - signature (16 bytes), so first 20 bytes aren't authenticated */
    if (pSession->dialect >= CS_DIALECT_SMB311 && pSession->isAesGcm)
    {
        static NQ_BYTE keyBuffer[AES_PRIV_SIZE];

        LOGMSG(CM_TRC_LEVEL_MESS_SOME, "Encrypt with GCM");

#ifdef UD_CS_INCLUDEEXTERNALNOTIFY
        NQ_BYTE encryptNonceReadOnly[SMB2_ENCRYPTION_HDR_NONCE_SIZE];

        syMemcpy(encryptNonceReadOnly, pUser->encryptNonce, SMB2_ENCRYPTION_HDR_NONCE_SIZE);

        aes128GcmEncrypt(pUser->encryptionKey, encryptNonceReadOnly, response + SMB2_TRANSFORMHEADER_SIZE, length, response + SMB2_TRANSFORMHEADER_OFFSET_TO_NONCE,
               SMB2_TRANSFORMHEADER_SIZE - SMB2_TRANSFORMHEADER_OFFSET_TO_NONCE, response + sizeof(cmSmb2TrnsfrmHdrProtocolId), keyBuffer, NULL);
#else /* UD_CS_INCLUDEEXTERNALNOTIFY */
        aes128GcmEncrypt(pUser->encryptionKey, pUser->encryptNonce, response + SMB2_TRANSFORMHEADER_SIZE, length, response + SMB2_TRANSFORMHEADER_OFFSET_TO_NONCE,
            SMB2_TRANSFORMHEADER_SIZE - SMB2_TRANSFORMHEADER_OFFSET_TO_NONCE, response + sizeof(cmSmb2TrnsfrmHdrProtocolId), keyBuffer, NULL);
#endif /* UD_CS_INCLUDEEXTERNALNOTIFY */
    }
    else
    {
        LOGMSG(CM_TRC_LEVEL_MESS_SOME, "Encrypt with CCM");

        AES_128_CCM_Encrypt(pUser->encryptionKey, pUser->encryptNonce, response + SMB2_TRANSFORMHEADER_SIZE, length, response + SMB2_TRANSFORMHEADER_OFFSET_TO_NONCE,
            SMB2_TRANSFORMHEADER_SIZE - SMB2_TRANSFORMHEADER_OFFSET_TO_NONCE, response + sizeof(cmSmb2TrnsfrmHdrProtocolId));
    }

    result = TRUE;

Exit:
    LOGFE(CM_TRC_LEVEL_FUNC_TOOL, "result:%s", result ? "TRUE" : "FALSE");
    return result;
}

NQ_BOOL cs2TransformHeaderDecrypt(NSRecvDescr *recvDescr, NQ_BYTE *request, NQ_COUNT length)
{
    CMSmb2TransformHeader header;
    NQ_BYTE nonce[SMB2_ENCRYPTION_HDR_NONCE_SIZE];
    CMBufferReader reader;
    NQ_BYTE *pBuf = request + sizeof(cmSmb2TrnsfrmHdrProtocolId);
    CSUser *pUser;
    CSSession *pSession;
    NQ_UINT msgLen;
    NQ_UINT nonceSize;
    NQ_BOOL result = FALSE;

    LOGFB(CM_TRC_LEVEL_FUNC_TOOL, "recvDescr:%p request:%p length:%d", recvDescr, request, length);

    syMemset(&nonce, 0, SMB2_ENCRYPTION_HDR_NONCE_SIZE);

    /* receive the rest of the transform header */
    if (NQ_FAIL == nsRecvIntoBuffer(recvDescr, pBuf, SMB2_TRANSFORMHEADER_SIZE - sizeof(cmSmb2TrnsfrmHdrProtocolId)))
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "Error reading from socket");
        goto Exit;
    }

    cmBufferReaderInit(&reader, request, length + (NQ_COUNT)sizeof(cmSmb2TrnsfrmHdrProtocolId));
    cmSmb2TransformHeaderRead(&header, &reader);
    
    pUser = csGetUserByUid((CSUid)sessionIdToUid(header.sid.low));
    if (NULL == pUser)
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "No User found");
        goto Exit;
    }

    if (NULL == (pSession = csGetSessionById(pUser->session)))
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "No session found");
        goto Exit;
    }
    
    if (pSession->dialect >= CS_DIALECT_SMB311 && header.encryptionArgorithm == SMB2_USE_NEGOTIATED_CIPHER && pSession->isAesGcm)
    {
        nonceSize = SMB2_AES128_GCM_NONCE_SIZE;
    }
    else
    {
        nonceSize = SMB2_AES128_CCM_NONCE_SIZE;
    }

    syMemcpy(&nonce, header.nonce, nonceSize);
    syMemset(&pUser->encryptNonce, 0, SMB2_ENCRYPTION_HDR_NONCE_SIZE);
    syMemcpy(&pUser->encryptNonce, nonce, nonceSize);
    pBuf = request + SMB2_TRANSFORMHEADER_SIZE;
    msgLen = (NQ_UINT)nsRecvIntoBuffer(recvDescr, pBuf, (NQ_COUNT)header.originalMsgSize); /* receive the encrypted packet */
    if ((NQ_UINT)NQ_FAIL == msgLen)
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "Error receiving data");
        goto Exit;
    }

    if (nonceSize == SMB2_AES128_GCM_NONCE_SIZE)
    {
        static NQ_BYTE keyBuffer[AES_PRIV_SIZE];

        LOGMSG(CM_TRC_LEVEL_MESS_SOME, "Decrypt with GCM, nonce size: %d", nonceSize);

        result = aes128GcmDecrypt(pUser->decryptionKey, nonce, pBuf, (NQ_UINT)header.originalMsgSize, request + SMB2_TRANSFORMHEADER_OFFSET_TO_NONCE,
                SMB2_TRANSFORMHEADER_SIZE - SMB2_TRANSFORMHEADER_OFFSET_TO_NONCE, header.signature, keyBuffer, NULL);
    }
    else
    {
        LOGMSG(CM_TRC_LEVEL_MESS_SOME, "Decrypt with CCM, nonce size: %d", nonceSize);

        result = AES_128_CCM_Decrypt(pUser->decryptionKey, nonce, pBuf, (NQ_UINT)header.originalMsgSize, request + SMB2_TRANSFORMHEADER_OFFSET_TO_NONCE,
                SMB2_TRANSFORMHEADER_SIZE - SMB2_TRANSFORMHEADER_OFFSET_TO_NONCE, header.signature);
    }

Exit:
    LOGFE(CM_TRC_LEVEL_FUNC_TOOL, "result:%s", result ? "TRUE" : "FALSE");
    return result;
}

#endif  /* defined(UD_NQ_INCLUDECIFSSERVER) && defined(UD_NQ_INCLUDESMB2) */
