import {ExtendedShaderMaterial, Shader, Vector2, Vector4, WebGLRenderer} from 'threepipe'
import {glsl} from 'ts-browser-helpers'

export class FunctionPlotMaterial extends ExtendedShaderMaterial {
    constructor() {
        super({
            uniforms: {
                vSize: {value: new Vector2(100, 100)},
                color: {value: new Vector4(1, 1, 1, 1)},
                gridSize: {value: 1},
                center: {value: new Vector2(0, 0)},
                zoom: {value: 1},
            },
            defines: {
                PLOT_MODE: 0, // 0 - line plot, 2 - gradient, 1 - direct
            },
            vertexShader: glsl`
                varying vec2 vUv;
                void main() {
                    vUv = uv;
                    gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
                }
            `,
            fragmentShader: glsl`
                uniform vec4 color;
                uniform float gridSize;
                uniform vec2 vSize;
                uniform vec2 center;
                uniform float zoom;
                varying vec2 vUv;

                //                float lineJitter = 0.5;
                //                float lineWidth = 7.0;
                //                float gridWidth = 1.7;
                //                float zoom = 2.5;
                //                vec2 offset = vec2(0.5);

//                #define fx(x) sin(x)
//                #define fxy(x, y) (sin(x)-y)
                #define_func

                float max2(vec2 v)
                {
                    return max(v.x, v.y);
                }

                // https://www.shadertoy.com/view/mlXBRr
                void main() {
                    //                    vec2 p = (vUv.xy - vec2(0.5)) * vSize;
                    vec2 p = (gl_FragCoord.xy - vSize.xy * 0.5);
                    float scale = 2. / max2(abs(-vSize.xy * 0.5));
                    // zoom
                    p *= zoom * scale;
                    // center
                    p += center;

                    #ifdef fx
                    float value = fx(p.x)-p.y;
                    #else
                    float value = fxy(p.x, p.y);
                    #endif

                    const float lineWidth = 2.5;
                    const float lineWidth2 = 2.;

                    #if PLOT_MODE == 0
                    vec2 grad = vec2(dFdx(value), dFdy(value));
                    float plot_alpha = smoothstep(lineWidth*length(grad), 0.0, abs(value));
                    #elif PLOT_MODE == 1
                    float plot_alpha = abs(value);
                    #elif PLOT_MODE == 2
                    vec2 grad = vec2(dFdx(value), dFdy(value));
                    float plot_alpha = length(grad);
                    #else
                    float plot_alpha = abs(value);
                    #endif

                    vec2 pm = mod(p, vec2(1.0));
                    float grid_alpha = smoothstep(lineWidth2 * scale * zoom, 0.0, min(min(pm.x, 1.0 - pm.x), min(pm.y, 1.0 - pm.y)));
                    float x_axis_alpha = smoothstep(lineWidth2 * scale * zoom, 0.0, abs(p.y));
                    float y_axis_alpha = smoothstep(lineWidth2 * scale * zoom, 0.0, abs(p.x));

                    vec3 color = vec3(1.0);
                    color = mix(color, vec3(0.6), grid_alpha);
                    color = mix(color, vec3(1.0, 0.0, 0.0), x_axis_alpha);
                    color = mix(color, vec3(0.0, 0.0, 1.0), y_axis_alpha);
                    color = mix(color, vec3(0.0), plot_alpha);

                    gl_FragColor = vec4(color, 1.0);
//                    gl_FragColor = vec4(p, 0., 1.0);
                    //                    gl_FragColor = vec4(1.0);
                }
            `,
        }, [])
    }

    funcDefineX = ''
    funcDefineXY = ''

    onBeforeCompile(s: Shader, renderer: WebGLRenderer) {
        super.onBeforeCompile(s, renderer);
        s.fragmentShader = s.fragmentShader.replace('#define_func',
            this.funcDefineX ? '#define fx' + this.funcDefineX :
                this.funcDefineXY ? '#define fxy' + this.funcDefineXY :
                    '#define fxy(x, y) (sin(x)-y)')

        s.fragmentShader = s.fragmentShader.replace('#define_func', this.funcDefineXY ? '#define fxy' + this.funcDefineXY : '')
    }

    customProgramCacheKey(): string {
        return (super.customProgramCacheKey() || '') + this.funcDefineX + this.funcDefineXY;
    }
}

