void main_vp
(
    in float4 iPos : POSITION
    , in float2 iUv : TEXCOORD0

    , out float4 oPos : POSITION
    , out float3 oScreenCoords : TEXCOORD0
    , out float2 oUv : TEXCOORD1
    , out float oDepth : TEXCOORD2
    , out float4 oEyeVector : TEXCOORD3

    , uniform float4x4 wvpMat
    , uniform float4 camPosObjSpace
)
{
    oPos = mul(wvpMat, iPos);

    oUv = iUv * 10; // uv scale
    oDepth = oPos.z;

    float4x4 scalemat = float4x4(   0.5,    0,      0,      0.5,
                                    0,      -0.5,   0,      0.5,
                                    0,      0,      0.5,    0.5,
                                    0,      0,      0,      1   );
    float4 texcoordProj = mul(scalemat, oPos);
    oScreenCoords = float3(texcoordProj.x, texcoordProj.y, texcoordProj.w);

    oEyeVector = camPosObjSpace - iPos;
}

void main_fp
(
    out float4 oColor : COLOR

    , in float3 iScreenCoords : TEXCOORD0
    , in float2 iUv : TEXCOORD1
    , in float iDepth : TEXCOORD2
    , in float4 iEyeVector : TEXCOORD3
    , uniform float renderTargetFlipping
    , uniform float4 lightPosObjSpace0
    , uniform float4 lightSpecularColour0

    , uniform sampler2D reflectionMap : register(s0)
    , uniform sampler2D refractionMap : register(s1)
    , uniform sampler2D depthMap : register(s2)
    , uniform sampler2D normalMap : register(s3)
    , uniform float time
    , uniform float far
    , uniform float4 fogParams
    , uniform float4 fogColour
    , uniform float isUnderwater
)
{

    float2 screenCoords = iScreenCoords.xy / iScreenCoords.z;
    screenCoords.y = (1-saturate(renderTargetFlipping))+renderTargetFlipping*screenCoords.y;

    // No need for transparency since we are using a refraction map
    oColor.a = 1;

    // Sample screen-space depth map and subtract pixel depth to get the real water depth
    float depthTex = tex2D(depthMap, screenCoords).r;
    float depth1 = depthTex * far - iDepth;
    depth1 = saturate(depth1 / 500.f);

    // Simple wave effect. to be replaced by something better
    float2 uv1 = iUv + time * float2(0.5, 0);
    float2 uv2 = iUv + time * float2(0, 0.5);
    float2 uv3 = iUv + time * float2(-0.5, 0);
    float2 uv4 = iUv + time * float2(0, -0.5);
    float4 normal = tex2D(normalMap, uv1) + tex2D(normalMap, uv2) + tex2D(normalMap, uv3) + tex2D(normalMap, uv4);
    normal = normal / 4.f;
    normal = 2*normal - 1;

    float2 screenCoords_reflect = screenCoords + normal.yx * 0.05;
    float2 screenCoords_refract = screenCoords + normal.yx * 0.05 * depth1;

    // Sample depth again with the refracted coordinates
    depthTex = tex2D(depthMap, screenCoords_refract).r;
    float depth2 = (depthTex * far - iDepth) / 500.f;
    depth2 = (depthTex == 0 ? 1 : depth2);
    // if depth2 is less than 0, this means we would refract something which is above water,
    // which we don't want to - so in that case, don't refract
    if (depth2 < 0.25) // delta due to inaccuracies
    {
        screenCoords_refract = screenCoords;
        depth2 = depth1;
    }
    depth2 = saturate(depth2);

    float4 reflection = tex2D(reflectionMap, screenCoords_reflect);
    float4 refraction = tex2D(refractionMap, screenCoords_refract);

    // tangent to object space
    normal.xyz = normal.xzy;

    iEyeVector.xyz = normalize(iEyeVector.xyz);

    // fresnel
    float facing = 1.0 - max(abs(dot(iEyeVector.xyz, normal.xyz)), 0);
    float reflectionFactor = saturate(0.35 + 0.65 * pow(facing, 2));

    // specular
    float3 lightDir = normalize(lightPosObjSpace0.xyz); // assumes that light 0 is a directional light
    float3 halfVector = normalize(iEyeVector + lightDir);
    float specular = pow(max(dot(normal.xyz, halfVector.xyz), 0), 64);

    float opacity = depth2 * saturate(reflectionFactor + specular);
    opacity *= (1-isUnderwater);

    reflection.xyz += lightSpecularColour0.xyz * specular;

    oColor.xyz = lerp(refraction.xyz, reflection.xyz, opacity);

    oColor.xyz += isUnderwater * float3(0, 0.35, 0.35); // underwater tint color
    oColor.xyz = lerp(oColor.xyz, float3(0, 0.65, 0.65), saturate(isUnderwater * (iDepth / 2000.f))); // underwater fog

    // add fog
    //float fogValue = saturate((iDepth - fogParams.y) * fogParams.w);
    //oColor.xyz = lerp(oColor.xyz, fogColour, fogValue);
}