/*

 RayEngine.c

 Copyright (c) 2006-2007, Lucas Stephen Beeler. All Rights Reserved.

 */

#include "RayEngine.h"
#include "ConfigurableParameters.h"
#include <stdio.h>
#include <limits.h>

/* #define DIAGNOSTICS_ON */

boolean  TriMeshIntersectFront(const Mesh*, const Ray3D*, IntersectionRecord*);
boolean  TriMeshIntersectAny(const Mesh*, const Ray3D*);
boolean  BVHNodeIntersect(const BVHNode*, const Ray3D*, IntersectionRecord*);




void  TraceScene(const Scene* scene, Image* image)
{
    double  aspectRatio =
        ((double)(image->pixelWidth)) / ((double)(image->pixelHeight));
    double  fovyRadians = (scene->eye.fovy) * (kPi / 180.0);
    double  imagePlaneHeight =
        tan(fovyRadians) * scene->eye.imagePlaneDistance;
    double  imagePlaneWidth = aspectRatio * imagePlaneHeight;
    HomogeneousTransform3D  eyeToWorld;
    HomogeneousVector3D     pixPointEye;
    HomogeneousVector3D     pixPointWorld;
    HomogeneousVector3D     currRayDir;
    Ray3D                   currentRay;
    int i, j, k;
    RGBColor                pixColor = kRGB_Black;
    RGBColor                currentSampleColor;
    double                  pixPointZ = -scene->eye.imagePlaneDistance;

    eyeToWorld.matrix[0]  = scene->eye.eyeSpaceX.x;
    eyeToWorld.matrix[1]  = scene->eye.eyeSpaceY.x;
    eyeToWorld.matrix[2]  = scene->eye.eyeSpaceZ.x;
    eyeToWorld.matrix[3]  = scene->eye.origin.x;
    eyeToWorld.matrix[4]  = scene->eye.eyeSpaceX.y;
    eyeToWorld.matrix[5]  = scene->eye.eyeSpaceY.y;
    eyeToWorld.matrix[6]  = scene->eye.eyeSpaceZ.y;
    eyeToWorld.matrix[7]  = scene->eye.origin.y;
    eyeToWorld.matrix[8]  = scene->eye.eyeSpaceX.z;
    eyeToWorld.matrix[9]  = scene->eye.eyeSpaceY.z;
    eyeToWorld.matrix[10] = scene->eye.eyeSpaceZ.z;
    eyeToWorld.matrix[11] = scene->eye.origin.z;
    eyeToWorld.matrix[12] = 0.0;
    eyeToWorld.matrix[13] = 0.0;
    eyeToWorld.matrix[14] = 0.0;
    eyeToWorld.matrix[15] = 1.0;

    for (i = 0; i < image->pixelWidth; i++) {

        for (j = 0; j < image->pixelHeight; j++) {

            pixColor = kRGB_Black;

            for (k = 0; k < pSamplesPerPixel; k++) {

                int  subPixelGridDimension = (int)
                    floor(sqrt((double)(pSamplesPerPixel)));

                double  subPixelGridSquareSize =
                    1.0 / ((double)(subPixelGridDimension));

                int  currentSampleGridX = k % subPixelGridDimension;
                int  currentSampleGridY = k / subPixelGridDimension;

                double  samplePosX =
                    i + (currentSampleGridX * subPixelGridSquareSize);
                double  samplePosY =
                    j + (currentSampleGridY * subPixelGridSquareSize);

                samplePosX +=
                    SampleUniform( ) * subPixelGridSquareSize;
                samplePosY +=
                    SampleUniform( ) * subPixelGridSquareSize;

                samplePosX = -(imagePlaneWidth / 2.0) +
                    ((imagePlaneWidth * (samplePosX + 0.5)) /
                    ((double) image->pixelWidth));
                samplePosY = -(imagePlaneHeight / 2.0) +
                    ((imagePlaneHeight * (samplePosY + 0.5)) / 
                    ((double) image->pixelHeight));

                SetVertex3D(&pixPointEye, samplePosX, samplePosY, pixPointZ);

                pixPointWorld = ApplyTransform(&eyeToWorld, &pixPointEye);

                currRayDir =
                    VectorSubtract(&pixPointWorld, &scene->eye.origin);

                SetRay(&currentRay, &scene->eye.origin, &currRayDir);

                currentSampleColor =
                    Trace(&currentRay, scene, pRayRecursionDepth);

                currentSampleColor = RGBColorScale(&currentSampleColor,
                    1.0 / (double)(pSamplesPerPixel));

                pixColor = RGBColorAdd(&pixColor, &currentSampleColor);
            }

            WriteImagePixel(image, i, j, &pixColor);
        }
    }
}




