use rug::{Complete, Integer};

use crate::{
    CHAL_BYTES, CS, Expr, Openings, Prefix, Proof, QUERIES, SEP_CHALLG, SEP_PREFIX, SEP_PROOFS,
    SEP_ROUND1, SEP_ROUND2, affine, challenge_idx, challenge_int, check_combination_norm,
    merkle::Tree, primes::PRIMES, regular_norm,
};

#[derive(Default)]
pub struct Prover {
    wits: Vec<Integer>,
    zero: Vec<Expr>,
    vars: Vec<Expr>,
}

fn commit(elems: &[Integer]) -> Tree<Vec<u32>> {
    let mut leafs = vec![Vec::new(); PRIMES.len()];
    for n in elems {
        regular_norm(n).expect("witness too large norm");
        let limbs = PRIMES.iter().map(|p| {
            let l = n % p;
            let l: u32 = l.complete().try_into().unwrap();
            l
        });
        for (leaf, elem) in leafs.iter_mut().zip(limbs) {
            leaf.push(elem);
        }
    }
    Tree::new(&leafs[..])
}

impl Prover {
    pub fn prove(mut self, msg: &[u8]) -> Result<Proof, anyhow::Error> {
        // create commitment to witnesses
        let tree = commit(&self.wits);

        // create the transcript
        let mut tx = merlin::Transcript::new(SEP_PROOFS);
        tx.append_message(
            SEP_PREFIX,
            &bincode::serialize(&Prefix {
                msg: msg.to_vec(),
                vars: self.vars.len(),
                zero: self.zero.iter().map(|expr| expr.hash()).collect(),
            })?,
        );

        // send the first round message
        tx.append_message(SEP_ROUND1, &tree.root());

        // squeeze the linear combination
        let challenges: Vec<Integer> = self
            .wits
            .iter()
            .map(|_| challenge_int(&mut tx, SEP_CHALLG, CHAL_BYTES))
            .collect();

        // compute the random combination
        let rand: Integer = challenges
            .iter()
            .cloned()
            .zip(self.wits.iter())
            .map(|(c, w)| c * w)
            .sum();

        // check norm of combination
        check_combination_norm(&rand, self.wits.len())?;

        // send round 2 message
        tx.append_message(SEP_ROUND2, &bincode::serialize(&rand)?);

        // add constraint for combination
        assert_eq!(self.vars.len(), self.wits.len());
        self.zero(affine(
            challenges.into_iter().zip(self.vars.iter().cloned()),
            -rand.clone(),
        ))?;

        // sample opening positions
        let poss = [(); QUERIES].map(|_| challenge_idx(&mut tx, SEP_CHALLG, 0..PRIMES.len()));

        // open required positions
        Ok(Proof {
            root: tree.root(),
            open: Box::new(Openings(poss.map(|pos| tree.open(pos)))),
            rand,
        })
    }
}

impl CS for Prover {
    fn var(&mut self, value: Option<Integer>) -> Result<Expr, anyhow::Error> {
        let var = Expr::var(self.vars.len());
        self.wits.push(value.unwrap());
        self.vars.push(var.clone());
        Ok(var)
    }

    fn zero(&mut self, expr: Expr) -> Result<(), anyhow::Error> {
        self.zero.push(expr.clone());
        Ok(())
    }
}
