####################################################################################
### This code computes a branching matrix program with l matrices of dimension m
### composed of only identity matrices. Then, it randomized the branching program
### by multiplying the matrices with random bundling scalars and random matrices
### with coefficients in R = Z[X]/(X^n+1).
### It then randomly multiply the matrices of the branching program (in the good 
### order but choosing A_{i,0} or A_{i,1} at random) and divide the resulting
### element by g. After obtaining many such post-zero-test values, it check whether
### they generate the full ring R or not.
### The result res contains a list of different values of g, together with the
### empirical property for each g that the post-zero-test values generate R.
####################################################################################


reset()
n = 32
m = 5
l = 20

ZX.<y> = PolynomialRing(ZZ)
I = ZX.ideal([y^n+1])
R = ZX.quotient_ring(I)

def matrix_from_vect(g,n):
    ## compute the matrix representing the multiplication by g in the ring R 
    ## when seen as a lattice
    M = matrix(n)
    for i in range(n):
        for j in range(n):
            M[i,j] = list(g)[(j-i)%n]
            if i > j:
                M[i,j] = - M[i,j]
    return M

def reduce_mod_g(Mg,Mg_inv,a):
    ## reduce a mod g in R, using Babai rounding
    b = vector(a)*Mg_inv
    for i in range(len(b)):
        b[i] = round(b[i])
    res_vect = b*Mg
    res = R(0)
    for i in range(len(res_vect)):
        res += R(Integer(floor(res_vect[i]))*y^i)
    return a-res


def div_by_g(Mg,Mg_inv,a):
    ## compute the division of a by g in the ring R (print a warning if a 
    ## is not divisible by g but do not fail)
    b = vector(a)*Mg_inv
    for i in range(len(b)):
        b[i] = round(b[i])
    res_vect = b*Mg
    res = R(0)
    for i in range(len(res_vect)):
        res += R(Integer(floor(res_vect[i]))*y^i)
    if a-res != R(0):
        print "division non exact"
    res_b = R(0)
    for i in range(len(b)):
        res_b += R(Integer(floor(b[i]))*y^i)
    return res_b


def generate_R(list_elms,n):
    ## test whether the ideal generated by a list of elements of R is the full ring R or not
    vect = []
    for a in list_elms:
        for b in matrix_from_vect(a,n):
            vect += [b]
    mat = Matrix(vect)
    J = image(mat)
    #print "\n", det(matrix(J.basis()))
    #print "\n J = ", J
    one = vector([1]+[0]*(n-1))
    return one in J


def test_fixed_g(g,l,m,n,nb_tests,nb_zt):
    ## take as input an element g (for the GGH map), n the dimension of the ring, 
    ## m the dimension of the matrices in the BP, l the number of matrices in the BP,
    ## nb_tests the number of times to run the obfuscator (the matrices R and the 
    ## scalars alpha are chosen at random each time)
    ## nb_zt the number of post_zero test element to create to try to generate the whole set R
    ## outputs the empirical probability that the post-zero-test elements generated the set R,
    ## computed over all the nb_test instances
    
    
    Mg = Matrix(QQ,matrix_from_vect(g,n))
    Mg_inv = Mg^(-1)
    res = 0
    for test in range(nb_tests):
        ## generation of a branching program composed of only identity matrices
        list_M = []
        list_M += [vector(R,[R(1)]+[R(0)]*(m-1))]
        for i in range(l):
            tmp_mat = Matrix(R,m)
            for i in range(m):
                tmp_mat[i,i] = R(1)
            tmp_vect = [tmp_mat]
            tmp_mat = Matrix(R,m)
            for i in range(m):
                tmp_mat[i,i] = R(1)
            tmp_vect += [tmp_mat]
            list_M += [tmp_vect]
        list_M += [vector(R,[R(0)]*(m-1)+[R(1)])]
        ## generation of the random matrices R_i
        list_R = []
        for i in range(l+1):
            tmp_mat = Matrix(R,m)
            for i in range(m):
                for j in range(m):
                    tmp_mat[i,j] = R.random_element()
            list_R += [tmp_mat]
        list_R_adj = []
        for i in range(l+1):
            list_R_adj += [list_R[i].adjoint()]
        ## generation of the boundling scalars alpha
        list_alpha = [[R.random_element(),R.random_element()] for i in range(l+2)]
        

        ## multiplication of the BP with the scalars alpha and the matrices R
        list_matrix_encoded = []
        list_matrix_encoded += [list_alpha[0][0]*list_M[0]*list_R[0]]
        for i in range(l):
            list_matrix_encoded += [[list_alpha[i+1][0]*list_R_adj[i]*list_M[i+1][0]*list_R[i+1],list_alpha[i+1][1]*list_R_adj[i]*list_M[i+1][1]*list_R[i+1]]]
        list_matrix_encoded += [list_alpha[l+1][0]*list_R_adj[l]*list_M[l+1]]
        

        ##reduction of the coefficients of the matrices mod g (to obtain short elements)
        for i in range(m):
            list_matrix_encoded[0][i] = reduce_mod_g(Mg,Mg_inv,list_matrix_encoded[0][i])
        for k in range(l):
            for b in range(2):
                for i in range(m):
                    for j in range(m):
                        list_matrix_encoded[k+1][b][i,j] = reduce_mod_g(Mg,Mg_inv,list_matrix_encoded[k+1][b][i,j])
        for i in range(m):
            list_matrix_encoded[l+1][i] = reduce_mod_g(Mg,Mg_inv,list_matrix_encoded[l+1][i])
            
        ## generates nb_zt random encodings and computes the value obtained after 
        ## zero-testing them (without the multiplication by h)
        zero_tested_values = []
        for zt in range(nb_zt):
            list_b = []
            c = list_matrix_encoded[0]
            for i in range(l):
                list_b += [ZZ.random_element(2)]
                c = c*list_matrix_encoded[i+1][list_b[-1]]
            c = c*list_matrix_encoded[l+1]
            zero_tested_values += [div_by_g(Mg,Mg_inv,c)]
            
        ## test whether the elements in zero_tested_value generate the whole ring R
        #print "zero_tested_values = ", zero_tested_values
        if generate_R(zero_tested_values,n):
            res += 1
        
    return RR(res/nb_tests)


def test_different_g(l,m,n,nb_g,nb_tests,nb_zt):
    ## repeat test_fixed_g for nb_g different values of g chosen randomly
    ## output a list of length nb_g containing the values of all test_fixed_g 
    ## obtained for the different values of g + the value of the corresponding g
    res = []
    for i in range(nb_g):
        ## sample g until its algebraic norm is prime
        g = R.random_element()
        print "g = ", g
        d = matrix_from_vect(g,n).det()
        while not d in Primes():
            g = R.random_element()
            d = matrix_from_vect(g,n).det()
        res += [[g,test_fixed_g(g,l,m,n,nb_tests,nb_zt)]]
    return res



time res = test_different_g(l,m,n,5,5,10)
## takes around 110 seconds on a personnal laptop