RGBColor  Trace(const Ray3D* ray, const Scene* scene, unsigned level)
{
    IntersectionRecord  hitRecord;
    boolean             hitSomething = FALSE;

    if (level == 0) {

        return kRGB_Black;
    }

    hitSomething = SceneIntersectFront(scene, ray, &hitRecord);

    if (hitSomething) {

        return Shade(scene, &hitRecord, level);
    }
    else {

        return kRGB_Black;
    }
}




boolean   SceneIntersectFront(const Scene* scene, const Ray3D* ray,
    IntersectionRecord* hitRecord)
{
    int                 i;
    IntersectionRecord  frontHitRecord;
    IntersectionRecord  currHitRecord;
    double              frontTParameter = kBigDouble;
    boolean             wasGelHit = FALSE;
    boolean             hitSomething = FALSE;

    for(i = 0; i < scene->numGels; i++) {

        wasGelHit = GelIntersectFront(&scene->gelData[i], ray, &currHitRecord);

        if (wasGelHit) {

            if ((currHitRecord.tParameter < frontTParameter) &&
                (currHitRecord.tParameter > 0.000001)) {

                frontTParameter = currHitRecord.tParameter;
                frontHitRecord = currHitRecord;
                hitSomething = TRUE;
            }
        }
    }

    if (hitSomething) {

        (*hitRecord) = frontHitRecord;

        return TRUE;
    }
    else {

        return FALSE;
    }
}




boolean   SceneIntersectAny(const Scene* scene, const Ray3D* ray)
{
    int                 i;
    boolean             wasGelHit;

    for(i = 0; i < scene->numGels; i++) {

        wasGelHit = GelIntersectAny(&scene->gelData[i], ray);

        if (wasGelHit) {

            return TRUE;
        }
    }

    return FALSE;
}




boolean  TriangleIntersect(const Mesh* mesh, const Ray3D* ray,
    const unsigned faceID, double* alpha, double* beta, double* gamma,
    double*  tParameter)
{
    HomogeneousVector3D         rayDirection =
        VectorNormalize(&ray->direction);

    HomogeneousVector3D         triEdge1;
    HomogeneousVector3D         triEdge2;
    HomogeneousVector3D         pVector;
    HomogeneousVector3D         sVector;
    HomogeneousVector3D         qVector;
    double                      relativeArea;
    double                      invArea;

    triEdge1 = VectorSubtract(&mesh->vertices[mesh->faces[faceID].vindexes[1]],
        &mesh->vertices[mesh->faces[faceID].vindexes[0]]);

    triEdge2 = VectorSubtract(&mesh->vertices[mesh->faces[faceID].vindexes[2]],
        &mesh->vertices[mesh->faces[faceID].vindexes[0]]);

    pVector = VectorCross(&rayDirection, &triEdge2);

    relativeArea = VectorDot(&triEdge1, &pVector);

    if (fabs(relativeArea) < 0.00000001) {

        return FALSE;
    }

    invArea = 1.0 / relativeArea;

    sVector = VectorSubtract(&ray->origin,
        &mesh->vertices[mesh->faces[faceID].vindexes[0]]);

    (*beta) = invArea * (VectorDot(&sVector, &pVector));

    if ((*beta < 0.0) || (*beta > 1.0)) {

        return FALSE;
    }

    qVector = VectorCross(&sVector, &triEdge1);

    (*gamma) = invArea * (VectorDot(&rayDirection, &qVector));

    if ((*gamma < 0.0) || (*gamma > 1.0)) {

        return FALSE;
    }

    (*tParameter) = invArea * (VectorDot(&triEdge2, &qVector));

    (*tParameter) *= (1.0 / VectorLength(&ray->direction));

    if (*tParameter <= 0.0) {

        return FALSE;
    }

    *alpha = 1.0 - *beta - *gamma;

    if ((*alpha < 0.0) || (*alpha > 1.0)) {

        return FALSE;
    }

    return TRUE;
}




