/*********************************************************************
*
*           Copyright (c) 2021 by Visuality Systems, Ltd.
*
*********************************************************************
* FILE NAME     : $Workfile:$
* ID            : $Header:$
* REVISION      : $Revision:$
*--------------------------------------------------------------------
* DESCRIPTION   : Kerberos authentication module (client and server)
*--------------------------------------------------------------------
* MODULE        : auth - am
* DEPENDENCIES  : None
********************************************************************/

#include "amkerberos.h"

/* Kerberos token identifier */

#define KRB5_TOK_ID_AP_REQ     0x0001
#define KRB5_TOK_ID_AP_REP     0x0002
#define KRB5_TOK_ID_ERROR      0x0003

#if defined(UD_NQ_INCLUDECIFSSERVER) && defined(UD_CS_INCLUDEEXTENDEDSECURITY) && defined(UD_NQ_INCLUDEKERBEROS)

/*
* KERBEROS protocol definitions
*/

#define MECHANISM_NAME  "MS KRB5"

#define SERVICE_NAME    "cifs"

/*  KRB context */
typedef struct
{
    NQ_BYTE *ctx;           /* underlying SASL Kerberos context */
    const CMAsn1Oid *oid;   /* oid used */
    CMBlob out;             /* outgoing blob if requested */
} KrbContext;


/* process incoming blob and compose response blob */
static NQ_UINT32                    /* error code or zero */
kerberosProcessor(
    CMRpcPacketDescriptor * in,
    CMRpcPacketDescriptor * out,
    AMNtlmDescriptor * descr,
    const NQ_WCHAR *pOwnDomain,
    const NQ_WCHAR *pOwnHostname,
    NQ_WCHAR * userName,
    NQ_WCHAR * clientHostName,
    NQ_WCHAR * pDomain,
    const NQ_BYTE** pSessionKey
)
{
    NQ_UINT32 status = AM_STATUS_GENERIC;
    KrbContext krbContext = { 0 };
    NQ_COUNT sessionKeyLen = SMB_SESSIONKEY_LENGTH;
    NQ_COUNT length;
    NQ_UINT16 token;
    CMAsn1Tag tag;
    NQ_BYTE* bufPosition;

    LOGFB(CM_TRC_LEVEL_FUNC_COMMON, "in:%p out:%p descr:%p user:%p domain:%p key:%p", in, out, descr, userName, pDomain, pSessionKey);

    out->length = 0;

    /* krb5_blob parsing */
    tag = cmAsn1ParseTag(in, &in->length);
    if (CM_ASN1_APPLICATION != tag)
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "Unexpected tag");
        goto Exit;
    }

    if (cmAsn1ParseCompareOid(in, amKerberosGetServerDescriptor()->oid, TRUE))
    {
        krbContext.oid = amKerberosGetServerDescriptor()->oid;
    }
    else if (cmAsn1ParseCompareOid(in, amKerberosGetServerDescriptor()->oidSecondary, TRUE))
    {
        krbContext.oid = amKerberosGetServerDescriptor()->oidSecondary;
    }
    else
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "Not Kerberos OID");
        goto Exit;
    }

    cmBufferReadUint16(in, &token);
    if (token != KRB5_TOK_ID_AP_REQ)
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "Unexpected Kerberos token");
        goto Exit;
    }

    krbContext.ctx = sySaslServerContextCreate(SERVICE_NAME);
    if (NULL == krbContext.ctx)
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "Failed to create context");
        goto Exit;
    }

    krbContext.out.len = 0;
    length = in->length - cmRpcGetDataCount(in);
    bufPosition = cmRpcGetPositionBytePtr(in, 0, length);
    if (!sySaslServerAuthenticate(krbContext.ctx, bufPosition, in->length, userName, pDomain, &krbContext.out.data, &krbContext.out.len))
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "Failed to authenticate");
        goto Exit;
    }

    *pSessionKey = (const NQ_BYTE *)cmMemoryAllocate(sessionKeyLen);
    if (NULL == *pSessionKey)
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "Memory failure");
        goto Exit;
    }

    if (!sySaslGetSessionKey(krbContext.ctx, (NQ_BYTE *)*pSessionKey, &sessionKeyLen))
    {
        cmMemoryFree(*pSessionKey);
        LOGERR(CM_TRC_LEVEL_ERROR, "Failed to get session key");
        goto Exit;
    }

    /* pack outgoing krb5_blob */
    if (krbContext.out.len > 0)
    {
        NQ_COUNT mechToken = krbContext.out.len + 2 + cmAsn1PackLen(krbContext.oid->size) + 1 + krbContext.oid->size;

        cmAsn1PackTag(out, CM_ASN1_APPLICATION, mechToken);
        cmAsn1PackOid(out, krbContext.oid);
        cmRpcPackUint16(out, KRB5_TOK_ID_AP_REP);
        cmRpcPackBytes(out, (const NQ_BYTE *)krbContext.out.data, krbContext.out.len);
        out->length = cmBufferWriterGetDataCount(out);
    }

    status = AM_STATUS_NOT_AUTHENTICATED; /* means authentication done */
    descr->isKrbAuthenticated = TRUE;

