import Nemo: FiniteField, characteristic, factor

export BSIDH_params, generator, walk

# Parameters

struct BSIDH_prime
    p::fmpz
    plus_one::Array{Tuple{Int,Int},1}
    minus_one::Array{Tuple{Int,Int},1}
end

function card(p::BSIDH_prime, go::Symbol)
    return go == :plus ? p.p + 1  : p.p - 1
end

function BSIDH_params()
    plus_one = [(2,32), (5,21), (7,1), (11,1), (163,1), (1181,1),
                (2389,1), (5233,1), (8353,1), (10139,1), (11939,1),
                (22003,1), (25391,1), (41843,1), (3726787,1), (6548911, 1)]

    minus_one = [(3,56), (31,1), (43,1), (59,1), (271,1),
                 (311,1), (353,1), (461,1), (593,1), (607,1), (647,1),
                 (691,1), (743,1), (769,1), (877,1), (1549,1),
                 (4721,1), (12433,1), (26449, 1)]

    p = prod(ZZ(a)^b for (a,b) in plus_one) - 1
    @assert isprime(p)
    @assert p - 2*prod(ZZ(a)^b for (a,b) in minus_one) == 1

    k = GF(p)
    A, I = PolynomialRing(k, "I")
    K, i = FiniteField(I^2 + 1, "i")
    E = Montgomery(K(0), K(1))

    return E, BSIDH_prime(p, plus_one, minus_one)
end

# Point sampling

function generator(E::Montgomery, p::BSIDH_prime, go::Symbol)
    k = base_ring(E)
    ord = go == :plus ? p.plus_one : p.minus_one
    cof = card(p, go)
    while(true)
        P = XZPoint(rand(k), k(1), E)
        if (isvalid(P) != (go == :plus) ||
            any(isinfinity((cof ÷ ell) * P) for (ell, e) in ord))
            continue
        end
        
        if go == :plus
            # exclude points that make 2-isogeny formula fail
            if ((cof ÷ 2) * P).X != 0
                return P
            end
        else
            return 2*P
        end
    end
end

# 2-isogeny

function two_isog(ker::XZPoint{T},
                  Ps::AbstractArray{XZPoint{T}, 1}=Array{XZPoint{T}}(undef, 0)) where T
    @assert ker.X != 0
    k = parent(ker.X)
    E = Montgomery(2*(1 - 2*(ker.X // ker.Z)^2), k(1))
    Ps = [XZPoint(P.X * (P.X * ker.X - P.Z * ker.Z), P.Z * (P.X * ker.Z - P.Z * ker.X), E)
          for P in Ps]
    return E, Ps
end

# Walk

function walk(ker::XZPoint{T}, p::BSIDH_prime, go::Symbol,
              Ps::AbstractArray{XZPoint{T}, 1}=Array{XZPoint{T}}(undef, 0),
              threshold::Number=100) where T
    E = ker.curve
    ord_fact = go == :plus ? p.plus_one : p.minus_one
    ord = card(p, go)
    Ps = vcat([ker], Ps)
    for (ell, e) in reverse(ord_fact)
        for i in 1:e
            ord = ord ÷ ell
            k = ord * Ps[1]
            @assert !isinfinity(k)
            if ell == 2
                if go == :plus
                    E, Ps = two_isog(k, Ps)
                end
            elseif ell > threshold
                E, Ps = VeluResultant(k, ell, Ps)
            else
                E, Ps = VeluRenes(k, ell, Ps)
            end
        end
    end
    return E, Ps[2:end]
end