boolean  TriMeshIntersectFront(const Mesh* mesh, const Ray3D* ray,
    IntersectionRecord* hitRecord)
{
    double    hitTParameter = kBigDouble;
    double    hitAlpha;
    double    hitBeta;
    double    hitGamma;
    boolean   wasTriangleHit;
    unsigned  hitFaceIndex;

    double    alphaTemp;
    double    betaTemp;
    double    gammaTemp;
    double    tParameterTemp;
    unsigned  i;
    

    HomogeneousVector3D  tempHitPoint;
    HomogeneousVector3D  tempHitNormal;
    HomogeneousVector3D  temp;
    HomogeneousVector3D  triEdge1;
    HomogeneousVector3D  triEdge2;

    if (!mesh->isTriangleMesh) {

        LogicError("TriMeshIntersectFront( )", "input mesh is not triangular");
    }

    for (i = 0; i < mesh->numFaces; i++) {

        wasTriangleHit = TriangleIntersect(mesh, ray, i, &alphaTemp,
            &betaTemp, &gammaTemp, &tParameterTemp);

        if (wasTriangleHit) {

            if (tParameterTemp < hitTParameter) {

                hitTParameter = tParameterTemp;
                hitAlpha = alphaTemp;
                hitBeta = betaTemp;
                hitGamma = gammaTemp;
                hitFaceIndex = i;
            }
        }
    }

    if (hitTParameter == kBigDouble) {

        return FALSE;
    }

    triEdge1 =
        VectorSubtract(&mesh->vertices[mesh->faces[hitFaceIndex].vindexes[1]],
        &mesh->vertices[mesh->faces[hitFaceIndex].vindexes[0]]);

    triEdge2 =
        VectorSubtract(&mesh->vertices[mesh->faces[hitFaceIndex].vindexes[2]],
        &mesh->vertices[mesh->faces[hitFaceIndex].vindexes[0]]);

    tempHitNormal = VectorCross(&triEdge1, &triEdge2);

    tempHitPoint =
        ScalarMultiply(&mesh->vertices[mesh->faces[hitFaceIndex].vindexes[0]],
        hitAlpha);

    temp =
        ScalarMultiply(&mesh->vertices[mesh->faces[hitFaceIndex].vindexes[1]],
        hitBeta);

    temp.w = 0.0;

    tempHitPoint = VectorAdd(&tempHitPoint, &temp);

    temp =
        ScalarMultiply(&mesh->vertices[mesh->faces[hitFaceIndex].vindexes[2]],
        hitGamma);

    temp.w = 0.0;

    tempHitPoint = VectorAdd(&tempHitPoint, &temp);


#ifdef DIAGNOSTICS_ON
    printf("\n\n");
    printf("------------------------------------------------------------\n");
    printf(" HIT DIAGNOSTICS\n");
    printf("------------------------------------------------------------\n");
    printf("hit position =       "); PrintVector(&tempHitPoint); printf("\n");
    printf("hit t parameter =    %f\n", hitTParameter);
    printf("hit b/cent. alpha =  %f\n", hitAlpha);
    printf("hit b/cent. beta  =  %f\n", hitBeta);
    printf("hit b/cent. gamma =  %f\n", hitGamma);
    printf("hit face index    =  %d\n", hitFaceIndex);
    printf("------------------------------------------------------------\n");
#endif

    hitRecord->normal = tempHitNormal;
    hitRecord->position = tempHitPoint;
    hitRecord->tParameter = hitTParameter;
    hitRecord->inbound = ray->direction;

    return TRUE;
}




boolean  TriMeshIntersectAny(const Mesh* mesh, const Ray3D* ray)
{
    int      i;
    double   alphaTemp;
    double   betaTemp;
    double   gammaTemp;
    double   tParameterTemp;
    boolean  wasTriangleHit;

    for (i = 0; i < mesh->numFaces; i++) {

        wasTriangleHit = TriangleIntersect(mesh, ray, i, &alphaTemp,
            &betaTemp, &gammaTemp, &tParameterTemp);

        if (wasTriangleHit) {

            return TRUE;
       }
    }

    return FALSE;
}