Exit:
    if (NULL != krbContext.ctx)
    {
        sySaslContextDispose(krbContext.ctx);
    }

    LOGFE(CM_TRC_LEVEL_FUNC_COMMON, "result:0x%x", status);
    return status;
}

static AMSpnegoServerMechDescriptor mechKerberosServerDescriptor =
{
    &cmGssApiOidMsKerberos,
    &cmGssApiOidKerberos,
    kerberosProcessor
};

const AMSpnegoServerMechDescriptor * amKerberosGetServerDescriptor()
{
    return &mechKerberosServerDescriptor;
}


NQ_BOOL
amKerberosGetSessionKey(
    NQ_BYTE* context,
    NQ_BYTE* buffer,
    NQ_COUNT* len
)
{
    return sySaslGetSessionKey(context, buffer, len);
}

#endif /* defined(UD_NQ_INCLUDECIFSSERVER) && defined(UD_CS_INCLUDEEXTENDEDSECURITY) && defined(UD_NQ_INCLUDEKERBEROS) */

#if defined(UD_NQ_INCLUDECIFSCLIENT) && defined(UD_CC_INCLUDEEXTENDEDSECURITY) && defined(UD_CC_INCLUDEEXTENDEDSECURITY_KERBEROS)

NQ_BOOL amKerberosClientInit(void *p)
{
    return sySaslClientInit(p);
}

NQ_BOOL amKerberosClientStop(void)
{
    return sySaslClientStop();
}

NQ_BOOL amKerberosClientSetMechanism(NQ_BYTE* ctx, const NQ_CHAR* name)
{
    return sySaslClientSetMechanism(ctx, name);
}

NQ_BYTE* amKerberosClientContextCreate(const NQ_CHAR* name, NQ_BOOL restrictCrypt)
{
    return sySaslContextCreate(name, FALSE, TRUE);
}

NQ_BOOL amKerberosClientGetSessionKey(NQ_BYTE* ctx, NQ_BYTE* buffer, NQ_COUNT* len, CMBlob blob)
{
    return sySaslGetSessionKey(ctx, buffer, len, blob.data, blob.len);
}

NQ_BOOL amKerberosClientContextIsValid(const NQ_BYTE* ctx)
{
    return sySaslContextIsValid(ctx);
}

NQ_BOOL amKerberosClientContextDispose(NQ_BYTE* ctx)
{
    return sySaslContextDispose(ctx);
}

NQ_BOOL amKerberosClientGenerateFirstRequest(NQ_BYTE * ctx, const NQ_CHAR * mechList, NQ_BYTE ** blob, NQ_COUNT * blobLen)
{
    return sySaslClientGenerateFirstRequest(ctx, mechList, blob, blobLen);
}

