module VeluSqrt

import Base: *, +, inv
import Nemo
import Nemo: PolynomialRing, coeff, RingElem, ResidueRing, setcoeff!
import EllipticCurves: Montgomery, XZPoint, base_ring, xdouble, xadd, normalized, isinfinity

export VeluRenes, VeluResultant

"""
Isogeny evaluation using Renes' formula

- `ker`: generator of the kernel,
- `ell`: order of `ker`,
- `Ps`: array of points to evaluate the isogeny on.
"""
function VeluRenes(ker::XZPoint{T}, ell::Integer, Ps::Array{XZPoint{T},1}) where T
    E = ker.curve
    k = base_ring(E)
    Q, R = ker, xdouble(ker)
    num = den = one(k)
    res = [XZPoint(one(k), one(k), E) for P in Ps]
    for i in reverse(1:div(ell-1, 2))
        QmZ = (Q.X - Q.Z)
        QpZ = (Q.X + Q.Z)
        num *= QmZ
        den *= QpZ
        for (j, P) in enumerate(Ps)
            mp = (P.X - P.Z) * QpZ
            pm = (P.X + P.Z) * QmZ
            res[j].X *= mp + pm
            res[j].Z *= mp - pm
        end
        if i > 1
            Q, R = R, xadd(R, ker, Q)
        end
    end
    @assert R == Q

    num = (E.A - 2)^ell * num^8
    den = (E.A + 2)^ell * den^8
    E = Montgomery(2*(num + den) // (den - num), E.B)
    
    for (i,P) in enumerate(Ps)
        res[i].X ^= 2
        res[i].X *= P.X
        res[i].Z ^= 2
        res[i].Z *= P.Z
        res[i].curve = E
    end
    return E, res
end

"""
Isogeny evaluation using the resultant algorithm

- `ker`: generator of the kernel,
- `ell`: order of `ker`,
- `Ps`: array of points to evaluate the isogeny on.
"""
function VeluResultant(ker::XZPoint{T}, ell::Integer, Ps::Array{XZPoint{T},1}; bs=nothing) where T
    E = ker.curve
    k = base_ring(E)
    A, x = PolynomialRing(k, "x")
    if bs === nothing
        bs = Int(floor(sqrt(ell-1)))
        bs += isodd(bs)
    end
    @assert iseven(bs)
    gs = div(ell >> 1, bs)

    # Construct subproduct tree for baby steps
    Q, step, diff = ker, xdouble(ker), ker
    leaves = Array{typeof(x)}(undef, bs>>1)
    for i in 1:bs>>1
        # @assert (2i-1) * ker == Q
        leaves[i] = x - Q.X // Q.Z   # x*Q.Z - Q.X
        if i < bs>>1
            Q, diff = xadd(Q, step, diff), Q
        end
    end
    tree = make_tree(leaves)

    # Construct the giant steps polynomials, two per evaluation point
    Q = bs * ker
    step, diff = xdouble(Q), Q
    giants = zeros(A, gs, length(Ps))
    eq_giant = zeros(A, gs, 2)
    for i in 1:gs
        # @assert Q == (2 * i * bs - bs)* ker
        for (j, P) in enumerate(Ps)
            XX = P.X * Q.X
            ZZ = P.Z * Q.Z
            ZX = P.Z * Q.X
            XZ = P.X * Q.Z
            setcoeff!(giants[i,j], 2, (XX - ZZ)^2)
            setcoeff!(giants[i,j], 1, -2 * ((XX + ZZ) * (ZX + XZ) + 2 * E.A * XX * ZZ))
            setcoeff!(giants[i,j], 0, (ZX - XZ)^2)
        end
        XmZ = (Q.X - Q.Z)^2
        setcoeff!(eq_giant[i, 1], 2, XmZ)
        setcoeff!(eq_giant[i, 1], 0, XmZ)
        
        XpZ = (Q.X + Q.Z)^2
        setcoeff!(eq_giant[i, 2], 2, XpZ)
        setcoeff!(eq_giant[i, 2], 0, XpZ)

        AXZ = 2 * E.A * Q.X * Q.Z
        setcoeff!(eq_giant[i, 1], 1, -2 * (XpZ + AXZ))
        setcoeff!(eq_giant[i, 2], 1, 2 * (XmZ + AXZ))
        
        if i < gs
            Q, diff = xadd(Q, step, diff), Q
        end
    end

    # Compute resultants
    res = similar(Ps)
    for (j, P) in enumerate(Ps)
        num_poly = make_tree(view(giants,:,j))[end]
        nums     = view(eval_tree(num_poly, tree), 1:bs>>1)
        den_poly = Nemo.reverse(num_poly)
        dens     = view(eval_tree(den_poly, tree), 1:bs>>1)
        res[j] = XZPoint(prod(r -> coeff(r,0), nums),
                         prod(r -> coeff(r,0), dens),
                         E)
    end
    num_poly =  make_tree(view(eq_giant,:,1))[end]
    nums = view(eval_tree(num_poly, tree), 1:bs>>1)
    num = prod(r -> coeff(r,0), nums)
    den_poly =  make_tree(view(eq_giant,:,2))[end]
    dens = view(eval_tree(den_poly, tree), 1:bs>>1)
    den = prod(r -> coeff(r,0), dens)

    # Add missing points
    Q = xdouble(ker)
    step, next = Q, xdouble(Q)
    stop = ell-2*bs*gs-1
    for i in 2:2:stop
        # @assert Q == i * ker
        QmZ = (Q.X - Q.Z)
        QpZ = (Q.X + Q.Z)
        num *= QmZ
        den *= QpZ
        for (j, P) in enumerate(Ps)
            mp = (P.X - P.Z) * QpZ
            pm = (P.X + P.Z) * QmZ
            res[j].X *= mp + pm
            res[j].Z *= mp - pm
        end
        if i < stop
            Q, next = next, xadd(next, step, Q)
        end
    end

    # Finish off
    num = (E.A - 2)^ell * num^8
    den = (E.A + 2)^ell * den^8
    E1 = Montgomery(2*(num + den) // (den - num), E.B)

    for (j, P) in enumerate(Ps)
        res[j].X ^= 2
        res[j].X *= P.X
        res[j].Z ^= 2
        res[j].Z *= P.Z
        res[j].curve = E1
    end

    return E1, res
end

"""
Computes subproduct tree of given list of leaves.

If the tree contains n leaves, the result is returned as a single
array of lenght 2n-1:

- The leaves are the first n elements;
- From start to end, every two contiguous elements are combined and
  the result is appended in the first free position;
- The last element is the root.

Example on n=7 leaves: 1 1 1 1 1 1 1 2 2 2 3 4 7
"""
function make_tree(leaves::AbstractArray)
    n = length(leaves)
    tree = similar(leaves, 2n-1)
    tree[1:n] = leaves
    for i in (1:n-1)
        tree[n + i] = tree[2i-1] * tree[2i]
    end
    return tree
end

"""
Reduce f modulo the product subtree tree.

The tree is assumed to be in the same format as returned by make_tree.
"""
function eval_tree(f::T, tree::AbstractArray{T}) where T
    result = similar(tree)
    result[end] = f % tree[end]
    n = div(length(tree) + 1, 2)
    for i in (n-1:-1:1)
        result[2i]   = result[n+i] % tree[2i]
        result[2i-1] = result[n+i] % tree[2i-1]
    end
    return result
end

################ CSIDH ##################

include("CSIDH.jl")
include("BSIDH.jl")

end # module