boolean   GelIntersectFront(const Gel* targetGel, const Ray3D* incomingRay,
    IntersectionRecord* hitRecord)
{
    boolean wasMeshHit = FALSE;

    if (!BoundingBoxIntersect(&targetGel->boundingBox, incomingRay)) {

        return FALSE;
    }

    if (targetGel->gelMesh->isTriangleMesh) {

        Ray3D  meshSpaceRay =
            TransformRay(&targetGel->invGelTrans, incomingRay);

        /* if this Gel has a BVH, then run BVHIntersect( ) to accelerate
           hit testing */
        if (targetGel->gelBVH) {

            wasMeshHit = BVHIntersectFront(targetGel->gelBVH, &meshSpaceRay,
                hitRecord);
        }
        else {

            wasMeshHit = TriMeshIntersectFront(targetGel->gelMesh, &meshSpaceRay,
                hitRecord);
        }

        if (wasMeshHit) {

            hitRecord->normal = ApplyTransform(&targetGel->gelTrans,
                &hitRecord->normal);

            hitRecord->position = ApplyTransform(&targetGel->gelTrans,
                &hitRecord->position);

            hitRecord->surfaceMaterial = targetGel->gelProps;

            hitRecord->inbound = ApplyTransform(&targetGel->gelTrans,
                &hitRecord->inbound);

            return TRUE;
        }
    }

    return FALSE;
}




boolean   GelIntersectAny(const Gel* targetGel, const Ray3D* ray)
{
    boolean wasMeshHit = FALSE;

    if (!BoundingBoxIntersect(&targetGel->boundingBox, ray)) {

        return FALSE;
    }

    if (targetGel->gelMesh->isTriangleMesh) {

        Ray3D  meshSpaceRay =
            TransformRay(&targetGel->invGelTrans, ray);

        /* if this Gel has a BVH, then run BVHIntersect( ) to accelerate
           hit testing */
        if (targetGel->gelBVH) {

            IntersectionRecord  throwAway;

            wasMeshHit = BVHIntersectFront(targetGel->gelBVH, &meshSpaceRay,
                &throwAway);
        }
        else {

            wasMeshHit = TriMeshIntersectAny(targetGel->gelMesh,
                &meshSpaceRay);
        }

        return wasMeshHit;
    }

    RuntimeWarning("GelIntersectAny( )", "encountered a non-triangluar mesh, "
        "but non-triangular mesh intersect handling isn't yet implemented..."
        "ignoring this Gel");

    return FALSE;
}