NQ_BOOL amKerberosClientGenerateNextRequest(NQ_BYTE * ctx, const NQ_BYTE * inBlob, NQ_COUNT inBlobLen, NQ_BYTE ** outBlob, NQ_COUNT* outBlobLen, NQ_BYTE* con)
{
    AMSpnegoClientSecurityContext *pContext = (AMSpnegoClientSecurityContext *)con;
    CMBufferReader reader;
    NQ_UINT16 token;
    CMAsn1Len len;
    CMAsn1Tag tag;
    NQ_BOOL result = FALSE;

    LOGFB(CM_TRC_LEVEL_FUNC_PROTOCOL, "context:%p inBlob:%p inLen:%d outBlob:%p outLen:%p con:%p", ctx, inBlob, inBlobLen, outBlob, outBlobLen, con);

    cmBufferReaderInit(&reader, (NQ_BYTE *)inBlob, inBlobLen);
    tag = cmAsn1ParseTag(&reader, &len);            /* supported mech */
    if (CM_ASN1_APPLICATION != tag)
    {
        logBadTag(CM_ASN1_APPLICATION, tag);
        goto Exit;
    }

    if ((FALSE == cmAsn1ParseCompareOid(&reader, pContext->mechanism->oid, TRUE)) && (FALSE == cmAsn1ParseCompareOid(&reader, pContext->mechanism->reqOid, TRUE)))
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "Unexpected mechanism in KRB5 server response");
        goto Exit;
    }

    cmBufferReadUint16(&reader, &token);
    if (KRB5_TOK_ID_ERROR == token)
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "KRB5 error occured");
        goto Exit;
    }

    result = sySaslClientGenerateNextRequest(ctx, inBlob, inBlobLen, outBlob, outBlobLen, con);
Exit:
    LOGFE(CM_TRC_LEVEL_FUNC_PROTOCOL, "result: %d", result);
    return result;
}

NQ_BOOL amKerberosClientPackNegotBlob(void * context, CMBufferWriter * writer, NQ_COUNT mechtokenBlobLen, NQ_COUNT * blobLen)
{
    NQ_COUNT gssapiLen;         /* tag length */
    NQ_COUNT spnegoLen;         /* tag length */
    NQ_COUNT negtokenLen;       /* tag length */
    NQ_COUNT mechtypesLen;      /* tag length */
    NQ_COUNT mechtypesSeqLen;   /* tag length */
    NQ_COUNT mechtokenBinLen;   /* tag length */
    NQ_COUNT mechtokenAppLen;   /* tag length */
    NQ_COUNT mechtokenLen;      /* tag length */
    NQ_COUNT len1, len2;        /* tag length */
    NQ_COUNT totalLen;          /* total packed blob length */
    NQ_BOOL res = FALSE;        /* return value */
    AMSpnegoClientSecurityContext * pContext = (AMSpnegoClientSecurityContext *)context;

    /* calculate field lengths - backwards */
    len1 = pContext->mechanism->oid->size;
    len2 = pContext->mechanism->reqOid->size;
    mechtypesSeqLen = 1 + cmAsn1PackLen(len1) + len1 + 1 + cmAsn1PackLen(cmGssApiOidNtlmSsp.size) + cmGssApiOidNtlmSsp.size + 1 + cmAsn1PackLen(len2) + len2;
    mechtypesLen = 1 + cmAsn1PackLen(mechtypesSeqLen) + mechtypesSeqLen;
    mechtokenAppLen = 1 + cmAsn1PackLen(len1) + pContext->mechanism->oid->size + 2 + mechtokenBlobLen;
    mechtokenBinLen = 1 + cmAsn1PackLen(mechtokenAppLen) + mechtokenAppLen;
    mechtokenLen = 1 + cmAsn1PackLen(mechtokenBinLen) + mechtokenBinLen;
    negtokenLen = 1 + cmAsn1PackLen(mechtypesLen) + mechtypesLen + 1 + cmAsn1PackLen(mechtokenLen) + mechtokenLen;
    spnegoLen = 1 + cmAsn1PackLen(negtokenLen) + negtokenLen;
    gssapiLen = 1 + cmAsn1PackLen(spnegoLen) + spnegoLen + 1 + cmAsn1PackLen(cmGssApiOidSpnego.size) + cmGssApiOidSpnego.size;
    totalLen = 1 + cmAsn1PackLen(gssapiLen) + gssapiLen;
    if (*blobLen < totalLen)
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "supplied %d, required %d", *blobLen, totalLen);
        goto Exit;
    }

    cmAsn1PackTag(writer, CM_ASN1_APPLICATION, gssapiLen);
    cmAsn1PackOid(writer, &cmGssApiOidSpnego);
    cmAsn1PackTag(writer, CM_ASN1_CONTEXT, spnegoLen);
    cmAsn1PackTag(writer, CM_ASN1_SEQUENCE, negtokenLen);
    cmAsn1PackTag(writer, CM_ASN1_CONTEXT, mechtypesLen);
    cmAsn1PackTag(writer, CM_ASN1_SEQUENCE, mechtypesSeqLen);
    cmAsn1PackOid(writer, pContext->mechanism->oid);
    cmAsn1PackOid(writer, pContext->mechanism->reqOid);
    cmAsn1PackOid(writer, &cmGssApiOidNtlmSsp);
    cmAsn1PackTag(writer, CM_ASN1_CONTEXT + 2, mechtokenLen);
    cmAsn1PackTag(writer, CM_ASN1_BINARY, mechtokenBinLen);
    cmAsn1PackTag(writer, CM_ASN1_APPLICATION, mechtokenAppLen);
    cmAsn1PackOid(writer, pContext->mechanism->reqOid);
    cmBufferWriteUint16(writer, KRB5_TOK_ID_AP_REQ);
    res = TRUE;

