use std::{
    ops::{Add, Mul, Range, Sub},
    rc::Rc,
};

use merkle::Path;
use merlin::Transcript;
use primes::{LOG_PRIMES, NUM_PRIMES};
use rug::{Integer, integer::Order};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};

use serde_big_array::BigArray;

mod merkle;
mod primes;
mod prover;
mod verifier;

type Hash = [u8; 32];

#[derive(Serialize, Deserialize, Clone)]
struct Openings(#[serde(with = "BigArray")] [Path<Vec<u32>, LOG_PRIMES>; QUERIES]);

#[derive(Serialize, Deserialize, Clone)]
pub struct Proof {
    root: Hash,
    rand: Integer,
    open: Box<Openings>,
}

pub use prover::Prover;
pub use verifier::Verifier;

// challenge size
const CHAL_BYTES: usize = 32;

const SEP_CHALLG: &[u8] = b"chal";
const SEP_PREFIX: &[u8] = b"prefix";
const SEP_ROUND1: &[u8] = b"round1";
const SEP_ROUND2: &[u8] = b"round2";
const SEP_PROOFS: &[u8] = b"integer-proof";

pub const SEC_TOTAL: u32 = 256;
pub const BITS_TOTAL: usize = NUM_PRIMES * 31;
pub const MAX_BITS: usize = 50_000; // witness can be at most MAX_BITS bits large
pub const BLOWUP: usize = BITS_TOTAL / (MAX_BITS + 1);
pub const SEC_PER_QUERY: u32 = BLOWUP.ilog2();
pub const QUERIES: usize = SEC_TOTAL.div_ceil(SEC_PER_QUERY) as usize;

fn check_combination_norm(n: &Integer, wits: usize) -> Result<(), anyhow::Error> {
    norm(n, wits.ilog2() as usize + 1 + CHAL_BYTES * 8 + MAX_BITS)
}

fn regular_norm(n: &Integer) -> Result<(), anyhow::Error> {
    norm(n, MAX_BITS)
}

fn norm(n: &Integer, bits: usize) -> Result<(), anyhow::Error> {
    anyhow::ensure!(n.clone().abs() < (Integer::from(2) << bits));
    Ok(())
}

fn hash<T: Serialize>(value: &T) -> Hash {
    let data = bincode::serialize(value).unwrap();
    Sha256::digest(&data).into()
}

pub trait CS {
    fn var(&mut self, value: Option<Integer>) -> Result<Expr, anyhow::Error>;

    fn public(&mut self, value: Integer) -> Result<Expr, anyhow::Error> {
        regular_norm(&value)?;
        let cnst = Expr::cnst(value);
        self.eq(cnst.clone(), cnst.clone())?;
        Ok(cnst)
    }

    fn zero(&mut self, expr: Expr) -> Result<(), anyhow::Error>;

    fn eq(&mut self, lhs: Expr, rhs: Expr) -> Result<(), anyhow::Error> {
        self.zero(&lhs - &rhs)
    }
}

#[derive(Debug, Clone)]
enum ExprInner {
    Var(usize),
    Add(Expr, Expr),
    Mul(Expr, Expr),
    Cst(Integer),
}

#[derive(Debug, Clone)]
pub struct Expr(Rc<ExprInner>);

fn affine<T: Into<Integer>, I: Iterator<Item = (T, Expr)>>(terms: I, cnst: T) -> Expr {
    terms.fold(Expr::cnst(cnst), |acc, (coeff, term)| {
        acc + Expr::cnst(coeff) * term
    })
}

impl Expr {
    pub fn cnst<T: Into<Integer>>(value: T) -> Self {
        Expr(Rc::new(ExprInner::Cst(value.into())))
    }

    pub fn var(idx: usize) -> Self {
        Expr(Rc::new(ExprInner::Var(idx)))
    }

    pub fn eval(&self, assign: &[u32], modulus: u32) -> Result<u32, anyhow::Error> {
        match self.0.as_ref() {
            ExprInner::Cst(c) => {
                let r = c.clone().modulo(&Integer::from(modulus));
                Ok(r.try_into().unwrap())
            }
            ExprInner::Var(i) => {
                anyhow::ensure!(assign[*i] < modulus);
                Ok(assign[*i])
            }
            ExprInner::Add(l, r) => {
                let lhs = l.eval(assign, modulus)? as u64;
                let rhs = r.eval(assign, modulus)? as u64;
                let res = (lhs + rhs) % modulus as u64;
                Ok(res as u32)
            }
            ExprInner::Mul(l, r) => {
                let lhs = l.eval(assign, modulus)? as u64;
                let rhs = r.eval(assign, modulus)? as u64;
                let res = (lhs * rhs) % modulus as u64;
                Ok(res as u32)
            }
        }
    }
}

impl<T: Into<Integer>> From<T> for Expr {
    fn from(value: T) -> Self {
        Expr(Rc::new(ExprInner::Cst(value.into())))
    }
}

impl Expr {
    fn hash(&self) -> Hash {
        #[derive(Serialize)]
        enum Enc {
            Var(usize),
            Cst(Integer),
            Add(Hash, Hash),
            Mul(Hash, Hash),
        }
        hash(&match self.0.as_ref() {
            ExprInner::Var(i) => Enc::Var(*i),
            ExprInner::Cst(v) => Enc::Cst(v.clone()),
            ExprInner::Add(lhs, rhs) => Enc::Add(lhs.hash(), rhs.hash()),
            ExprInner::Mul(lhs, rhs) => Enc::Mul(lhs.hash(), rhs.hash()),
        })
    }
}

impl Add for &Expr {
    type Output = Expr;
    fn add(self, other: Self) -> Self::Output {
        Expr(Rc::new(ExprInner::Add(self.clone(), other.clone())))
    }
}

impl Add for Expr {
    type Output = Expr;
    fn add(self, other: Self) -> Self::Output {
        &self + &other
    }
}

impl Mul for &Expr {
    type Output = Expr;
    fn mul(self, other: Self) -> Self::Output {
        Expr(Rc::new(ExprInner::Mul(self.clone(), other.clone())))
    }
}

impl Mul for Expr {
    type Output = Expr;
    fn mul(self, other: Self) -> Self::Output {
        &self * &other
    }
}

impl Sub for &Expr {
    type Output = Expr;
    fn sub(self, other: &Expr) -> Self::Output {
        let m1: Expr = (-1).into();
        self + &(&m1 * other)
    }
}

impl Sub for Expr {
    type Output = Expr;
    fn sub(self, other: Expr) -> Self::Output {
        &self - &other
    }
}

fn challenge_int(tx: &mut Transcript, label: &'static [u8], bytes: usize) -> Integer {
    let mut bs = vec![0; bytes];
    tx.challenge_bytes(label, &mut bs);
    Integer::from_digits(&bs, Order::LsfLe)
}

fn challenge_idx(tx: &mut Transcript, label: &'static [u8], range: Range<usize>) -> usize {
    assert!(range.end > range.start, "empty set");
    let sz = range.end - range.start;
    range
        .start
        .checked_add({
            let mut v = [0; std::mem::size_of::<u128>()];
            tx.challenge_bytes(label, &mut v);
            (u128::from_be_bytes(v) % sz as u128) as usize
        })
        .unwrap()
}

#[derive(Serialize)]
struct Prefix {
    msg: Vec<u8>,
    vars: usize,
    zero: Vec<Hash>,
}