RGBColor  Shade(const Scene* scene, const IntersectionRecord* hitRecord, unsigned level)
{
    RGBColor             result;
    RGBColor             tempColor;
    unsigned             i;
    HomogeneousVector3D  eyeVector = VectorNormalize(&hitRecord->inbound);
    HomogeneousVector3D  surfaceNormal = VectorNormalize(&hitRecord->normal);
    HomogeneousVector3D  reflectionOutDirection;
    HomogeneousVector3D  userVector = VectorNormalize(&hitRecord->inbound);
    double               eyeDotNorm;
    Ray3D                reflectionOutRay;
    double               bounceScale;

    eyeVector = ScalarMultiply(&eyeVector, -1.0);

    /* accumulate the ambient contribution into the result */
    result = RGBColorMultiply(&scene->ambientIntensity,
        &hitRecord->surfaceMaterial.ambient);

    /* loop over all the lights in the scene and accumulate their diffuse and
       specular contributions into the result */
    for (i = 0; i < scene->numLights; i++) {

        HomogeneousVector3D   remoteLightPosition;
        HomogeneousVector3D   lightingVector;
        HomogeneousVector3D   halfwayVector;
        HomogeneousVector3D   lightingVectorOut;
        double                normDotHalfway;
        double                normDotLight;
        double                specularMultiplier;
        Ray3D                 occlusionTestRay;
        boolean               occlusionHit = FALSE;

        remoteLightPosition = SampleLightPosition(&scene->lightData[i],
            &hitRecord->position);

        lightingVector = VectorSubtract(&remoteLightPosition,
            &hitRecord->position);

        lightingVectorOut = lightingVector;
        lightingVector = VectorNormalize(&lightingVector);
        normDotLight = VectorDot(&lightingVector, &surfaceNormal);

        /* clamp normDotLight into the range [0, +INF) */
        normDotLight = (normDotLight < 0.0) ? 0.0 : normDotLight;

        /* do an occlusion test to see if we even need to accumulate the
           lighting equation contribution from the i-th light into
           the result */
        SetRay(&occlusionTestRay, &hitRecord->position, &lightingVectorOut);

        occlusionHit = SceneIntersectAny(scene, &occlusionTestRay);

        if (!occlusionHit) {

            /* compute diffuse contribution from the current light */
            tempColor = RGBColorMultiply(&scene->lightData[i].color,
                &hitRecord->surfaceMaterial.diffuse);

            tempColor = RGBColorScale(&tempColor, normDotLight);

            result = RGBColorAdd(&result, &tempColor);

            /* compute specular contribution from the current light */
            halfwayVector = VectorAdd(&lightingVector, &eyeVector);
            halfwayVector = VectorNormalize(&halfwayVector);

            normDotHalfway = VectorDot(&halfwayVector, &surfaceNormal);

            specularMultiplier = pow(normDotHalfway,
                hitRecord->surfaceMaterial.shinyness);

            tempColor = RGBColorMultiply(&scene->lightData[i].color,
                &hitRecord->surfaceMaterial.specular);

            tempColor = RGBColorScale(&tempColor, specularMultiplier);

            result = RGBColorAdd(&result, &tempColor);
        }

    } /* for all lights */

    /* compute reflection */
    eyeDotNorm = VectorDot(&surfaceNormal, &eyeVector);

    reflectionOutDirection =
        ScalarMultiply(&surfaceNormal, (2.0 * eyeDotNorm));

    reflectionOutDirection = VectorAdd(&reflectionOutDirection, &userVector);

    reflectionOutDirection = VectorNormalize(&reflectionOutDirection);

    reflectionOutDirection =
        VectorPerturb(&reflectionOutDirection,
            hitRecord->surfaceMaterial.reflectionBlur);

    SetRay(&reflectionOutRay, &hitRecord->position, &reflectionOutDirection);

    tempColor = Trace(&reflectionOutRay, scene, level - 1);

    tempColor =
        RGBColorMultiply(&tempColor, &hitRecord->surfaceMaterial.specular);

    bounceScale = ((double)(level)) /
        ((1.0 / pBounceScaleFactor) * (double)(pRayRecursionDepth));

    tempColor = RGBColorScale(&tempColor, bounceScale);

    result = RGBColorAdd(&result, &tempColor);

    return result;
}




boolean  BVHIntersectFront(const BoundingVolumeHierarchy* bvh,
    const Ray3D* ray, IntersectionRecord* hitRecord)
{
    return BVHNodeIntersect(bvh->rootNode, ray, hitRecord);
}




