import Base: ^
import Base.Iterators: drop, reverse, take, cycle
import Nemo: ZZ, fmpz, fmpz_mod_ctx_struct, GF, GaloisFmpzField, gfp_fmpz_elem, isprime, RingElement, order
import Primes: primes

export CSIDH_setup, CSURF_setup, CSIDH_sk, CSURF_sk, CSIDH_sk_max, CSURF_sk_max, CSIDH_sk_alt, CSURF_sk_alt, CSIDH_action

struct CSIDH_setup{T<:RingElement}
    p::fmpz
    fact::Vector{Int64}
    E::Montgomery{T}
end

######## Parameter sets

function CSIDH_setup()
    fact = vcat([4], primes(3,373), [587])
    p = prod(ZZ, fact) - 1
    @assert isprime(p)
    k = GF(p)
    
    return CSIDH_setup(p, fact, Montgomery(k(0), k(1)))
end

function CSURF_setup()
    fact = vcat([8, 9], primes(5,389))
    filter!(p -> p != 347 && p != 359, fact)
    p = prod(ZZ, fact) - 1
    @assert isprime(p)
    k = GF(p)
    
    return CSIDH_setup(p, fact, Montgomery(k(0), k(1)))
end

######## Elligator2, efficient sampling of points

# Filling in missing Nemo definition
function ^(x::gfp_fmpz_elem, y::fmpz)
   R = parent(x)
   if y < 0
      x = inv(x)
      y = -y
   end
   d = fmpz()
   ccall((:fmpz_mod_pow_fmpz, :libflint), Nothing,
	 (Ref{fmpz}, Ref{fmpz}, Ref{fmpz}, Ref{fmpz_mod_ctx_struct}),
						 d, x.data, y, R.ninv)
   return gfp_fmpz_elem(d, R)
end

function Elligator2(E::Montgomery{T}, u::T) where T
    @assert u != 0 && u != 1 && u != -1
    k = base_ring(E)
    p = order(k)
    @assert p % 4 == 3
    if E.A == 0
        P, Q = XZPoint(u, one(k), E), XZPoint(-u, one(k), E)
        t = u^3 + u
    else
        u2 = u^2
        u21 = u2 - 1
        Au2 = E.A*u2
        P, Q = XZPoint(E.A, u21, E), XZPoint(-Au2, u21, E)
        t = (E.A*Au2 + u21^2)*E.A*u21
    end
    if t^((p-1) ÷ 2) == 1
        return (P, Q)
    else
        return (Q, P)
    end
end

######## Group action

function CSIDH_advance(params::CSIDH_setup{T}, P::XZPoint{T}, ord::fmpz, vec::Vector{Int},
                       thresh::Int) where T
    @assert vec[1] == 0

    E = P.curve
    cof = ZZ(1)
    go = false
    # All degrees except 2, 3
    for i in length(vec):-1:3
        ell = params.fact[i]
        if vec[i] > 0
            ord ÷= ell
            ker = ord * P
            if !isinfinity(ker)
                if ell < thresh
                    E, (P,) = VeluRenes(ker, ell, [P])
                else
                    E, (P,) = VeluResultant(ker, ell, [P])
                end
                vec[i] -= 1
                if vec[i] == 0
                    cof *= ell
                end
            end
        end
        if vec[i] > 0
            go = true
        end
    end
    
    # Degree 3
    while vec[2] > 0 && !isinfinity(P)
        ord = ord ÷ 3
        ker = ord * P
        if !isinfinity(ker)
            E, (P,) = VeluRenes(ker, 3, [P])
            vec[2] -= 1
            if vec[2] == 0
                cof *= params.fact[2]
            end
        end
    end
    if vec[2] > 0
        go = true
    end
        
    return E, cof, go
end

function CSIDH_action(params::CSIDH_setup{T}, E::Montgomery{T}, sk::Vector{Int};
                      thresh::Int=250) where T
    @assert length(sk) == length(params.fact)
    gf = base_ring(E)
    left  = map(x -> x > 0 ?  x : 0, sk)
    right = map(x -> x < 0 ? -x : 0, sk)

    # Handle degree 2
    A = E.A
    if sk[1] < 0
        A = -A
    end
    sq = (params.p + 1) ÷ 4
    # This formula describes a two-cycle on the floor, parallel to the surface
    for i in 1:sk[1]
        A = 2 - 4*((2+A)^sq - 2)^4 // (2-A)^2
    end
    if sk[1] < 0
        A = -A
    end
    E = Montgomery(A, E.B)
    left[1], right[1] = 0, 0
    
    # Handle other degrees
    cof = vec -> prod(ZZ, e == 0 ? ell : 1 for (ell, e) in zip(params.fact, vec))
    left_cof, right_cof = cof.((left, right))
    left_ord, right_ord = (params.p ÷ left_cof) + 1, (params.p ÷ right_cof) + 1
    go_left, go_right = any(x -> x > 0, left), any(x -> x > 0, right)

    while go_left || go_right
        u = rand(gf)
        if u == 0 || u == 1 || u == -1
            continue
        end
        P, Q = Elligator2(E, u)
        if go_left
            P = left_cof * P
            E, nc, go_left = CSIDH_advance(params, P, left_ord, left, thresh)
            left_cof *= nc
            left_ord ÷= nc
        elseif go_right
            Q = right_cof * Q
            E, nc, go_right = CSIDH_advance(params, Q, right_ord, right, thresh)
            right_cof *= nc
            right_ord ÷= nc
        end
    end

    return E
end

function CSIDH_action(params::CSIDH_setup{T}, sk::Vector{Int};
                      thresh::Int=250) where T
    return CSIDH_action(params, params.E, sk; thresh=thresh)
end

######## Sampling secret keys

function CSIDH_sk(params::CSIDH_setup)
    return vcat([0], rand(-5:5, length(params.fact) - 1))
end

function CSIDH_sk_max(params::CSIDH_setup)
    return vcat([0], repeat([5], length(params.fact) - 1))
end

function CSIDH_sk_alt(params::CSIDH_setup)
    return vcat([0], collect(take(cycle([5,-5]), length(params.fact) - 1)))
end

function CSURF_sk(params::CSIDH_setup)
    sk = vcat(rand(-137:137, 1), rand(-4:4, 3), rand(-5:5, 46), rand(-3:3, 25))
    @assert length(sk) == length(params.fact)
    return sk
end

function CSURF_sk_max(params::CSIDH_setup)
    sk = vcat([137], repeat([4], 3), repeat([5], 46), repeat([3], 25))
    @assert length(sk) == length(params.fact)
    return sk
end

function CSURF_sk_alt(params::CSIDH_setup)
    sk = vcat([137], [4,-4,4], collect(take(cycle([-5,5]), 46)), collect(take(cycle([-3,3]), 25)))
    @assert length(sk) == length(params.fact)
    return sk
end