Exit:
    *blobLen = totalLen;
    return res;
}

NQ_BOOL amKerberosClientOnComplete(NQ_BYTE *context, CMBufferReader *reader)
{
    CMAsn1Len len;                      /* length of the current field */
    CMAsn1Tag tag;                      /* next tag */
    NQ_UINT16 token;
    NQ_BOOL result = FALSE;
    AMSpnegoClientSecurityContext *pContext = (AMSpnegoClientSecurityContext *)context;

    LOGFB(CM_TRC_LEVEL_FUNC_COMMON, "context:%p reader:%p", context, reader);

    /* KRB blob might not be present in final successful response */
    if (0 == cmBufferReaderGetRemaining(reader))
    {
        result = TRUE;
        goto Exit;
    }

    tag = cmAsn1ParseTag(reader, &len);            /* supported mech */
    if (CM_ASN1_CONTEXT + 1 != tag)
    {
        logBadTag(CM_ASN1_CONTEXT + 1, tag);
        goto Exit;
    }
    if (!cmAsn1ParseCompareOid(reader, pContext->mechanism->oid, TRUE) && !cmAsn1ParseCompareOid(reader, pContext->mechanism->reqOid, TRUE))
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "Unexpected mechanism in KRB5 server response");
        goto Exit;
    }
    /* KRB blob might be incomplete in final successful response */
    if (0 == cmBufferReaderGetRemaining(reader))
    {
        result = TRUE;
        goto Exit;
    }
    /* authentication blob */
    tag = cmAsn1ParseTag(reader, &len);            /* response token */
    if (CM_ASN1_CONTEXT + 2 != tag)
    {
        logBadTag(CM_ASN1_CONTEXT + 2, tag);
        goto Exit;
    }
    tag = cmAsn1ParseTag(reader, &len);            /* response token */
    if (CM_ASN1_BINARY != tag)
    {
        logBadTag(CM_ASN1_BINARY, tag);
        goto Exit;
    }
    /* KRB blob might be incomplete in final successful response */
    if (0 == cmBufferReaderGetRemaining(reader))
    {
        result = TRUE;
        goto Exit;
    }
    tag = cmAsn1ParseTag(reader, &len);            /* supported mech */
    if (CM_ASN1_APPLICATION != tag)
    {
        logBadTag(CM_ASN1_APPLICATION, tag);
        goto Exit;
    }
    if (!cmAsn1ParseCompareOid(reader, pContext->mechanism->oid, TRUE) && !cmAsn1ParseCompareOid(reader, pContext->mechanism->reqOid, TRUE))
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "Unexpected mechanism in KRB5 server response");
        goto Exit;
    }
    cmBufferReadUint16(reader, &token);
    if (KRB5_TOK_ID_AP_REP == token)
    {
        /* update context with the new data accepted from server */
        pContext->dataBlob.data = cmBufferReaderGetPositionBytePtr(reader, len);
        pContext->dataBlob.len = (NQ_COUNT)cmBufferReaderGetRemaining(reader);
        result = TRUE;
    }
    else
    {
        LOGERR(CM_TRC_LEVEL_ERROR, "Unexpected KRB5 token");
    }

Exit:
    LOGFE(CM_TRC_LEVEL_FUNC_COMMON, "result:%s", result ? "TRUE" : "FALSE");
    return result;
}

#endif /* defined(UD_NQ_INCLUDECIFSCLIENT) && defined(UD_CC_INCLUDEEXTENDEDSECURITY) && defined(UD_CC_INCLUDEEXTENDEDSECURITY_KERBEROS) */