// AA https://www.shadertoy.com/view/4sB3zz
export class FunctionPlotMaterialA extends ExtendedShaderMaterial{
    constructor() {
        super({
            uniforms: {
                vSize: {value: new Vector2(100, 100)},
                color: {value: new Vector4(1, 1, 1, 1)},
                gridSize: {value: 1},
            },
            vertexShader: glsl`
                varying vec2 vUv;
                void main() {
                    vUv = uv;
                    gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
                }
            `,
            fragmentShader: glsl`
                uniform vec4 color;
                uniform float gridSize;
                uniform vec2 vSize;
                varying vec2 vUv;

                float lineJitter = 0.5;
                float lineWidth = 7.0;
                float gridWidth = 1.7;
                float scale = 0.0013;
                float zoom = 2.5;
                vec2 offset = vec2(0.5);

                // from https://thebookofshaders.com/07/
                
                #define function(x) cos(x) 

                float rand (in vec2 co) {
                    return fract(sin(dot(co.xy,vec2(12.9898,78.233)))*43758.5453);
                }
                vec3 plot2D(in vec2 _st, in float _width ) {
                    const float samples = 3.0;
//                    const float samples = 1.0;

                    vec2 steping = _width*vec2(scale)/samples;

                    float count = 0.0;
                    float mySamples = 0.0;
                    for (float i = 0.0; i < samples; i++) {
                        for (float j = 0.0;j < samples; j++) {
                            if (i*i+j*j>samples*samples)
                            continue;
                            mySamples++;
                            float ii = i + lineJitter*rand(vec2(_st.x+ i*steping.x,_st.y+ j*steping.y));
                            float jj = j + lineJitter*rand(vec2(_st.y + i*steping.x,_st.x+ j*steping.y));
                            float f = function(_st.x+ ii*steping.x)-(_st.y+ jj*steping.y);
                            count += (f>0.) ? 1.0 : -1.0;
                        }
                    }
                    vec3 color = vec3(1.0);
                    if (abs(count)!=mySamples)
                    color = vec3(abs(float(count))/float(mySamples));
                    return color;
                }
                vec3 plot2D1(in vec2 _st, in float _width ) {
                    float f = function(_st.x)-(_st.y);

                    float count = 0.0;
                    count += (f>0.) ? 1.0 : -1.0;
                    float mySamples = 1.0;
                    vec3 color = vec3(f>0.?1.0:0.0);
//                    if (abs(count)!=mySamples)
//                    color = vec3(abs(float(count))/float(mySamples));
                    return color;
                }

                vec3 grid2D( in vec2 _st, in float _width ) {
                    float axisDetail = _width*scale;
                    if (abs(_st.x)<axisDetail || abs(_st.y)<axisDetail)
                    return 1.0-vec3(0.65,0.65,1.0);
                    if (abs(mod(_st.x,1.0))<axisDetail || abs(mod(_st.y,1.0))<axisDetail)
                    return 1.0-vec3(0.80,0.80,1.0);
                    if (abs(mod(_st.x,0.25))<axisDetail || abs(mod(_st.y,0.25))<axisDetail)
                    return 1.0-vec3(0.95,0.95,1.0);
                    return vec3(0.0);
                }

                #define GRAD_OFFS vec2(0.0001 * ZOOM, 0.0)
                #define GRAD(f, p) (vec2(f(p) - f(p + GRAD_OFFS.xy), f(p) - f(p + GRAD_OFFS.yx)) / GRAD_OFFS.xx)
                #define PLOT(f, c, d, p) d = mix(c, d, smoothstep(0.0, (LINE_SIZE / iResolution.y * ZOOM), abs(f(p) / length(GRAD(f,p)))))

                #define LINE_SIZE 2.0
                #define GRID_LINE_SIZE 1.0
                #define GRID_AXIS_SIZE 3.0
                #define GRID_LINES 1.0
                
                float grid2(vec2 p)
                { 
                    vec2 uv = mod(p, 1.0 / GRID_LINES);
                    float halfScale = 1.0 / GRID_LINES / 2.0;
                    float gridRad = (GRID_LINE_SIZE / iResolution.y) * ZOOM;
                    float grid = halfScale - max(abs(uv.x - halfScale), abs(uv.y - halfScale));
                    grid = smoothstep(0.0, gridRad, grid);

                    float axisRad = (GRID_AXIS_SIZE / iResolution.y) * ZOOM;
                    float axis = min(abs(p.x), abs(p.y));
                    axis = smoothstep(axisRad - 0.05, axisRad, axis);

                    return min(grid, axis);
                }

                void main() {
                    vec2 st = vUv.xy-offset;
                    st.x *= vSize.x/vSize.y;

                    scale *= zoom;
                    st *= zoom;

                    vec3 color = plot2D1(st,lineWidth);
                    color -= grid2D(st,gridWidth);
                    
                    gl_FragColor = vec4(color, 1.0);
                }
            `,
        }, [])
    }
}