boolean  BVHNodeIntersect(const BVHNode* node, const Ray3D* ray,
    IntersectionRecord* hitRecord)
{
    boolean  boundsHit;
    
    boundsHit = BoundingBoxIntersect(&node->bounds, ray);

    if (!boundsHit) {

        return FALSE;
    }

    /* if execution reaches this point, the ray hit the bounding box,
       so figure out what kind of node in the BVH we are  (i.e. leaf node
       or non-leaf node), and proceed accordingly */


    /* we've got a leaf node */
    if (node->faceID != kNullFace) {

        double    barycentricAlpha;
        double    barycentricBeta;
        double    barycentricGamma;
        double    tParameter;
        boolean   wasTriangleHit = FALSE;

        if (!node->hostMesh) {

            LogicError("BVHNodeIntersect( )", "node mesh pointer invalid");
        }

        wasTriangleHit = TriangleIntersect(node->hostMesh, ray, node->faceID,
            &barycentricAlpha, &barycentricBeta, &barycentricGamma, &tParameter);

        if (wasTriangleHit) {

            HomogeneousVector3D  triEdge1;
            HomogeneousVector3D  triEdge2;
            HomogeneousVector3D  tempHitPoint;
            HomogeneousVector3D  temp;
            Mesh*                mesh = node->hostMesh;

            triEdge1 =
                VectorSubtract(
                    &mesh->vertices[mesh->faces[node->faceID].vindexes[1]],
                    &mesh->vertices[mesh->faces[node->faceID].vindexes[0]]);

            triEdge2 =
                VectorSubtract(
                    &mesh->vertices[mesh->faces[node->faceID].vindexes[2]],
                    &mesh->vertices[mesh->faces[node->faceID].vindexes[0]]);

            /* compute hitRecord.normal */
            if (mesh->normalMode == kFaceNormalMode) {

                hitRecord->normal = VectorCross(&triEdge1, &triEdge2);
            }
            else {

                hitRecord->normal =
                    ScalarMultiply(
                        &mesh->normals[mesh->faces[node->faceID].vindexes[0]],
                        barycentricAlpha);

                temp =
                    ScalarMultiply(
                        &mesh->normals[mesh->faces[node->faceID].vindexes[1]],
                        barycentricBeta);

                hitRecord->normal = VectorAdd(&hitRecord->normal, &temp);

                temp =
                    ScalarMultiply(
                        &mesh->normals[mesh->faces[node->faceID].vindexes[2]],
                        barycentricGamma);

                hitRecord->normal = VectorAdd(&hitRecord->normal, &temp);
            }

            tempHitPoint =
                ScalarMultiply(
                    &mesh->vertices[mesh->faces[node->faceID].vindexes[0]],
                    barycentricAlpha);
            temp =
                ScalarMultiply(
                    &mesh->vertices[mesh->faces[node->faceID].vindexes[1]],
                    barycentricBeta);

            temp.w = 0.0;

            tempHitPoint = VectorAdd(&tempHitPoint, &temp);

            temp =
                ScalarMultiply(
                    &mesh->vertices[mesh->faces[node->faceID].vindexes[2]],
                    barycentricGamma);

            temp.w = 0.0;

            tempHitPoint = VectorAdd(&tempHitPoint, &temp);

            /* compute hitRecord.position */
            hitRecord->position = tempHitPoint;

            /* set hitRecord.tParameter */
            hitRecord->tParameter = tParameter;

            /* set hitRecord.inbound */
            hitRecord->inbound = ray->direction;

            /* hitRecord.surfaceMaterial set in calling function
               GelInterect...( ) */

            return TRUE;
        } 
        else {

            return FALSE;

        } /* else triangle not hit */

    } /* if a valid leaf node */

    /* deal with a non-leaf node with only one (left) child */
    if ((node->leftChild) && (!node->rightChild)) {

        return BVHNodeIntersect(node->leftChild, ray, hitRecord);
    }


    /* deal with a non-leaf node with two children, left and right */
    if ((node->leftChild) && (node->rightChild)) {

        boolean             wasLeftHit = FALSE;
        boolean             wasRightHit = FALSE;
        IntersectionRecord  leftHitRecord;
        IntersectionRecord  rightHitRecord;

        wasLeftHit = BVHNodeIntersect(node->leftChild, ray, &leftHitRecord);

        wasRightHit = BVHNodeIntersect(node->rightChild, ray, &rightHitRecord);

        /* if neither the left nor the right were hit, then return false */

        if ((!wasLeftHit) && (!wasRightHit)) {

            return FALSE;
        }

        if (wasLeftHit && wasRightHit) {

            /* if both left and right were hit, return the closer of the
               two hits, but before returning, make sure to set the
               hitRecord accordingly! */
            if (rightHitRecord.tParameter < leftHitRecord.tParameter) {

                (*hitRecord) = rightHitRecord;
                return TRUE;
            }
            else {

                (*hitRecord) = leftHitRecord;
                return TRUE;
            }
        } /* if left and right both hit */

        if (wasLeftHit) {

            /* only left was hit */
            (*hitRecord) = leftHitRecord;
            return TRUE;

        } /* if only left child hit */

        /* if execution reaches this point, only the right child must've
           been hit */
        (*hitRecord) = rightHitRecord;
        return TRUE;
    } /* if a leaf node with two children */

    RuntimeWarning("BVHNodeIntersect( )", "all cases should've been handled");
    return FALSE;
}
