diff --git a/gkr-graph/examples/series_connection_alt.rs b/gkr-graph/examples/series_connection_alt.rs index 713798a32..453121f0c 100644 --- a/gkr-graph/examples/series_connection_alt.rs +++ b/gkr-graph/examples/series_connection_alt.rs @@ -7,8 +7,8 @@ use gkr::{ use gkr_graph::{ error::GKRGraphError, structs::{ - CircuitGraphAuxInfo, CircuitGraphBuilder, IOPProverState, IOPVerifierState, NodeOutputType, - PredType, TargetEvaluations, + CircuitGraphAuxInfo, CircuitGraphBuilder, IOPProverState, IOPVerifierState, NodeOutputType, PredType, + TargetEvaluations, }, }; use goldilocks::{Goldilocks, GoldilocksExt2}; @@ -16,10 +16,7 @@ use simple_frontend::structs::{ChallengeId, CircuitBuilder, MixedCell}; use std::sync::Arc; use transcript::Transcript; -fn construct_input( - input_size: usize, - challenge: ChallengeId, -) -> Arc> { +fn construct_input(input_size: usize, challenge: ChallengeId) -> Arc> { let mut circuit_builder = CircuitBuilder::::new(); let (_, inputs) = circuit_builder.create_witness_in(input_size); let (_, lookup_inputs) = circuit_builder.create_ext_witness_out(input_size); @@ -34,10 +31,7 @@ fn construct_input( /// Construct a selector for n_instances and each instance contains `num` /// items. `num` must be a power of 2. -pub(crate) fn construct_prefix_selector( - n_instances: usize, - num: usize, -) -> Arc> { +pub(crate) fn construct_prefix_selector(n_instances: usize, num: usize) -> Arc> { assert_eq!(num, num.next_power_of_two()); let mut circuit_builder = CircuitBuilder::::new(); let _ = circuit_builder.create_constant_in(n_instances * num, 1); @@ -59,12 +53,7 @@ pub(crate) fn construct_inv_sum() -> Arc> { let den_mul = circuit_builder.create_ext_cell(); circuit_builder.mul2_ext(&den_mul, &input[0], &input[1], E::BaseField::ONE); let tmp = circuit_builder.create_ext_cell(); - circuit_builder.sel_mixed_and_ext( - &tmp, - &MixedCell::Constant(E::BaseField::ONE), - &input[0], - cond[0], - ); + circuit_builder.sel_mixed_and_ext(&tmp, &MixedCell::Constant(E::BaseField::ONE), &input[0], cond[0]); circuit_builder.sel_ext(&output[0], &tmp, &den_mul, cond[1]); // select the numerator 0 or 1 or input[0] + input[1] @@ -143,11 +132,7 @@ fn main() -> Result<(), GKRGraphError> { let mut prover_graph_builder = CircuitGraphBuilder::::new(); let mut verifier_graph_builder = CircuitGraphBuilder::::new(); let mut prover_transcript = Transcript::::new(b"test"); - let challenge = vec![ - prover_transcript - .get_and_append_challenge(b"lookup challenge") - .elements, - ]; + let challenge = vec![prover_transcript.get_and_append_challenge(b"lookup challenge").elements]; let mut add_node_and_witness = |label: &'static str, circuit: &Arc>, @@ -225,12 +210,8 @@ fn main() -> Result<(), GKRGraphError> { // Proofs generation // ================= let output_point = vec![ - prover_transcript - .get_and_append_challenge(b"output point") - .elements, - prover_transcript - .get_and_append_challenge(b"output point") - .elements, + prover_transcript.get_and_append_challenge(b"output point").elements, + prover_transcript.get_and_append_challenge(b"output point").elements, ]; let output_eval = circuit_witness .node_witnesses @@ -259,12 +240,8 @@ fn main() -> Result<(), GKRGraphError> { .elements]; let output_point = vec![ - verifier_transcript - .get_and_append_challenge(b"output point") - .elements, - verifier_transcript - .get_and_append_challenge(b"output point") - .elements, + verifier_transcript.get_and_append_challenge(b"output point").elements, + verifier_transcript.get_and_append_challenge(b"output point").elements, ]; IOPVerifierState::verify( diff --git a/gkr-graph/src/circuit_builder.rs b/gkr-graph/src/circuit_builder.rs index acf215f9a..96a470071 100644 --- a/gkr-graph/src/circuit_builder.rs +++ b/gkr-graph/src/circuit_builder.rs @@ -8,11 +8,7 @@ use itertools::Itertools; use crate::structs::{CircuitGraph, CircuitGraphWitness, NodeOutputType, TargetEvaluations}; impl CircuitGraph { - pub fn target_evals( - &self, - witness: &CircuitGraphWitness, - point: &Point, - ) -> TargetEvaluations { + pub fn target_evals(&self, witness: &CircuitGraphWitness, point: &Point) -> TargetEvaluations { // println!("targets: {:?}, point: {:?}", self.targets, point); let target_evals = self .targets @@ -24,8 +20,8 @@ impl CircuitGraph { .instances .as_slice() .original_mle(), - NodeOutputType::WireOut(node_id, wit_id) => witness.node_witnesses[*node_id] - .witness_out_ref()[*wit_id as usize] + NodeOutputType::WireOut(node_id, wit_id) => witness.node_witnesses[*node_id].witness_out_ref() + [*wit_id as usize] .instances .as_slice() .original_mle(), diff --git a/gkr-graph/src/circuit_graph_builder.rs b/gkr-graph/src/circuit_graph_builder.rs index 5ae1854cc..6b79ae917 100644 --- a/gkr-graph/src/circuit_graph_builder.rs +++ b/gkr-graph/src/circuit_graph_builder.rs @@ -9,8 +9,7 @@ use simple_frontend::structs::WitnessId; use crate::{ error::GKRGraphError, structs::{ - CircuitGraph, CircuitGraphBuilder, CircuitGraphWitness, CircuitNode, NodeInputType, - NodeOutputType, PredType, + CircuitGraph, CircuitGraphBuilder, CircuitGraphWitness, CircuitNode, NodeInputType, NodeOutputType, PredType, }, }; @@ -45,16 +44,13 @@ impl CircuitGraphBuilder { assert!(num_instances.is_power_of_two()); assert_eq!(sources.len(), circuit.n_witness_in); assert!( - !sources.iter().any( - |source| source.instances.len() != 0 && source.instances.len() != num_instances - ), + !sources + .iter() + .any(|source| source.instances.len() != 0 && source.instances.len() != num_instances), "node_id: {}, num_instances: {}, sources_num_instances: {:?}", id, num_instances, - sources - .iter() - .map(|source| source.instances.len()) - .collect_vec() + sources.iter().map(|source| source.instances.len()).collect_vec() ); let mut witness = CircuitWitness::new(circuit, challenges); @@ -65,14 +61,11 @@ impl CircuitGraphBuilder { let (id, out) = &match out { NodeOutputType::OutputLayer(id) => ( *id, - &self.witness.node_witnesses[*id] - .output_layer_witness_ref() - .instances, + &self.witness.node_witnesses[*id].output_layer_witness_ref().instances, ), NodeOutputType::WireOut(id, wit_id) => ( *id, - &self.witness.node_witnesses[*id].witness_out_ref()[*wit_id as usize] - .instances, + &self.witness.node_witnesses[*id].witness_out_ref()[*wit_id as usize].instances, ), }; let old_num_instances = self.witness.node_witnesses[*id].n_instances(); @@ -94,10 +87,7 @@ impl CircuitGraphBuilder { out.iter() .cloned() .flat_map(|single_instance| { - single_instance - .into_iter() - .cycle() - .take(num_dups * old_size) + single_instance.into_iter().cycle().take(num_dups * old_size) }) .chunks(old_size) .into_iter() @@ -146,9 +136,7 @@ impl CircuitGraphBuilder { } /// Collect the information of `self.sources` and `self.targets`. - pub fn finalize_graph_and_witness( - mut self, - ) -> (CircuitGraph, CircuitGraphWitness) { + pub fn finalize_graph_and_witness(mut self) -> (CircuitGraph, CircuitGraphWitness) { // Generate all possible graph output let outs = self .graph @@ -244,10 +232,7 @@ impl CircuitGraphBuilder { }, ); - assert_eq!( - expected_target, - targets.iter().cloned().collect::>() - ); + assert_eq!(expected_target, targets.iter().cloned().collect::>()); self.graph.sources = sources.into_iter().collect(); self.graph.targets = targets.to_vec(); diff --git a/gkr-graph/src/prover.rs b/gkr-graph/src/prover.rs index 74cbcf0eb..d1ded8e5c 100644 --- a/gkr-graph/src/prover.rs +++ b/gkr-graph/src/prover.rs @@ -7,8 +7,8 @@ use transcript::Transcript; use crate::{ error::GKRGraphError, structs::{ - CircuitGraph, CircuitGraphWitness, GKRProverState, IOPProof, IOPProverState, - NodeOutputType, PredType, TargetEvaluations, + CircuitGraph, CircuitGraphWitness, GKRProverState, IOPProof, IOPProverState, NodeOutputType, PredType, + TargetEvaluations, }, }; diff --git a/gkr-graph/src/verifier.rs b/gkr-graph/src/verifier.rs index 7094dfb85..c6357c3a4 100644 --- a/gkr-graph/src/verifier.rs +++ b/gkr-graph/src/verifier.rs @@ -7,8 +7,8 @@ use transcript::Transcript; use crate::{ error::GKRGraphError, structs::{ - CircuitGraph, CircuitGraphAuxInfo, GKRVerifierState, IOPProof, IOPVerifierState, - NodeOutputType, PredType, TargetEvaluations, + CircuitGraph, CircuitGraphAuxInfo, GKRVerifierState, IOPProof, IOPVerifierState, NodeOutputType, PredType, + TargetEvaluations, }, }; @@ -50,56 +50,50 @@ impl IOPVerifierState { let new_instance_num_vars = aux_info.instance_num_vars[node.id]; - izip!(&node.preds, input_claim.point_and_evals).for_each( - |(pred_type, point_and_eval)| { - match pred_type { - PredType::Source => { - // TODO: collect `(proof.point.clone(), *eval)` as `TargetEvaluations` - // for later PCS open? - } - PredType::PredWire(pred_out) | PredType::PredWireDup(pred_out) => { - let point = match pred_type { - PredType::PredWire(_) => point_and_eval.point.clone(), - PredType::PredWireDup(out) => { - let node_id = match out { - NodeOutputType::OutputLayer(id) => *id, - NodeOutputType::WireOut(id, _) => *id, - }; - // Suppose the new point is - // [single_instance_slice || - // new_instance_index_slice]. The old point - // is [single_instance_slices || - // new_instance_index_slices[(new_instance_num_vars - // - old_instance_num_vars)..]] - let old_instance_num_vars = aux_info.instance_num_vars[node_id]; - let num_vars = - point_and_eval.point.len() - new_instance_num_vars; - [ - point_and_eval.point[..num_vars].to_vec(), - point_and_eval.point[num_vars - + (new_instance_num_vars - old_instance_num_vars)..] - .to_vec(), - ] - .concat() - } - _ => unreachable!(), - }; - match pred_out { - NodeOutputType::OutputLayer(id) => output_evals[*id] - .push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)), - NodeOutputType::WireOut(id, wire_id) => { - let evals = &mut wit_out_evals[*id][*wire_id as usize]; - assert!( - evals.point.is_empty() && evals.eval.is_zero_vartime(), - "unimplemented", - ); - *evals = PointAndEval::new(point, point_and_eval.eval); - } + izip!(&node.preds, input_claim.point_and_evals).for_each(|(pred_type, point_and_eval)| { + match pred_type { + PredType::Source => { + // TODO: collect `(proof.point.clone(), *eval)` as `TargetEvaluations` + // for later PCS open? + } + PredType::PredWire(pred_out) | PredType::PredWireDup(pred_out) => { + let point = match pred_type { + PredType::PredWire(_) => point_and_eval.point.clone(), + PredType::PredWireDup(out) => { + let node_id = match out { + NodeOutputType::OutputLayer(id) => *id, + NodeOutputType::WireOut(id, _) => *id, + }; + // Suppose the new point is + // [single_instance_slice || + // new_instance_index_slice]. The old point + // is [single_instance_slices || + // new_instance_index_slices[(new_instance_num_vars + // - old_instance_num_vars)..]] + let old_instance_num_vars = aux_info.instance_num_vars[node_id]; + let num_vars = point_and_eval.point.len() - new_instance_num_vars; + [ + point_and_eval.point[..num_vars].to_vec(), + point_and_eval.point[num_vars + (new_instance_num_vars - old_instance_num_vars)..] + .to_vec(), + ] + .concat() + } + _ => unreachable!(), + }; + match pred_out { + NodeOutputType::OutputLayer(id) => { + output_evals[*id].push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)) + } + NodeOutputType::WireOut(id, wire_id) => { + let evals = &mut wit_out_evals[*id][*wire_id as usize]; + assert!(evals.point.is_empty() && evals.eval.is_zero_vartime(), "unimplemented",); + *evals = PointAndEval::new(point, point_and_eval.eval); } } } - }, - ); + } + }); } Ok(()) diff --git a/gkr/benches/keccak256.rs b/gkr/benches/keccak256.rs index b27b37e14..20b92bf02 100644 --- a/gkr/benches/keccak256.rs +++ b/gkr/benches/keccak256.rs @@ -33,10 +33,7 @@ const NUM_SAMPLES: usize = 10; const RAYON_NUM_THREADS: usize = 8; fn bench_keccak256(c: &mut Criterion) { - println!( - "#layers: {}", - keccak256_circuit::().layers.len() - ); + println!("#layers: {}", keccak256_circuit::().layers.len()); let max_thread_id = { if !is_power_of_2(RAYON_NUM_THREADS) { @@ -74,10 +71,7 @@ fn bench_keccak256(c: &mut Criterion) { BenchmarkId::new("prove_keccak256", format!("keccak256_log2_{}", log2_n)), |b| { b.iter(|| { - assert!( - prove_keccak256(log2_n, &circuit, (1 << log2_n).min(max_thread_id),) - .is_some() - ); + assert!(prove_keccak256(log2_n, &circuit, (1 << log2_n).min(max_thread_id),).is_some()); }); }, ); diff --git a/gkr/examples/keccak256.rs b/gkr/examples/keccak256.rs index a105e0930..7756751c7 100644 --- a/gkr/examples/keccak256.rs +++ b/gkr/examples/keccak256.rs @@ -16,10 +16,7 @@ use tracing_flame::FlameLayer; use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry}; fn main() { - println!( - "#layers: {}", - keccak256_circuit::().layers.len() - ); + println!("#layers: {}", keccak256_circuit::().layers.len()); #[allow(unused_mut)] let mut max_thread_id: usize = env::var("RAYON_NUM_THREADS") @@ -29,9 +26,7 @@ fn main() { if !is_power_of_2(max_thread_id) { #[cfg(not(feature = "non_pow2_rayon_thread"))] { - panic!( - "add --features non_pow2_rayon_thread to support non pow of 2 rayon thread pool" - ); + panic!("add --features non_pow2_rayon_thread to support non pow of 2 rayon thread pool"); } #[cfg(feature = "non_pow2_rayon_thread")] @@ -57,17 +52,12 @@ fn main() { witness.add_instance(&circuit, all_zero); witness.add_instance(&circuit, all_one); - izip!( - &witness.witness_out_ref()[0].instances, - [[0; 25], [u64::MAX; 25]] - ) - .for_each(|(wire_out, state)| { + izip!(&witness.witness_out_ref()[0].instances, [[0; 25], [u64::MAX; 25]]).for_each(|(wire_out, state)| { let output = wire_out[..256] .chunks_exact(64) .map(|bits| { bits.iter().fold(0, |acc, bit| { - (acc << 1) - + (*bit == ::BaseField::ONE) as u64 + (acc << 1) + (*bit == ::BaseField::ONE) as u64 }) }) .collect_vec(); @@ -82,12 +72,7 @@ fn main() { let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let subscriber = Registry::default() - .with( - fmt::layer() - .compact() - .with_thread_ids(false) - .with_thread_names(false), - ) + .with(fmt::layer().compact().with_thread_ids(false).with_thread_names(false)) .with(EnvFilter::from_default_env()) .with(flame_layer.with_threads_collapsed(true)); tracing::subscriber::set_global_default(subscriber).unwrap(); diff --git a/gkr/src/circuit.rs b/gkr/src/circuit.rs index 80d3defab..6d3d003e3 100644 --- a/gkr/src/circuit.rs +++ b/gkr/src/circuit.rs @@ -36,12 +36,7 @@ pub trait EvaluateGate1In where E: ExtensionField, { - fn eval( - &self, - out_eq_vec: &[E], - in_eq_vec: &[E], - challenges: &HashMap>, - ) -> E; + fn eval(&self, out_eq_vec: &[E], in_eq_vec: &[E], challenges: &HashMap>) -> E; fn fix_out_variables( &self, in_size: usize, @@ -54,16 +49,9 @@ impl EvaluateGate1In for &[Gate1In>] where E: ExtensionField, { - fn eval( - &self, - out_eq_vec: &[E], - in_eq_vec: &[E], - challenges: &HashMap>, - ) -> E { + fn eval(&self, out_eq_vec: &[E], in_eq_vec: &[E], challenges: &HashMap>) -> E { self.iter().fold(E::ZERO, |acc, gate| { - acc + out_eq_vec[gate.idx_out] - * in_eq_vec[gate.idx_in[0]] - * (&gate.scalar.eval(challenges)) + acc + out_eq_vec[gate.idx_out] * in_eq_vec[gate.idx_in[0]] * (&gate.scalar.eval(challenges)) }) } fn fix_out_variables( diff --git a/gkr/src/circuit/circuit_layout.rs b/gkr/src/circuit/circuit_layout.rs index 8e71bd4cb..3938bb51b 100644 --- a/gkr/src/circuit/circuit_layout.rs +++ b/gkr/src/circuit/circuit_layout.rs @@ -35,8 +35,7 @@ impl LayerSubsets { return old_wire_id; } if !self.subsets.contains_key(&(old_layer_id, old_wire_id)) { - self.subsets - .insert((old_layer_id, old_wire_id), self.wire_id_assigner); + self.subsets.insert((old_layer_id, old_wire_id), self.wire_id_assigner); self.wire_id_assigner += 1; } self.subsets[&(old_layer_id, old_wire_id)] @@ -47,10 +46,7 @@ impl LayerSubsets { fn update_layer_info(&self, layers: &mut Vec>) { let mut paste_from = BTreeMap::new(); for ((old_layer_id, old_wire_id), new_wire_id) in self.subsets.iter() { - paste_from - .entry(*old_layer_id) - .or_insert(vec![]) - .push(*new_wire_id); + paste_from.entry(*old_layer_id).or_insert(vec![]).push(*new_wire_id); layers[*old_layer_id as usize] .copy_to .entry(self.layer_id) @@ -60,9 +56,8 @@ impl LayerSubsets { layers[self.layer_id as usize].paste_from = paste_from; layers[self.layer_id as usize].num_vars = ceil_log2(self.wire_id_assigner) as usize; - layers[self.layer_id as usize].max_previous_num_vars = layers[self.layer_id as usize] - .max_previous_num_vars - .max(ceil_log2( + layers[self.layer_id as usize].max_previous_num_vars = + layers[self.layer_id as usize].max_previous_num_vars.max(ceil_log2( layers[self.layer_id as usize] .paste_from .iter() @@ -126,9 +121,7 @@ impl Circuit { // Each wire_in should be assigned with a consecutive // input layer segment. Then we can use a special // sumcheck protocol to prove it. - assert!( - i == 0 || wire_ids_in_layer[*cell_id] == wire_ids_in_layer[wire_in[i - 1]] + 1 - ); + assert!(i == 0 || wire_ids_in_layer[*cell_id] == wire_ids_in_layer[wire_in[i - 1]] + 1); }); let segment = ( wire_ids_in_layer[in_cell_ids[0]], @@ -138,17 +131,15 @@ impl Circuit { match ty { InType::Witness(wit_id) => { input_paste_from_wits_in[*wit_id as usize] = segment; - max_in_wit_num_vars = max_in_wit_num_vars - .map_or(Some(ceil_log2(in_cell_ids.len())), |x| { - Some(x.max(ceil_log2(in_cell_ids.len()))) - }); + max_in_wit_num_vars = max_in_wit_num_vars.map_or(Some(ceil_log2(in_cell_ids.len())), |x| { + Some(x.max(ceil_log2(in_cell_ids.len()))) + }); } InType::Counter(num_vars) => { input_paste_from_counter_in.push((*num_vars, segment)); - max_in_wit_num_vars = max_in_wit_num_vars - .map_or(Some(ceil_log2(in_cell_ids.len())), |x| { - Some(x.max(ceil_log2(in_cell_ids.len()))) - }); + max_in_wit_num_vars = max_in_wit_num_vars.map_or(Some(ceil_log2(in_cell_ids.len())), |x| { + Some(x.max(ceil_log2(in_cell_ids.len()))) + }); } InType::Constant(constant) => { input_paste_from_consts_in.push((*constant, segment)); @@ -164,9 +155,7 @@ impl Circuit { let new_layer_id = layer_id + 1; let mut subsets = LayerSubsets::new( new_layer_id, - layers_of_cell_id[new_layer_id as usize] - .len() - .next_power_of_two(), + layers_of_cell_id[new_layer_id as usize].len().next_power_of_two(), ); for (i, cell_id) in layers_of_cell_id[layer_id as usize].iter().enumerate() { @@ -236,8 +225,7 @@ impl Circuit { subsets.update_layer_info(&mut layers); // Initialize the next layer `max_previous_num_vars` equals that of the `self.layer_id`. - layers[layer_id as usize].max_previous_num_vars = - layers[new_layer_id as usize].num_vars; + layers[layer_id as usize].max_previous_num_vars = layers[new_layer_id as usize].num_vars; } // Compute the copy_to from the output layer to the wires_out. Notice @@ -282,10 +270,7 @@ impl Circuit { || circuit_builder.n_witness_out() == 1 && output_copy_to[0] != seg || !output_assert_const.is_empty() { - curr_sc_steps.extend([ - SumcheckStepType::OutputPhase1Step1, - SumcheckStepType::OutputPhase1Step2, - ]); + curr_sc_steps.extend([SumcheckStepType::OutputPhase1Step1, SumcheckStepType::OutputPhase1Step2]); } } else { let last_layer = &layers[(layer_id - 1) as usize]; @@ -296,8 +281,7 @@ impl Circuit { if layer.layer_id == n_layers - 1 { if input_paste_from_wits_in.len() > 1 - || input_paste_from_wits_in.len() == 1 - && input_paste_from_wits_in[0] != (0, 1 << layer.num_vars) + || input_paste_from_wits_in.len() == 1 && input_paste_from_wits_in[0] != (0, 1 << layer.num_vars) || !input_paste_from_counter_in.is_empty() || !input_paste_from_consts_in.is_empty() { @@ -336,10 +320,7 @@ impl Circuit { } } - pub(crate) fn generate_basefield_challenges( - &self, - challenges: &[E], - ) -> HashMap> { + pub(crate) fn generate_basefield_challenges(&self, challenges: &[E]) -> HashMap> { let mut challenge_exps = HashMap::::new(); let mut update_const = |constant| match constant { ConstantType::Challenge(c, _) => { @@ -355,19 +336,10 @@ impl Circuit { _ => {} }; self.layers.iter().for_each(|layer| { - layer - .add_consts - .iter() - .for_each(|gate| update_const(gate.scalar)); + layer.add_consts.iter().for_each(|gate| update_const(gate.scalar)); layer.adds.iter().for_each(|gate| update_const(gate.scalar)); - layer - .mul2s - .iter() - .for_each(|gate| update_const(gate.scalar)); - layer - .mul3s - .iter() - .for_each(|gate| update_const(gate.scalar)); + layer.mul2s.iter().for_each(|gate| update_const(gate.scalar)); + layer.mul3s.iter().for_each(|gate| update_const(gate.scalar)); }); challenge_exps .into_iter() @@ -421,11 +393,7 @@ impl Layer { 1 << self.max_previous_num_vars } - pub fn paste_from_fix_variables_eq( - &self, - old_layer_id: LayerId, - current_point_eq: &[E], - ) -> Vec { + pub fn paste_from_fix_variables_eq(&self, old_layer_id: LayerId, current_point_eq: &[E]) -> Vec { assert_eq!(current_point_eq.len(), self.size()); self.paste_from .get(&old_layer_id) @@ -434,12 +402,7 @@ impl Layer { .fix_row_col_first(current_point_eq, self.max_previous_num_vars) } - pub fn paste_from_eval_eq( - &self, - old_layer_id: LayerId, - current_point_eq: &[E], - subset_point_eq: &[E], - ) -> E { + pub fn paste_from_eval_eq(&self, old_layer_id: LayerId, current_point_eq: &[E], subset_point_eq: &[E]) -> E { assert_eq!(current_point_eq.len(), self.size()); assert_eq!(subset_point_eq.len(), self.max_previous_size()); self.paste_from @@ -456,12 +419,7 @@ impl Layer { .fix_row_row_first(subset_point_eq, self.num_vars) } - pub fn copy_to_eval_eq( - &self, - new_layer_id: LayerId, - subset_point_eq: &[E], - current_point_eq: &[E], - ) -> E { + pub fn copy_to_eval_eq(&self, new_layer_id: LayerId, subset_point_eq: &[E], current_point_eq: &[E]) -> E { self.copy_to .get(&new_layer_id) .unwrap() @@ -507,11 +465,7 @@ impl fmt::Debug for Circuit { } writeln!(f, " n_witness_in: {}", self.n_witness_in)?; writeln!(f, " paste_from_wits_in: {:?}", self.paste_from_wits_in)?; - writeln!( - f, - " paste_from_counter_in: {:?}", - self.paste_from_counter_in - )?; + writeln!(f, " paste_from_counter_in: {:?}", self.paste_from_counter_in)?; writeln!(f, " paste_from_consts_in: {:?}", self.paste_from_consts_in)?; writeln!(f, " copy_to_wits_out: {:?}", self.copy_to_wits_out)?; writeln!(f, " assert_const: {:?}", self.assert_consts)?; @@ -616,10 +570,7 @@ mod tests { let mut expected_paste_from_consts_in = vec![]; expected_paste_from_consts_in.push((1, (11, 13))); assert_eq!(circuit.paste_from_wits_in, expected_paste_from_wits_in); - assert_eq!( - circuit.paste_from_counter_in, - expected_paste_from_counter_in - ); + assert_eq!(circuit.paste_from_counter_in, expected_paste_from_counter_in); assert_eq!(circuit.paste_from_consts_in, expected_paste_from_consts_in); } @@ -689,46 +640,22 @@ mod tests { Gate { idx_in: [], idx_out: 0, - scalar: ConstantType::::Challenge( - ChallengeConst { - challenge: 0, - exp: 2, - }, - 0, - ), + scalar: ConstantType::::Challenge(ChallengeConst { challenge: 0, exp: 2 }, 0), }, Gate { idx_in: [], idx_out: 1, - scalar: ConstantType::Challenge( - ChallengeConst { - challenge: 0, - exp: 2, - }, - 1, - ), + scalar: ConstantType::Challenge(ChallengeConst { challenge: 0, exp: 2 }, 1), }, Gate { idx_in: [], idx_out: 2, - scalar: ConstantType::Challenge( - ChallengeConst { - challenge: 1, - exp: 2, - }, - 0, - ), + scalar: ConstantType::Challenge(ChallengeConst { challenge: 1, exp: 2 }, 0), }, Gate { idx_in: [], idx_out: 3, - scalar: ConstantType::Challenge( - ChallengeConst { - challenge: 1, - exp: 2, - }, - 1, - ), + scalar: ConstantType::Challenge(ChallengeConst { challenge: 1, exp: 2 }, 1), }, ]; @@ -736,90 +663,42 @@ mod tests { Gate { idx_in: [0], idx_out: 0, - scalar: ConstantType::::Challenge( - ChallengeConst { - challenge: 0, - exp: 0, - }, - 0, - ), + scalar: ConstantType::::Challenge(ChallengeConst { challenge: 0, exp: 0 }, 0), }, Gate { idx_in: [1], idx_out: 0, - scalar: ConstantType::Challenge( - ChallengeConst { - challenge: 0, - exp: 1, - }, - 0, - ), + scalar: ConstantType::Challenge(ChallengeConst { challenge: 0, exp: 1 }, 0), }, Gate { idx_in: [0], idx_out: 1, - scalar: ConstantType::Challenge( - ChallengeConst { - challenge: 0, - exp: 0, - }, - 1, - ), + scalar: ConstantType::Challenge(ChallengeConst { challenge: 0, exp: 0 }, 1), }, Gate { idx_in: [1], idx_out: 1, - scalar: ConstantType::Challenge( - ChallengeConst { - challenge: 0, - exp: 1, - }, - 1, - ), + scalar: ConstantType::Challenge(ChallengeConst { challenge: 0, exp: 1 }, 1), }, Gate { idx_in: [2], idx_out: 2, - scalar: ConstantType::Challenge( - ChallengeConst { - challenge: 1, - exp: 0, - }, - 0, - ), + scalar: ConstantType::Challenge(ChallengeConst { challenge: 1, exp: 0 }, 0), }, Gate { idx_in: [3], idx_out: 2, - scalar: ConstantType::Challenge( - ChallengeConst { - challenge: 1, - exp: 1, - }, - 0, - ), + scalar: ConstantType::Challenge(ChallengeConst { challenge: 1, exp: 1 }, 0), }, Gate { idx_in: [2], idx_out: 3, - scalar: ConstantType::Challenge( - ChallengeConst { - challenge: 1, - exp: 0, - }, - 1, - ), + scalar: ConstantType::Challenge(ChallengeConst { challenge: 1, exp: 0 }, 1), }, Gate { idx_in: [3], idx_out: 3, - scalar: ConstantType::Challenge( - ChallengeConst { - challenge: 1, - exp: 1, - }, - 1, - ), + scalar: ConstantType::Challenge(ChallengeConst { challenge: 1, exp: 1 }, 1), }, ]; @@ -899,18 +778,12 @@ mod tests { assert_eq!(circuit.layers.len(), 3); // Single input witness, therefore no input phase 2 steps. - assert_eq!( - circuit.layers[2].sumcheck_steps, - vec![SumcheckStepType::Phase1Step1] - ); + assert_eq!(circuit.layers[2].sumcheck_steps, vec![SumcheckStepType::Phase1Step1]); // There are only one incoming evals since the last layer is linear, and // no subset evals. Therefore, there are no phase1 steps. assert_eq!( circuit.layers[1].sumcheck_steps, - vec![ - SumcheckStepType::Phase2Step1, - SumcheckStepType::Phase2Step2NoStep3, - ] + vec![SumcheckStepType::Phase2Step1, SumcheckStepType::Phase2Step2NoStep3,] ); // Output layer, single output witness, therefore no output phase 1 steps. assert_eq!( @@ -930,17 +803,11 @@ mod tests { assert_eq!(circuit.layers.len(), 2); // Single input witness, therefore no input phase 2 steps. - assert_eq!( - circuit.layers[1].sumcheck_steps, - vec![SumcheckStepType::Phase1Step1] - ); + assert_eq!(circuit.layers[1].sumcheck_steps, vec![SumcheckStepType::Phase1Step1]); // Output layer, single output witness, therefore no output phase 1 steps. assert_eq!( circuit.layers[0].sumcheck_steps, - vec![ - SumcheckStepType::Phase2Step1, - SumcheckStepType::Phase2Step2NoStep3 - ] + vec![SumcheckStepType::Phase2Step1, SumcheckStepType::Phase2Step2NoStep3] ); } } diff --git a/gkr/src/circuit/circuit_witness.rs b/gkr/src/circuit/circuit_witness.rs index f7351e652..193cb8964 100644 --- a/gkr/src/circuit/circuit_witness.rs +++ b/gkr/src/circuit/circuit_witness.rs @@ -50,14 +50,12 @@ impl CircuitWitness { // The first layer. layer_wits[n_layers - 1] = { - let mut layer_wit = - vec![vec![F::ZERO; circuit.layers[n_layers - 1].size()]; n_instances]; + let mut layer_wit = vec![vec![F::ZERO; circuit.layers[n_layers - 1].size()]; n_instances]; for instance_id in 0..n_instances { assert_eq!(wits_in.len(), circuit.paste_from_wits_in.len()); for (wit_id, (l, r)) in circuit.paste_from_wits_in.iter().enumerate() { for i in *l..*r { - layer_wit[instance_id][i] = - wits_in[wit_id as usize].instances[instance_id][i - *l]; + layer_wit[instance_id][i] = wits_in[wit_id as usize].instances[instance_id][i - *l]; } } for (constant, (l, r)) in circuit.paste_from_consts_in.iter() { @@ -67,63 +65,53 @@ impl CircuitWitness { } for (num_vars, (l, r)) in circuit.paste_from_counter_in.iter() { for i in *l..*r { - layer_wit[instance_id][i] = - F::from(((instance_id << num_vars) ^ (i - *l)) as u64); + layer_wit[instance_id][i] = F::from(((instance_id << num_vars) ^ (i - *l)) as u64); } } } - LayerWitness { - instances: layer_wit, - } + LayerWitness { instances: layer_wit } }; for (layer_id, layer) in circuit.layers.iter().enumerate().rev().skip(1) { let size = circuit.layers[layer_id].size(); let mut current_layer_wits = vec![vec![F::ZERO; size]; n_instances]; - izip!((0..n_instances), current_layer_wits.iter_mut()).for_each( - |(instance_id, current_layer_wit)| { - layer - .paste_from + izip!((0..n_instances), current_layer_wits.iter_mut()).for_each(|(instance_id, current_layer_wit)| { + layer.paste_from.iter().for_each(|(old_layer_id, new_wire_ids)| { + new_wire_ids .iter() - .for_each(|(old_layer_id, new_wire_ids)| { - new_wire_ids.iter().enumerate().for_each( - |(subset_wire_id, new_wire_id)| { - let old_wire_id = circuit.layers[*old_layer_id as usize] - .copy_to - .get(&(layer_id as LayerId)) - .unwrap()[subset_wire_id]; - current_layer_wit[*new_wire_id] = layer_wits - [*old_layer_id as usize] - .instances[instance_id][old_wire_id]; - }, - ); + .enumerate() + .for_each(|(subset_wire_id, new_wire_id)| { + let old_wire_id = circuit.layers[*old_layer_id as usize] + .copy_to + .get(&(layer_id as LayerId)) + .unwrap()[subset_wire_id]; + current_layer_wit[*new_wire_id] = + layer_wits[*old_layer_id as usize].instances[instance_id][old_wire_id]; }); + }); - let last_layer_wit = &layer_wits[layer_id + 1].instances[instance_id]; - for add_const in layer.add_consts.iter() { - current_layer_wit[add_const.idx_out] += add_const.scalar.eval(&challenges); - } + let last_layer_wit = &layer_wits[layer_id + 1].instances[instance_id]; + for add_const in layer.add_consts.iter() { + current_layer_wit[add_const.idx_out] += add_const.scalar.eval(&challenges); + } - for add in layer.adds.iter() { - current_layer_wit[add.idx_out] += - last_layer_wit[add.idx_in[0]] * add.scalar.eval(&challenges); - } + for add in layer.adds.iter() { + current_layer_wit[add.idx_out] += last_layer_wit[add.idx_in[0]] * add.scalar.eval(&challenges); + } - for mul2 in layer.mul2s.iter() { - current_layer_wit[mul2.idx_out] += last_layer_wit[mul2.idx_in[0]] - * last_layer_wit[mul2.idx_in[1]] - * mul2.scalar.eval(&challenges); - } + for mul2 in layer.mul2s.iter() { + current_layer_wit[mul2.idx_out] += + last_layer_wit[mul2.idx_in[0]] * last_layer_wit[mul2.idx_in[1]] * mul2.scalar.eval(&challenges); + } - for mul3 in layer.mul3s.iter() { - current_layer_wit[mul3.idx_out] += last_layer_wit[mul3.idx_in[0]] - * last_layer_wit[mul3.idx_in[1]] - * last_layer_wit[mul3.idx_in[2]] - * mul3.scalar.eval(&challenges); - } - }, - ); + for mul3 in layer.mul3s.iter() { + current_layer_wit[mul3.idx_out] += last_layer_wit[mul3.idx_in[0]] + * last_layer_wit[mul3.idx_in[1]] + * last_layer_wit[mul3.idx_in[2]] + * mul3.scalar.eval(&challenges); + } + }); layer_wits[layer_id] = LayerWitness { instances: current_layer_wits, @@ -173,35 +161,23 @@ impl CircuitWitness { self.add_instances(circuit, wits_in, 1); } - pub fn add_instances( - &mut self, - circuit: &Circuit, - new_wits_in: Vec>, - n_instances: usize, - ) where + pub fn add_instances(&mut self, circuit: &Circuit, new_wits_in: Vec>, n_instances: usize) + where E: ExtensionField, { assert_eq!(new_wits_in.len(), circuit.n_witness_in); assert!(n_instances.is_power_of_two()); - assert!(!new_wits_in - .iter() - .any(|wit_in| wit_in.instances.len() != n_instances)); + assert!(!new_wits_in.iter().any(|wit_in| wit_in.instances.len() != n_instances)); let (inferred_layer_wits, inferred_wits_out) = CircuitWitness::new_instances(circuit, &new_wits_in, &self.challenges, n_instances); // Merge self and circuit_witness. - for (layer_wit, inferred_layer_wit) in - self.layers.iter_mut().zip(inferred_layer_wits.into_iter()) - { + for (layer_wit, inferred_layer_wit) in self.layers.iter_mut().zip(inferred_layer_wits.into_iter()) { layer_wit.instances.extend(inferred_layer_wit.instances); } - for (wit_out, inferred_wits_out) in self - .witness_out - .iter_mut() - .zip(inferred_wits_out.into_iter()) - { + for (wit_out, inferred_wits_out) in self.witness_out.iter_mut().zip(inferred_wits_out.into_iter()) { wit_out.instances.extend(inferred_wits_out.instances); } @@ -274,13 +250,8 @@ impl CircuitWitness { } } - for (layer_id, (layer_witnesses, layer)) in self - .layers - .iter() - .zip(circuit.layers.iter()) - .enumerate() - .rev() - .skip(1) + for (layer_id, (layer_witnesses, layer)) in + self.layers.iter().zip(circuit.layers.iter()).enumerate().rev().skip(1) { let prev_layer_wits = &self.layers[layer_id + 1]; for (copy_id, (prev, curr)) in prev_layer_wits @@ -294,13 +265,11 @@ impl CircuitWitness { expected[add_const.idx_out] += add_const.scalar.eval(&self.challenges); } for add in layer.adds.iter() { - expected[add.idx_out] += - prev[add.idx_in[0]] * add.scalar.eval(&self.challenges); + expected[add.idx_out] += prev[add.idx_in[0]] * add.scalar.eval(&self.challenges); } for mul2 in layer.mul2s.iter() { - expected[mul2.idx_out] += prev[mul2.idx_in[0]] - * prev[mul2.idx_in[1]] - * mul2.scalar.eval(&self.challenges); + expected[mul2.idx_out] += + prev[mul2.idx_in[0]] * prev[mul2.idx_in[1]] * mul2.scalar.eval(&self.challenges); } for mul3 in layer.mul3s.iter() { expected[mul3.idx_out] += prev[mul3.idx_in[0]] @@ -317,8 +286,7 @@ impl CircuitWitness { .copy_to .get(&(layer_id as LayerId)) .unwrap()[subset_wire_id]; - expected[*new_wire_id] = - self.layers[*old_layer_id as usize].instances[copy_id][old_wire_id]; + expected[*new_wire_id] = self.layers[*old_layer_id as usize].instances[copy_id][old_wire_id]; } } assert_eq!( @@ -375,10 +343,7 @@ impl CircuitWitness { for gate in circuit.assert_consts.iter() { if let ConstantType::Field(constant) = gate.scalar { for copy_id in 0..self.n_instances { - assert_eq!( - output_layer_witness.instances[copy_id][gate.idx_out], - constant - ); + assert_eq!(output_layer_witness.instances[copy_id][gate.idx_out], constant); } } } @@ -418,14 +383,11 @@ impl CircuitWitness { single_num_vars: usize, multi_threads_meta: (usize, usize), ) -> ArcDenseMultilinearExtension { - self.layers[layer_id as usize] - .instances - .as_slice() - .mle_with_meta( - single_num_vars, - self.instance_num_vars(), - multi_threads_meta, - ) + self.layers[layer_id as usize].instances.as_slice().mle_with_meta( + single_num_vars, + self.instance_num_vars(), + multi_threads_meta, + ) } } @@ -480,13 +442,7 @@ mod test { // Layer 0 let (_, mul_001123) = circuit_builder.create_witness_out(1); - circuit_builder.mul3( - mul_001123[0], - mul_01, - mul_012, - input[3], - Ext::BaseField::ONE, - ); + circuit_builder.mul3(mul_001123[0], mul_01, mul_012, input[3], Ext::BaseField::ONE); circuit_builder.configure(); let circuit = Circuit::new(&circuit_builder); @@ -494,10 +450,8 @@ mod test { circuit } - fn copy_and_paste_witness() -> ( - Vec>, - CircuitWitness, - ) { + fn copy_and_paste_witness( + ) -> (Vec>, CircuitWitness) { // witness_in, single instance let inputs = vec![vec![ i64_to_field(5), @@ -572,10 +526,8 @@ mod test { circuit } - fn paste_from_wit_in_witness() -> ( - Vec>, - CircuitWitness, - ) { + fn paste_from_wit_in_witness( + ) -> (Vec>, CircuitWitness) { // witness_in, single instance let leaves1 = vec![vec![i64_to_field(5), i64_to_field(7), i64_to_field(11)]]; let leaves2 = vec![vec![i64_to_field(13), i64_to_field(17), i64_to_field(19)]]; @@ -623,12 +575,8 @@ mod test { let outputs1 = vec![vec![i64_to_field(35), i64_to_field(143)]]; let outputs2 = vec![vec![i64_to_field(5005)]]; let witness_out = vec![ - LayerWitness { - instances: outputs1, - }, - LayerWitness { - instances: outputs2, - }, + LayerWitness { instances: outputs1 }, + LayerWitness { instances: outputs2 }, ]; ( @@ -664,10 +612,8 @@ mod test { circuit } - fn copy_to_wit_out_witness() -> ( - Vec>, - CircuitWitness, - ) { + fn copy_to_wit_out_witness( + ) -> (Vec>, CircuitWitness) { // witness_in, single instance let leaves = vec![vec![ i64_to_field(5), @@ -714,24 +660,12 @@ mod test { ) } - fn copy_to_wit_out_witness_2() -> ( - Vec>, - CircuitWitness, - ) { + fn copy_to_wit_out_witness_2( + ) -> (Vec>, CircuitWitness) { // witness_in, 2 instances let leaves = vec![ - vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ], - vec![ - i64_to_field(5), - i64_to_field(13), - i64_to_field(11), - i64_to_field(7), - ], + vec![i64_to_field(5), i64_to_field(7), i64_to_field(11), i64_to_field(13)], + vec![i64_to_field(5), i64_to_field(13), i64_to_field(11), i64_to_field(7)], ]; let witness_in = vec![LayerWitness { instances: leaves }]; @@ -760,18 +694,8 @@ mod test { }, LayerWitness { instances: vec![ - vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ], - vec![ - i64_to_field(5), - i64_to_field(13), - i64_to_field(11), - i64_to_field(7), - ], + vec![i64_to_field(5), i64_to_field(7), i64_to_field(11), i64_to_field(13)], + vec![i64_to_field(5), i64_to_field(13), i64_to_field(11), i64_to_field(7)], ], }, ]; @@ -846,18 +770,8 @@ mod test { // witness_in, double instances let leaves = vec![ - vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ], - vec![ - i64_to_field(5), - i64_to_field(13), - i64_to_field(11), - i64_to_field(7), - ], + vec![i64_to_field(5), i64_to_field(7), i64_to_field(11), i64_to_field(13)], + vec![i64_to_field(5), i64_to_field(13), i64_to_field(11), i64_to_field(7)], ]; let witness_in = vec![LayerWitness { instances: leaves.clone(), @@ -877,16 +791,8 @@ mod test { + challenge_pows[1][2].1; let inners = vec![ - [ - inner00.clone().as_bases().to_vec(), - inner01.clone().as_bases().to_vec(), - ] - .concat(), - [ - inner10.clone().as_bases().to_vec(), - inner11.clone().as_bases().to_vec(), - ] - .concat(), + [inner00.clone().as_bases().to_vec(), inner01.clone().as_bases().to_vec()].concat(), + [inner10.clone().as_bases().to_vec(), inner11.clone().as_bases().to_vec()].concat(), ]; let root_tmp0 = vec![ @@ -911,9 +817,7 @@ mod test { LayerWitness { instances: roots.clone(), }, - LayerWitness { - instances: root_tmps, - }, + LayerWitness { instances: root_tmps }, LayerWitness { instances: inners }, LayerWitness { instances: leaves }, ]; @@ -1008,21 +912,13 @@ mod test { let (_, out) = circuit_builder.create_witness_out(2); // like a bypass gate, passing 6 to output out[0] - circuit_builder.add( - out[0], - mul_0_1_res, - ::BaseField::ONE, - ); + circuit_builder.add(out[0], mul_0_1_res, ::BaseField::ONE); // assert const 2 circuit_builder.assert_const(leaves[2], 5); // 5 + -5 = 0, put in out[1] - circuit_builder.add( - out[1], - leaves[2], - ::BaseField::ONE, - ); + circuit_builder.add(out[1], leaves[2], ::BaseField::ONE); circuit_builder.add_const( out[1], ::BaseField::from(5).neg(), // -5 diff --git a/gkr/src/gadgets/keccak256.rs b/gkr/src/gadgets/keccak256.rs index 4d02658fc..64fa0152e 100644 --- a/gkr/src/gadgets/keccak256.rs +++ b/gkr/src/gadgets/keccak256.rs @@ -145,26 +145,18 @@ fn not_lhs_and_rhs(cb: &mut CircuitBuilder, lhs: &Word, rh // (x0 + x1 + x2) - 2x0x2 - 2x1x2 - 2x0x1 + 4x0x1x2 fn xor3<'a, E: ExtensionField>(cb: &mut CircuitBuilder, words: &[Word; 3]) -> Word { let out = Word::new(cb); - izip!(&out.0, &words[0].0, &words[1].0, &words[2].0).for_each( - |(out, wire_0, wire_1, wire_2)| { - // (x0 + x1 + x2) - cb.add(*out, *wire_0, E::BaseField::ONE); - cb.add(*out, *wire_1, E::BaseField::ONE); - cb.add(*out, *wire_2, E::BaseField::ONE); - // - 2x0x2 - 2x1x2 - 2x0x1 - cb.mul2(*out, *wire_0, *wire_1, -E::BaseField::ONE.double()); - cb.mul2(*out, *wire_0, *wire_2, -E::BaseField::ONE.double()); - cb.mul2(*out, *wire_1, *wire_2, -E::BaseField::ONE.double()); - // 4x0x1x2 - cb.mul3( - *out, - *wire_0, - *wire_1, - *wire_2, - E::BaseField::ONE.double().double(), - ); - }, - ); + izip!(&out.0, &words[0].0, &words[1].0, &words[2].0).for_each(|(out, wire_0, wire_1, wire_2)| { + // (x0 + x1 + x2) + cb.add(*out, *wire_0, E::BaseField::ONE); + cb.add(*out, *wire_1, E::BaseField::ONE); + cb.add(*out, *wire_2, E::BaseField::ONE); + // - 2x0x2 - 2x1x2 - 2x0x1 + cb.mul2(*out, *wire_0, *wire_1, -E::BaseField::ONE.double()); + cb.mul2(*out, *wire_0, *wire_2, -E::BaseField::ONE.double()); + cb.mul2(*out, *wire_1, *wire_2, -E::BaseField::ONE.double()); + // 4x0x1x2 + cb.mul3(*out, *wire_0, *wire_1, *wire_2, E::BaseField::ONE.double().double()); + }); out } @@ -184,18 +176,16 @@ fn xor3<'a, E: ExtensionField>(cb: &mut CircuitBuilder, words: &[Word; 3]) -> // = (x0 + x2) - 2x0x2 - x1x2 + 2x0x1x2 fn chi<'a, E: ExtensionField>(cb: &mut CircuitBuilder, words: &[Word; 3]) -> Word { let out = Word::new(cb); - izip!(&out.0, &words[0].0, &words[1].0, &words[2].0).for_each( - |(out, wire_0, wire_1, wire_2)| { - // (x0 + x2) - cb.add(*out, *wire_0, E::BaseField::ONE); - cb.add(*out, *wire_2, E::BaseField::ONE); - // - 2x0x2 - x1x2 - cb.mul2(*out, *wire_0, *wire_2, -E::BaseField::ONE.double()); - cb.mul2(*out, *wire_1, *wire_2, -E::BaseField::ONE); - // 2x0x1x2 - cb.mul3(*out, *wire_0, *wire_1, *wire_2, E::BaseField::ONE.double()); - }, - ); + izip!(&out.0, &words[0].0, &words[1].0, &words[2].0).for_each(|(out, wire_0, wire_1, wire_2)| { + // (x0 + x2) + cb.add(*out, *wire_0, E::BaseField::ONE); + cb.add(*out, *wire_2, E::BaseField::ONE); + // - 2x0x2 - x1x2 + cb.mul2(*out, *wire_0, *wire_2, -E::BaseField::ONE.double()); + cb.mul2(*out, *wire_1, *wire_2, -E::BaseField::ONE); + // 2x0x1x2 + cb.mul3(*out, *wire_0, *wire_1, *wire_2, E::BaseField::ONE.double()); + }); out } @@ -204,20 +194,14 @@ fn chi<'a, E: ExtensionField>(cb: &mut CircuitBuilder, words: &[Word; 3]) -> // = c + (x0 + x2) - 2x0x2 - x1x2 + 2x0x1x2 - 2(c*x0 + c*x2 - 2c*x0*x2 - c*x1*x2 + 2*c*x0*x1*x2) // = x0 + x2 + c - 2*x0*x2 - x1*x2 + 2*x0*x1*x2 - 2*c*x0 - 2*c*x2 + 4*c*x0*x2 + 2*c*x1*x2 - 4*c*x0*x1*x2 // = x0*(1-2c) + x2*(1-2c) + c + x0*x2*(-2 + 4c) + x1*x2(-1 + 2c) + x0*x1*x2(2 - 4c) -fn chi_and_xor_constant<'a, E: ExtensionField>( - cb: &mut CircuitBuilder, - words: &[Word; 3], - constant: u64, -) -> Word { +fn chi_and_xor_constant<'a, E: ExtensionField>(cb: &mut CircuitBuilder, words: &[Word; 3], constant: u64) -> Word { let out = Word::new(cb); izip!( &out.0, &words[0].0, &words[1].0, &words[2].0, - iter::successors(Some(constant.reverse_bits()), |constant| { - Some(constant >> 1) - }) + iter::successors(Some(constant.reverse_bits()), |constant| { Some(constant >> 1) }) ) .for_each(|(out, wire_0, wire_1, wire_2, constant)| { let const_bit = constant & 1; @@ -274,20 +258,14 @@ fn chi_and_xor_constant<'a, E: ExtensionField>( } #[allow(dead_code)] -fn xor2_constant( - cb: &mut CircuitBuilder, - words: &[Word; 2], - constant: u64, -) -> Word { +fn xor2_constant(cb: &mut CircuitBuilder, words: &[Word; 2], constant: u64) -> Word { let out = Word::new(cb); izip!( &out.0, &words[0].0, &words[1].0, - iter::successors(Some(constant.reverse_bits()), |constant| { - Some(constant >> 1) - }) + iter::successors(Some(constant.reverse_bits()), |constant| { Some(constant >> 1) }) ) .for_each(|(out, wire_0, wire_1, constant)| { let const_bit = constant & 1; @@ -334,10 +312,7 @@ pub fn keccak256_circuit() -> Circuit { // Absorption state = izip!( state.iter(), - input - .into_iter() - .map(|input| Some(input)) - .chain(iter::repeat(None)) + input.into_iter().map(|input| Some(input)).chain(iter::repeat(None)) ) .map(|(state, input)| { if let Some(input) = input { @@ -442,8 +417,7 @@ pub fn keccak256_circuit() -> Circuit { // cb.create_wire_out_from_cells(&state.iter().flat_map(|word| word.0).collect_vec()); let (_, out) = cb.create_witness_out(256); - izip!(&out, state.iter().flat_map(|word| &word.0)) - .for_each(|(out, state)| cb.add(*out, *state, E::BaseField::ONE)); + izip!(&out, state.iter().flat_map(|word| &word.0)).for_each(|(out, state)| cb.add(*out, *state, E::BaseField::ONE)); cb.configure(); Circuit::new(cb) @@ -463,29 +437,18 @@ pub fn prove_keccak256( // Sanity-check #[cfg(test)] { - let all_zero = vec![ - vec![E::BaseField::ZERO; 25 * 64], - vec![E::BaseField::ZERO; 17 * 64], - ]; - let all_one = vec![ - vec![E::BaseField::ONE; 25 * 64], - vec![E::BaseField::ZERO; 17 * 64], - ]; + let all_zero = vec![vec![E::BaseField::ZERO; 25 * 64], vec![E::BaseField::ZERO; 17 * 64]]; + let all_one = vec![vec![E::BaseField::ONE; 25 * 64], vec![E::BaseField::ZERO; 17 * 64]]; let mut witness = CircuitWitness::new(&circuit, Vec::new()); witness.add_instance(&circuit, all_zero); witness.add_instance(&circuit, all_one); - izip!( - &witness.witness_out_ref()[0].instances, - [[0; 25], [u64::MAX; 25]] - ) - .for_each(|(wire_out, state)| { + izip!(&witness.witness_out_ref()[0].instances, [[0; 25], [u64::MAX; 25]]).for_each(|(wire_out, state)| { let output = wire_out[..256] .chunks_exact(64) .map(|bits| { - bits.iter().fold(0, |acc, bit| { - (acc << 1) + (*bit == E::BaseField::ONE) as u64 - }) + bits.iter() + .fold(0, |acc, bit| (acc << 1) + (*bit == E::BaseField::ONE) as u64) }) .collect_vec(); let expected = { @@ -519,13 +482,9 @@ pub fn prove_keccak256( .mle(lo_num_vars, instance_num_vars); let mut prover_transcript = Transcript::::new(b"test"); - let output_point = iter::repeat_with(|| { - prover_transcript - .get_and_append_challenge(b"output point") - .elements - }) - .take(output_mle.num_vars) - .collect_vec(); + let output_point = iter::repeat_with(|| prover_transcript.get_and_append_challenge(b"output point").elements) + .take(output_mle.num_vars) + .collect_vec(); let output_eval = output_mle.evaluate(&output_point); let start = std::time::Instant::now(); @@ -548,13 +507,9 @@ pub fn verify_keccak256( circuit: &Circuit, ) -> Result, GKRError> { let mut verifer_transcript = Transcript::::new(b"test"); - let output_point = iter::repeat_with(|| { - verifer_transcript - .get_and_append_challenge(b"output point") - .elements - }) - .take(output_mle.num_vars) - .collect_vec(); + let output_point = iter::repeat_with(|| verifer_transcript.get_and_append_challenge(b"output point").elements) + .take(output_mle.num_vars) + .collect_vec(); let output_eval = output_mle.evaluate(&output_point); crate::structs::IOPVerifierState::verify_parallel( &circuit, diff --git a/gkr/src/prover.rs b/gkr/src/prover.rs index ee1e6890b..a718836b3 100644 --- a/gkr/src/prover.rs +++ b/gkr/src/prover.rs @@ -8,18 +8,15 @@ use multilinear_extensions::{ virtual_poly::{build_eq_x_r_vec, VirtualPolynomial}, }; use rayon::iter::{ - IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, - IntoParallelRefMutIterator, ParallelIterator, + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, + ParallelIterator, }; use simple_frontend::structs::LayerId; use transcript::Transcript; use crate::{ entered_span, exit_span, - structs::{ - Circuit, CircuitWitness, GKRInputClaims, IOPProof, IOPProverState, PointAndEval, - SumcheckStepType, - }, + structs::{Circuit, CircuitWitness, GKRInputClaims, IOPProof, IOPProverState, PointAndEval, SumcheckStepType}, tracing_span, }; @@ -51,13 +48,7 @@ impl IOPProverState { assert!(circuit_witness.n_instances == 1 << circuit_witness.instance_num_vars()); let mut prover_state = tracing_span!("prover_init_parallel").in_scope(|| { - Self::prover_init_parallel( - circuit, - circuit_witness, - output_evals, - wires_out_evals, - transcript, - ) + Self::prover_init_parallel(circuit, circuit_witness, output_evals, wires_out_evals, transcript) }); let sumcheck_proofs = (0..circuit.layers.len() as LayerId) @@ -69,62 +60,69 @@ impl IOPProverState { let dummy_step = SumcheckStepType::Undefined; let proofs = circuit.layers[layer_id as usize] .sumcheck_steps - .iter().chain(vec![&dummy_step, &dummy_step]) + .iter() + .chain(vec![&dummy_step, &dummy_step]) .tuple_windows() .flat_map(|steps| match steps { - (SumcheckStepType::OutputPhase1Step1, SumcheckStepType::OutputPhase1Step2, _) => { - [prover_state - .prove_and_update_state_output_phase1_step1( - circuit, - circuit_witness, - transcript, - ), - prover_state - .prove_and_update_state_output_phase1_step2( - circuit, - circuit_witness, - transcript, - )].to_vec() - }, + (SumcheckStepType::OutputPhase1Step1, SumcheckStepType::OutputPhase1Step2, _) => [ + prover_state.prove_and_update_state_output_phase1_step1( + circuit, + circuit_witness, + transcript, + ), + prover_state.prove_and_update_state_output_phase1_step2( + circuit, + circuit_witness, + transcript, + ), + ] + .to_vec(), (SumcheckStepType::Phase1Step1, _, _) => { - let alpha = transcript - .get_and_append_challenge(b"combine subset evals") - .elements; + let alpha = transcript.get_and_append_challenge(b"combine subset evals").elements; let hi_num_vars = circuit_witness.instance_num_vars(); - let eq_t = prover_state.to_next_phase_point_and_evals.par_iter().chain(prover_state.subset_point_and_evals[layer_id as usize].par_iter().map(|(_, point_and_eval)| point_and_eval)).map(|point_and_eval|{ - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - build_eq_x_r_vec(&point_and_eval.point[point_lo_num_vars..]) - }).collect::>>(); - let virtual_polys: Vec> = (0..max_thread_id).into_par_iter().map(|thread_id| { - let span = entered_span!("build_poly"); - let virtual_poly = IOPProverState::build_phase1_step1_sumcheck_poly( - &prover_state, - layer_id, - alpha, - &eq_t, - circuit, - circuit_witness, - (thread_id, max_thread_id), - ); - exit_span!(span); - virtual_poly - }).collect(); + let eq_t = prover_state + .to_next_phase_point_and_evals + .par_iter() + .chain( + prover_state.subset_point_and_evals[layer_id as usize] + .par_iter() + .map(|(_, point_and_eval)| point_and_eval), + ) + .map(|point_and_eval| { + let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + build_eq_x_r_vec(&point_and_eval.point[point_lo_num_vars..]) + }) + .collect::>>(); + let virtual_polys: Vec> = (0..max_thread_id) + .into_par_iter() + .map(|thread_id| { + let span = entered_span!("build_poly"); + let virtual_poly = IOPProverState::build_phase1_step1_sumcheck_poly( + &prover_state, + layer_id, + alpha, + &eq_t, + circuit, + circuit_witness, + (thread_id, max_thread_id), + ); + exit_span!(span); + virtual_poly + }) + .collect(); - let (sumcheck_proof, sumcheck_prover_state) = sumcheck::structs::IOPProverState::::prove_batch_polys( - max_thread_id, - virtual_polys.try_into().unwrap(), - transcript, - ); + let (sumcheck_proof, sumcheck_prover_state) = + sumcheck::structs::IOPProverState::::prove_batch_polys( + max_thread_id, + virtual_polys.try_into().unwrap(), + transcript, + ); - let prover_msg = prover_state.combine_phase1_step1_evals( - sumcheck_proof, - sumcheck_prover_state, - ); + let prover_msg = + prover_state.combine_phase1_step1_evals(sumcheck_proof, sumcheck_prover_state); vec![prover_msg] - - } - , + } (SumcheckStepType::Phase2Step1, step2, _) => { let span = entered_span!("phase2_gkr"); let max_steps = match step2 { @@ -134,111 +132,104 @@ impl IOPProverState { }; let mut eqs = vec![]; - let mut layer_polys = (0..max_thread_id).map(|_| ArcDenseMultilinearExtension::default()).collect::>>(); + let mut layer_polys = (0..max_thread_id) + .map(|_| ArcDenseMultilinearExtension::default()) + .collect::>>(); let mut res = vec![]; for step in 0..max_steps { let bounded_eval_point = prover_state.to_next_step_point.clone(); eqs.push(build_eq_x_r_vec(&bounded_eval_point)); // build step round poly - let virtual_polys: Vec> = (0..max_thread_id).into_par_iter().zip(layer_polys.par_iter_mut()).map(|(thread_id, layer_poly)| { - let span = entered_span!("build_poly"); - let (next_layer_poly_step1, virtual_poly) = match step { - 0 => { - let (next_layer_poly, virtual_poly) = IOPProverState::build_phase2_step1_sumcheck_poly( - eqs.as_slice().try_into().unwrap(), - layer_id, - circuit, - circuit_witness, - (thread_id, max_thread_id), - ); - (Some(next_layer_poly), virtual_poly) - }, - 1 => { - let virtual_poly = IOPProverState::build_phase2_step2_sumcheck_poly( - &layer_poly, - layer_id, - eqs.as_slice().try_into().unwrap(), - circuit, - circuit_witness, - (thread_id, max_thread_id), - ); - (None, virtual_poly) - }, - 2 => { - let virtual_poly = IOPProverState::build_phase2_step3_sumcheck_poly( - &layer_poly, - layer_id, - eqs.as_slice().try_into().unwrap(), - circuit, - circuit_witness, - (thread_id, max_thread_id), - ); - (None, virtual_poly) - - }, - _ => unimplemented!(), - }; - if let Some(next_layer_poly_step1) = next_layer_poly_step1 { - let _ = mem::replace(layer_poly, next_layer_poly_step1); - } - exit_span!(span); - virtual_poly - }).collect(); + let virtual_polys: Vec> = (0..max_thread_id) + .into_par_iter() + .zip(layer_polys.par_iter_mut()) + .map(|(thread_id, layer_poly)| { + let span = entered_span!("build_poly"); + let (next_layer_poly_step1, virtual_poly) = match step { + 0 => { + let (next_layer_poly, virtual_poly) = + IOPProverState::build_phase2_step1_sumcheck_poly( + eqs.as_slice().try_into().unwrap(), + layer_id, + circuit, + circuit_witness, + (thread_id, max_thread_id), + ); + (Some(next_layer_poly), virtual_poly) + } + 1 => { + let virtual_poly = IOPProverState::build_phase2_step2_sumcheck_poly( + &layer_poly, + layer_id, + eqs.as_slice().try_into().unwrap(), + circuit, + circuit_witness, + (thread_id, max_thread_id), + ); + (None, virtual_poly) + } + 2 => { + let virtual_poly = IOPProverState::build_phase2_step3_sumcheck_poly( + &layer_poly, + layer_id, + eqs.as_slice().try_into().unwrap(), + circuit, + circuit_witness, + (thread_id, max_thread_id), + ); + (None, virtual_poly) + } + _ => unimplemented!(), + }; + if let Some(next_layer_poly_step1) = next_layer_poly_step1 { + let _ = mem::replace(layer_poly, next_layer_poly_step1); + } + exit_span!(span); + virtual_poly + }) + .collect(); - let (sumcheck_proof, sumcheck_prover_state) = sumcheck::structs::IOPProverState::::prove_batch_polys( - max_thread_id, - virtual_polys.try_into().unwrap(), - transcript, - ); + let (sumcheck_proof, sumcheck_prover_state) = + sumcheck::structs::IOPProverState::::prove_batch_polys( + max_thread_id, + virtual_polys.try_into().unwrap(), + transcript, + ); - let iop_prover_step = - match step { - 0 => { - prover_state.combine_phase2_step1_evals( - circuit, - sumcheck_proof, - sumcheck_prover_state, - ) - }, - 1 => { - let no_step3: bool = max_steps == 2; - prover_state.combine_phase2_step2_evals( - circuit, - sumcheck_proof, - sumcheck_prover_state, - no_step3, - ) - }, - 2 => { - prover_state.combine_phase2_step3_evals( - circuit, - sumcheck_proof, - sumcheck_prover_state, - ) - }, - _ => unimplemented!() - }; + let iop_prover_step = match step { + 0 => prover_state.combine_phase2_step1_evals( + circuit, + sumcheck_proof, + sumcheck_prover_state, + ), + 1 => { + let no_step3: bool = max_steps == 2; + prover_state.combine_phase2_step2_evals( + circuit, + sumcheck_proof, + sumcheck_prover_state, + no_step3, + ) + } + 2 => prover_state.combine_phase2_step3_evals( + circuit, + sumcheck_proof, + sumcheck_prover_state, + ), + _ => unimplemented!(), + }; res.push(iop_prover_step); } exit_span!(span); res - }, - (SumcheckStepType::LinearPhase2Step1, _, _) => - [prover_state - .prove_and_update_state_linear_phase2_step1( - circuit, - circuit_witness, - transcript, - )].to_vec(), - (SumcheckStepType::InputPhase2Step1, _, _) => - [prover_state - .prove_and_update_state_input_phase2_step1( - circuit, - circuit_witness, - transcript, - ) - ].to_vec(), + } + (SumcheckStepType::LinearPhase2Step1, _, _) => [prover_state + .prove_and_update_state_linear_phase2_step1(circuit, circuit_witness, transcript)] + .to_vec(), + (SumcheckStepType::InputPhase2Step1, _, _) => [prover_state + .prove_and_update_state_input_phase2_step1(circuit, circuit_witness, transcript)] + .to_vec(), _ => { vec![] } @@ -278,17 +269,10 @@ impl IOPProverState { wires_out_evals.last().unwrap().point.clone() }; let assert_point = (0..output_wit_num_vars) - .map(|_| { - transcript - .get_and_append_challenge(b"assert_point") - .elements - }) + .map(|_| transcript.get_and_append_challenge(b"assert_point").elements) .collect_vec(); let to_next_phase_point_and_evals = output_evals; - subset_point_and_evals[0] = wires_out_evals - .into_iter() - .map(|p| (0 as LayerId, p)) - .collect(); + subset_point_and_evals[0] = wires_out_evals.into_iter().map(|p| (0 as LayerId, p)).collect(); Self { to_next_phase_point_and_evals, diff --git a/gkr/src/prover/phase1.rs b/gkr/src/prover/phase1.rs index 04dd3170f..81bef3b0b 100644 --- a/gkr/src/prover/phase1.rs +++ b/gkr/src/prover/phase1.rs @@ -12,9 +12,7 @@ use sumcheck::{entered_span, util::ceil_log2}; use crate::{ exit_span, - structs::{ - Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval, SumcheckProof, - }, + structs::{Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval, SumcheckProof}, utils::{tensor_product, MatrixMLERowFirst}, }; @@ -37,9 +35,8 @@ impl IOPProverState { let span = entered_span!("preparation"); let timer = start_timer!(|| "Prover sumcheck phase 1 step 1"); - let total_length = self.to_next_phase_point_and_evals.len() - + self.subset_point_and_evals[self.layer_id as usize].len() - + 1; + let total_length = + self.to_next_phase_point_and_evals.len() + self.subset_point_and_evals[self.layer_id as usize].len() + 1; let alpha_pows = { let mut alpha_pows = vec![E::ONE; total_length]; for i in 0..total_length.saturating_sub(1) { @@ -59,11 +56,7 @@ impl IOPProverState { // f1^{(j)}(y) = layers[i](t || y) let f1: Arc> = circuit_witness - .layer_poly::( - (layer_id).try_into().unwrap(), - lo_num_vars, - multi_threads_meta, - ) + .layer_poly::((layer_id).try_into().unwrap(), lo_num_vars, multi_threads_meta) .into(); assert_eq!( @@ -81,22 +74,17 @@ impl IOPProverState { // g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - let eq_y = - build_eq_x_r_vec_sequential(&point_and_eval.point[..point_lo_num_vars]) - .into_iter() - .take(1 << lo_num_vars) - .map(|eq| *alpha_pow * eq) - .collect_vec(); + let eq_y = build_eq_x_r_vec_sequential(&point_and_eval.point[..point_lo_num_vars]) + .into_iter() + .take(1 << lo_num_vars) + .map(|eq| *alpha_pow * eq) + .collect_vec(); let eq_t_unit_len = eq_t.len() / max_thread_id; let start_index = thread_id * eq_t_unit_len; - let g1_j = - tensor_product(&eq_t[start_index..(start_index + eq_t_unit_len)], &eq_y); + let g1_j = tensor_product(&eq_t[start_index..(start_index + eq_t_unit_len)], &eq_y); - assert_eq!( - g1_j.len(), - (1 << (hi_num_vars + lo_num_vars - log2_max_thread_id)) - ); + assert_eq!(g1_j.len(), (1 << (hi_num_vars + lo_num_vars - log2_max_thread_id))); g1_j }) @@ -106,44 +94,34 @@ impl IOPProverState { &alpha_pows[self.to_next_phase_point_and_evals.len()..], eq_t.iter().skip(self.to_next_phase_point_and_evals.len()) ) - .map( - |((new_layer_id, point_and_eval), alpha_pow, eq_t)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - let copy_to = ©_to_matrices[new_layer_id]; - let lo_eq_w_p = build_eq_x_r_vec_sequential( - &point_and_eval.point[..point_lo_num_vars], - ); - - // g2^{(j)}(y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) - let eq_t_unit_len = eq_t.len() / max_thread_id; - let start_index = thread_id * eq_t_unit_len; - let g2_j = tensor_product( - &eq_t[start_index..(start_index + eq_t_unit_len)], - ©_to.as_slice().fix_row_row_first_with_scalar( - &lo_eq_w_p, - lo_num_vars, - alpha_pow, - ), - ); - - assert_eq!( - g2_j.len(), - (1 << (hi_num_vars + lo_num_vars - log2_max_thread_id)) - ); - g2_j - }, - ), + .map(|((new_layer_id, point_and_eval), alpha_pow, eq_t)| { + let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + let copy_to = ©_to_matrices[new_layer_id]; + let lo_eq_w_p = build_eq_x_r_vec_sequential(&point_and_eval.point[..point_lo_num_vars]); + + // g2^{(j)}(y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) + let eq_t_unit_len = eq_t.len() / max_thread_id; + let start_index = thread_id * eq_t_unit_len; + let g2_j = tensor_product( + &eq_t[start_index..(start_index + eq_t_unit_len)], + ©_to + .as_slice() + .fix_row_row_first_with_scalar(&lo_eq_w_p, lo_num_vars, alpha_pow), + ); + + assert_eq!(g2_j.len(), (1 << (hi_num_vars + lo_num_vars - log2_max_thread_id))); + g2_j + }), ) .collect::>>(); DenseMultilinearExtension::from_evaluations_ext_vec( hi_num_vars + lo_num_vars - log2_max_thread_id, - gs.into_iter() - .fold(vec![E::ZERO; 1 << f1.num_vars], |mut acc, g| { - assert_eq!(1 << f1.num_vars, g.len()); - acc.iter_mut().enumerate().for_each(|(i, v)| *v += g[i]); - acc - }), + gs.into_iter().fold(vec![E::ZERO; 1 << f1.num_vars], |mut acc, g| { + assert_eq!(1 << f1.num_vars, g.len()); + acc.iter_mut().enumerate().for_each(|(i, v)| *v += g[i]); + acc + }), ) .into() }; @@ -172,10 +150,7 @@ impl IOPProverState { let eval_value_1 = f1.remove(0).1; self.to_next_step_point = sumcheck_proof_1.point.clone(); - self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref( - &self.to_next_step_point, - &eval_value_1, - )]; + self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref(&self.to_next_step_point, &eval_value_1)]; self.subset_point_and_evals[self.layer_id as usize].clear(); IOPProverStepMessage { diff --git a/gkr/src/prover/phase1_output.rs b/gkr/src/prover/phase1_output.rs index 3dbce07d2..be2bef774 100644 --- a/gkr/src/prover/phase1_output.rs +++ b/gkr/src/prover/phase1_output.rs @@ -38,13 +38,10 @@ impl IOPProverState { transcript: &mut Transcript, ) -> IOPProverStepMessage { let timer = start_timer!(|| "Prover sumcheck output phase 1 step 1"); - let alpha = transcript - .get_and_append_challenge(b"combine subset evals") - .elements; + let alpha = transcript.get_and_append_challenge(b"combine subset evals").elements; - let total_length = self.to_next_phase_point_and_evals.len() - + self.subset_point_and_evals[self.layer_id as usize].len() - + 1; + let total_length = + self.to_next_phase_point_and_evals.len() + self.subset_point_and_evals[self.layer_id as usize].len() + 1; let alpha_pows = { let mut alpha_pows = vec![E::ONE; total_length]; for i in 0..total_length.saturating_sub(1) { @@ -75,18 +72,12 @@ impl IOPProverState { let point = &point_and_eval.point; let lo_eq_w_p = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); - let f1_j = self - .phase1_layer_poly - .fix_high_variables(&point[point_lo_num_vars..]); + let f1_j = self.phase1_layer_poly.fix_high_variables(&point[point_lo_num_vars..]); - let g1_j = lo_eq_w_p - .into_iter() - .map(|eq| *alpha_pow * eq) - .collect_vec(); + let g1_j = lo_eq_w_p.into_iter().map(|eq| *alpha_pow * eq).collect_vec(); ( f1_j.into(), - DenseMultilinearExtension::::from_evaluations_ext_vec(lo_num_vars, g1_j) - .into(), + DenseMultilinearExtension::::from_evaluations_ext_vec(lo_num_vars, g1_j).into(), ) }) .unzip(); @@ -106,15 +97,11 @@ impl IOPProverState { let lo_eq_w_p = build_eq_x_r_vec(&point[..point_lo_num_vars]); assert!(copy_to.len() <= lo_eq_w_p.len()); - let f1_j = self - .phase1_layer_poly - .fix_high_variables(&point[point_lo_num_vars..]); + let f1_j = self.phase1_layer_poly.fix_high_variables(&point[point_lo_num_vars..]); - let g1_j = copy_to.as_slice().fix_row_row_first_with_scalar( - &lo_eq_w_p, - lo_num_vars, - alpha_pow, - ); + let g1_j = copy_to + .as_slice() + .fix_row_row_first_with_scalar(&lo_eq_w_p, lo_num_vars, alpha_pow); ( f1_j.into(), @@ -149,8 +136,7 @@ impl IOPProverState { virtual_poly_1.merge(&tmp); } - let (sumcheck_proof_1, prover_state) = - SumcheckState::prove_parallel(virtual_poly_1, transcript); + let (sumcheck_proof_1, prover_state) = SumcheckState::prove_parallel(virtual_poly_1, transcript); let (f1, g1): (Vec<_>, Vec<_>) = prover_state .get_mle_final_evaluations() .into_iter() @@ -207,18 +193,14 @@ impl IOPProverState { .collect_vec() }) .fold(vec![E::ZERO; 1 << hi_num_vars], |acc, nxt| { - acc.into_iter() - .zip(nxt.into_iter()) - .map(|(a, b)| a + b) - .collect_vec() + acc.into_iter().zip(nxt.into_iter()).map(|(a, b)| a + b).collect_vec() }); let g2 = DenseMultilinearExtension::from_evaluations_ext_vec(hi_num_vars, g2); // sumcheck: sigma = \sum_t( g2(t) * f2(t) ) let mut virtual_poly_2 = VirtualPolynomial::new_from_mle(f2, E::BaseField::ONE); virtual_poly_2.mul_by_mle(g2.into(), E::BaseField::ONE); - let (sumcheck_proof_2, prover_state) = - SumcheckState::prove_parallel(virtual_poly_2, transcript); + let (sumcheck_proof_2, prover_state) = SumcheckState::prove_parallel(virtual_poly_2, transcript); let (mut f2, _): (Vec<_>, Vec<_>) = prover_state .get_mle_final_evaluations() .into_iter() @@ -226,15 +208,8 @@ impl IOPProverState { .partition(|(i, _)| i % 2 == 0); let eval_value_2 = f2.remove(0).1; - self.to_next_step_point = [ - mem::take(&mut self.to_next_step_point), - sumcheck_proof_2.point.clone(), - ] - .concat(); - self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref( - &self.to_next_step_point, - &eval_value_2, - )]; + self.to_next_step_point = [mem::take(&mut self.to_next_step_point), sumcheck_proof_2.point.clone()].concat(); + self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref(&self.to_next_step_point, &eval_value_2)]; self.subset_point_and_evals[self.layer_id as usize].clear(); diff --git a/gkr/src/prover/phase2.rs b/gkr/src/prover/phase2.rs index a8f786039..95c8d2605 100644 --- a/gkr/src/prover/phase2.rs +++ b/gkr/src/prover/phase2.rs @@ -15,9 +15,7 @@ use crate::structs::Step::{Step1, Step2, Step3}; use crate::{ circuit::EvaluateConstant, - structs::{ - Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval, SumcheckProof, - }, + structs::{Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval, SumcheckProof}, }; macro_rules! prepare_stepx_g_fn { @@ -89,16 +87,10 @@ impl IOPProverState { let span = entered_span!("f1_g1"); // merge next_layer_vec with next_layer_poly - let next_layer_vec = circuit_witness.layers[layer_id as usize + 1] - .instances - .as_slice(); + let next_layer_vec = circuit_witness.layers[layer_id as usize + 1].instances.as_slice(); let num_vars = circuit.layers[layer_id as usize].max_previous_num_vars(); let phase2_next_layer_polys_v2: ArcDenseMultilinearExtension = circuit_witness - .layer_poly( - (layer_id + 1).try_into().unwrap(), - num_vars, - multi_threads_meta, - ) + .layer_poly((layer_id + 1).try_into().unwrap(), num_vars, multi_threads_meta) .into(); // f1(s1 || x1) = layers[i + 1](s1 || x1) @@ -132,9 +124,7 @@ impl IOPProverState { * (&gate.scalar.eval(&challenges)) }, adds_fanin_mapping, - |s, gate| { - eq[(s << lo_out_num_vars) ^ gate.idx_out] * (&gate.scalar.eval(&challenges)) - } + |s, gate| { eq[(s << lo_out_num_vars) ^ gate.idx_out] * (&gate.scalar.eval(&challenges)) } ); let g1 = DenseMultilinearExtension::from_evaluations_ext_vec(f1.num_vars, g1).into(); exit_span!(span); @@ -142,33 +132,34 @@ impl IOPProverState { // f1'^{(j)}(s1 || x1) = subset[j][i](s1 || x1) // g1'^{(j)}(s1 || x1) = eq(rt, s1) paste_from[j](ry, x1) let span = entered_span!("f1j_g1j"); - let (f1_j, g1_j)= izip!(&layer.paste_from).map(|(j, paste_from)| { - let paste_from_sources = circuit_witness.layers_ref(); - let old_wire_id = |old_layer_id: usize, subset_wire_id: usize| -> usize { - circuit.layers[old_layer_id].copy_to[&(layer_id as u32)][subset_wire_id] - }; - - let mut f1_j = vec![0.into(); 1 << f1.num_vars]; - let mut g1_j = vec![E::ZERO; 1 << f1.num_vars]; - - paste_from - .iter() - .enumerate() - .for_each(|(subset_wire_id, &new_wire_id)| { - for s in 0..(1 << (hi_num_vars - log2_max_thread_id)) { - let global_s = thread_s + s; - f1_j[(s << lo_in_num_vars) ^ subset_wire_id] = - paste_from_sources[*j as usize].instances[global_s] - [old_wire_id(*j as usize, subset_wire_id)]; - g1_j[(s << lo_in_num_vars) ^ subset_wire_id] += eq[(global_s << lo_out_num_vars) ^ new_wire_id]; - } - }); - ( - DenseMultilinearExtension::from_evaluations_vec(f1.num_vars, f1_j).into(), - DenseMultilinearExtension::from_evaluations_ext_vec(f1.num_vars, g1_j).into() - ) - }) - .unzip::<_, _, Vec>, Vec>>(); + let (f1_j, g1_j) = izip!(&layer.paste_from) + .map(|(j, paste_from)| { + let paste_from_sources = circuit_witness.layers_ref(); + let old_wire_id = |old_layer_id: usize, subset_wire_id: usize| -> usize { + circuit.layers[old_layer_id].copy_to[&(layer_id as u32)][subset_wire_id] + }; + + let mut f1_j = vec![0.into(); 1 << f1.num_vars]; + let mut g1_j = vec![E::ZERO; 1 << f1.num_vars]; + + paste_from + .iter() + .enumerate() + .for_each(|(subset_wire_id, &new_wire_id)| { + for s in 0..(1 << (hi_num_vars - log2_max_thread_id)) { + let global_s = thread_s + s; + f1_j[(s << lo_in_num_vars) ^ subset_wire_id] = paste_from_sources[*j as usize].instances + [global_s][old_wire_id(*j as usize, subset_wire_id)]; + g1_j[(s << lo_in_num_vars) ^ subset_wire_id] += + eq[(global_s << lo_out_num_vars) ^ new_wire_id]; + } + }); + ( + DenseMultilinearExtension::from_evaluations_vec(f1.num_vars, f1_j).into(), + DenseMultilinearExtension::from_evaluations_ext_vec(f1.num_vars, g1_j).into(), + ) + }) + .unzip::<_, _, Vec>, Vec>>(); exit_span!(span); let (f, g): ( @@ -208,18 +199,13 @@ impl IOPProverState { // eval_values_g1[0] eval_values_1.push(g1_vec[0].1); - self.to_next_phase_point_and_evals = - vec![PointAndEval::new_from_ref(&eval_point_1, &eval_values_1[0])]; - izip!( - layer.paste_from.iter(), - eval_values_1[..f1_vec_len].iter().skip(1) - ) - .for_each(|((&old_layer_id, _), &subset_value)| { - self.subset_point_and_evals[old_layer_id as usize].push(( - self.layer_id, - PointAndEval::new_from_ref(&eval_point_1, &subset_value), - )); - }); + self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref(&eval_point_1, &eval_values_1[0])]; + izip!(layer.paste_from.iter(), eval_values_1[..f1_vec_len].iter().skip(1)).for_each( + |((&old_layer_id, _), &subset_value)| { + self.subset_point_and_evals[old_layer_id as usize] + .push((self.layer_id, PointAndEval::new_from_ref(&eval_point_1, &subset_value))); + }, + ); self.to_next_step_point = eval_point_1; IOPProverStepMessage { @@ -256,9 +242,7 @@ impl IOPProverState { let threads_num_vars = hi_num_vars - log2_max_thread_id; let thread_s = thread_id << threads_num_vars; - let phase2_next_layer_vec = circuit_witness.layers[layer_id as usize + 1] - .instances - .as_slice(); + let phase2_next_layer_vec = circuit_witness.layers[layer_id as usize + 1].instances.as_slice(); let challenges = &circuit_witness.challenges; @@ -368,16 +352,11 @@ impl IOPProverState { let g3 = { let mut g3 = vec![E::ZERO; 1 << (f3.num_vars)]; let fanin_mapping = &layer.mul3s_fanin_mapping[Step3 as usize]; - prepare_stepx_g_fn!( - &mut g3, - lo_in_num_vars, - thread_s, - fanin_mapping, - |s, gate| eq0[(s << lo_out_num_vars) ^ gate.idx_out] - * eq1[(s << lo_in_num_vars) ^ gate.idx_in[0]] - * eq2[(s << lo_in_num_vars) ^ gate.idx_in[1]] - * (&gate.scalar.eval(&challenges)) - ); + prepare_stepx_g_fn!(&mut g3, lo_in_num_vars, thread_s, fanin_mapping, |s, gate| eq0 + [(s << lo_out_num_vars) ^ gate.idx_out] + * eq1[(s << lo_in_num_vars) ^ gate.idx_in[0]] + * eq2[(s << lo_in_num_vars) ^ gate.idx_in[1]] + * (&gate.scalar.eval(&challenges))); DenseMultilinearExtension::from_evaluations_ext_vec(f3.num_vars, g3).into() }; diff --git a/gkr/src/prover/phase2_input.rs b/gkr/src/prover/phase2_input.rs index 350e0c644..c20940b2c 100644 --- a/gkr/src/prover/phase2_input.rs +++ b/gkr/src/prover/phase2_input.rs @@ -67,15 +67,12 @@ impl IOPProverState { } ( { - let mut f = DenseMultilinearExtension::from_evaluations_vec( - max_lo_in_num_vars + hi_num_vars, - f, - ); + let mut f = + DenseMultilinearExtension::from_evaluations_vec(max_lo_in_num_vars + hi_num_vars, f); f.fix_high_variables_in_place(hi_point); f.into() }, - DenseMultilinearExtension::from_evaluations_ext_vec(max_lo_in_num_vars, g) - .into(), + DenseMultilinearExtension::from_evaluations_ext_vec(max_lo_in_num_vars, g).into(), ) }) .unzip(); @@ -99,15 +96,12 @@ impl IOPProverState { } ( { - let mut f = DenseMultilinearExtension::from_evaluations_vec( - max_lo_in_num_vars + hi_num_vars, - f, - ); + let mut f = + DenseMultilinearExtension::from_evaluations_vec(max_lo_in_num_vars + hi_num_vars, f); f.fix_high_variables_in_place(&hi_point); f.into() }, - DenseMultilinearExtension::from_evaluations_ext_vec(max_lo_in_num_vars, g) - .into(), + DenseMultilinearExtension::from_evaluations_ext_vec(max_lo_in_num_vars, g).into(), ) }) .unzip(); @@ -122,19 +116,14 @@ impl IOPProverState { virtual_poly.merge(&tmp); } - let (sumcheck_proofs, prover_state) = - SumcheckState::prove_parallel(virtual_poly, transcript); + let (sumcheck_proofs, prover_state) = SumcheckState::prove_parallel(virtual_poly, transcript); let eval_point = sumcheck_proofs.point.clone(); let (f_vec, _): (Vec<_>, Vec<_>) = prover_state .get_mle_final_evaluations() .into_iter() .enumerate() .partition(|(i, _)| i % 2 == 0); - let eval_values_f = f_vec - .into_iter() - .take(wits_in.len()) - .map(|(_, f)| f) - .collect_vec(); + let eval_values_f = f_vec.into_iter().take(wits_in.len()).map(|(_, f)| f).collect_vec(); self.to_next_phase_point_and_evals = izip!(paste_from_wit_in.iter(), eval_values_f.iter()) .map(|((l, r), eval)| { diff --git a/gkr/src/prover/phase2_linear.rs b/gkr/src/prover/phase2_linear.rs index 206d483f0..ce3ab449e 100644 --- a/gkr/src/prover/phase2_linear.rs +++ b/gkr/src/prover/phase2_linear.rs @@ -51,9 +51,7 @@ impl IOPProverState { let f1_g1 = || { // f1(x1) = layers[i + 1](rt || x1) - let layer_in_vec = circuit_witness.layers[self.layer_id as usize + 1] - .instances - .as_slice(); + let layer_in_vec = circuit_witness.layers[self.layer_id as usize + 1].instances.as_slice(); let mut f1 = layer_in_vec.mle(lo_in_num_vars, hi_num_vars); Arc::make_mut(&mut f1).fix_high_variables_in_place(&hi_point); @@ -85,24 +83,18 @@ impl IOPProverState { .enumerate() .for_each(|(subset_wire_id, &new_wire_id)| { for s in 0..(1 << hi_num_vars) { - f1_j[(s << lo_in_num_vars) ^ subset_wire_id] = paste_from_sources - [j as usize] - .instances[s][old_wire_id(j as usize, subset_wire_id)]; + f1_j[(s << lo_in_num_vars) ^ subset_wire_id] = + paste_from_sources[j as usize].instances[s][old_wire_id(j as usize, subset_wire_id)]; } g1_j[subset_wire_id] += eq_y_ry[new_wire_id]; }); f1_vec.push({ - let mut f1_j = DenseMultilinearExtension::from_evaluations_vec( - lo_in_num_vars + hi_num_vars, - f1_j, - ); + let mut f1_j = DenseMultilinearExtension::from_evaluations_vec(lo_in_num_vars + hi_num_vars, f1_j); f1_j.fix_high_variables_in_place(&hi_point); f1_j.into() }); - g1_vec.push( - DenseMultilinearExtension::from_evaluations_ext_vec(lo_in_num_vars, g1_j).into(), - ); + g1_vec.push(DenseMultilinearExtension::from_evaluations_ext_vec(lo_in_num_vars, g1_j).into()); }); // sumcheck: sigma = \sum_{x1} f1(x1) * g1(x1) + \sum_j f1'_j(x1) * g1'_j(x1) @@ -113,8 +105,7 @@ impl IOPProverState { virtual_poly_1.merge(&tmp); } - let (sumcheck_proof_1, prover_state) = - SumcheckState::prove_parallel(virtual_poly_1, transcript); + let (sumcheck_proof_1, prover_state) = SumcheckState::prove_parallel(virtual_poly_1, transcript); let eval_point_1 = sumcheck_proof_1.point.clone(); let (f1_vec, _): (Vec<_>, Vec<_>) = prover_state .get_mle_final_evaluations() @@ -124,14 +115,11 @@ impl IOPProverState { let eval_values_f1 = f1_vec.into_iter().map(|(_, f1_j)| f1_j).collect_vec(); let new_point = [&eval_point_1, hi_point].concat(); - self.to_next_phase_point_and_evals = - vec![PointAndEval::new_from_ref(&new_point, &eval_values_f1[0])]; + self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref(&new_point, &eval_values_f1[0])]; izip!(layer.paste_from.iter(), eval_values_f1.iter().skip(1)).for_each( |((&old_layer_id, _), &subset_value)| { - self.subset_point_and_evals[old_layer_id as usize].push(( - self.layer_id, - PointAndEval::new_from_ref(&new_point, &subset_value), - )); + self.subset_point_and_evals[old_layer_id as usize] + .push((self.layer_id, PointAndEval::new_from_ref(&new_point, &subset_value))); }, ); self.to_next_step_point = new_point; diff --git a/gkr/src/prover/test.rs b/gkr/src/prover/test.rs index fe90f2003..45229a0ce 100644 --- a/gkr/src/prover/test.rs +++ b/gkr/src/prover/test.rs @@ -9,9 +9,7 @@ use simple_frontend::structs::{ChallengeConst, ChallengeId, CircuitBuilder, Mixe use transcript::Transcript; use crate::{ - structs::{ - Circuit, CircuitWitness, IOPProverState, IOPVerifierState, LayerWitness, PointAndEval, - }, + structs::{Circuit, CircuitWitness, IOPProverState, IOPVerifierState, LayerWitness, PointAndEval}, utils::{i64_to_field, MultilinearExtensionFromVectors}, }; @@ -30,13 +28,7 @@ fn copy_and_paste_circuit() -> Circuit { // Layer 0 let (_, mul_001123) = circuit_builder.create_witness_out(1); - circuit_builder.mul3( - mul_001123[0], - mul_01, - mul_012, - input[3], - Ext::BaseField::ONE, - ); + circuit_builder.mul3(mul_001123[0], mul_01, mul_012, input[3], Ext::BaseField::ONE); circuit_builder.configure(); let circuit = Circuit::new(&circuit_builder); @@ -44,10 +36,8 @@ fn copy_and_paste_circuit() -> Circuit { circuit } -fn copy_and_paste_witness() -> ( - Vec>, - CircuitWitness, -) { +fn copy_and_paste_witness() -> (Vec>, CircuitWitness) +{ // witness_in, single instance let inputs = vec![vec![ i64_to_field(5), @@ -122,10 +112,8 @@ fn paste_from_wit_in_circuit() -> Circuit { circuit } -fn paste_from_wit_in_witness() -> ( - Vec>, - CircuitWitness, -) { +fn paste_from_wit_in_witness( +) -> (Vec>, CircuitWitness) { // witness_in, single instance let leaves1 = vec![vec![i64_to_field(5), i64_to_field(7), i64_to_field(11)]]; let leaves2 = vec![vec![i64_to_field(13), i64_to_field(17), i64_to_field(19)]]; @@ -173,12 +161,8 @@ fn paste_from_wit_in_witness() -> ( let outputs1 = vec![vec![i64_to_field(35), i64_to_field(143)]]; let outputs2 = vec![vec![i64_to_field(5005)]]; let witness_out = vec![ - LayerWitness { - instances: outputs1, - }, - LayerWitness { - instances: outputs2, - }, + LayerWitness { instances: outputs1 }, + LayerWitness { instances: outputs2 }, ]; ( @@ -214,10 +198,8 @@ fn copy_to_wit_out_circuit() -> Circuit { circuit } -fn copy_to_wit_out_witness() -> ( - Vec>, - CircuitWitness, -) { +fn copy_to_wit_out_witness() -> (Vec>, CircuitWitness) +{ // witness_in, single instance let leaves = vec![vec![ i64_to_field(5), @@ -264,24 +246,12 @@ fn copy_to_wit_out_witness() -> ( ) } -fn copy_to_wit_out_witness_2() -> ( - Vec>, - CircuitWitness, -) { +fn copy_to_wit_out_witness_2( +) -> (Vec>, CircuitWitness) { // witness_in, 2 instances let leaves = vec![ - vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ], - vec![ - i64_to_field(5), - i64_to_field(13), - i64_to_field(11), - i64_to_field(7), - ], + vec![i64_to_field(5), i64_to_field(7), i64_to_field(11), i64_to_field(13)], + vec![i64_to_field(5), i64_to_field(13), i64_to_field(11), i64_to_field(7)], ]; let witness_in = vec![LayerWitness { instances: leaves }]; @@ -310,18 +280,8 @@ fn copy_to_wit_out_witness_2() -> ( }, LayerWitness { instances: vec![ - vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ], - vec![ - i64_to_field(5), - i64_to_field(13), - i64_to_field(11), - i64_to_field(7), - ], + vec![i64_to_field(5), i64_to_field(7), i64_to_field(11), i64_to_field(13)], + vec![i64_to_field(5), i64_to_field(13), i64_to_field(11), i64_to_field(7)], ], }, ]; @@ -396,35 +356,21 @@ where // witness_in, double instances let leaves = vec![ - vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ], - vec![ - i64_to_field(5), - i64_to_field(13), - i64_to_field(11), - i64_to_field(7), - ], + vec![i64_to_field(5), i64_to_field(7), i64_to_field(11), i64_to_field(13)], + vec![i64_to_field(5), i64_to_field(13), i64_to_field(11), i64_to_field(7)], ]; let witness_in = vec![LayerWitness { instances: leaves.clone(), }]; - let inner00: Ext = challenge_pows[0][0].1 * (&leaves[0][0]) - + challenge_pows[0][1].1 * (&leaves[0][1]) - + challenge_pows[0][2].1; - let inner01: Ext = challenge_pows[1][0].1 * (&leaves[0][2]) - + challenge_pows[1][1].1 * (&leaves[0][3]) - + challenge_pows[1][2].1; - let inner10: Ext = challenge_pows[0][0].1 * (&leaves[1][0]) - + challenge_pows[0][1].1 * (&leaves[1][1]) - + challenge_pows[0][2].1; - let inner11: Ext = challenge_pows[1][0].1 * (&leaves[1][2]) - + challenge_pows[1][1].1 * (&leaves[1][3]) - + challenge_pows[1][2].1; + let inner00: Ext = + challenge_pows[0][0].1 * (&leaves[0][0]) + challenge_pows[0][1].1 * (&leaves[0][1]) + challenge_pows[0][2].1; + let inner01: Ext = + challenge_pows[1][0].1 * (&leaves[0][2]) + challenge_pows[1][1].1 * (&leaves[0][3]) + challenge_pows[1][2].1; + let inner10: Ext = + challenge_pows[0][0].1 * (&leaves[1][0]) + challenge_pows[0][1].1 * (&leaves[1][1]) + challenge_pows[0][2].1; + let inner11: Ext = + challenge_pows[1][0].1 * (&leaves[1][2]) + challenge_pows[1][1].1 * (&leaves[1][3]) + challenge_pows[1][2].1; let inners = vec![ [inner00.clone().as_bases(), inner01.clone().as_bases()].concat(), @@ -456,9 +402,7 @@ where LayerWitness { instances: roots.clone(), }, - LayerWitness { - instances: root_tmps, - }, + LayerWitness { instances: root_tmps }, LayerWitness { instances: inners }, LayerWitness { instances: leaves }, ]; @@ -493,12 +437,7 @@ fn inv_sum_circuit() -> Circuit { let den_mul = circuit_builder.create_ext_cell(); circuit_builder.mul2_ext(&den_mul, &input[0], &input[1], Ext::BaseField::ONE); let tmp = circuit_builder.create_ext_cell(); - circuit_builder.sel_mixed_and_ext( - &tmp, - &MixedCell::Constant(Ext::BaseField::ONE), - &input[0], - cond[0], - ); + circuit_builder.sel_mixed_and_ext(&tmp, &MixedCell::Constant(Ext::BaseField::ONE), &input[0], cond[0]); circuit_builder.sel_ext(&output[0], &tmp, &den_mul, cond[1]); // select the numerator 0 or 1 or input[0] + input[1] @@ -515,30 +454,10 @@ fn inv_sum_witness_4_instances() -> CircuitWitness(); // witness_in, double instances let leaves = vec![ - vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ], - vec![ - i64_to_field(5), - i64_to_field(13), - i64_to_field(11), - i64_to_field(7), - ], - vec![ - i64_to_field(23), - i64_to_field(29), - i64_to_field(17), - i64_to_field(19), - ], - vec![ - i64_to_field(29), - i64_to_field(17), - i64_to_field(19), - i64_to_field(23), - ], + vec![i64_to_field(5), i64_to_field(7), i64_to_field(11), i64_to_field(13)], + vec![i64_to_field(5), i64_to_field(13), i64_to_field(11), i64_to_field(7)], + vec![i64_to_field(23), i64_to_field(29), i64_to_field(17), i64_to_field(19)], + vec![i64_to_field(29), i64_to_field(17), i64_to_field(19), i64_to_field(23)], ]; let cond: Vec::BaseField>> = vec![ vec![i64_to_field(1), i64_to_field(1)], @@ -546,10 +465,7 @@ fn inv_sum_witness_4_instances() -> CircuitWitness() -> CircuitWitness( ) { let mut rng = test_rng(); let out_num_vars = circuit.output_num_vars() + circuit_wits.instance_num_vars(); - let out_point = (0..out_num_vars) - .map(|_| Ext::random(&mut rng)) - .collect_vec(); + let out_point = (0..out_num_vars).map(|_| Ext::random(&mut rng)).collect_vec(); let out_point_and_evals = if circuit.n_witness_out == 0 { vec![PointAndEval::new( diff --git a/gkr/src/structs.rs b/gkr/src/structs.rs index 3f667a7a6..6e5c0818f 100644 --- a/gkr/src/structs.rs +++ b/gkr/src/structs.rs @@ -148,11 +148,11 @@ pub struct Layer { // Gates. Should be all None if it's the input layer. pub(crate) add_consts: Vec>>, pub(crate) adds: Vec>>, - pub(crate) adds_fanin_mapping: [BTreeMap>>>; 1], /* grouping for 1 fanins */ + pub(crate) adds_fanin_mapping: [BTreeMap>>>; 1], // grouping for 1 fanins pub(crate) mul2s: Vec>>, - pub(crate) mul2s_fanin_mapping: [BTreeMap>>>; 2], /* grouping for 2 fanins */ + pub(crate) mul2s_fanin_mapping: [BTreeMap>>>; 2], // grouping for 2 fanins pub(crate) mul3s: Vec>>, - pub(crate) mul3s_fanin_mapping: [BTreeMap>>>; 3], /* grouping for 3 fanins */ + pub(crate) mul3s_fanin_mapping: [BTreeMap>>>; 3], // grouping for 3 fanins /// The corresponding wires copied from this layer to later layers. It is /// (later layer id -> current wire id to be copied). It stores the non-zero diff --git a/gkr/src/test/is_zero_gadget.rs b/gkr/src/test/is_zero_gadget.rs index d4aa436f6..d2ac3ee39 100644 --- a/gkr/src/test/is_zero_gadget.rs +++ b/gkr/src/test/is_zero_gadget.rs @@ -1,12 +1,13 @@ -use crate::structs::{Circuit, CircuitWitness, IOPProverState, IOPVerifierState, PointAndEval}; -use crate::utils::MultilinearExtensionFromVectors; +use crate::{ + structs::{Circuit, CircuitWitness, IOPProverState, IOPVerifierState, PointAndEval}, + utils::MultilinearExtensionFromVectors, +}; use ff::Field; use ff_ext::ExtensionField; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; use simple_frontend::structs::{CellId, CircuitBuilder}; -use std::iter; -use std::time::Duration; +use std::{iter, time::Duration}; use transcript::Transcript; // build an IsZero Gadget @@ -75,7 +76,7 @@ fn test_gkr_circuit_is_zero_gadget_simple() { }; println!("circuit witness: {:?}", circuit_witness); // use of check_correctness will panic - //circuit_witness.check_correctness(&circuit); + // circuit_witness.check_correctness(&circuit); // check the result let layers = circuit_witness.layers_ref(); @@ -96,10 +97,8 @@ fn test_gkr_circuit_is_zero_gadget_simple() { assert_eq!(is_zero_wire_out_ref.instances[0][0], out_is_zero); // add prover-verifier process - let mut prover_transcript = - Transcript::::new(b"test_gkr_circuit_IsZeroGadget_simple"); - let mut verifier_transcript = - Transcript::::new(b"test_gkr_circuit_IsZeroGadget_simple"); + let mut prover_transcript = Transcript::::new(b"test_gkr_circuit_IsZeroGadget_simple"); + let mut verifier_transcript = Transcript::::new(b"test_gkr_circuit_IsZeroGadget_simple"); let mut prover_wires_out_evals = vec![]; let mut verifier_wires_out_evals = vec![]; @@ -130,10 +129,7 @@ fn test_gkr_circuit_is_zero_gadget_simple() { let prover_output_eval = output_mle.evaluate(&prover_output_point); let verifier_output_eval = output_mle.evaluate(&verifier_output_point); prover_wires_out_evals.push(PointAndEval::new(prover_output_point, prover_output_eval)); - verifier_wires_out_evals.push(PointAndEval::new( - verifier_output_point, - verifier_output_eval, - )); + verifier_wires_out_evals.push(PointAndEval::new(verifier_output_point, verifier_output_eval)); } let start = std::time::Instant::now(); @@ -192,23 +188,17 @@ fn test_gkr_circuit_is_zero_gadget_u256() { let mut is_zero_prev_items = circuit_builder.create_cell(); circuit_builder.add_const(is_zero_prev_items, Goldilocks::from(1)); for (value_item, inv_item) in value.into_iter().zip(inv) { - let (is_zero_item, cond1_item, cond2_item) = - is_zero_gadget(&mut circuit_builder, value_item, inv_item); + let (is_zero_item, cond1_item, cond2_item) = is_zero_gadget(&mut circuit_builder, value_item, inv_item); cond1.push(cond1_item); cond2.push(cond2_item); let is_zero = circuit_builder.create_cell(); // TODO: can optimize using mul3 - circuit_builder.mul2( - is_zero, - is_zero_prev_items, - is_zero_item, - Goldilocks::from(1), - ); + circuit_builder.mul2(is_zero, is_zero_prev_items, is_zero_item, Goldilocks::from(1)); is_zero_prev_items = is_zero; } - let cond_wire_out_id = circuit_builder - .create_witness_out_from_cells(&[cond1.as_slice(), cond2.as_slice()].concat()); + let cond_wire_out_id = + circuit_builder.create_witness_out_from_cells(&[cond1.as_slice(), cond2.as_slice()].concat()); let is_zero_wire_out_id = circuit_builder.create_witness_out_from_cells(&[is_zero_prev_items]); circuit_builder.configure(); @@ -232,7 +222,7 @@ fn test_gkr_circuit_is_zero_gadget_u256() { }; println!("circuit witness: {:?}", circuit_witness); // use of check_correctness will panic - //circuit_witness.check_correctness(&circuit); + // circuit_witness.check_correctness(&circuit); // check the result let layers = circuit_witness.layers_ref(); @@ -254,10 +244,8 @@ fn test_gkr_circuit_is_zero_gadget_u256() { assert_eq!(is_zero_wire_out_ref.instances[0][0], out_is_zero); // add prover-verifier process - let mut prover_transcript = - Transcript::::new(b"test_gkr_circuit_IsZeroGadget_simple"); - let mut verifier_transcript = - Transcript::::new(b"test_gkr_circuit_IsZeroGadget_simple"); + let mut prover_transcript = Transcript::::new(b"test_gkr_circuit_IsZeroGadget_simple"); + let mut verifier_transcript = Transcript::::new(b"test_gkr_circuit_IsZeroGadget_simple"); let mut prover_wires_out_evals = vec![]; let mut verifier_wires_out_evals = vec![]; @@ -288,10 +276,7 @@ fn test_gkr_circuit_is_zero_gadget_u256() { let prover_output_eval = output_mle.evaluate(&prover_output_point); let verifier_output_eval = output_mle.evaluate(&verifier_output_point); prover_wires_out_evals.push(PointAndEval::new(prover_output_point, prover_output_eval)); - verifier_wires_out_evals.push(PointAndEval::new( - verifier_output_point, - verifier_output_eval, - )); + verifier_wires_out_evals.push(PointAndEval::new(verifier_output_point, verifier_output_eval)); } let start = std::time::Instant::now(); @@ -305,21 +290,19 @@ fn test_gkr_circuit_is_zero_gadget_u256() { ); let proof_time: Duration = start.elapsed(); - /* // verifier panics due to mismatch of number of variables - let start = std::time::Instant::now(); - let _claim = IOPVerifierState::verify_parallel( - &circuit, - &[], - &[], - &verifier_wires_out_evals, - &proof, - instance_num_vars, - &mut verifier_transcript, - ).unwrap(); - let verification_time: Duration = start.elapsed(); - - println!("proof time: {:?}, verification time: {:?}", proof_time, verification_time); - */ + // let start = std::time::Instant::now(); + // let _claim = IOPVerifierState::verify_parallel( + // &circuit, + // &[], + // &[], + // &verifier_wires_out_evals, + // &proof, + // instance_num_vars, + // &mut verifier_transcript, + // ).unwrap(); + // let verification_time: Duration = start.elapsed(); + // + // println!("proof time: {:?}, verification time: {:?}", proof_time, verification_time); println!("proof time: {:?}", proof_time); } diff --git a/gkr/src/utils.rs b/gkr/src/utils.rs index 2670c68f6..cd79d2732 100644 --- a/gkr/src/utils.rs +++ b/gkr/src/utils.rs @@ -27,8 +27,7 @@ pub(crate) fn segment_eval_greater_than(min_idx: usize, a: &[ running_product[a.len()] = E::ONE; for i in (0..a.len()).rev() { let bit = E::from(((min_idx >> i) & 1) as u64); - running_product[i] = - running_product[i + 1] * (a[i] * bit + (E::ONE - a[i]) * (E::ONE - bit)); + running_product[i] = running_product[i + 1] * (a[i] * bit + (E::ONE - a[i]) * (E::ONE - bit)); } running_product }; @@ -70,8 +69,8 @@ pub(crate) fn eq_eval_greater_than(min_idx: usize, a: &[F], b: &[ running_product[b.len()] = F::ONE; for i in (0..b.len()).rev() { let bit = F::from(((min_idx >> i) & 1) as u64); - running_product[i] = running_product[i + 1] - * (a[i] * b[i] * bit + (F::ONE - a[i]) * (F::ONE - b[i]) * (F::ONE - bit)); + running_product[i] = + running_product[i + 1] * (a[i] * b[i] * bit + (F::ONE - a[i]) * (F::ONE - b[i]) * (F::ONE - bit)); } running_product }; @@ -116,8 +115,8 @@ pub(crate) fn eq_eval_less_or_equal_than(max_idx: usize, a: & running_product[b.len()] = E::ONE; for i in (0..b.len()).rev() { let bit = E::from(((max_idx >> i) & 1) as u64); - running_product[i] = running_product[i + 1] - * (a[i] * b[i] * bit + (E::ONE - a[i]) * (E::ONE - b[i]) * (E::ONE - bit)); + running_product[i] = + running_product[i + 1] * (a[i] * b[i] * bit + (E::ONE - a[i]) * (E::ONE - b[i]) * (E::ONE - bit)); } running_product }; @@ -277,12 +276,7 @@ impl MultilinearExtensionFromVectors for &[Vec { fn fix_row_col_first(&self, row_point_eq: &[E], col_num_vars: usize) -> Vec; - fn fix_row_col_first_with_scalar( - &self, - row_point_eq: &[E], - col_num_vars: usize, - scalar: &E, - ) -> Vec; + fn fix_row_col_first_with_scalar(&self, row_point_eq: &[E], col_num_vars: usize, scalar: &E) -> Vec; fn eval_col_first(&self, row_point_eq: &[E], col_point_eq: &[E]) -> E; } @@ -295,12 +289,7 @@ impl MatrixMLEColumnFirst for &[usize] { ans } - fn fix_row_col_first_with_scalar( - &self, - row_point_eq: &[E], - col_num_vars: usize, - scalar: &E, - ) -> Vec { + fn fix_row_col_first_with_scalar(&self, row_point_eq: &[E], col_num_vars: usize, scalar: &E) -> Vec { let mut ans = vec![E::ZERO; 1 << col_num_vars]; for (col, &non_zero_row) in self.iter().enumerate() { ans[col] = row_point_eq[non_zero_row] * scalar; @@ -309,22 +298,15 @@ impl MatrixMLEColumnFirst for &[usize] { } fn eval_col_first(&self, row_point_eq: &[E], col_point_eq: &[E]) -> E { - self.iter() - .enumerate() - .fold(E::ZERO, |acc, (col, &non_zero_row)| { - acc + row_point_eq[non_zero_row] * col_point_eq[col] - }) + self.iter().enumerate().fold(E::ZERO, |acc, (col, &non_zero_row)| { + acc + row_point_eq[non_zero_row] * col_point_eq[col] + }) } } pub(crate) trait MatrixMLERowFirst { fn fix_row_row_first(&self, row_point_eq: &[E], col_num_vars: usize) -> Vec; - fn fix_row_row_first_with_scalar( - &self, - row_point_eq: &[E], - col_num_vars: usize, - scalar: &E, - ) -> Vec; + fn fix_row_row_first_with_scalar(&self, row_point_eq: &[E], col_num_vars: usize, scalar: &E) -> Vec; fn eval_row_first(&self, row_point_eq: &[E], col_point_eq: &[E]) -> E; } @@ -337,12 +319,7 @@ impl MatrixMLERowFirst for &[usize] { ans } - fn fix_row_row_first_with_scalar( - &self, - row_point_eq: &[E], - col_num_vars: usize, - scalar: &E, - ) -> Vec { + fn fix_row_row_first_with_scalar(&self, row_point_eq: &[E], col_num_vars: usize, scalar: &E) -> Vec { let mut ans = vec![E::ZERO; 1 << col_num_vars]; for (row, &non_zero_col) in self.iter().enumerate() { ans[non_zero_col] = row_point_eq[row] * scalar; @@ -351,11 +328,9 @@ impl MatrixMLERowFirst for &[usize] { } fn eval_row_first(&self, row_point_eq: &[E], col_point_eq: &[E]) -> E { - self.iter() - .enumerate() - .fold(E::ZERO, |acc, (row, &non_zero_col)| { - acc + row_point_eq[row] * col_point_eq[non_zero_col] - }) + self.iter().enumerate().fold(E::ZERO, |acc, (row, &non_zero_col)| { + acc + row_point_eq[row] * col_point_eq[non_zero_col] + }) } } @@ -374,13 +349,11 @@ impl SubsetIndices for &[usize] { ans } fn subset_eq_eval(&self, eq_1: &[F]) -> F { - self.iter() - .fold(F::ZERO, |acc, &non_zero_i| acc + eq_1[non_zero_i]) + self.iter().fold(F::ZERO, |acc, &non_zero_i| acc + eq_1[non_zero_i]) } fn subset_eq2_eval(&self, eq_1: &[F], eq_2: &[F]) -> F { - self.iter().fold(F::ZERO, |acc, &non_zero_i| { - acc + eq_1[non_zero_i] * eq_2[non_zero_i] - }) + self.iter() + .fold(F::ZERO, |acc, &non_zero_i| acc + eq_1[non_zero_i] * eq_2[non_zero_i]) } } @@ -411,12 +384,8 @@ mod test { let mut rng = test_rng(); let n = 5; let pow_n = 1 << n; - let a = (0..n) - .map(|_| GoldilocksExt2::random(&mut rng)) - .collect_vec(); - let b = (0..n) - .map(|_| GoldilocksExt2::random(&mut rng)) - .collect_vec(); + let a = (0..n).map(|_| GoldilocksExt2::random(&mut rng)).collect_vec(); + let b = (0..n).map(|_| GoldilocksExt2::random(&mut rng)).collect_vec(); let eq_vec = build_eq_x_r_vec(&a); @@ -424,8 +393,7 @@ mod test { let max_idx = 0; let mut partial_eq_vec: Vec<_> = eq_vec[0..=max_idx].to_vec(); partial_eq_vec.extend(vec![GoldilocksExt2::ZERO; pow_n - max_idx - 1]); - let expected_ans = - DenseMultilinearExtension::from_evaluations_ext_vec(n, partial_eq_vec).evaluate(&b); + let expected_ans = DenseMultilinearExtension::from_evaluations_ext_vec(n, partial_eq_vec).evaluate(&b); assert_eq!(expected_ans, eq_eval_less_or_equal_than(max_idx, &a, &b)); } @@ -433,8 +401,7 @@ mod test { let max_idx = 1; let mut partial_eq_vec: Vec<_> = eq_vec[0..=max_idx].to_vec(); partial_eq_vec.extend(vec![GoldilocksExt2::ZERO; pow_n - max_idx - 1]); - let expected_ans = - DenseMultilinearExtension::from_evaluations_ext_vec(n, partial_eq_vec).evaluate(&b); + let expected_ans = DenseMultilinearExtension::from_evaluations_ext_vec(n, partial_eq_vec).evaluate(&b); assert_eq!(expected_ans, eq_eval_less_or_equal_than(max_idx, &a, &b)); } @@ -442,8 +409,7 @@ mod test { let max_idx = 12; let mut partial_eq_vec: Vec<_> = eq_vec[0..=max_idx].to_vec(); partial_eq_vec.extend(vec![GoldilocksExt2::ZERO; pow_n - max_idx - 1]); - let expected_ans = - DenseMultilinearExtension::from_evaluations_ext_vec(n, partial_eq_vec).evaluate(&b); + let expected_ans = DenseMultilinearExtension::from_evaluations_ext_vec(n, partial_eq_vec).evaluate(&b); assert_eq!(expected_ans, eq_eval_less_or_equal_than(max_idx, &a, &b)); } @@ -451,8 +417,7 @@ mod test { let max_idx = 1 << (n - 1) - 1; let mut partial_eq_vec: Vec<_> = eq_vec[0..=max_idx].to_vec(); partial_eq_vec.extend(vec![GoldilocksExt2::ZERO; pow_n - max_idx - 1]); - let expected_ans = - DenseMultilinearExtension::from_evaluations_ext_vec(n, partial_eq_vec).evaluate(&b); + let expected_ans = DenseMultilinearExtension::from_evaluations_ext_vec(n, partial_eq_vec).evaluate(&b); assert_eq!(expected_ans, eq_eval_less_or_equal_than(max_idx, &a, &b)); } @@ -460,8 +425,7 @@ mod test { let max_idx = 1 << (n - 1); let mut partial_eq_vec: Vec<_> = eq_vec[0..=max_idx].to_vec(); partial_eq_vec.extend(vec![GoldilocksExt2::ZERO; pow_n - max_idx - 1]); - let expected_ans = - DenseMultilinearExtension::from_evaluations_ext_vec(n, partial_eq_vec).evaluate(&b); + let expected_ans = DenseMultilinearExtension::from_evaluations_ext_vec(n, partial_eq_vec).evaluate(&b); assert_eq!(expected_ans, eq_eval_less_or_equal_than(max_idx, &a, &b)); } } diff --git a/gkr/src/verifier.rs b/gkr/src/verifier.rs index d84ff6ce2..3f5b2342f 100644 --- a/gkr/src/verifier.rs +++ b/gkr/src/verifier.rs @@ -8,8 +8,7 @@ use transcript::Transcript; use crate::{ error::GKRError, structs::{ - Circuit, GKRInputClaims, IOPProof, IOPProverStepMessage, IOPVerifierState, PointAndEval, - SumcheckStepType, + Circuit, GKRInputClaims, IOPProof, IOPProverStepMessage, IOPVerifierState, PointAndEval, SumcheckStepType, }, }; @@ -51,39 +50,35 @@ impl IOPVerifierState { verifier_state.layer_id = layer_id as LayerId; let layer = &circuit.layers[layer_id as usize]; - for (step, step_proof) in izip!(layer.sumcheck_steps.iter(), &mut sumcheck_proofs_iter) - { + for (step, step_proof) in izip!(layer.sumcheck_steps.iter(), &mut sumcheck_proofs_iter) { match step { - SumcheckStepType::OutputPhase1Step1 => verifier_state - .verify_and_update_state_output_phase1_step1( - circuit, step_proof, transcript, - )?, - SumcheckStepType::OutputPhase1Step2 => verifier_state - .verify_and_update_state_output_phase1_step2( - circuit, step_proof, transcript, - )?, - SumcheckStepType::Phase1Step1 => verifier_state - .verify_and_update_state_phase1_step1(circuit, step_proof, transcript)?, - SumcheckStepType::Phase2Step1 => verifier_state - .verify_and_update_state_phase2_step1(circuit, step_proof, transcript)?, - SumcheckStepType::Phase2Step2 => verifier_state - .verify_and_update_state_phase2_step2( - circuit, step_proof, transcript, false, - )?, - SumcheckStepType::Phase2Step2NoStep3 => verifier_state - .verify_and_update_state_phase2_step2( - circuit, step_proof, transcript, true, - )?, - SumcheckStepType::Phase2Step3 => verifier_state - .verify_and_update_state_phase2_step3(circuit, step_proof, transcript)?, - SumcheckStepType::LinearPhase2Step1 => verifier_state - .verify_and_update_state_linear_phase2_step1( - circuit, step_proof, transcript, - )?, - SumcheckStepType::InputPhase2Step1 => verifier_state - .verify_and_update_state_input_phase2_step1( - circuit, step_proof, transcript, - )?, + SumcheckStepType::OutputPhase1Step1 => { + verifier_state.verify_and_update_state_output_phase1_step1(circuit, step_proof, transcript)? + } + SumcheckStepType::OutputPhase1Step2 => { + verifier_state.verify_and_update_state_output_phase1_step2(circuit, step_proof, transcript)? + } + SumcheckStepType::Phase1Step1 => { + verifier_state.verify_and_update_state_phase1_step1(circuit, step_proof, transcript)? + } + SumcheckStepType::Phase2Step1 => { + verifier_state.verify_and_update_state_phase2_step1(circuit, step_proof, transcript)? + } + SumcheckStepType::Phase2Step2 => { + verifier_state.verify_and_update_state_phase2_step2(circuit, step_proof, transcript, false)? + } + SumcheckStepType::Phase2Step2NoStep3 => { + verifier_state.verify_and_update_state_phase2_step2(circuit, step_proof, transcript, true)? + } + SumcheckStepType::Phase2Step3 => { + verifier_state.verify_and_update_state_phase2_step3(circuit, step_proof, transcript)? + } + SumcheckStepType::LinearPhase2Step1 => { + verifier_state.verify_and_update_state_linear_phase2_step1(circuit, step_proof, transcript)? + } + SumcheckStepType::InputPhase2Step1 => { + verifier_state.verify_and_update_state_input_phase2_step1(circuit, step_proof, transcript)? + } _ => unreachable!(), } } @@ -114,11 +109,7 @@ impl IOPVerifierState { wires_out_evals.last().unwrap().clone() }; let assert_point = (0..output_wit_num_vars) - .map(|_| { - transcript - .get_and_append_challenge(b"assert_point") - .elements - }) + .map(|_| transcript.get_and_append_challenge(b"assert_point").elements) .collect_vec(); let to_next_phase_point_and_evals = output_evals; subset_point_and_evals[0] = wires_out_evals.into_iter().map(|p| (0, p)).collect(); diff --git a/gkr/src/verifier/phase1.rs b/gkr/src/verifier/phase1.rs index 3bc6c13a3..f39ff6d0f 100644 --- a/gkr/src/verifier/phase1.rs +++ b/gkr/src/verifier/phase1.rs @@ -21,12 +21,9 @@ impl IOPVerifierState { transcript: &mut Transcript, ) -> Result<(), GKRError> { let timer = start_timer!(|| "Verifier sumcheck phase 1 step 1"); - let alpha = transcript - .get_and_append_challenge(b"combine subset evals") - .elements; - let total_length = self.to_next_phase_point_and_evals.len() - + self.subset_point_and_evals[self.layer_id as usize].len() - + 1; + let alpha = transcript.get_and_append_challenge(b"combine subset evals").elements; + let total_length = + self.to_next_phase_point_and_evals.len() + self.subset_point_and_evals[self.layer_id as usize].len() + 1; let alpha_pows = { let mut alpha_pows = vec![E::ONE; total_length]; for i in 0..total_length.saturating_sub(1) { @@ -45,9 +42,7 @@ impl IOPVerifierState { }); sigma_1 += izip!( self.subset_point_and_evals[self.layer_id as usize].iter(), - alpha_pows - .iter() - .skip(self.to_next_phase_point_and_evals.len()) + alpha_pows.iter().skip(self.to_next_phase_point_and_evals.len()) ) .fold(E::ZERO, |acc, ((_, point_and_eval), alpha_pow)| { acc + point_and_eval.eval * alpha_pow @@ -76,25 +71,21 @@ impl IOPVerifierState { let f_value = step_msg.sumcheck_eval_values[0]; let g_value: E = chain![ - izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()).map( - |(point_and_eval, alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - let eq_t = eq_eval( - &point_and_eval.point[point_lo_num_vars..], - &claim1_point[(claim1_point.len() - hi_num_vars)..], - ); - let eq_y = eq_eval( - &point_and_eval.point[..point_lo_num_vars], - &claim1_point[..point_lo_num_vars], - ); - eq_t * eq_y * alpha_pow - } - ), + izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()).map(|(point_and_eval, alpha_pow)| { + let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + let eq_t = eq_eval( + &point_and_eval.point[point_lo_num_vars..], + &claim1_point[(claim1_point.len() - hi_num_vars)..], + ); + let eq_y = eq_eval( + &point_and_eval.point[..point_lo_num_vars], + &claim1_point[..point_lo_num_vars], + ); + eq_t * eq_y * alpha_pow + }), izip!( self.subset_point_and_evals[self.layer_id as usize].iter(), - alpha_pows - .iter() - .skip(self.to_next_phase_point_and_evals.len()) + alpha_pows.iter().skip(self.to_next_phase_point_and_evals.len()) ) .map(|((new_layer_id, point_and_eval), alpha_pow)| { let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; diff --git a/gkr/src/verifier/phase1_output.rs b/gkr/src/verifier/phase1_output.rs index 50aa74d67..bf25f47da 100644 --- a/gkr/src/verifier/phase1_output.rs +++ b/gkr/src/verifier/phase1_output.rs @@ -22,12 +22,9 @@ impl IOPVerifierState { transcript: &mut Transcript, ) -> Result<(), GKRError> { let timer = start_timer!(|| "Verifier sumcheck phase 1 step 1"); - let alpha = transcript - .get_and_append_challenge(b"combine subset evals") - .elements; - let total_length = self.to_next_phase_point_and_evals.len() - + self.subset_point_and_evals[self.layer_id as usize].len() - + 1; + let alpha = transcript.get_and_append_challenge(b"combine subset evals").elements; + let total_length = + self.to_next_phase_point_and_evals.len() + self.subset_point_and_evals[self.layer_id as usize].len() + 1; let alpha_pows = { let mut alpha_pows = vec![E::ONE; total_length]; for i in 0..total_length.saturating_sub(1) { @@ -49,9 +46,7 @@ impl IOPVerifierState { }); sigma_1 += izip!( self.subset_point_and_evals[self.layer_id as usize].iter(), - alpha_pows - .iter() - .skip(self.to_next_phase_point_and_evals.len()) + alpha_pows.iter().skip(self.to_next_phase_point_and_evals.len()) ) .fold(E::ZERO, |acc, ((_, point_and_eval), alpha_pow)| { acc + point_and_eval.eval * alpha_pow @@ -79,18 +74,14 @@ impl IOPVerifierState { let claim1_point = claim_1.point.iter().map(|x| x.elements).collect_vec(); let eq_y_ry = build_eq_x_r_vec(&claim1_point); self.g1_values = chain![ - izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()).map( - |(point_and_eval, alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - eq_eval(&point_and_eval.point[..point_lo_num_vars], &claim1_point) * alpha_pow - } - ), + izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()).map(|(point_and_eval, alpha_pow)| { + let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + eq_eval(&point_and_eval.point[..point_lo_num_vars], &claim1_point) * alpha_pow + }), izip!( circuit.copy_to_wits_out.iter(), self.subset_point_and_evals[self.layer_id as usize].iter(), - alpha_pows - .iter() - .skip(self.to_next_phase_point_and_evals.len()) + alpha_pows.iter().skip(self.to_next_phase_point_and_evals.len()) ) .map(|(copy_to, (_, point_and_eval), alpha_pow)| { let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; @@ -118,8 +109,7 @@ impl IOPVerifierState { return Err(GKRError::VerifyError("output phase1 step1 failed")); } - self.to_next_step_point_and_eval = - PointAndEval::new(claim1_point, claim_1.expected_evaluation); + self.to_next_step_point_and_eval = PointAndEval::new(claim1_point, claim_1.expected_evaluation); Ok(()) } @@ -172,11 +162,7 @@ impl IOPVerifierState { } self.to_next_step_point_and_eval = PointAndEval::new( - [ - mem::take(&mut self.to_next_step_point_and_eval.point), - claim2_point, - ] - .concat(), + [mem::take(&mut self.to_next_step_point_and_eval.point), claim2_point].concat(), f2_value, ); self.subset_point_and_evals[self.layer_id as usize].clear(); diff --git a/gkr/src/verifier/phase2.rs b/gkr/src/verifier/phase2.rs index 095864930..d24ddeedf 100644 --- a/gkr/src/verifier/phase2.rs +++ b/gkr/src/verifier/phase2.rs @@ -35,11 +35,8 @@ impl IOPVerifierState { self.eq_y_ry = build_eq_x_r_vec(lo_point); // sigma = layers[i](rt || ry) - add_const(ry), - let sumcheck_sigma = self.to_next_step_point_and_eval.eval - - layer - .add_consts - .as_slice() - .eval(&self.eq_y_ry, &self.challenges); + let sumcheck_sigma = + self.to_next_step_point_and_eval.eval - layer.add_consts.as_slice().eval(&self.eq_y_ry, &self.challenges); // Sumcheck 1: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * g1'_j(s1 || x1) // f1(s1 || x1) = layers[i + 1](s1 || x1) @@ -72,30 +69,21 @@ impl IOPVerifierState { let g1_values_iter = chain![ received_g1_values.iter().cloned(), layer.paste_from.iter().map(|(_, paste_from)| { - hi_eq_eval - * paste_from - .as_slice() - .eval_col_first(&self.eq_y_ry, &self.eq_x1_rx1) + hi_eq_eval * paste_from.as_slice().eval_col_first(&self.eq_y_ry, &self.eq_x1_rx1) }) ]; - let got_value_1 = - izip!(f1_values.iter(), g1_values_iter).fold(E::ZERO, |acc, (&f1, g1)| acc + f1 * g1); + let got_value_1 = izip!(f1_values.iter(), g1_values_iter).fold(E::ZERO, |acc, (&f1, g1)| acc + f1 * g1); end_timer!(timer); if claim_1.expected_evaluation != got_value_1 { return Err(GKRError::VerifyError("phase2 step1 failed")); } - self.to_next_phase_point_and_evals = - vec![PointAndEval::new_from_ref(&claim1_point, &f1_values[0])]; - izip!(layer.paste_from.iter(), f1_values.iter().skip(1)).for_each( - |((&old_layer_id, _), &subset_value)| { - self.subset_point_and_evals[old_layer_id as usize].push(( - self.layer_id, - PointAndEval::new_from_ref(&claim1_point, &subset_value), - )); - }, - ); + self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref(&claim1_point, &f1_values[0])]; + izip!(layer.paste_from.iter(), f1_values.iter().skip(1)).for_each(|((&old_layer_id, _), &subset_value)| { + self.subset_point_and_evals[old_layer_id as usize] + .push((self.layer_id, PointAndEval::new_from_ref(&claim1_point, &subset_value))); + }); self.to_next_step_point_and_eval = PointAndEval::new(claim1_point, received_g1_values[0]); Ok(()) @@ -149,12 +137,10 @@ impl IOPVerifierState { &self.out_point[lo_out_num_vars..], &self.to_next_phase_point_and_evals[0].point[lo_in_num_vars..], &claim2_point[lo_in_num_vars..], - ) * layer.mul2s.as_slice().eval( - &self.eq_y_ry, - &self.eq_x1_rx1, - &self.eq_x2_rx2, - &self.challenges, - ) + ) * layer + .mul2s + .as_slice() + .eval(&self.eq_y_ry, &self.eq_x1_rx1, &self.eq_x2_rx2, &self.challenges) } else { step_msg.sumcheck_eval_values[1] }; @@ -190,12 +176,10 @@ impl IOPVerifierState { &self.out_point[lo_out_num_vars..], &self.to_next_phase_point_and_evals[0].point[lo_in_num_vars..], &self.to_next_phase_point_and_evals[1].point[lo_in_num_vars..], - ) * layer.mul2s.as_slice().eval( - &self.eq_y_ry, - &self.eq_x1_rx1, - &self.eq_x2_rx2, - &self.challenges, - ); + ) * layer + .mul2s + .as_slice() + .eval(&self.eq_y_ry, &self.eq_x1_rx1, &self.eq_x2_rx2, &self.challenges); // Sumcheck 3 sigma = \sum_{s3 || x3} f3(s3 || x3) * g3(s3 || x3) // f3(s3 || x3) = layers[i + 1](s3 || x3) diff --git a/gkr/src/verifier/phase2_input.rs b/gkr/src/verifier/phase2_input.rs index 9f63b1913..8eb2b2386 100644 --- a/gkr/src/verifier/phase2_input.rs +++ b/gkr/src/verifier/phase2_input.rs @@ -51,10 +51,7 @@ impl IOPVerifierState { let mut sumcheck_sigma = self.to_next_step_point_and_eval.eval - g_value_const; if !layer.add_consts.is_empty() { - sumcheck_sigma -= layer - .add_consts - .as_slice() - .eval(&self.eq_y_ry, &self.challenges); + sumcheck_sigma -= layer.add_consts.as_slice().eval(&self.eq_y_ry, &self.challenges); } if lo_in_num_vars.is_none() { @@ -81,16 +78,9 @@ impl IOPVerifierState { self.eq_x1_rx1 = build_eq_x_r_vec(&claim_point); let g_values_iter = chain![ circuit.paste_from_wits_in.iter().cloned(), - circuit - .paste_from_counter_in - .iter() - .map(|(_, (l, r))| (*l, *r)) + circuit.paste_from_counter_in.iter().map(|(_, (l, r))| (*l, *r)) ] - .map(|(l, r)| { - (l..r) - .map(|i| self.eq_y_ry[i] * self.eq_x1_rx1[i - l]) - .sum::() - }); + .map(|(l, r)| (l..r).map(|i| self.eq_y_ry[i] * self.eq_x1_rx1[i - l]).sum::()); // TODO: Double check here. let f_counter_values = circuit @@ -99,17 +89,11 @@ impl IOPVerifierState { .map(|(num_vars, _)| { let point = [&claim_point[..*num_vars], hi_point].concat(); counter_eval(num_vars + hi_num_vars, &point) - * claim_point[*num_vars..] - .iter() - .map(|x| E::ONE - *x) - .product::() + * claim_point[*num_vars..].iter().map(|x| E::ONE - *x).product::() }) .collect_vec(); let got_value = izip!( - chain![ - step_msg.sumcheck_eval_values.iter(), - f_counter_values.iter() - ], + chain![step_msg.sumcheck_eval_values.iter(), f_counter_values.iter()], g_values_iter ) .map(|(f, g)| *f * g) @@ -132,8 +116,7 @@ impl IOPVerifierState { PointAndEval::new_from_ref(&point, &wit_in_eval) }) .collect_vec(); - self.to_next_step_point_and_eval = - PointAndEval::new([&claim_point, hi_point].concat(), E::ZERO); + self.to_next_step_point_and_eval = PointAndEval::new([&claim_point, hi_point].concat(), E::ZERO); end_timer!(timer); if claim.expected_evaluation != got_value { diff --git a/gkr/src/verifier/phase2_linear.rs b/gkr/src/verifier/phase2_linear.rs index 6ff507937..b0a1631f1 100644 --- a/gkr/src/verifier/phase2_linear.rs +++ b/gkr/src/verifier/phase2_linear.rs @@ -32,11 +32,8 @@ impl IOPVerifierState { self.eq_y_ry = build_eq_x_r_vec(lo_point); // sigma = layers[i](rt || ry) - add_const(ry), - let sumcheck_sigma = self.to_next_step_point_and_eval.eval - - layer - .add_consts - .as_slice() - .eval(&self.eq_y_ry, &self.challenges); + let sumcheck_sigma = + self.to_next_step_point_and_eval.eval - layer.add_consts.as_slice().eval(&self.eq_y_ry, &self.challenges); // Sumcheck 1: sigma = \sum_{x1} f1(x1) * g1(x1) + \sum_j f1'_j(x1) * g1'_j(x1) // sigma = layers[i](rt || ry) - add_const(ry), @@ -58,21 +55,20 @@ impl IOPVerifierState { self.eq_x1_rx1 = build_eq_x_r_vec(&claim1_point[..lo_in_num_vars]); let g1_values_iter = chain![ - iter::once(layer.adds.as_slice().eval( - &self.eq_y_ry, - &self.eq_x1_rx1, - &self.challenges - )), - layer.paste_from.iter().map(|(_, paste_from)| { - paste_from + iter::once( + layer + .adds .as_slice() - .eval_col_first(&self.eq_y_ry, &self.eq_x1_rx1) - }) + .eval(&self.eq_y_ry, &self.eq_x1_rx1, &self.challenges) + ), + layer + .paste_from + .iter() + .map(|(_, paste_from)| { paste_from.as_slice().eval_col_first(&self.eq_y_ry, &self.eq_x1_rx1) }) ]; let f1_values = &step_msg.sumcheck_eval_values; - let got_value_1 = - izip!(f1_values.iter(), g1_values_iter).fold(E::ZERO, |acc, (&f1, g1)| acc + f1 * g1); + let got_value_1 = izip!(f1_values.iter(), g1_values_iter).fold(E::ZERO, |acc, (&f1, g1)| acc + f1 * g1); end_timer!(timer); if claim_1.expected_evaluation != got_value_1 { @@ -80,16 +76,11 @@ impl IOPVerifierState { } let new_point = [&claim1_point, &self.out_point[lo_out_num_vars..]].concat(); - self.to_next_phase_point_and_evals = - vec![PointAndEval::new_from_ref(&new_point, &f1_values[0])]; - izip!(layer.paste_from.iter(), f1_values.iter().skip(1)).for_each( - |((&old_layer_id, _), &subset_value)| { - self.subset_point_and_evals[old_layer_id as usize].push(( - self.layer_id, - PointAndEval::new_from_ref(&new_point, &subset_value), - )); - }, - ); + self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref(&new_point, &f1_values[0])]; + izip!(layer.paste_from.iter(), f1_values.iter().skip(1)).for_each(|((&old_layer_id, _), &subset_value)| { + self.subset_point_and_evals[old_layer_id as usize] + .push((self.layer_id, PointAndEval::new_from_ref(&new_point, &subset_value))); + }); self.to_next_step_point_and_eval = self.to_next_phase_point_and_evals[0].clone(); Ok(()) diff --git a/mpcs/src/prover.rs b/mpcs/src/prover.rs index b26d9e757..19671a504 100644 --- a/mpcs/src/prover.rs +++ b/mpcs/src/prover.rs @@ -5,17 +5,11 @@ use crate::structs::{Commitment, PCSProof, PCSProverState}; #[allow(unused)] impl PCSProverState { - pub fn prove( - polys: &[(Commitment, ArcDenseMultilinearExtension)], - eval_point: &[E], - ) -> PCSProof { + pub fn prove(polys: &[(Commitment, ArcDenseMultilinearExtension)], eval_point: &[E]) -> PCSProof { todo!() } - pub(crate) fn prover_init( - polys: &[(Commitment, ArcDenseMultilinearExtension)], - eval_point: &[E], - ) -> Self { + pub(crate) fn prover_init(polys: &[(Commitment, ArcDenseMultilinearExtension)], eval_point: &[E]) -> Self { todo!() } } diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 4b8a9baec..618a887a7 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -94,11 +94,7 @@ impl DenseMultilinearExtension { /// Returns an error if the MLE length does not match the point. pub fn evaluate(&self, point: &[E]) -> E { // TODO: return error. - assert_eq!( - self.num_vars, - point.len(), - "MLE size does not match the point" - ); + assert_eq!(self.num_vars, point.len(), "MLE size does not match the point"); let mle = self.fix_variables_parallel(point); op_mle!(mle, |f| f[0], |v| E::from(v)) } @@ -107,10 +103,7 @@ impl DenseMultilinearExtension { /// `partial_point.len()` variables at `partial_point`. pub fn fix_variables(&self, partial_point: &[E]) -> Self { // TODO: return error. - assert!( - partial_point.len() <= self.num_vars, - "invalid size of partial point" - ); + assert!(partial_point.len() <= self.num_vars, "invalid size of partial point"); let mut poly = Cow::Borrowed(self); // evaluate single variable of partial point from left to right @@ -158,8 +151,7 @@ impl DenseMultilinearExtension { } FieldType::Ext(evaluations) => { (0..evaluations.len()).step_by(2).for_each(|b| { - evaluations[b >> 1] = - evaluations[b] + (evaluations[b + 1] - evaluations[b]) * point + evaluations[b >> 1] = evaluations[b] + (evaluations[b + 1] - evaluations[b]) * point }); } FieldType::Unreachable => unreachable!(), @@ -180,10 +172,7 @@ impl DenseMultilinearExtension { /// `partial_point.len()` variables at `partial_point` from high position pub fn fix_high_variables(&self, partial_point: &[E]) -> Self { // TODO: return error. - assert!( - partial_point.len() <= self.num_vars, - "invalid size of partial point" - ); + assert!(partial_point.len() <= self.num_vars, "invalid size of partial point"); let current_eval_size = self.evaluations.len(); let mut poly = Cow::Borrowed(self); // `Cow` type here to skip first evaluation vector copy @@ -213,10 +202,7 @@ impl DenseMultilinearExtension { /// `partial_point.len()` variables at `partial_point` from high position in place pub fn fix_high_variables_in_place(&mut self, partial_point: &[E]) { // TODO: return error. - assert!( - partial_point.len() <= self.num_vars, - "invalid size of partial point" - ); + assert!(partial_point.len() <= self.num_vars, "invalid size of partial point"); let nv = self.num_vars; let mut current_eval_size = self.evaluations.len(); for point in partial_point.iter().rev() { @@ -256,9 +242,7 @@ impl DenseMultilinearExtension { /// Generate a random evaluation of a multilinear poly pub fn random(nv: usize, mut rng: &mut impl RngCore) -> Self { - let eval = (0..1 << nv) - .map(|_| E::BaseField::random(&mut rng)) - .collect(); + let eval = (0..1 << nv).map(|_| E::BaseField::random(&mut rng)).collect(); DenseMultilinearExtension::from_evaluations_vec(nv, eval) } @@ -400,10 +384,7 @@ impl DenseMultilinearExtension { /// `partial_point.len()` variables at `partial_point`. pub fn fix_variables_parallel(&self, partial_point: &[E]) -> Self { // TODO: return error. - assert!( - partial_point.len() <= self.num_vars, - "invalid size of partial point" - ); + assert!(partial_point.len() <= self.num_vars, "invalid size of partial point"); let mut poly = Cow::Borrowed(self); // evaluate single variable of partial point from left to right diff --git a/multilinear_extensions/src/test.rs b/multilinear_extensions/src/test.rs index 10bb599f7..73a38eb0d 100644 --- a/multilinear_extensions/src/test.rs +++ b/multilinear_extensions/src/test.rs @@ -92,8 +92,7 @@ fn test_fix_high_variables() { let result1 = poly.fix_high_variables(&partial_point[1..]); assert_eq!(result1, expected1); - let expected2 = - DenseMultilinearExtension::from_evaluations_ext_vec(1, vec![-E::from(23), E::from(139)]); + let expected2 = DenseMultilinearExtension::from_evaluations_ext_vec(1, vec![-E::from(23), E::from(139)]); let result2 = poly.fix_high_variables(&partial_point); assert_eq!(result2, expected2); } @@ -126,8 +125,7 @@ fn build_eq_x_r_for_test(r: &[E]) -> ArcDenseMultilinearExten let mut current_eval = E::ONE; let bit_sequence = bit_decompose(i, num_var); - for (&bit, (ri, one_minus_ri)) in bit_sequence.iter().zip(r.iter().zip(one_minus_r.iter())) - { + for (&bit, (ri, one_minus_ri)) in bit_sequence.iter().zip(r.iter().zip(one_minus_r.iter())) { current_eval *= if bit { *ri } else { *one_minus_ri }; } eval.push(current_eval); diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index eb3edac92..957f793d0 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -113,11 +113,7 @@ impl VirtualPolynomial { /// /// The MLEs will be multiplied together, and then multiplied by the scalar /// `coefficient`. - pub fn add_mle_list( - &mut self, - mle_list: Vec>, - coefficient: E::BaseField, - ) { + pub fn add_mle_list(&mut self, mle_list: Vec>, coefficient: E::BaseField) { let mle_list: Vec> = mle_list.into_iter().collect(); let mut indexed_product = Vec::with_capacity(mle_list.len()); @@ -211,11 +207,7 @@ impl VirtualPolynomial { point.len() ); - let evals: Vec = self - .flattened_ml_extensions - .iter() - .map(|x| x.evaluate(point)) - .collect(); + let evals: Vec = self.flattened_ml_extensions.iter().map(|x| x.evaluate(point)).collect(); let res = self .products @@ -239,10 +231,8 @@ impl VirtualPolynomial { let mut sum = E::ZERO; let mut poly = VirtualPolynomial::new(nv); for _ in 0..num_products { - let num_multiplicands = - rng.gen_range(num_multiplicands_range.0..num_multiplicands_range.1); - let (product, product_sum) = - DenseMultilinearExtension::random_mle_list(nv, num_multiplicands, rng); + let num_multiplicands = rng.gen_range(num_multiplicands_range.0..num_multiplicands_range.1); + let (product, product_sum) = DenseMultilinearExtension::random_mle_list(nv, num_multiplicands, rng); let coefficient = E::BaseField::random(&mut rng); poly.add_mle_list(product, coefficient); sum += product_sum * coefficient; @@ -262,10 +252,8 @@ impl VirtualPolynomial { ) -> Self { let mut poly = VirtualPolynomial::new(nv); for _ in 0..num_products { - let num_multiplicands = - rng.gen_range(num_multiplicands_range.0..num_multiplicands_range.1); - let product = - DenseMultilinearExtension::random_zero_mle_list(nv, num_multiplicands, rng); + let num_multiplicands = rng.gen_range(num_multiplicands_range.0..num_multiplicands_range.1); + let product = DenseMultilinearExtension::random_zero_mle_list(nv, num_multiplicands, rng); let coefficient = E::BaseField::random(&mut rng); poly.add_mle_list(product, coefficient); } diff --git a/rustfmt.toml b/rustfmt.toml index 835c6b277..af77e46a5 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -3,7 +3,7 @@ edition = "2021" wrap_comments = false comment_width = 300 imports_granularity = "Crate" -max_width = 100 +max_width = 120 newline_style = "Unix" normalize_comments = true reorder_imports = true diff --git a/simple-frontend/examples/poseidon.rs b/simple-frontend/examples/poseidon.rs index b06b73725..09dd2aec0 100644 --- a/simple-frontend/examples/poseidon.rs +++ b/simple-frontend/examples/poseidon.rs @@ -7,9 +7,7 @@ use mock_constant::{poseidon_c, poseidon_m, poseidon_p, poseidon_s}; use simple_frontend::structs::{CellId, CircuitBuilder}; // round constant -const N_ROUNDS_P: [usize; 16] = [ - 56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68, -]; +const N_ROUNDS_P: [usize; 16] = [56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68]; // template Sigma() { // signal input in; @@ -185,15 +183,9 @@ fn poseidon_ex( let c = poseidon_c::(t); let s = poseidon_s::(t); let m = poseidon_m::(t); - let m_slices = m - .iter() - .map(|row| row.as_slice()) - .collect::>(); + let m_slices = m.iter().map(|row| row.as_slice()).collect::>(); let p = poseidon_p::(t); - let p_slices = p - .iter() - .map(|row| row.as_slice()) - .collect::>(); + let p_slices = p.iter().map(|row| row.as_slice()).collect::>(); // component ark[nRoundsF]; // component sigmaF[nRoundsF][t]; @@ -257,11 +249,7 @@ fn poseidon_ex( for r in 0..(n_rounds_f / 2 - 1) { for j in 0..t { - sigma_f_in[r][j] = if r == 0 { - ark_out[0][j] - } else { - mix_out[r - 1][j] - }; + sigma_f_in[r][j] = if r == 0 { ark_out[0][j] } else { mix_out[r - 1][j] }; sigma_f_out[r][j] = sigma(circuit_builder, sigma_f_in[r][j]); } @@ -281,8 +269,7 @@ fn poseidon_ex( for j in 0..t { sigma_f_in[n_rounds_f / 2 - 1][j] = mix_out[n_rounds_f / 2 - 2][j]; - sigma_f_out[n_rounds_f / 2 - 1][j] = - sigma(circuit_builder, sigma_f_in[n_rounds_f / 2 - 1][j]); + sigma_f_out[n_rounds_f / 2 - 1][j] = sigma(circuit_builder, sigma_f_in[n_rounds_f / 2 - 1][j]); } // ark[nRoundsF\2] = Ark(t, C, (nRoundsF\2)*t ); @@ -293,12 +280,7 @@ fn poseidon_ex( for j in 0..t { ark_in[n_rounds_f / 2].push(sigma_f_out[n_rounds_f / 2 - 1][j]); } - ark_out[n_rounds_f / 2] = ark( - circuit_builder, - &ark_in[n_rounds_f / 2], - &c, - n_rounds_f / 2 * t, - ); + ark_out[n_rounds_f / 2] = ark(circuit_builder, &ark_in[n_rounds_f / 2], &c, n_rounds_f / 2 * t); // mix[nRoundsF\2-1] = Mix(t,P); // for (var j=0; j( } else { mix_out[n_rounds_f / 2 + r - 1][j] }; - sigma_f_out[n_rounds_f / 2 + r][j] = - sigma(circuit_builder, sigma_f_in[n_rounds_f / 2 + r][j]); + sigma_f_out[n_rounds_f / 2 + r][j] = sigma(circuit_builder, sigma_f_in[n_rounds_f / 2 + r][j]); } for j in 0..t { diff --git a/simple-frontend/src/circuit_builder.rs b/simple-frontend/src/circuit_builder.rs index 22d5ac6b5..cb06de1a9 100644 --- a/simple-frontend/src/circuit_builder.rs +++ b/simple-frontend/src/circuit_builder.rs @@ -147,11 +147,7 @@ mod tests { let mut circuit_builder = CircuitBuilder::::new(); let (_, input_cells) = circuit_builder.create_witness_in(4); let zero_cells = circuit_builder.create_cells(2); - let leaves = input_cells - .iter() - .chain(zero_cells.iter()) - .cloned() - .collect_vec(); + let leaves = input_cells.iter().chain(zero_cells.iter()).cloned().collect_vec(); let inners = circuit_builder.create_cells(2); circuit_builder.mul3(inners[0], leaves[0], leaves[1], leaves[2], Goldilocks::ONE); circuit_builder.mul3(inners[1], leaves[3], leaves[4], leaves[5], Goldilocks::ONE); @@ -174,11 +170,7 @@ mod tests { circuit_builder.add_const(const_cells[0], Goldilocks::ONE); circuit_builder.add_const(const_cells[1], Goldilocks::ONE); - let leaves = input_cells - .iter() - .chain(const_cells.iter()) - .cloned() - .collect_vec(); + let leaves = input_cells.iter().chain(const_cells.iter()).cloned().collect_vec(); let inners = circuit_builder.create_cells(2); circuit_builder.mul3(inners[0], leaves[0], leaves[1], leaves[2], Goldilocks::ONE); circuit_builder.mul3(inners[1], leaves[3], leaves[4], leaves[5], Goldilocks::ONE); diff --git a/simple-frontend/src/circuit_builder/base_opt.rs b/simple-frontend/src/circuit_builder/base_opt.rs index 2970acf5a..eb10c34c7 100644 --- a/simple-frontend/src/circuit_builder/base_opt.rs +++ b/simple-frontend/src/circuit_builder/base_opt.rs @@ -2,8 +2,7 @@ use ff::Field; use ff_ext::ExtensionField; use crate::structs::{ - Cell, CellId, CellType, CircuitBuilder, ConstantType, GateType, InType, MixedCell, OutType, - WitnessId, + Cell, CellId, CellType, CircuitBuilder, ConstantType, GateType, InType, MixedCell, OutType, WitnessId, }; impl CircuitBuilder { @@ -26,10 +25,7 @@ impl CircuitBuilder { pub fn create_witness_in(&mut self, num: usize) -> (WitnessId, Vec) { let cell = self.create_cells(num); - self.mark_cells( - CellType::In(InType::Witness(self.n_witness_in as WitnessId)), - &cell, - ); + self.mark_cells(CellType::In(InType::Witness(self.n_witness_in as WitnessId)), &cell); self.n_witness_in += 1; ((self.n_witness_in - 1) as WitnessId, cell) } @@ -51,19 +47,13 @@ impl CircuitBuilder { pub fn create_witness_out(&mut self, num: usize) -> (WitnessId, Vec) { let cell = self.create_cells(num); - self.mark_cells( - CellType::Out(OutType::Witness(self.n_witness_out as WitnessId)), - &cell, - ); + self.mark_cells(CellType::Out(OutType::Witness(self.n_witness_out as WitnessId)), &cell); self.n_witness_out += 1; ((self.n_witness_out - 1) as WitnessId, cell) } pub fn create_witness_out_from_cells(&mut self, cells: &[CellId]) -> WitnessId { - self.mark_cells( - CellType::Out(OutType::Witness(self.n_witness_out as WitnessId)), - &cells, - ); + self.mark_cells(CellType::Out(OutType::Witness(self.n_witness_out as WitnessId)), &cells); self.n_witness_out += 1; (self.n_witness_out - 1) as WitnessId } @@ -103,25 +93,12 @@ impl CircuitBuilder { self.mul2_internal(out, in_0, in_1, ConstantType::Field(scalar)); } - pub(crate) fn mul2_internal( - &mut self, - out: CellId, - in_0: CellId, - in_1: CellId, - scalar: ConstantType, - ) { + pub(crate) fn mul2_internal(&mut self, out: CellId, in_0: CellId, in_1: CellId, scalar: ConstantType) { let out_cell = &mut self.cells[out]; out_cell.gates.push(GateType::mul2(in_0, in_1, scalar)); } - pub fn mul3( - &mut self, - out: CellId, - in_0: CellId, - in_1: CellId, - in_2: CellId, - scalar: Ext::BaseField, - ) { + pub fn mul3(&mut self, out: CellId, in_0: CellId, in_1: CellId, in_2: CellId, scalar: Ext::BaseField) { if scalar == Ext::BaseField::ZERO { return; } @@ -137,9 +114,7 @@ impl CircuitBuilder { scalar: ConstantType, ) { let out_cell = &mut self.cells[out]; - out_cell - .gates - .push(GateType::mul3(in_0, in_1, in_2, scalar)); + out_cell.gates.push(GateType::mul3(in_0, in_1, in_2, scalar)); } pub fn assert_const(&mut self, out: CellId, constant: i64) { @@ -180,13 +155,7 @@ impl CircuitBuilder { self.add(out, in_0, Ext::BaseField::ONE); } - pub fn sel_mixed( - &mut self, - out: CellId, - in_0: MixedCell, - in_1: MixedCell, - cond: CellId, - ) { + pub fn sel_mixed(&mut self, out: CellId, in_0: MixedCell, in_1: MixedCell, cond: CellId) { // (1 - cond) * in_0 + cond * in_1 = (in_1 - in_0) * cond + in_0 match (in_0, in_1) { (MixedCell::Constant(in_0), MixedCell::Constant(in_1)) => { diff --git a/simple-frontend/src/circuit_builder/derives.rs b/simple-frontend/src/circuit_builder/derives.rs index a15170d0c..de0da745e 100644 --- a/simple-frontend/src/circuit_builder/derives.rs +++ b/simple-frontend/src/circuit_builder/derives.rs @@ -20,11 +20,7 @@ macro_rules! rlc_base_term { }; ($builder:ident, $n_ext:expr, $out:expr, $in_0:expr; $c:expr, $scalar:expr) => { for j in 0..$n_ext { - $builder.add_internal( - $out[j], - $in_0, - ConstantType::ChallengeScaled($c, j, $scalar), - ); + $builder.add_internal($out[j], $in_0, ConstantType::ChallengeScaled($c, j, $scalar)); } }; } diff --git a/simple-frontend/src/circuit_builder/ext_opt.rs b/simple-frontend/src/circuit_builder/ext_opt.rs index f51012123..dec837be1 100644 --- a/simple-frontend/src/circuit_builder/ext_opt.rs +++ b/simple-frontend/src/circuit_builder/ext_opt.rs @@ -6,8 +6,8 @@ use std::marker::PhantomData; use crate::{ rlc_base_term, rlc_const_term, structs::{ - CellId, CellType, ChallengeConst, ChallengeId, CircuitBuilder, ConstantType, ExtCellId, - InType, MixedCell, OutType, WitnessId, + CellId, CellType, ChallengeConst, ChallengeId, CircuitBuilder, ConstantType, ExtCellId, InType, MixedCell, + OutType, WitnessId, }, }; @@ -70,10 +70,7 @@ impl CircuitBuilder { pub fn create_ext_witness_in(&mut self, num: usize) -> (WitnessId, Vec>) { let cells = self.create_cells(num * ::DEGREE); - self.mark_cells( - CellType::In(InType::Witness(self.n_witness_in as WitnessId)), - &cells, - ); + self.mark_cells(CellType::In(InType::Witness(self.n_witness_in as WitnessId)), &cells); self.n_witness_in += 1; ( (self.n_witness_in - 1) as WitnessId, @@ -89,10 +86,7 @@ impl CircuitBuilder { let cells = self.create_ext_cells(num); cells.iter().for_each(|ext_cell| { // first base field is the constant - self.mark_cells( - CellType::In(InType::Constant(constant)), - &[ext_cell.cells[0]], - ); + self.mark_cells(CellType::In(InType::Constant(constant)), &[ext_cell.cells[0]]); // the rest fields are 0s self.mark_cells(CellType::In(InType::Constant(0)), &ext_cell.cells[1..]); }); @@ -101,10 +95,7 @@ impl CircuitBuilder { pub fn create_ext_witness_out(&mut self, num: usize) -> (WitnessId, Vec>) { let cells = self.create_cells(num * ::DEGREE); - self.mark_cells( - CellType::Out(OutType::Witness(self.n_witness_out as WitnessId)), - &cells, - ); + self.mark_cells(CellType::Out(OutType::Witness(self.n_witness_out as WitnessId)), &cells); self.n_witness_out += 1; ( (self.n_witness_out - 1) as WitnessId, @@ -144,14 +135,9 @@ impl CircuitBuilder { out.cells .iter() - .zip_eq( - in_0.cells.iter().zip_eq( - [*in_1].iter().chain( - std::iter::repeat(&MixedCell::Constant(Ext::BaseField::ZERO)) - .take(::DEGREE - 1), - ), - ), - ) + .zip_eq(in_0.cells.iter().zip_eq([*in_1].iter().chain( + std::iter::repeat(&MixedCell::Constant(Ext::BaseField::ZERO)).take(::DEGREE - 1), + ))) .for_each(|(&out, (&in0, &in1))| self.sel_mixed(out, in0.into(), in1, cond)); } @@ -170,25 +156,14 @@ impl CircuitBuilder { out.cells .iter() - .zip_eq( - in_1.cells.iter().zip_eq( - [*in_0].iter().chain( - std::iter::repeat(&MixedCell::Constant(Ext::BaseField::ZERO)) - .take(::DEGREE - 1), - ), - ), - ) + .zip_eq(in_1.cells.iter().zip_eq([*in_0].iter().chain( + std::iter::repeat(&MixedCell::Constant(Ext::BaseField::ZERO)).take(::DEGREE - 1), + ))) .for_each(|(&out, (&in1, &in0))| self.sel_mixed(out, in0, in1.into(), cond)); } /// Base on the condition, select extension cells in_0 or in_1 - pub fn sel_ext( - &mut self, - out: &ExtCellId, - in_0: &ExtCellId, - in_1: &ExtCellId, - cond: CellId, - ) { + pub fn sel_ext(&mut self, out: &ExtCellId, in_0: &ExtCellId, in_1: &ExtCellId, cond: CellId) { // we only need to check one degree since the rest are // enforced by zip_eq assert_eq!(out.degree(), ::DEGREE); @@ -215,13 +190,7 @@ impl CircuitBuilder { /// Constrain /// - out[i] += in_0[i] * in_1 * scalar for i in 0..DEGREE-1 - pub fn mul_ext_base( - &mut self, - out: &ExtCellId, - in_0: &ExtCellId, - in_1: CellId, - scalar: Ext::BaseField, - ) { + pub fn mul_ext_base(&mut self, out: &ExtCellId, in_0: &ExtCellId, in_1: CellId, scalar: Ext::BaseField) { assert_eq!(out.degree(), ::DEGREE); out.cells .iter() @@ -250,12 +219,7 @@ impl CircuitBuilder { } /// Constrain out += in_0 * c - pub fn add_product_of_ext_and_challenge( - &mut self, - out: &ExtCellId, - in_0: &ExtCellId, - c: ChallengeConst, - ) { + pub fn add_product_of_ext_and_challenge(&mut self, out: &ExtCellId, in_0: &ExtCellId, c: ChallengeConst) { assert_eq!(out.degree(), ::DEGREE); assert_eq!(in_0.degree(), ::DEGREE); match ::DEGREE { @@ -289,12 +253,7 @@ impl CircuitBuilder { /// Compute the random linear combination of `in_array` by challenge. /// out = \sum_{i = 0}^{in_array.len()} challenge^i * in_array[i] + challenge^{in_array.len()}. - pub fn rlc_ext( - &mut self, - out: &ExtCellId, - in_array: &[ExtCellId], - challenge: ChallengeId, - ) { + pub fn rlc_ext(&mut self, out: &ExtCellId, in_array: &[ExtCellId], challenge: ChallengeId) { assert_eq!(out.degree(), ::DEGREE); match ::DEGREE { 2 => self.rlc_ext_2(out, in_array, challenge), @@ -305,12 +264,7 @@ impl CircuitBuilder { /// Compute the random linear combination of `in_array` with mixed types by challenge. /// out = \sum_{i = 0}^{in_array.len()} challenge^i * (\sum_j in_array[i][j]) + challenge^{in_array.len()}. - pub fn rlc_mixed( - &mut self, - out: &ExtCellId, - in_array: &[MixedCell], - challenge: ChallengeId, - ) { + pub fn rlc_mixed(&mut self, out: &ExtCellId, in_array: &[MixedCell], challenge: ChallengeId) { assert_eq!(out.degree(), ::DEGREE); for (i, item) in in_array.iter().enumerate() { let c: ChallengeConst = ChallengeConst { @@ -346,13 +300,7 @@ impl CircuitBuilder { /// let a2b2 = a.0[1] * b.0[1]; /// let c1 = a1b1 + Goldilocks(7) * a2b2; /// let c2 = a2b1 + a1b2; - fn mul2_degree_2_ext_internal( - &mut self, - out: &[CellId], - in_0: &[CellId], - in_1: &[CellId], - scalar: Ext::BaseField, - ) { + fn mul2_degree_2_ext_internal(&mut self, out: &[CellId], in_0: &[CellId], in_1: &[CellId], scalar: Ext::BaseField) { let a0b0 = self.create_cell(); self.mul2(a0b0, in_0[0], in_1[0], Ext::BaseField::ONE); let a0b1 = self.create_cell(); @@ -384,12 +332,7 @@ impl CircuitBuilder { } /// Random linear combinations for extension cells with degree = 2 - fn rlc_ext_2( - &mut self, - out: &ExtCellId, - in_array: &[ExtCellId], - challenge: ChallengeId, - ) { + fn rlc_ext_2(&mut self, out: &ExtCellId, in_array: &[ExtCellId], challenge: ChallengeId) { assert_eq!(out.degree(), ::DEGREE); for (i, item) in in_array.iter().enumerate() { let c = ChallengeConst { @@ -406,12 +349,7 @@ impl CircuitBuilder { } /// Random linear combinations for extension cells with degree = 3 - fn rlc_ext_3( - &mut self, - out: &ExtCellId, - in_array: &[ExtCellId], - challenge: ChallengeId, - ) { + fn rlc_ext_3(&mut self, out: &ExtCellId, in_array: &[ExtCellId], challenge: ChallengeId) { assert_eq!(out.degree(), 3); for (i, item) in in_array.iter().enumerate() { let c = ChallengeConst { @@ -440,13 +378,7 @@ impl CircuitBuilder { /// let c2 = a2b1 + a1b2 + a2b3 + a3b2 + a3b3; /// let c3 = a3b1 + a2b2 + a1b3 + a3b3; /// GoldilocksExt3([c1, c2, c3]) - fn mul2_degree_3_ext_internal( - &mut self, - out: &[CellId], - in_0: &[CellId], - in_1: &[CellId], - scalar: Ext::BaseField, - ) { + fn mul2_degree_3_ext_internal(&mut self, out: &[CellId], in_0: &[CellId], in_1: &[CellId], scalar: Ext::BaseField) { let a0b0 = self.create_cell(); self.mul2(a0b0, in_0[0], in_1[0], Ext::BaseField::ONE); let a0b1 = self.create_cell(); diff --git a/singer-pro/examples/simple.rs b/singer-pro/examples/simple.rs index 58fa4cdb4..2b70b8dc7 100644 --- a/singer-pro/examples/simple.rs +++ b/singer-pro/examples/simple.rs @@ -10,16 +10,12 @@ use transcript::Transcript; fn main() { let chip_challenges = ChipChallenges::default(); - let circuit_builder = SingerInstCircuitBuilder::::new(chip_challenges) - .expect("circuit builder failed"); + let circuit_builder = + SingerInstCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let bytecode = vec![vec![0x60 as u8, 0x01, 0x50]]; - let singer_builder = SingerGraphBuilder::::new( - circuit_builder.clone(), - &bytecode, - chip_challenges, - ) - .expect("graph builder failed"); + let singer_builder = SingerGraphBuilder::::new(circuit_builder.clone(), &bytecode, chip_challenges) + .expect("graph builder failed"); let mut prover_transcript = Transcript::new(b"Singer pro"); @@ -52,9 +48,8 @@ fn main() { // 4. Verify. let mut verifier_transcript = Transcript::new(b"Singer pro"); - let singer_builder = - SingerGraphBuilder::::new(circuit_builder, &bytecode, chip_challenges) - .expect("graph builder failed"); + let singer_builder = SingerGraphBuilder::::new(circuit_builder, &bytecode, chip_challenges) + .expect("graph builder failed"); let circuit = singer_builder .construct_graph(&singer_aux_info) .expect("construct failed"); diff --git a/singer-pro/src/basic_block.rs b/singer-pro/src/basic_block.rs index 5842eec22..dbfaf5cec 100644 --- a/singer-pro/src/basic_block.rs +++ b/singer-pro/src/basic_block.rs @@ -9,9 +9,7 @@ use crate::{ basic_block::bb_ret::{BBReturnRestMemLoad, BBReturnRestMemStore}, component::{AccessoryCircuit, BBFinalCircuit, BBStartCircuit}, error::ZKVMError, - instructions::{ - construct_inst_graph, construct_inst_graph_and_witness, SingerInstCircuitBuilder, - }, + instructions::{construct_inst_graph, construct_inst_graph_and_witness, SingerInstCircuitBuilder}, BasicBlockWiresIn, SingerParams, }; @@ -103,10 +101,7 @@ impl SingerBasicBlockBuilder { } pub(crate) fn basic_block_bytecode(&self) -> Vec> { - self.basic_blocks - .iter() - .map(|bb| bb.bytecode.clone()) - .collect() + self.basic_blocks.iter().map(|bb| bb.bytecode.clone()).collect() } } @@ -129,11 +124,7 @@ pub struct BasicBlock { } impl BasicBlock { - pub(crate) fn new( - bytecode: &[u8], - pc_start: u64, - challenges: ChipChallenges, - ) -> Result { + pub(crate) fn new(bytecode: &[u8], pc_start: u64, challenges: ChipChallenges) -> Result { let mut stack_top = 0 as i64; let mut pc = pc_start; let mut stack_offsets = HashSet::new(); @@ -170,8 +161,7 @@ impl BasicBlock { stack_offsets.sort(); let bb_start_stack_top_offsets = stack_offsets[0..=lower_bound(&stack_offsets, 0)].to_vec(); - let bb_final_stack_top_offsets = - stack_offsets[0..=lower_bound(&stack_offsets, stack_top)].to_vec(); + let bb_final_stack_top_offsets = stack_offsets[0..=lower_bound(&stack_offsets, stack_top)].to_vec(); let info = BasicBlockInfo { pc_start, @@ -181,22 +171,18 @@ impl BasicBlock { }; let bb_start_circuit = BasicBlockStart::construct_circuit(&info, challenges)?; - let (bb_final_circuit, bb_acc_circuits) = - if bytecode.last() == Some(&(OpcodeType::RETURN as u8)) { - ( - BasicBlockReturn::construct_circuit(&info, challenges)?, - vec![ - BBReturnRestMemLoad::construct_circuit(challenges)?, - BBReturnRestMemStore::construct_circuit(challenges)?, - BBReturnRestStackPop::construct_circuit(challenges)?, - ], - ) - } else { - ( - BasicBlockFinal::construct_circuit(&info, challenges)?, - vec![], - ) - }; + let (bb_final_circuit, bb_acc_circuits) = if bytecode.last() == Some(&(OpcodeType::RETURN as u8)) { + ( + BasicBlockReturn::construct_circuit(&info, challenges)?, + vec![ + BBReturnRestMemLoad::construct_circuit(challenges)?, + BBReturnRestMemStore::construct_circuit(challenges)?, + BBReturnRestStackPop::construct_circuit(challenges)?, + ], + ) + } else { + (BasicBlockFinal::construct_circuit(&info, challenges)?, vec![]) + }; Ok(BasicBlock { bytecode: bytecode.to_vec(), @@ -245,8 +231,7 @@ impl BasicBlock { let mut to_succ = &bb_start_circuit.layout.to_succ_inst; let mut next_pc = None; - let mut local_stack = - BasicBlockStack::initialize(self.info.clone(), bb_start_node_id, to_succ); + let mut local_stack = BasicBlockStack::initialize(self.info.clone(), bb_start_node_id, to_succ); let mut pred_node_id = bb_start_node_id; // The return instruction will return the size of the public output. We @@ -260,12 +245,9 @@ impl BasicBlock { let mode = StackOpMode::from(opcode); let stack = local_stack.pop_node_outputs(mode); let memory_ts = NodeOutputType::WireOut(pred_node_id, to_succ.next_memory_ts_id); - let preds = inst_circuit.layout.input( - inst_circuit.circuit.n_witness_in, - opcode, - stack, - memory_ts, - ); + let preds = inst_circuit + .layout + .input(inst_circuit.circuit.n_witness_in, opcode, stack, memory_ts); let (node_id, stack, po) = construct_inst_graph_and_witness( opcode, graph_builder, @@ -290,17 +272,10 @@ impl BasicBlock { } let stack = local_stack.finalize(); - let stack_ts = NodeOutputType::WireOut( - bb_start_node_id, - bb_start_circuit.layout.to_bb_final.stack_ts_id, - ); + let stack_ts = NodeOutputType::WireOut(bb_start_node_id, bb_start_circuit.layout.to_bb_final.stack_ts_id); let memory_ts = NodeOutputType::WireOut(pred_node_id, to_succ.next_memory_ts_id); - let stack_top = NodeOutputType::WireOut( - bb_start_node_id, - bb_start_circuit.layout.to_bb_final.stack_top_id, - ); - let clk = - NodeOutputType::WireOut(bb_start_node_id, bb_start_circuit.layout.to_bb_final.clk_id); + let stack_top = NodeOutputType::WireOut(bb_start_node_id, bb_start_circuit.layout.to_bb_final.stack_top_id); + let clk = NodeOutputType::WireOut(bb_start_node_id, bb_start_circuit.layout.to_bb_final.clk_id); let preds = bb_final_circuit.layout.input( bb_final_circuit.circuit.n_witness_in, stack, @@ -326,11 +301,7 @@ impl BasicBlock { real_n_instances, )?; - let real_n_instances_bb_accs = vec![ - params.n_mem_finalize, - params.n_mem_initialize, - params.n_stack_finalize, - ]; + let real_n_instances_bb_accs = vec![params.n_mem_finalize, params.n_mem_initialize, params.n_stack_finalize]; for ((acc, acc_wires_in), real_n_instances) in bb_acc_circuits .iter() .zip(bb_wires_in.bb_accs.iter_mut()) @@ -384,8 +355,7 @@ impl BasicBlock { let mut to_succ = &bb_start_circuit.layout.to_succ_inst; let mut next_pc = None; - let mut local_stack = - BasicBlockStack::initialize(self.info.clone(), bb_start_node_id, to_succ); + let mut local_stack = BasicBlockStack::initialize(self.info.clone(), bb_start_node_id, to_succ); let mut pred_node_id = bb_start_node_id; // The return instruction will return the size of the public output. We @@ -398,12 +368,9 @@ impl BasicBlock { let mode = StackOpMode::from(*opcode); let stack = local_stack.pop_node_outputs(mode); let memory_ts = NodeOutputType::WireOut(pred_node_id, to_succ.next_memory_ts_id); - let preds = inst_circuit.layout.input( - inst_circuit.circuit.n_witness_in, - *opcode, - stack, - memory_ts, - ); + let preds = inst_circuit + .layout + .input(inst_circuit.circuit.n_witness_in, *opcode, stack, memory_ts); let (node_id, stack, po) = construct_inst_graph( *opcode, graph_builder, @@ -426,17 +393,10 @@ impl BasicBlock { } let stack = local_stack.finalize(); - let stack_ts = NodeOutputType::WireOut( - bb_start_node_id, - bb_start_circuit.layout.to_bb_final.stack_ts_id, - ); + let stack_ts = NodeOutputType::WireOut(bb_start_node_id, bb_start_circuit.layout.to_bb_final.stack_ts_id); let memory_ts = NodeOutputType::WireOut(pred_node_id, to_succ.next_memory_ts_id); - let stack_top = NodeOutputType::WireOut( - bb_start_node_id, - bb_start_circuit.layout.to_bb_final.stack_top_id, - ); - let clk = - NodeOutputType::WireOut(bb_start_node_id, bb_start_circuit.layout.to_bb_final.clk_id); + let stack_top = NodeOutputType::WireOut(bb_start_node_id, bb_start_circuit.layout.to_bb_final.stack_top_id); + let clk = NodeOutputType::WireOut(bb_start_node_id, bb_start_circuit.layout.to_bb_final.clk_id); let preds = bb_final_circuit.layout.input( bb_final_circuit.circuit.n_witness_in, stack, @@ -446,8 +406,7 @@ impl BasicBlock { clk, next_pc, ); - let bb_final_node_id = - graph_builder.add_node("BB final", &bb_final_circuit.circuit, preds)?; + let bb_final_node_id = graph_builder.add_node("BB final", &bb_final_circuit.circuit, preds)?; chip_builder.construct_chip_check_graph( graph_builder, bb_final_node_id, @@ -455,17 +414,10 @@ impl BasicBlock { real_n_instances, )?; - let real_n_instances_bb_accs = vec![ - params.n_mem_finalize, - params.n_mem_initialize, - params.n_stack_finalize, - ]; + let real_n_instances_bb_accs = vec![params.n_mem_finalize, params.n_mem_initialize, params.n_stack_finalize]; for (acc, real_n_instances) in bb_acc_circuits.iter().zip(real_n_instances_bb_accs) { - let acc_node_id = graph_builder.add_node( - "BB acc", - &acc.circuit, - vec![PredType::Source; acc.circuit.n_witness_in], - )?; + let acc_node_id = + graph_builder.add_node("BB acc", &acc.circuit, vec![PredType::Source; acc.circuit.n_witness_in])?; chip_builder.construct_chip_check_graph( graph_builder, acc_node_id, @@ -480,9 +432,7 @@ impl BasicBlock { #[cfg(test)] mod test { use crate::{ - basic_block::{ - bb_final::BasicBlockFinal, bb_start::BasicBlockStart, BasicBlock, BasicBlockInfo, - }, + basic_block::{bb_final::BasicBlockFinal, bb_start::BasicBlockStart, BasicBlock, BasicBlockInfo}, instructions::{add::AddInstruction, SingerInstCircuitBuilder}, scheme::GKRGraphProverState, BasicBlockWiresIn, SingerParams, @@ -506,8 +456,7 @@ mod test { #[cfg(not(debug_assertions))] fn bench_bb_helper(instance_num_vars: usize, n_adds_in_bb: usize) { let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerInstCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerInstCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let bytecode = vec![vec![OpcodeType::ADD as u8; n_adds_in_bb]]; @@ -524,10 +473,7 @@ mod test { }; let bb_witness = BasicBlockWiresIn { bb_start: vec![LayerWitness { - instances: random_matrix( - n_instances, - BasicBlockStart::phase0_size(n_adds_in_bb + 1), - ), + instances: random_matrix(n_instances, BasicBlockStart::phase0_size(n_adds_in_bb + 1)), }], bb_final: vec![ LayerWitness { @@ -564,10 +510,10 @@ mod test { bb_final_stack_top_offsets: vec![-1], delta_stack_top: -(n_adds_in_bb as i64), }; - let bb_start_circuit = BasicBlockStart::construct_circuit(&info, chip_challenges) - .expect("construct circuit failed"); - let bb_final_circuit = BasicBlockFinal::construct_circuit(&info, chip_challenges) - .expect("construct circuit failed"); + let bb_start_circuit = + BasicBlockStart::construct_circuit(&info, chip_challenges).expect("construct circuit failed"); + let bb_final_circuit = + BasicBlockFinal::construct_circuit(&info, chip_challenges).expect("construct circuit failed"); let bb = BasicBlock { bytecode: bytecode[0].clone(), info, @@ -603,8 +549,8 @@ mod test { let mut prover_transcript = &mut Transcript::new(b"Singer"); let timer = Instant::now(); - let _ = GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1) - .expect("prove failed"); + let _ = + GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1).expect("prove failed"); println!( "AddInstruction::prove, instance_num_vars = {}, time = {}s", instance_num_vars, diff --git a/singer-pro/src/basic_block/bb_final.rs b/singer-pro/src/basic_block/bb_final.rs index 3fc0e4531..f27baa34e 100644 --- a/singer-pro/src/basic_block/bb_final.rs +++ b/singer-pro/src/basic_block/bb_final.rs @@ -4,15 +4,12 @@ use gkr::structs::Circuit; use itertools::Itertools; use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ - chip_handler::{ - GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, - StackChipOperations, - }, + chip_handler::{GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, StackChipOperations}, chips::IntoEnumIterator, register_witness, structs::{ChipChallenges, InstOutChipType, PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, + uint::constants::AddSubConstants, }; use std::{collections::BTreeMap, sync::Arc}; @@ -61,12 +58,7 @@ impl BasicBlockFinal { let stack_ts = TSUInt::try_from(stack_ts)?; let stack_ts_add_witness = &phase0[Self::phase0_stack_ts_add()]; - let next_stack_ts = rom_handler.add_ts_with_const( - &mut circuit_builder, - &stack_ts, - 1, - stack_ts_add_witness, - )?; + let next_stack_ts = rom_handler.add_ts_with_const(&mut circuit_builder, &stack_ts, 1, stack_ts_add_witness)?; let (memory_ts_id, memory_ts) = circuit_builder.create_witness_in(TSUInt::N_OPERAND_CELLS); let stack_top_expr = MixedCell::Cell(stack_top[0]); @@ -83,8 +75,7 @@ impl BasicBlockFinal { // Check the of stack_top + offset. let stack_top_l = stack_top_expr.add(i64_to_base_field::(stack_top_offsets[0])); rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_l)?; - let stack_top_r = - stack_top_expr.add(i64_to_base_field::(stack_top_offsets[n_stack_items - 1])); + let stack_top_r = stack_top_expr.add(i64_to_base_field::(stack_top_offsets[n_stack_items - 1])); rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_r)?; // From predesessor instruction diff --git a/singer-pro/src/basic_block/bb_ret.rs b/singer-pro/src/basic_block/bb_ret.rs index e351fca36..f5c0f5347 100644 --- a/singer-pro/src/basic_block/bb_ret.rs +++ b/singer-pro/src/basic_block/bb_ret.rs @@ -13,8 +13,7 @@ use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{ - AccessoryCircuit, AccessoryLayout, BBFinalCircuit, BBFinalLayout, FromBBStart, - FromPredInst, FromWitness, + AccessoryCircuit, AccessoryLayout, BBFinalCircuit, BBFinalLayout, FromBBStart, FromPredInst, FromWitness, }, error::ZKVMError, utils::i64_to_base_field, @@ -57,8 +56,7 @@ impl BasicBlockReturn { let stack_top_expr = MixedCell::Cell(stack_top[0]); let stack_top_l = stack_top_expr.add(i64_to_base_field::(stack_top_offsets[0])); rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_l)?; - let stack_top_r = - stack_top_expr.add(i64_to_base_field::(stack_top_offsets[n_stack_items - 1])); + let stack_top_r = stack_top_expr.add(i64_to_base_field::(stack_top_offsets[n_stack_items - 1])); rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_r)?; // From predesessor instruction @@ -223,12 +221,7 @@ impl BBReturnRestStackPop { let stack_top = circuit_builder.create_counter_in(0); let stack_values = &phase0[Self::phase0_stack_values()]; let old_stack_ts = &phase0[Self::phase0_old_stack_ts()]; - ram_handler.stack_pop( - &mut circuit_builder, - stack_top[0].into(), - old_stack_ts, - stack_values, - ); + ram_handler.stack_pop(&mut circuit_builder, stack_top[0].into(), old_stack_ts, stack_values); let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); circuit_builder.configure(); diff --git a/singer-pro/src/basic_block/bb_start.rs b/singer-pro/src/basic_block/bb_start.rs index 2616b9076..cf6966a18 100644 --- a/singer-pro/src/basic_block/bb_start.rs +++ b/singer-pro/src/basic_block/bb_start.rs @@ -3,15 +3,12 @@ use ff_ext::ExtensionField; use gkr::structs::Circuit; use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ - chip_handler::{ - GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, - StackChipOperations, - }, + chip_handler::{GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, StackChipOperations}, chips::IntoEnumIterator, register_multi_witness, structs::{ChipChallenges, InstOutChipType, PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, + uint::constants::AddSubConstants, }; use std::sync::Arc; @@ -49,8 +46,7 @@ impl BasicBlockStart { let n_stack_items = stack_top_offsets.len(); // From witness - let (phase0_wire_id, phase0) = - circuit_builder.create_witness_in(Self::phase0_size(n_stack_items)); + let (phase0_wire_id, phase0) = circuit_builder.create_witness_in(Self::phase0_size(n_stack_items)); let mut ram_handler = RAMHandler::new(&challenges); let mut rom_handler = ROMHandler::new(&challenges); @@ -62,27 +58,18 @@ impl BasicBlockStart { let stack_top = phase0[Self::phase0_stack_top(n_stack_items).start]; let stack_top_expr = MixedCell::Cell(stack_top); let clk = phase0[Self::phase0_clk(n_stack_items).start]; - ram_handler.state_in( - &mut circuit_builder, - pc, - stack_ts, - memory_ts, - stack_top, - clk, - ); + ram_handler.state_in(&mut circuit_builder, pc, stack_ts, memory_ts, stack_top, clk); // Check the of stack_top + offset. let stack_top_l = stack_top_expr.add(i64_to_base_field::(stack_top_offsets[0])); rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_l)?; - let stack_top_r = - stack_top_expr.add(i64_to_base_field::(stack_top_offsets[n_stack_items - 1])); + let stack_top_r = stack_top_expr.add(i64_to_base_field::(stack_top_offsets[n_stack_items - 1])); rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_r)?; // pop all elements from the stack. let stack_ts = TSUInt::try_from(stack_ts)?; for (i, offset) in stack_top_offsets.iter().enumerate() { - let old_stack_ts = - TSUInt::try_from(&phase0[Self::phase0_old_stack_ts(i, n_stack_items)])?; + let old_stack_ts = TSUInt::try_from(&phase0[Self::phase0_old_stack_ts(i, n_stack_items)])?; TSUInt::assert_lt( &mut circuit_builder, &mut rom_handler, @@ -101,8 +88,7 @@ impl BasicBlockStart { // To successor instruction let mut stack_result_ids = Vec::with_capacity(n_stack_items); for i in 0..n_stack_items { - let (stack_operand_id, stack_operand) = - circuit_builder.create_witness_out(StackUInt::N_OPERAND_CELLS); + let (stack_operand_id, stack_operand) = circuit_builder.create_witness_out(StackUInt::N_OPERAND_CELLS); let old_stack = &phase0[Self::phase0_old_stack_values(i, n_stack_items)]; for j in 0..StackUInt::N_OPERAND_CELLS { circuit_builder.add(stack_operand[j], old_stack[j], E::BaseField::ONE); @@ -113,8 +99,7 @@ impl BasicBlockStart { add_assign_each_cell(&mut circuit_builder, &out_memory_ts, &memory_ts); // To BB final - let (out_stack_ts_id, out_stack_ts) = - circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); + let (out_stack_ts_id, out_stack_ts) = circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); add_assign_each_cell(&mut circuit_builder, &out_stack_ts, stack_ts.values()); let (out_stack_top_id, out_stack_top) = circuit_builder.create_witness_out(1); circuit_builder.add(out_stack_top[0], stack_top, E::BaseField::ONE); diff --git a/singer-pro/src/basic_block/utils.rs b/singer-pro/src/basic_block/utils.rs index be29c1ad2..4a8de82b6 100644 --- a/singer-pro/src/basic_block/utils.rs +++ b/singer-pro/src/basic_block/utils.rs @@ -11,21 +11,15 @@ pub(super) struct BasicBlockStack { } impl BasicBlockStack { - pub(super) fn initialize( - info: BasicBlockInfo, - bb_start_node_id: usize, - bb_to_succ: &ToSuccInst, - ) -> Self { - let mut stack = - vec![NodeOutputType::OutputLayer(0); -info.bb_start_stack_top_offsets[0] as usize]; + pub(super) fn initialize(info: BasicBlockInfo, bb_start_node_id: usize, bb_to_succ: &ToSuccInst) -> Self { + let mut stack = vec![NodeOutputType::OutputLayer(0); -info.bb_start_stack_top_offsets[0] as usize]; let stack_top = stack.len() as i64; bb_to_succ .stack_result_ids .iter() .zip(info.bb_start_stack_top_offsets.iter().rev()) .for_each(|(&wire_id, &offset)| { - stack[(stack_top + offset) as usize] = - NodeOutputType::WireOut(bb_start_node_id, wire_id); + stack[(stack_top + offset) as usize] = NodeOutputType::WireOut(bb_start_node_id, wire_id); }); Self { stack, info } } @@ -44,10 +38,7 @@ impl BasicBlockStack { match mode { StackOpMode::PopPush(n, _) => (0..n).map(|_| self.stack.pop().unwrap()).collect_vec(), StackOpMode::Swap(n) => { - vec![ - self.stack[self.stack.len() - 1], - self.stack[self.stack.len() - n - 1], - ] + vec![self.stack[self.stack.len() - 1], self.stack[self.stack.len() - n - 1]] } StackOpMode::Dup(n) => { vec![self.stack[self.stack.len() - n]] diff --git a/singer-pro/src/instructions.rs b/singer-pro/src/instructions.rs index 9a55475a5..8f5788be8 100644 --- a/singer-pro/src/instructions.rs +++ b/singer-pro/src/instructions.rs @@ -12,8 +12,8 @@ use crate::{ }; use self::{ - add::AddInstruction, calldataload::CalldataloadInstruction, gt::GtInstruction, - jump::JumpInstruction, jumpi::JumpiInstruction, mstore::MstoreInstruction, + add::AddInstruction, calldataload::CalldataloadInstruction, gt::GtInstruction, jump::JumpInstruction, + jumpi::JumpiInstruction, mstore::MstoreInstruction, }; // arithmetic @@ -47,10 +47,7 @@ impl SingerInstCircuitBuilder { let mut insts_circuits = HashMap::new(); insts_circuits.insert(0x01, AddInstruction::construct_circuits(challenges)?); insts_circuits.insert(0x11, GtInstruction::construct_circuits(challenges)?); - insts_circuits.insert( - 0x35, - CalldataloadInstruction::construct_circuits(challenges)?, - ); + insts_circuits.insert(0x35, CalldataloadInstruction::construct_circuits(challenges)?); insts_circuits.insert(0x52, MstoreInstruction::construct_circuits(challenges)?); insts_circuits.insert(0x56, JumpInstruction::construct_circuits(challenges)?); insts_circuits.insert(0x57, JumpiInstruction::construct_circuits(challenges)?); @@ -138,9 +135,7 @@ pub(crate) trait InstructionGraph { /// Construct instruction circuits and its extensions. Mostly there is no /// extensions. - fn construct_circuits( - challenges: ChipChallenges, - ) -> Result<(InstCircuit, Vec>), ZKVMError> { + fn construct_circuits(challenges: ChipChallenges) -> Result<(InstCircuit, Vec>), ZKVMError> { Ok((Self::InstType::construct_circuit(challenges)?, vec![])) } @@ -195,11 +190,7 @@ pub(crate) trait InstructionGraph { real_n_instances: usize, _params: &SingerParams, ) -> Result<(Vec, Vec, Option), ZKVMError> { - let node_id = graph_builder.add_node( - >::NAME, - &inst_circuit.circuit, - preds, - )?; + let node_id = graph_builder.add_node(>::NAME, &inst_circuit.circuit, preds)?; let stack = inst_circuit .layout .to_succ_inst diff --git a/singer-pro/src/instructions/add.rs b/singer-pro/src/instructions/add.rs index 14b919cd0..7d6d6de49 100644 --- a/singer-pro/src/instructions/add.rs +++ b/singer-pro/src/instructions/add.rs @@ -2,13 +2,13 @@ use ff_ext::ExtensionField; use gkr::structs::Circuit; use paste::paste; use simple_frontend::structs::CircuitBuilder; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ chip_handler::ROMOperations, chips::IntoEnumIterator, constants::OpcodeType, register_witness, structs::{ChipChallenges, InstOutChipType, ROMHandler, StackUInt, TSUInt}, + uint::constants::AddSubConstants, }; use std::{collections::BTreeMap, sync::Arc}; @@ -62,8 +62,7 @@ impl Instruction for AddInstruction { )?; // To successor instruction let stack_result_id = circuit_builder.create_witness_out_from_cells(result.values()); - let (next_memory_ts_id, next_memory_ts) = - circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); + let (next_memory_ts_id, next_memory_ts) = circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); add_assign_each_cell(&mut circuit_builder, &next_memory_ts, &memory_ts); // To chips diff --git a/singer-pro/src/instructions/calldataload.rs b/singer-pro/src/instructions/calldataload.rs index 152788138..760722c8b 100644 --- a/singer-pro/src/instructions/calldataload.rs +++ b/singer-pro/src/instructions/calldataload.rs @@ -53,8 +53,7 @@ impl Instruction for CalldataloadInstruction { // To successor instruction let (data_copy_id, data_copy) = circuit_builder.create_witness_out(data.len()); add_assign_each_cell(&mut circuit_builder, &data_copy, &data); - let (next_memory_ts_id, next_memory_ts) = - circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); + let (next_memory_ts_id, next_memory_ts) = circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); add_assign_each_cell(&mut circuit_builder, &next_memory_ts, &memory_ts); // To chips diff --git a/singer-pro/src/instructions/gt.rs b/singer-pro/src/instructions/gt.rs index 1bc74c1f6..ab5c551c0 100644 --- a/singer-pro/src/instructions/gt.rs +++ b/singer-pro/src/instructions/gt.rs @@ -2,13 +2,13 @@ use ff_ext::ExtensionField; use gkr::structs::Circuit; use paste::paste; use simple_frontend::structs::CircuitBuilder; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ chip_handler::ROMOperations, chips::IntoEnumIterator, constants::OpcodeType, register_witness, structs::{ChipChallenges, InstOutChipType, ROMHandler, StackUInt, TSUInt}, + uint::constants::AddSubConstants, }; use std::{collections::BTreeMap, sync::Arc}; @@ -43,10 +43,8 @@ impl Instruction for GtInstruction { // From predesessor instruction let (memory_ts_id, memory_ts) = circuit_builder.create_witness_in(TSUInt::N_OPERAND_CELLS); - let (operand_0_id, operand_0) = - circuit_builder.create_witness_in(StackUInt::N_OPERAND_CELLS); - let (operand_1_id, operand_1) = - circuit_builder.create_witness_in(StackUInt::N_OPERAND_CELLS); + let (operand_0_id, operand_0) = circuit_builder.create_witness_in(StackUInt::N_OPERAND_CELLS); + let (operand_1_id, operand_1) = circuit_builder.create_witness_in(StackUInt::N_OPERAND_CELLS); let mut rom_handler = ROMHandler::new(&challenges); @@ -67,8 +65,7 @@ impl Instruction for GtInstruction { .concat(); // To successor instruction let stack_result_id = circuit_builder.create_witness_out_from_cells(&result); - let (next_memory_ts_id, next_memory_ts) = - circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); + let (next_memory_ts_id, next_memory_ts) = circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); add_assign_each_cell(&mut circuit_builder, &next_memory_ts, &memory_ts); // To chips diff --git a/singer-pro/src/instructions/jump.rs b/singer-pro/src/instructions/jump.rs index fd3919de9..80d90b11f 100644 --- a/singer-pro/src/instructions/jump.rs +++ b/singer-pro/src/instructions/jump.rs @@ -35,8 +35,7 @@ impl Instruction for JumpInstruction { add_assign_each_cell(&mut circuit_builder, &next_pc_copy, &next_pc); // To Succesor instruction - let (next_memory_ts_id, next_memory_ts) = - circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); + let (next_memory_ts_id, next_memory_ts) = circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); add_assign_each_cell(&mut circuit_builder, &next_memory_ts, &memory_ts); // To chips diff --git a/singer-pro/src/instructions/jumpi.rs b/singer-pro/src/instructions/jumpi.rs index 34a2a0c55..59c99dfc0 100644 --- a/singer-pro/src/instructions/jumpi.rs +++ b/singer-pro/src/instructions/jumpi.rs @@ -49,8 +49,7 @@ impl Instruction for JumpiInstruction { // From predesessor instruction let (memory_ts_id, memory_ts) = circuit_builder.create_witness_in(TSUInt::N_OPERAND_CELLS); let (dest_id, dest) = circuit_builder.create_witness_in(StackUInt::N_OPERAND_CELLS); - let (cond_values_id, cond_values) = - circuit_builder.create_witness_in(StackUInt::N_OPERAND_CELLS); + let (cond_values_id, cond_values) = circuit_builder.create_witness_in(StackUInt::N_OPERAND_CELLS); let mut rom_handler = ROMHandler::new(&challenges); @@ -66,8 +65,7 @@ impl Instruction for JumpiInstruction { .iter() .for_each(|x| circuit_builder.add(non_zero_or, *x, E::BaseField::ONE)); let cond_non_zero_or_inv = phase0[Self::phase0_cond_non_zero_or_inv().start]; - let cond_non_zero = - rom_handler.non_zero(&mut circuit_builder, non_zero_or, cond_non_zero_or_inv)?; + let cond_non_zero = rom_handler.non_zero(&mut circuit_builder, non_zero_or, cond_non_zero_or_inv)?; // If cond_non_zero, next_pc = dest, otherwise, pc = pc + 1 let pc_plus_1 = &phase0[Self::phase0_pc_plus_1()]; @@ -89,8 +87,7 @@ impl Instruction for JumpiInstruction { rom_handler.bytecode_with_pc_byte(&mut circuit_builder, &next_pc, next_opcode); // To successor instruction - let (next_memory_ts_id, next_memory_ts) = - circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); + let (next_memory_ts_id, next_memory_ts) = circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); add_assign_each_cell(&mut circuit_builder, &next_memory_ts, &memory_ts); let rom_id = rom_handler.finalize(&mut circuit_builder); diff --git a/singer-pro/src/instructions/mstore.rs b/singer-pro/src/instructions/mstore.rs index c8ea30089..10fdf326c 100644 --- a/singer-pro/src/instructions/mstore.rs +++ b/singer-pro/src/instructions/mstore.rs @@ -15,10 +15,7 @@ use singer_utils::{ use std::{collections::BTreeMap, mem, sync::Arc}; use crate::{ - component::{ - AccessoryCircuit, AccessoryLayout, FromPredInst, FromWitness, InstCircuit, InstLayout, - ToSuccInst, - }, + component::{AccessoryCircuit, AccessoryLayout, FromPredInst, FromWitness, InstCircuit, InstLayout, ToSuccInst}, error::ZKVMError, utils::add_assign_each_cell, CircuitWitnessIn, SingerParams, @@ -31,9 +28,7 @@ pub struct MstoreInstruction; impl InstructionGraph for MstoreInstruction { type InstType = Self; - fn construct_circuits( - challenges: ChipChallenges, - ) -> Result<(InstCircuit, Vec>), ZKVMError> { + fn construct_circuits(challenges: ChipChallenges) -> Result<(InstCircuit, Vec>), ZKVMError> { Ok(( Self::InstType::construct_circuit(challenges)?, vec![MstoreAccessory::construct_circuit(challenges)?], @@ -127,8 +122,7 @@ impl Instruction for MstoreInstruction { // From predesessor instruction let (memory_ts_id, memory_ts) = circuit_builder.create_witness_in(TSUInt::N_OPERAND_CELLS); let (offset_id, offset) = circuit_builder.create_witness_in(StackUInt::N_OPERAND_CELLS); - let (mem_value_id, mem_values) = - circuit_builder.create_witness_in(StackUInt::N_OPERAND_CELLS); + let (mem_value_id, mem_values) = circuit_builder.create_witness_in(StackUInt::N_OPERAND_CELLS); let mut rom_handler = ROMHandler::new(&challenges); @@ -141,16 +135,14 @@ impl Instruction for MstoreInstruction { &phase0[Self::phase0_memory_ts_add()], )?; // To successor instruction - let next_memory_ts_id = - circuit_builder.create_witness_out_from_cells(next_memory_ts.values()); + let next_memory_ts_id = circuit_builder.create_witness_out_from_cells(next_memory_ts.values()); // Pop mem_bytes from stack let mem_bytes = &phase0[Self::phase0_mem_bytes()]; rom_handler.range_check_bytes(&mut circuit_builder, mem_bytes)?; let mem_values = StackUInt::try_from(mem_values.as_slice())?; - let mem_values_from_bytes = - StackUInt::from_bytes_big_endian(&mut circuit_builder, &mem_bytes)?; + let mem_values_from_bytes = StackUInt::from_bytes_big_endian(&mut circuit_builder, &mem_bytes)?; StackUInt::assert_eq(&mut circuit_builder, &mem_values_from_bytes, &mem_values)?; // To chips. @@ -161,8 +153,7 @@ impl Instruction for MstoreInstruction { to_chip_ids[InstOutChipType::ROMInput as usize] = rom_id; // To accessory circuits. - let (to_acc_dup_id, to_acc_dup) = - circuit_builder.create_witness_out(MstoreAccessory::pred_dup_size()); + let (to_acc_dup_id, to_acc_dup) = circuit_builder.create_witness_out(MstoreAccessory::pred_dup_size()); add_assign_each_cell( &mut circuit_builder, &to_acc_dup[MstoreAccessory::pred_dup_memory_ts()], @@ -174,8 +165,8 @@ impl Instruction for MstoreInstruction { &offset, ); - let (to_acc_ooo_id, to_acc_ooo) = circuit_builder - .create_witness_out(MstoreAccessory::pred_ooo_size() * EVM_STACK_BYTE_WIDTH); + let (to_acc_ooo_id, to_acc_ooo) = + circuit_builder.create_witness_out(MstoreAccessory::pred_ooo_size() * EVM_STACK_BYTE_WIDTH); add_assign_each_cell(&mut circuit_builder, &to_acc_ooo, mem_bytes); circuit_builder.configure(); @@ -227,9 +218,7 @@ register_witness!( ); impl MstoreAccessory { - pub fn construct_circuit( - challenges: ChipChallenges, - ) -> Result, ZKVMError> { + pub fn construct_circuit(challenges: ChipChallenges) -> Result, ZKVMError> { let mut circuit_builder = CircuitBuilder::new(); // From predesessor circuit. @@ -250,13 +239,8 @@ impl MstoreAccessory { let offset = StackUInt::try_from(&pred_dup[Self::pred_dup_offset()])?; let offset_add_delta = &phase0[Self::phase0_offset_add_delta()]; let delta = circuit_builder.create_counter_in(0)[0]; - let offset_plus_delta = StackUInt::add_cell( - &mut circuit_builder, - &mut rom_handler, - &offset, - delta, - offset_add_delta, - )?; + let offset_plus_delta = + StackUInt::add_cell(&mut circuit_builder, &mut rom_handler, &offset, delta, offset_add_delta)?; TSUInt::assert_lt( &mut circuit_builder, &mut rom_handler, diff --git a/singer-pro/src/instructions/ret.rs b/singer-pro/src/instructions/ret.rs index e319e5055..8bac600bb 100644 --- a/singer-pro/src/instructions/ret.rs +++ b/singer-pro/src/instructions/ret.rs @@ -3,21 +3,18 @@ use gkr::structs::Circuit; use gkr_graph::structs::{CircuitGraphBuilder, NodeOutputType, PredType}; use paste::paste; use simple_frontend::structs::CircuitBuilder; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ chip_handler::{OAMOperations, ROMOperations}, chips::{IntoEnumIterator, SingerChipBuilder}, constants::OpcodeType, register_witness, structs::{ChipChallenges, InstOutChipType, RAMHandler, ROMHandler, StackUInt, TSUInt}, + uint::constants::AddSubConstants, }; use std::{collections::BTreeMap, mem, sync::Arc}; use crate::{ - component::{ - AccessoryCircuit, FromPredInst, FromPublicIO, FromWitness, InstCircuit, InstLayout, - ToSuccInst, - }, + component::{AccessoryCircuit, FromPredInst, FromPublicIO, FromWitness, InstCircuit, InstLayout, ToSuccInst}, error::ZKVMError, utils::add_assign_each_cell, CircuitWitnessIn, SingerParams, @@ -42,8 +39,7 @@ impl InstructionGraph for ReturnInstruction { _: usize, params: &SingerParams, ) -> Result<(Vec, Vec, Option), ZKVMError> { - let public_output_size = - preds[inst_circuit.layout.from_pred_inst.stack_operand_ids[1] as usize].clone(); + let public_output_size = preds[inst_circuit.layout.from_pred_inst.stack_operand_ids[1] as usize].clone(); // Add the instruction circuit to the graph. let node_id = graph_builder.add_node_with_witness( @@ -79,8 +75,7 @@ impl InstructionGraph for ReturnInstruction { real_n_instances: usize, _: &SingerParams, ) -> Result<(Vec, Vec, Option), ZKVMError> { - let public_output_size = - preds[inst_circuit.layout.from_pred_inst.stack_operand_ids[1] as usize].clone(); + let public_output_size = preds[inst_circuit.layout.from_pred_inst.stack_operand_ids[1] as usize].clone(); // Add the instruction circuit to the graph. let node_id = graph_builder.add_node( @@ -138,13 +133,8 @@ impl Instruction for ReturnInstruction { let delta = circuit_builder.create_counter_in(0)[0]; let offset = StackUInt::try_from(offset.as_slice())?; let offset_add_delta = &phase0[Self::phase0_offset_add()]; - let offset_plus_delta = StackUInt::add_cell( - &mut circuit_builder, - &mut rom_handler, - &offset, - delta, - offset_add_delta, - )?; + let offset_plus_delta = + StackUInt::add_cell(&mut circuit_builder, &mut rom_handler, &offset, delta, offset_add_delta)?; // Load from memory let mem_byte = public_io[Self::public_io_byte().start]; @@ -159,8 +149,7 @@ impl Instruction for ReturnInstruction { ); // To successor instruction - let (next_memory_ts_id, next_memory_ts) = - circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); + let (next_memory_ts_id, next_memory_ts) = circuit_builder.create_witness_out(TSUInt::N_OPERAND_CELLS); add_assign_each_cell(&mut circuit_builder, &next_memory_ts, memory_ts.values()); let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); diff --git a/singer-pro/src/instructions/unknown.rs b/singer-pro/src/instructions/unknown.rs index bf09444b6..42c1627ab 100644 --- a/singer-pro/src/instructions/unknown.rs +++ b/singer-pro/src/instructions/unknown.rs @@ -1,6 +1,5 @@ use ff_ext::ExtensionField; -use singer_utils::constants::OpcodeType; -use singer_utils::structs::ChipChallenges; +use singer_utils::{constants::OpcodeType, structs::ChipChallenges}; use crate::{component::InstCircuit, error::ZKVMError}; diff --git a/singer-pro/src/lib.rs b/singer-pro/src/lib.rs index 4191e202a..8f80b4548 100644 --- a/singer-pro/src/lib.rs +++ b/singer-pro/src/lib.rs @@ -5,9 +5,7 @@ use basic_block::SingerBasicBlockBuilder; use error::ZKVMError; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; -use gkr_graph::structs::{ - CircuitGraph, CircuitGraphAuxInfo, CircuitGraphBuilder, CircuitGraphWitness, NodeOutputType, -}; +use gkr_graph::structs::{CircuitGraph, CircuitGraphAuxInfo, CircuitGraphBuilder, CircuitGraphWitness, NodeOutputType}; use goldilocks::SmallField; use instructions::SingerInstCircuitBuilder; use itertools::Itertools; @@ -65,14 +63,7 @@ impl SingerGraphBuilder { program_input: &[u8], real_challenges: &[E], params: &SingerParams, - ) -> Result< - ( - SingerCircuit, - SingerWitness, - SingerWiresOutID, - ), - ZKVMError, - > { + ) -> Result<(SingerCircuit, SingerWitness, SingerWiresOutID), ZKVMError> { let basic_blocks = self.bb_builder.basic_block_bytecode(); // Construct tables for lookup arguments, including bytecode, range and // calldata @@ -118,17 +109,10 @@ impl SingerGraphBuilder { let (graph, graph_witness) = graph_builder.finalize_graph_and_witness_with_targets(&singer_wire_out_id.to_vec()); - Ok(( - SingerCircuit(graph), - SingerWitness(graph_witness), - singer_wire_out_id, - )) + Ok((SingerCircuit(graph), SingerWitness(graph_witness), singer_wire_out_id)) } - pub fn construct_graph( - mut self, - aux_info: &SingerAuxInfo, - ) -> Result, ZKVMError> { + pub fn construct_graph(mut self, aux_info: &SingerAuxInfo) -> Result, ZKVMError> { // Construct tables for lookup arguments, including bytecode, range and // calldata let pub_out_id = self.bb_builder.construct_graph( @@ -200,12 +184,7 @@ pub struct SingerWiresOutID { impl SingerWiresOutID { pub fn to_vec(&self) -> Vec { - let mut res = [ - self.ram_load.clone(), - self.ram_store.clone(), - self.rom_input.clone(), - ] - .concat(); + let mut res = [self.ram_load.clone(), self.ram_store.clone(), self.rom_input.clone()].concat(); if let Some(public_output_size) = self.public_output_size { res.push(public_output_size); } diff --git a/singer-pro/src/scheme/prover.rs b/singer-pro/src/scheme/prover.rs index 82f500aba..11d744481 100644 --- a/singer-pro/src/scheme/prover.rs +++ b/singer-pro/src/scheme/prover.rs @@ -5,9 +5,7 @@ use gkr_graph::structs::{CircuitGraphAuxInfo, NodeOutputType}; use itertools::Itertools; use transcript::Transcript; -use crate::{ - error::ZKVMError, SingerCircuit, SingerWiresOutID, SingerWiresOutValues, SingerWitness, -}; +use crate::{error::ZKVMError, SingerCircuit, SingerWiresOutID, SingerWiresOutValues, SingerWitness}; use super::{GKRGraphProverState, SingerProof}; @@ -19,11 +17,7 @@ pub fn prove( ) -> Result<(SingerProof, CircuitGraphAuxInfo), ZKVMError> { // TODO: Add PCS. let point = (0..2 * ::DEGREE) - .map(|_| { - transcript - .get_and_append_challenge(b"output point") - .elements - }) + .map(|_| transcript.get_and_append_challenge(b"output point").elements) .collect_vec(); let singer_out_evals = { @@ -32,15 +26,13 @@ pub fn prove( .iter() .map(|node| { match node { - NodeOutputType::OutputLayer(node_id) => vm_witness.0.node_witnesses - [*node_id as usize] + NodeOutputType::OutputLayer(node_id) => vm_witness.0.node_witnesses[*node_id as usize] .output_layer_witness_ref() .instances .iter() .cloned() .flatten(), - NodeOutputType::WireOut(node_id, wit_id) => vm_witness.0.node_witnesses - [*node_id as usize] + NodeOutputType::WireOut(node_id, wit_id) => vm_witness.0.node_witnesses[*node_id as usize] .witness_out_ref()[*wit_id as usize] .instances .iter() @@ -76,8 +68,7 @@ pub fn prove( }; let target_evals = vm_circuit.0.target_evals(&vm_witness.0, &point); - let gkr_phase_proof = - GKRGraphProverState::prove(&vm_circuit.0, &vm_witness.0, &target_evals, transcript, 1)?; + let gkr_phase_proof = GKRGraphProverState::prove(&vm_circuit.0, &vm_witness.0, &target_evals, transcript, 1)?; Ok(( SingerProof { gkr_phase_proof, diff --git a/singer-pro/src/scheme/verifier.rs b/singer-pro/src/scheme/verifier.rs index f47183cc3..c7ad58a89 100644 --- a/singer-pro/src/scheme/verifier.rs +++ b/singer-pro/src/scheme/verifier.rs @@ -17,11 +17,7 @@ pub fn verify( ) -> Result<(), ZKVMError> { // TODO: Add PCS. let point = (0..2 * ::DEGREE) - .map(|_| { - transcript - .get_and_append_challenge(b"output point") - .elements - }) + .map(|_| transcript.get_and_append_challenge(b"output point").elements) .collect_vec(); let SingerWiresOutValues { @@ -45,9 +41,7 @@ pub fn verify( let (den, num) = x.split_at(l / 2); (E::from_limbs(den), E::from_limbs(num)) }) - .fold((E::ONE, E::ZERO), |acc, x| { - (acc.0 * x.0, acc.0 * x.1 + acc.1 * x.0) - }); + .fold((E::ONE, E::ZERO), |acc, x| (acc.0 * x.0, acc.0 * x.1 + acc.1 * x.0)); let rom_table_sum = rom_table .iter() .map(|x| { @@ -55,9 +49,7 @@ pub fn verify( let (den, num) = x.split_at(l / 2); (E::from_limbs(den), E::from_limbs(num)) }) - .fold((E::ONE, E::ZERO), |acc, x| { - (acc.0 * x.0, acc.0 * x.1 + acc.1 * x.0) - }); + .fold((E::ONE, E::ZERO), |acc, x| (acc.0 * x.0, acc.0 * x.1 + acc.1 * x.0)); if rom_input_sum.0 * rom_table_sum.1 != rom_input_sum.1 * rom_table_sum.0 { return Err(ZKVMError::VerifyError); } @@ -66,10 +58,7 @@ pub fn verify( chain![ram_load, ram_store, rom_input, rom_table] .map(|x| { let f = vec![x.to_vec()].as_slice().original_mle(); - PointAndEval::new( - point[..f.num_vars].to_vec(), - f.evaluate(&point[..f.num_vars]), - ) + PointAndEval::new(point[..f.num_vars].to_vec(), f.evaluate(&point[..f.num_vars])) }) .collect_vec(), ); @@ -80,10 +69,7 @@ pub fn verify( point[..f.num_vars].to_vec(), f.evaluate(&point[..f.num_vars]), )); - assert_eq!( - output[0], - E::BaseField::from(aux_info.program_output_len as u64) - ) + assert_eq!(output[0], E::BaseField::from(aux_info.program_output_len as u64)) } GKRGraphVerifierState::verify( diff --git a/singer-utils/src/chip_handler.rs b/singer-utils/src/chip_handler.rs index b8cb096ad..05d85014a 100644 --- a/singer-utils/src/chip_handler.rs +++ b/singer-utils/src/chip_handler.rs @@ -9,23 +9,14 @@ pub mod global_state; pub mod memory; pub mod ram_handler; pub mod range; +pub mod register; pub mod rom_handler; pub mod stack; pub trait BytecodeChipOperations: ROMOperations { - fn bytecode_with_pc_opcode( - &mut self, - circuit_builder: &mut CircuitBuilder, - pc: &[CellId], - opcode: OpcodeType, - ); + fn bytecode_with_pc_opcode(&mut self, circuit_builder: &mut CircuitBuilder, pc: &[CellId], opcode: OpcodeType); - fn bytecode_with_pc_byte( - &mut self, - circuit_builder: &mut CircuitBuilder, - pc: &[CellId], - byte: CellId, - ); + fn bytecode_with_pc_byte(&mut self, circuit_builder: &mut CircuitBuilder, pc: &[CellId], byte: CellId); } pub trait StackChipOperations: OAMOperations { @@ -46,6 +37,27 @@ pub trait StackChipOperations: OAMOperations { ); } +pub trait RegisterChipOperations: OAMOperations { + fn register_load( + &mut self, + circuit_builder: &mut CircuitBuilder, + register_id: &[CellId], + prev_timestamp: &[CellId], + timestamp: &[CellId], + values: &[CellId], + ); + + fn register_store( + &mut self, + circuit_builder: &mut CircuitBuilder, + register_id: &[CellId], + prev_timestamp: &[CellId], + timestamp: &[CellId], + prev_values: &[CellId], + values: &[CellId], + ); +} + pub trait RangeChipOperations: ROMOperations { fn range_check_stack_top( &mut self, @@ -91,12 +103,7 @@ pub trait MemoryChipOperations: RAMOperations { } pub trait CalldataChipOperations: ROMOperations { - fn calldataload( - &mut self, - circuit_builder: &mut CircuitBuilder, - offset: &[CellId], - data: &[CellId], - ); + fn calldataload(&mut self, circuit_builder: &mut CircuitBuilder, offset: &[CellId], data: &[CellId]); } pub trait GlobalStateChipOperations { @@ -122,12 +129,7 @@ pub trait GlobalStateChipOperations { } pub trait ROMOperations { - fn rom_load( - &mut self, - circuit_builder: &mut CircuitBuilder, - key: &[CellId], - value: &[CellId], - ); + fn rom_load(&mut self, circuit_builder: &mut CircuitBuilder, key: &[CellId], value: &[CellId]); fn rom_load_mixed( &mut self, @@ -157,13 +159,7 @@ pub trait OAMOperations { value: &[MixedCell], ); - fn oam_store( - &mut self, - circuit_builder: &mut CircuitBuilder, - ts: &[CellId], - key: &[CellId], - value: &[CellId], - ); + fn oam_store(&mut self, circuit_builder: &mut CircuitBuilder, ts: &[CellId], key: &[CellId], value: &[CellId]); fn oam_store_mixed( &mut self, diff --git a/singer-utils/src/chip_handler/bytecode.rs b/singer-utils/src/chip_handler/bytecode.rs index c3afef76e..8ec95f583 100644 --- a/singer-utils/src/chip_handler/bytecode.rs +++ b/singer-utils/src/chip_handler/bytecode.rs @@ -9,37 +9,34 @@ use crate::{ use super::{BytecodeChipOperations, ROMOperations}; -impl BytecodeChipOperations for ROMHandler { - fn bytecode_with_pc_opcode( - &mut self, - circuit_builder: &mut CircuitBuilder, - pc: &[CellId], - opcode: OpcodeType, - ) { +impl ROMHandler { + pub fn bytecode_with_pc(&mut self, circuit_builder: &mut CircuitBuilder, pc: &[CellId], opcode: u64) { let key = [ - vec![MixedCell::Constant(Ext::BaseField::from( - ROMType::Bytecode as u64, - ))], + vec![MixedCell::Constant(Ext::BaseField::from(ROMType::Bytecode as u64))], pc.iter().map(|&x| x.into()).collect_vec(), ] .concat(); self.rom_load_mixed( circuit_builder, &key, - &[MixedCell::Constant(Ext::BaseField::from(opcode as u64))], + &[MixedCell::Constant(Ext::BaseField::from(opcode))], ); } +} - fn bytecode_with_pc_byte( +impl BytecodeChipOperations for ROMHandler { + fn bytecode_with_pc_opcode( &mut self, circuit_builder: &mut CircuitBuilder, pc: &[CellId], - byte: CellId, + opcode: OpcodeType, ) { + self.bytecode_with_pc(circuit_builder, pc, opcode.into()); + } + + fn bytecode_with_pc_byte(&mut self, circuit_builder: &mut CircuitBuilder, pc: &[CellId], byte: CellId) { let key = [ - vec![MixedCell::Constant(Ext::BaseField::from( - ROMType::Bytecode as u64, - ))], + vec![MixedCell::Constant(Ext::BaseField::from(ROMType::Bytecode as u64))], pc.iter().map(|&x| x.into()).collect_vec(), ] .concat(); diff --git a/singer-utils/src/chip_handler/calldata.rs b/singer-utils/src/chip_handler/calldata.rs index 4925482f1..cf5be4ced 100644 --- a/singer-utils/src/chip_handler/calldata.rs +++ b/singer-utils/src/chip_handler/calldata.rs @@ -7,16 +7,9 @@ use crate::structs::{ROMHandler, ROMType}; use super::{CalldataChipOperations, ROMOperations}; impl CalldataChipOperations for ROMHandler { - fn calldataload( - &mut self, - circuit_builder: &mut CircuitBuilder, - offset: &[CellId], - data: &[CellId], - ) { + fn calldataload(&mut self, circuit_builder: &mut CircuitBuilder, offset: &[CellId], data: &[CellId]) { let key = [ - vec![MixedCell::Constant(Ext::BaseField::from( - ROMType::Calldata as u64, - ))], + vec![MixedCell::Constant(Ext::BaseField::from(ROMType::Calldata as u64))], offset.iter().map(|&x| x.into()).collect_vec(), ] .concat(); diff --git a/singer-utils/src/chip_handler/global_state.rs b/singer-utils/src/chip_handler/global_state.rs index 96eb168a9..67b74a448 100644 --- a/singer-utils/src/chip_handler/global_state.rs +++ b/singer-utils/src/chip_handler/global_state.rs @@ -16,9 +16,7 @@ impl GlobalStateChipOperations for RAMHandler { clk: CellId, ) { let key = [ - vec![MixedCell::Constant(Ext::BaseField::from( - RAMType::GlobalState as u64, - ))], + vec![MixedCell::Constant(Ext::BaseField::from(RAMType::GlobalState as u64))], pc.iter().map(|&x| x.into()).collect::>(), stack_ts.iter().map(|&x| x.into()).collect::>(), memory_ts.iter().map(|&x| x.into()).collect::>(), @@ -38,9 +36,7 @@ impl GlobalStateChipOperations for RAMHandler { clk: MixedCell, ) { let key = [ - vec![MixedCell::Constant(Ext::BaseField::from( - RAMType::GlobalState as u64, - ))], + vec![MixedCell::Constant(Ext::BaseField::from(RAMType::GlobalState as u64))], pc.iter().map(|&x| x.into()).collect::>(), stack_ts.iter().map(|&x| x.into()).collect::>(), memory_ts.iter().map(|&x| x.into()).collect::>(), diff --git a/singer-utils/src/chip_handler/memory.rs b/singer-utils/src/chip_handler/memory.rs index 50cb3a701..d3c1512ca 100644 --- a/singer-utils/src/chip_handler/memory.rs +++ b/singer-utils/src/chip_handler/memory.rs @@ -16,9 +16,7 @@ impl MemoryChipOperations for RAMHandler { byte: CellId, ) { let key = [ - vec![MixedCell::Constant(Ext::BaseField::from( - RAMType::Memory as u64, - ))], + vec![MixedCell::Constant(Ext::BaseField::from(RAMType::Memory as u64))], offset.iter().map(|&x| x.into()).collect_vec(), ] .concat(); @@ -37,9 +35,7 @@ impl MemoryChipOperations for RAMHandler { cur_byte: CellId, ) { let key = [ - vec![MixedCell::Constant(Ext::BaseField::from( - RAMType::Memory as u64, - ))], + vec![MixedCell::Constant(Ext::BaseField::from(RAMType::Memory as u64))], offset.iter().map(|&x| x.into()).collect_vec(), ] .concat(); diff --git a/singer-utils/src/chip_handler/ram_handler.rs b/singer-utils/src/chip_handler/ram_handler.rs index a3d8126ef..70d42e999 100644 --- a/singer-utils/src/chip_handler/ram_handler.rs +++ b/singer-utils/src/chip_handler/ram_handler.rs @@ -103,10 +103,7 @@ impl OAMOperations for RAMHandler { circuit_builder.add_const(out.cells[0], Ext::BaseField::ONE); records.push(out); } - Some(( - circuit_builder.create_witness_out_from_exts(&records), - records.len(), - )) + Some((circuit_builder.create_witness_out_from_exts(&records), records.len())) }; let mut rd_records = self.rd_records; diff --git a/singer-utils/src/chip_handler/range.rs b/singer-utils/src/chip_handler/range.rs index a1a68e0ab..fe80e5c74 100644 --- a/singer-utils/src/chip_handler/range.rs +++ b/singer-utils/src/chip_handler/range.rs @@ -92,6 +92,14 @@ impl ROMHandler { } impl ROMHandler { + pub fn increase_pc( + circuit_builder: &mut CircuitBuilder, + pc: &PCUInt, + witness: &[CellId], + ) -> Result { + ROMHandler::add_pc_const(circuit_builder, &pc, 1, witness) + } + pub fn add_pc_const( circuit_builder: &mut CircuitBuilder, pc: &PCUInt, @@ -99,12 +107,16 @@ impl ROMHandler { witness: &[CellId], ) -> Result { let carry = PCUInt::extract_unsafe_carry_add(witness); - PCUInt::add_const_unsafe( - circuit_builder, - &pc, - i64_to_base_field::(constant), - carry, - ) + PCUInt::add_const_unsafe(circuit_builder, &pc, i64_to_base_field::(constant), carry) + } + + pub fn increase_ts( + &mut self, + circuit_builder: &mut CircuitBuilder, + ts: &TSUInt, + witness: &[CellId], + ) -> Result { + self.add_ts_with_const(circuit_builder, &ts, 1, witness) } pub fn add_ts_with_const( @@ -114,13 +126,7 @@ impl ROMHandler { constant: i64, witness: &[CellId], ) -> Result { - TSUInt::add_const( - circuit_builder, - self, - &ts, - i64_to_base_field::(constant), - witness, - ) + TSUInt::add_const(circuit_builder, self, &ts, i64_to_base_field::(constant), witness) } pub fn non_zero( diff --git a/singer-utils/src/chip_handler/register.rs b/singer-utils/src/chip_handler/register.rs new file mode 100644 index 000000000..9da16e42f --- /dev/null +++ b/singer-utils/src/chip_handler/register.rs @@ -0,0 +1,50 @@ +use ark_std::iterable::Iterable; +use ff_ext::ExtensionField; +use itertools::Itertools; +use simple_frontend::structs::{CellId, CircuitBuilder, MixedCell}; + +use crate::structs::{RAMHandler, RAMType}; + +use super::{RAMOperations, RegisterChipOperations}; + +impl RegisterChipOperations for RAMHandler { + fn register_load( + &mut self, + circuit_builder: &mut CircuitBuilder, + register_id: &[CellId], + prev_timestamp: &[CellId], + timestamp: &[CellId], + value: &[CellId], + ) { + let key = [ + vec![MixedCell::Constant(Ext::BaseField::from(RAMType::Register as u64))], + register_id.iter().map(|&x| x.into()).collect_vec(), + ] + .concat(); + let prev_timestamp = prev_timestamp.iter().map(|&x| x.into()).collect_vec(); + let timestamp = timestamp.iter().map(|&x| x.into()).collect_vec(); + let value = value.iter().map(|&x| x.into()).collect_vec(); + self.ram_load_mixed(circuit_builder, &prev_timestamp, ×tamp, &key, &value); + } + + fn register_store( + &mut self, + circuit_builder: &mut CircuitBuilder, + register_id: &[CellId], + prev_timestamp: &[CellId], + timestamp: &[CellId], + prev_value: &[CellId], + value: &[CellId], + ) { + let key = [ + vec![MixedCell::Constant(Ext::BaseField::from(RAMType::Register as u64))], + register_id.iter().map(|&x| x.into()).collect_vec(), + ] + .concat(); + let prev_timestamp = prev_timestamp.iter().map(|&x| x.into()).collect_vec(); + let timestamp = timestamp.iter().map(|&x| x.into()).collect_vec(); + let value = value.iter().map(|&x| x.into()).collect_vec(); + let prev_value = prev_value.iter().map(|&x| x.into()).collect_vec(); + self.ram_store_mixed(circuit_builder, &prev_timestamp, ×tamp, &key, &prev_value, &value); + } +} diff --git a/singer-utils/src/chip_handler/rom_handler.rs b/singer-utils/src/chip_handler/rom_handler.rs index ec138e71a..146a6ca07 100644 --- a/singer-utils/src/chip_handler/rom_handler.rs +++ b/singer-utils/src/chip_handler/rom_handler.rs @@ -16,12 +16,7 @@ impl ROMHandler { } impl ROMOperations for ROMHandler { - fn rom_load( - &mut self, - circuit_builder: &mut CircuitBuilder, - key: &[CellId], - value: &[CellId], - ) { + fn rom_load(&mut self, circuit_builder: &mut CircuitBuilder, key: &[CellId], value: &[CellId]) { let item_rlc = circuit_builder.create_ext_cell(); let items = [key.to_vec(), value.to_vec()].concat(); circuit_builder.rlc(&item_rlc, &items, self.challenge.record_item_rlc()); @@ -58,9 +53,6 @@ impl ROMOperations for ROMHandler { circuit_builder.add_ext(&out, &last, Ext::BaseField::ONE); records.push(out); } - Some(( - circuit_builder.create_witness_out_from_exts(&records), - records.len(), - )) + Some((circuit_builder.create_witness_out_from_exts(&records), records.len())) } } diff --git a/singer-utils/src/chips.rs b/singer-utils/src/chips.rs index f57459c38..0736b05d7 100644 --- a/singer-utils/src/chips.rs +++ b/singer-utils/src/chips.rs @@ -69,10 +69,8 @@ impl SingerChipBuilder { real_n_instances.next_power_of_two(), )?; let mut preds = vec![PredType::Source; 2]; - preds[leaf.input_id as usize] = - PredType::PredWire(NodeOutputType::WireOut(node_id, input_wit_id)); - preds[leaf.cond_id as usize] = - PredType::PredWire(NodeOutputType::OutputLayer(selector_node_id)); + preds[leaf.input_id as usize] = PredType::PredWire(NodeOutputType::WireOut(node_id, input_wit_id)); + preds[leaf.cond_id as usize] = PredType::PredWire(NodeOutputType::OutputLayer(selector_node_id)); let instance_num_vars = ceil_log2(real_n_instances * num) - 1; build_tree_graph_and_witness( @@ -138,22 +136,13 @@ impl SingerChipBuilder { inner: &Arc>| -> Result { let selector = ChipCircuitGadgets::construct_prefix_selector(n_instances, num); - let selector_node_id = - graph_builder.add_node("selector circuit", &selector.circuit, vec![])?; + let selector_node_id = graph_builder.add_node("selector circuit", &selector.circuit, vec![])?; let mut preds = vec![PredType::Source; 2]; - preds[leaf.input_id as usize] = - PredType::PredWire(NodeOutputType::WireOut(node_id, input_wit_id)); - preds[leaf.cond_id as usize] = - PredType::PredWire(NodeOutputType::OutputLayer(selector_node_id)); + preds[leaf.input_id as usize] = PredType::PredWire(NodeOutputType::WireOut(node_id, input_wit_id)); + preds[leaf.cond_id as usize] = PredType::PredWire(NodeOutputType::OutputLayer(selector_node_id)); let instance_num_vars = ceil_log2(real_n_instances) - 1; - build_tree_graph( - graph_builder, - preds, - &leaf.circuit, - inner, - instance_num_vars, - ) + build_tree_graph(graph_builder, preds, &leaf.circuit, inner, instance_num_vars) }; // Set equality argument @@ -213,17 +202,9 @@ impl SingerChipBuilder { (preds, sources) }; - let (input_pred, selector_pred, instance_num_vars) = construct_bytecode_table_and_witness( - graph_builder, - bytecode, - challenges, - real_challenges, - )?; - let (preds, sources) = pred_source( - LookupChipType::BytecodeChip as usize, - input_pred, - selector_pred, - ); + let (input_pred, selector_pred, instance_num_vars) = + construct_bytecode_table_and_witness(graph_builder, bytecode, challenges, real_challenges)?; + let (preds, sources) = pred_source(LookupChipType::BytecodeChip as usize, input_pred, selector_pred); tables_out[LookupChipType::BytecodeChip as usize] = build_tree_graph_and_witness( graph_builder, preds, @@ -234,17 +215,9 @@ impl SingerChipBuilder { instance_num_vars, )?; - let (input_pred, selector_pred, instance_num_vars) = construct_calldata_table_and_witness( - graph_builder, - program_input, - challenges, - real_challenges, - )?; - let (preds, sources) = pred_source( - LookupChipType::CalldataChip as usize, - input_pred, - selector_pred, - ); + let (input_pred, selector_pred, instance_num_vars) = + construct_calldata_table_and_witness(graph_builder, program_input, challenges, real_challenges)?; + let (preds, sources) = pred_source(LookupChipType::CalldataChip as usize, input_pred, selector_pred); tables_out[LookupChipType::CalldataChip as usize] = build_tree_graph_and_witness( graph_builder, preds, @@ -264,12 +237,8 @@ impl SingerChipBuilder { mem::take(&mut table_count_witness[table_type as usize].instances); (preds, sources) }; - let (input_pred, instance_num_vars) = construct_range_table_and_witness( - graph_builder, - RANGE_CHIP_BIT_WIDTH, - challenges, - real_challenges, - )?; + let (input_pred, instance_num_vars) = + construct_range_table_and_witness(graph_builder, RANGE_CHIP_BIT_WIDTH, challenges, real_challenges)?; let (preds, sources) = preds_no_selector(LookupChipType::RangeChip as usize, input_pred); tables_out[LookupChipType::RangeChip as usize] = build_tree_graph_and_witness( graph_builder, @@ -307,24 +276,14 @@ impl SingerChipBuilder { let (input_pred, selector_pred, instance_num_vars) = construct_bytecode_table(graph_builder, byte_code_len, challenges)?; let preds = compute_preds(input_pred, selector_pred); - tables_out[LookupChipType::BytecodeChip as usize] = build_tree_graph( - graph_builder, - preds, - &leaf.circuit, - inner, - instance_num_vars, - )?; + tables_out[LookupChipType::BytecodeChip as usize] = + build_tree_graph(graph_builder, preds, &leaf.circuit, inner, instance_num_vars)?; let (input_pred, selector_pred, instance_num_vars) = construct_calldata_table(graph_builder, program_input_len, challenges)?; let preds = compute_preds(input_pred, selector_pred); - tables_out[LookupChipType::CalldataChip as usize] = build_tree_graph( - graph_builder, - preds, - &leaf.circuit, - inner, - instance_num_vars, - )?; + tables_out[LookupChipType::CalldataChip as usize] = + build_tree_graph(graph_builder, preds, &leaf.circuit, inner, instance_num_vars)?; let leaf = &self.chip_circuit_gadgets.frac_sum_leaf_no_selector; let compute_preds_no_selector = |table_pred| { @@ -332,16 +291,10 @@ impl SingerChipBuilder { preds[leaf.input_den_id as usize] = table_pred; preds }; - let (input_pred, instance_num_vars) = - construct_range_table(graph_builder, RANGE_CHIP_BIT_WIDTH, challenges)?; + let (input_pred, instance_num_vars) = construct_range_table(graph_builder, RANGE_CHIP_BIT_WIDTH, challenges)?; let preds = compute_preds_no_selector(input_pred); - tables_out[LookupChipType::RangeChip as usize] = build_tree_graph( - graph_builder, - preds, - &leaf.circuit, - inner, - instance_num_vars, - )?; + tables_out[LookupChipType::RangeChip as usize] = + build_tree_graph(graph_builder, preds, &leaf.circuit, inner, instance_num_vars)?; Ok(tables_out) } } @@ -374,28 +327,27 @@ fn build_tree_graph_and_witness( real_challenges: &[E], instance_num_vars: usize, ) -> Result { - let (last_pred, _) = - (0..=instance_num_vars).fold(Ok((first_pred, first_source)), |prev, i| { - let circuit = if i == 0 { leaf } else { inner }; - match prev { - Ok((pred, source)) => graph_builder - .add_node_with_witness( - "tree inner node", - circuit, - pred, - real_challenges.to_vec(), - source, - 1 << (instance_num_vars - i), + let (last_pred, _) = (0..=instance_num_vars).fold(Ok((first_pred, first_source)), |prev, i| { + let circuit = if i == 0 { leaf } else { inner }; + match prev { + Ok((pred, source)) => graph_builder + .add_node_with_witness( + "tree inner node", + circuit, + pred, + real_challenges.to_vec(), + source, + 1 << (instance_num_vars - i), + ) + .map(|id| { + ( + vec![PredType::PredWire(NodeOutputType::OutputLayer(id))], + vec![LayerWitness { instances: vec![] }], ) - .map(|id| { - ( - vec![PredType::PredWire(NodeOutputType::OutputLayer(id))], - vec![LayerWitness { instances: vec![] }], - ) - }), - Err(err) => Err(err), - } - })?; + }), + Err(err) => Err(err), + } + })?; match last_pred[0] { PredType::PredWire(out) => Ok(out), _ => unreachable!(), diff --git a/singer-utils/src/chips/bytecode.rs b/singer-utils/src/chips/bytecode.rs index 314b3afa1..d59328674 100644 --- a/singer-utils/src/chips/bytecode.rs +++ b/singer-utils/src/chips/bytecode.rs @@ -81,14 +81,9 @@ pub(crate) fn construct_bytecode_table( let bytecode_circuit = construct_circuit(challenges); let selector = ChipCircuitGadgets::construct_prefix_selector(bytecode_len, 1); - let selector_node_id = - builder.add_node("bytecode selector circuit", &selector.circuit, vec![])?; + let selector_node_id = builder.add_node("bytecode selector circuit", &selector.circuit, vec![])?; - let table_node_id = builder.add_node( - "bytecode table circuit", - &bytecode_circuit, - vec![PredType::Source; 2], - )?; + let table_node_id = builder.add_node("bytecode table circuit", &bytecode_circuit, vec![PredType::Source; 2])?; Ok(( PredType::PredWire(NodeOutputType::OutputLayer(table_node_id)), diff --git a/singer-utils/src/chips/calldata.rs b/singer-utils/src/chips/calldata.rs index 2968ea0f2..261e5818c 100644 --- a/singer-utils/src/chips/calldata.rs +++ b/singer-utils/src/chips/calldata.rs @@ -96,14 +96,9 @@ pub(crate) fn construct_calldata_table( let calldata_circuit = construct_circuit(challenges); let selector = ChipCircuitGadgets::construct_prefix_selector(program_input_len, 1); - let selector_node_id = - builder.add_node("calldata selector circuit", &selector.circuit, vec![])?; + let selector_node_id = builder.add_node("calldata selector circuit", &selector.circuit, vec![])?; - let table_node_id = builder.add_node( - "calldata table circuit", - &calldata_circuit, - vec![PredType::Source; 2], - )?; + let table_node_id = builder.add_node("calldata table circuit", &calldata_circuit, vec![PredType::Source; 2])?; Ok(( PredType::PredWire(NodeOutputType::OutputLayer(table_node_id)), diff --git a/singer-utils/src/chips/circuit_gadgets.rs b/singer-utils/src/chips/circuit_gadgets.rs index a6f881168..ba5e3cf07 100644 --- a/singer-utils/src/chips/circuit_gadgets.rs +++ b/singer-utils/src/chips/circuit_gadgets.rs @@ -48,10 +48,7 @@ impl ChipCircuitGadgets { /// Construct a selector for n_instances and each instance contains `num` /// items. `num` must be a power of 2. - pub(crate) fn construct_prefix_selector( - n_instances: usize, - num: usize, - ) -> PrefixSelectorCircuit { + pub(crate) fn construct_prefix_selector(n_instances: usize, num: usize) -> PrefixSelectorCircuit { assert_eq!(num, num.next_power_of_two()); let mut circuit_builder = CircuitBuilder::::new(); let _ = circuit_builder.create_constant_in(n_instances * num, 1); @@ -75,12 +72,7 @@ impl ChipCircuitGadgets { let den_mul = circuit_builder.create_ext_cell(); circuit_builder.mul2_ext(&den_mul, &input[0], &input[1], E::BaseField::ONE); let tmp = circuit_builder.create_ext_cell(); - circuit_builder.sel_mixed_and_ext( - &tmp, - &MixedCell::Constant(E::BaseField::ONE), - &input[0], - cond[0], - ); + circuit_builder.sel_mixed_and_ext(&tmp, &MixedCell::Constant(E::BaseField::ONE), &input[0], cond[0]); circuit_builder.sel_ext(&output[0], &tmp, &den_mul, cond[1]); // select the numerator 0 or 1 or input[0] + input[1] @@ -114,12 +106,7 @@ impl ChipCircuitGadgets { let den_mul = circuit_builder.create_ext_cell(); circuit_builder.mul2_ext(&den_mul, &input_den[0], &input_den[1], E::BaseField::ONE); let tmp = circuit_builder.create_ext_cell(); - circuit_builder.sel_mixed_and_ext( - &tmp, - &MixedCell::Constant(E::BaseField::ONE), - &input_den[0], - cond[0], - ); + circuit_builder.sel_mixed_and_ext(&tmp, &MixedCell::Constant(E::BaseField::ONE), &input_den[0], cond[0]); circuit_builder.sel_ext(&output[0], &tmp, &den_mul, cond[1]); // select the numerator, 0 or input_num[0] or input_den[0] * input_num[1] + input_num[0] * input_den[1] @@ -231,12 +218,7 @@ impl ChipCircuitGadgets { let mul = circuit_builder.create_ext_cell(); circuit_builder.mul2_ext(&mul, &input[0], &input[1], E::BaseField::ONE); let tmp = circuit_builder.create_ext_cell(); - circuit_builder.sel_mixed_and_ext( - &tmp, - &MixedCell::Constant(E::BaseField::ONE), - &input[0], - sel[0], - ); + circuit_builder.sel_mixed_and_ext(&tmp, &MixedCell::Constant(E::BaseField::ONE), &input[0], sel[0]); circuit_builder.sel_ext(&output[0], &tmp, &mul, sel[1]); circuit_builder.configure(); diff --git a/singer-utils/src/constants.rs b/singer-utils/src/constants.rs index 7427f4a62..1f27b52ad 100644 --- a/singer-utils/src/constants.rs +++ b/singer-utils/src/constants.rs @@ -28,4 +28,12 @@ pub enum OpcodeType { SWAP2 = 0x91, SWAP4 = 0x93, RETURN = 0xf3, + // risc-v + RISCV = 0xFF, +} + +impl From for u64 { + fn from(opcode: OpcodeType) -> Self { + opcode as u64 + } } diff --git a/singer-utils/src/lib.rs b/singer-utils/src/lib.rs index 5e98a75b6..9f1e3c957 100644 --- a/singer-utils/src/lib.rs +++ b/singer-utils/src/lib.rs @@ -4,6 +4,7 @@ pub mod chip_handler; pub mod chips; pub mod constants; pub mod error; +pub mod riscv_constant; pub mod structs; pub mod uint; diff --git a/singer-utils/src/riscv_constant.rs b/singer-utils/src/riscv_constant.rs new file mode 100644 index 000000000..c82318202 --- /dev/null +++ b/singer-utils/src/riscv_constant.rs @@ -0,0 +1,171 @@ +use strum_macros::EnumIter; + +/// This struct is used to define the opcode format for RISC-V instructions, +/// containing three main components: the opcode, funct3, and funct7 fields. +/// These fields are crucial for specifying the +/// exact operation and variants in the RISC-V instruction set architecture. +#[derive(Default, Clone)] +pub struct RvOpcode { + pub opcode: RV64IOpcode, + pub funct3: u8, + pub funct7: u8, +} + +impl From for u64 { + fn from(opcode: RvOpcode) -> Self { + let mut result: u64 = 0; + result |= (opcode.opcode as u64) & 0xFF; + result |= ((opcode.funct3 as u64) & 0xFF) << 8; + result |= ((opcode.funct7 as u64) & 0xFF) << 16; + result + } +} + +/// List all instruction formats in RV64I which contains +/// R-Type, I-Type, S-Type, B-Type, U-Type, J-Type and special type. +#[derive(Debug, Clone)] +pub enum RV64IOpcode { + UNKNOWN = 0x00, + + R = 0x33, + I_LOAD = 0x03, + I_ARITH = 0x13, + S = 0x63, + B = 0x23, + U_LUI = 0x37, + U_AUIPC = 0x7, + J = 0x6F, + JAR = 0x67, + SYS = 0x73, +} + +impl Default for RV64IOpcode { + fn default() -> Self { + RV64IOpcode::UNKNOWN + } +} + +impl From for u8 { + fn from(opcode: RV64IOpcode) -> Self { + opcode as u8 + } +} + +#[derive(Debug, Clone, Copy, EnumIter)] +pub enum RvInstructions { + // Type R + ADD = 0, + SUB, + SLL, + SLTU, + SLT, + XOR, + SRL, + SRA, + OR, + AND, + // Type I-LOAD + LB, + LH, + LW, + LBU, + LHU, + + // a workaround to get number of valid instructions + END, +} + +impl From for RvOpcode { + fn from(ins: RvInstructions) -> Self { + // Find the instruction format here: + // https://fraserinnovations.com/risc-v/risc-v-instruction-set-explanation/ + match ins { + // Type R + RvInstructions::ADD => RvOpcode { + opcode: RV64IOpcode::R, + funct3: 0b000 as u8, + funct7: 0, + }, + RvInstructions::SUB => RvOpcode { + opcode: RV64IOpcode::R, + funct3: 0b000 as u8, + funct7: 0b010_0000, + }, + RvInstructions::SLL => RvOpcode { + opcode: RV64IOpcode::R, + funct3: 0b001 as u8, + funct7: 0, + }, + RvInstructions::SLT => RvOpcode { + opcode: RV64IOpcode::R, + funct3: 0b010 as u8, + funct7: 0, + }, + RvInstructions::SLTU => RvOpcode { + opcode: RV64IOpcode::R, + funct3: 0b011 as u8, + funct7: 0, + }, + RvInstructions::XOR => RvOpcode { + opcode: RV64IOpcode::R, + funct3: 0b100 as u8, + funct7: 0, + }, + RvInstructions::SRL => RvOpcode { + opcode: RV64IOpcode::R, + funct3: 0b101 as u8, + funct7: 0, + }, + RvInstructions::SRA => RvOpcode { + opcode: RV64IOpcode::R, + funct3: 0b101 as u8, + funct7: 0b010_0000, + }, + RvInstructions::OR => RvOpcode { + opcode: RV64IOpcode::R, + funct3: 0b110 as u8, + funct7: 0, + }, + RvInstructions::AND => RvOpcode { + opcode: RV64IOpcode::R, + funct3: 0b111 as u8, + funct7: 0, + }, + // Type I-LOAD + RvInstructions::LB => RvOpcode { + opcode: RV64IOpcode::I_LOAD, + funct3: 0b000 as u8, + funct7: 0, + }, + RvInstructions::LH => RvOpcode { + opcode: RV64IOpcode::I_LOAD, + funct3: 0b001 as u8, + funct7: 0, + }, + RvInstructions::LW => RvOpcode { + opcode: RV64IOpcode::I_LOAD, + funct3: 0b010 as u8, + funct7: 0, + }, + RvInstructions::LBU => RvOpcode { + opcode: RV64IOpcode::I_LOAD, + funct3: 0b100 as u8, + funct7: 0, + }, + RvInstructions::LHU => RvOpcode { + opcode: RV64IOpcode::I_LOAD, + funct3: 0b101 as u8, + funct7: 0, + }, + // TODO add more + _ => RvOpcode::default(), + } + } +} + +impl From for u64 { + fn from(ins: RvInstructions) -> Self { + let opcode: RvOpcode = ins.into(); + opcode.into() + } +} diff --git a/singer-utils/src/structs.rs b/singer-utils/src/structs.rs index 50416aa08..c1c81caa2 100644 --- a/singer-utils/src/structs.rs +++ b/singer-utils/src/structs.rs @@ -13,6 +13,7 @@ pub enum RAMType { Stack, Memory, GlobalState, + Register, } #[derive(Clone, Debug, Copy, EnumIter)] @@ -54,3 +55,4 @@ pub type UInt64 = UInt<64, VALUE_BIT_WIDTH>; pub type PCUInt = UInt64; pub type TSUInt = UInt<48, 48>; pub type StackUInt = UInt<{ EVM_STACK_BIT_WIDTH as usize }, { VALUE_BIT_WIDTH as usize }>; +pub type RegisterUInt = UInt64; diff --git a/singer-utils/src/uint/arithmetic.rs b/singer-utils/src/uint/arithmetic.rs index b98f17dc5..02d131288 100644 --- a/singer-utils/src/uint/arithmetic.rs +++ b/singer-utils/src/uint/arithmetic.rs @@ -32,9 +32,7 @@ impl UInt { addend_1: &UInt, carry: &[CellId], ) -> Result, UtilError> { - let result: UInt = circuit_builder - .create_cells(Self::N_OPERAND_CELLS) - .try_into()?; + let result: UInt = circuit_builder.create_cells(Self::N_OPERAND_CELLS).try_into()?; for i in 0..Self::N_OPERAND_CELLS { let (a, b, result) = (addend_0.values[i], addend_1.values[i], result.values[i]); @@ -70,9 +68,7 @@ impl UInt { constant: E::BaseField, carry: &[CellId], ) -> Result, UtilError> { - let result: UInt = circuit_builder - .create_cells(Self::N_OPERAND_CELLS) - .try_into()?; + let result: UInt = circuit_builder.create_cells(Self::N_OPERAND_CELLS).try_into()?; // add constant to the first limb circuit_builder.add_const(result.values[0], constant); @@ -125,9 +121,7 @@ impl UInt { addend_1: CellId, carry: &[CellId], ) -> Result, UtilError> { - let result: UInt = circuit_builder - .create_cells(Self::N_OPERAND_CELLS) - .try_into()?; + let result: UInt = circuit_builder.create_cells(Self::N_OPERAND_CELLS).try_into()?; // add small_value to the first limb circuit_builder.add(result.values[0], addend_1, E::BaseField::ONE); @@ -180,13 +174,10 @@ impl UInt { subtrahend: &UInt, borrow: &[CellId], ) -> Result, UtilError> { - let result: UInt = circuit_builder - .create_cells(Self::N_OPERAND_CELLS) - .try_into()?; + let result: UInt = circuit_builder.create_cells(Self::N_OPERAND_CELLS).try_into()?; for i in 0..Self::N_OPERAND_CELLS { - let (minuend, subtrahend, result) = - (minuend.values[i], subtrahend.values[i], result.values[i]); + let (minuend, subtrahend, result) = (minuend.values[i], subtrahend.values[i], result.values[i]); circuit_builder.add(result, minuend, E::BaseField::ONE); circuit_builder.add(result, subtrahend, -E::BaseField::ONE); @@ -235,11 +226,7 @@ impl UInt { // handle overflow carry // we need to subtract the carry value from the current result if limb_index < carry.len() { - circuit_builder.add( - result_cell_id, - carry[limb_index], - -E::BaseField::from(1 << C), - ); + circuit_builder.add(result_cell_id, carry[limb_index], -E::BaseField::from(1 << C)); } // handle last operation carry @@ -275,11 +262,7 @@ impl UInt { // handle borrow // we need to add borrow units of C to the result if limb_index < borrow.len() { - circuit_builder.add( - result_cell_id, - borrow[limb_index], - E::BaseField::from(1 << C), - ); + circuit_builder.add(result_cell_id, borrow[limb_index], E::BaseField::from(1 << C)); } // handle last borrow @@ -321,19 +304,15 @@ mod tests { // input wires // addend_0, addend_1, carry - let (addend_0_id, addend_0_cells) = - circuit_builder.create_witness_in(UInt20::N_OPERAND_CELLS); - let (addend_1_id, addend_1_cells) = - circuit_builder.create_witness_in(UInt20::N_OPERAND_CELLS); - let (carry_id, carry_cells) = - circuit_builder.create_witness_in(AddSubConstants::::N_CARRY_CELLS); + let (addend_0_id, addend_0_cells) = circuit_builder.create_witness_in(UInt20::N_OPERAND_CELLS); + let (addend_1_id, addend_1_cells) = circuit_builder.create_witness_in(UInt20::N_OPERAND_CELLS); + let (carry_id, carry_cells) = circuit_builder.create_witness_in(AddSubConstants::::N_CARRY_CELLS); let addend_0 = UInt20::try_from(addend_0_cells).expect("should build uint"); let addend_1 = UInt20::try_from(addend_1_cells).expect("should build uint"); // update circuit builder with circuit instructions - let _ = - UInt20::add_unsafe(&mut circuit_builder, &addend_0, &addend_1, &carry_cells).unwrap(); + let _ = UInt20::add_unsafe(&mut circuit_builder, &addend_0, &addend_1, &carry_cells).unwrap(); circuit_builder.configure(); let circuit = Circuit::new(&circuit_builder); @@ -373,10 +352,7 @@ mod tests { let result_values = circuit_witness.output_layer_witness_ref().instances[0].to_vec(); assert_eq!( result_values, - [14, 17, 31, 14] - .into_iter() - .map(|v| Goldilocks::from(v)) - .collect_vec() + [14, 17, 31, 14].into_iter().map(|v| Goldilocks::from(v)).collect_vec() ); } @@ -398,21 +374,13 @@ mod tests { // input wires // addend_0, carry, constant - let (addend_0_id, addend_0_cells) = - circuit_builder.create_witness_in(UInt20::N_OPERAND_CELLS); - let (carry_id, carry_cells) = - circuit_builder.create_witness_in(AddSubConstants::::N_CARRY_CELLS); + let (addend_0_id, addend_0_cells) = circuit_builder.create_witness_in(UInt20::N_OPERAND_CELLS); + let (carry_id, carry_cells) = circuit_builder.create_witness_in(AddSubConstants::::N_CARRY_CELLS); let addend_0 = UInt20::try_from(addend_0_cells).expect("should build uint"); // update circuit builder - let _ = UInt20::add_const_unsafe( - &mut circuit_builder, - &addend_0, - Goldilocks::from(200), - &carry_cells, - ) - .unwrap(); + let _ = UInt20::add_const_unsafe(&mut circuit_builder, &addend_0, Goldilocks::from(200), &carry_cells).unwrap(); circuit_builder.configure(); let circuit = Circuit::new(&circuit_builder); @@ -446,10 +414,7 @@ mod tests { let result_values = circuit_witness.output_layer_witness_ref().instances[0].to_vec(); assert_eq!( result_values, - [22, 2, 0, 15] - .into_iter() - .map(|v| Goldilocks::from(v)) - .collect_vec() + [22, 2, 0, 15].into_iter().map(|v| Goldilocks::from(v)).collect_vec() ); } @@ -471,22 +436,14 @@ mod tests { // input wires // addend_0, carry, constant - let (addend_0_id, addend_0_cells) = - circuit_builder.create_witness_in(UInt20::N_OPERAND_CELLS); + let (addend_0_id, addend_0_cells) = circuit_builder.create_witness_in(UInt20::N_OPERAND_CELLS); let (small_value_id, small_value_cell) = circuit_builder.create_witness_in(1); - let (carry_id, carry_cells) = - circuit_builder.create_witness_in(AddSubConstants::::N_CARRY_CELLS); + let (carry_id, carry_cells) = circuit_builder.create_witness_in(AddSubConstants::::N_CARRY_CELLS); let addend_0 = UInt20::try_from(addend_0_cells).expect("should build uint"); // update circuit builder - let _ = UInt20::add_cell_unsafe( - &mut circuit_builder, - &addend_0, - small_value_cell[0], - &carry_cells, - ) - .unwrap(); + let _ = UInt20::add_cell_unsafe(&mut circuit_builder, &addend_0, small_value_cell[0], &carry_cells).unwrap(); circuit_builder.configure(); let circuit = Circuit::new(&circuit_builder); @@ -497,10 +454,7 @@ mod tests { .rev() .map(|v| Goldilocks::from(v)) .collect_vec(); - let small_value_witness = vec![200] - .into_iter() - .map(|v| Goldilocks::from(v)) - .collect_vec(); + let small_value_witness = vec![200].into_iter().map(|v| Goldilocks::from(v)).collect_vec(); let carry_witness = vec![0, 1, 1, 6] .into_iter() .rev() @@ -525,10 +479,7 @@ mod tests { let result_values = circuit_witness.output_layer_witness_ref().instances[0].to_vec(); assert_eq!( result_values, - [22, 2, 0, 15] - .into_iter() - .map(|v| Goldilocks::from(v)) - .collect_vec() + [22, 2, 0, 15].into_iter().map(|v| Goldilocks::from(v)).collect_vec() ); } @@ -547,20 +498,16 @@ mod tests { // input wires // minuend, subtrahend, borrow - let (minuend_id, minuend_cells) = - circuit_builder.create_witness_in(UInt20::N_OPERAND_CELLS); - let (subtrahend_id, subtrahend_cells) = - circuit_builder.create_witness_in(UInt20::N_OPERAND_CELLS); + let (minuend_id, minuend_cells) = circuit_builder.create_witness_in(UInt20::N_OPERAND_CELLS); + let (subtrahend_id, subtrahend_cells) = circuit_builder.create_witness_in(UInt20::N_OPERAND_CELLS); // |Carry| == |Borrow| - let (borrow_id, borrow_cells) = - circuit_builder.create_witness_in(AddSubConstants::::N_CARRY_CELLS); + let (borrow_id, borrow_cells) = circuit_builder.create_witness_in(AddSubConstants::::N_CARRY_CELLS); let minuend = UInt20::try_from(minuend_cells).expect("should build uint"); let subtrahend = UInt20::try_from(subtrahend_cells).expect("should build uint"); // update the circuit builder - let _ = - UInt20::sub_unsafe(&mut circuit_builder, &minuend, &subtrahend, &borrow_cells).unwrap(); + let _ = UInt20::sub_unsafe(&mut circuit_builder, &minuend, &subtrahend, &borrow_cells).unwrap(); circuit_builder.configure(); let circuit = Circuit::new(&circuit_builder); @@ -600,10 +547,7 @@ mod tests { let result_values = circuit_witness.output_layer_witness_ref().instances[0].to_vec(); assert_eq!( result_values, - [20, 30, 21, 3] - .into_iter() - .map(|v| Goldilocks::from(v)) - .collect_vec() + [20, 30, 21, 3].into_iter().map(|v| Goldilocks::from(v)).collect_vec() ); } } diff --git a/singer-utils/src/uint/cmp.rs b/singer-utils/src/uint/cmp.rs index 48920337b..073f14ebe 100644 --- a/singer-utils/src/uint/cmp.rs +++ b/singer-utils/src/uint/cmp.rs @@ -20,11 +20,7 @@ impl UInt { let range_values = Self::extract_range_values(witness); let computed_diff = Self::sub_unsafe(circuit_builder, operand_0, operand_1, borrow)?; - let diff = range_chip_handler.range_check_uint( - circuit_builder, - &computed_diff, - Some(&range_values), - )?; + let diff = range_chip_handler.range_check_uint(circuit_builder, &computed_diff, Some(&range_values))?; // if operand_0 < operand_1, the last borrow should equal 1 if borrow.len() == AddSubConstants::::N_CARRY_CELLS { @@ -42,13 +38,7 @@ impl UInt { operand_1: &UInt, witness: &[CellId], ) -> Result<(), UtilError> { - let (borrow, _) = Self::lt( - circuit_builder, - range_chip_handler, - operand_0, - operand_1, - witness, - )?; + let (borrow, _) = Self::lt(circuit_builder, range_chip_handler, operand_0, operand_1, witness)?; circuit_builder.assert_const(borrow, 1); Ok(()) } @@ -61,13 +51,7 @@ impl UInt { operand_1: &UInt, witness: &[CellId], ) -> Result<(), UtilError> { - let (borrow, diff) = Self::lt( - circuit_builder, - range_chip_handler, - operand_0, - operand_1, - witness, - )?; + let (borrow, diff) = Self::lt(circuit_builder, range_chip_handler, operand_0, operand_1, witness)?; // we have two scenarios // 1. eq @@ -81,12 +65,7 @@ impl UInt { let diff_values = diff.values(); for d in diff_values.iter() { let s = circuit_builder.create_cell(); - circuit_builder.sel_mixed( - s, - (*d).into(), - MixedCell::Constant(E::BaseField::ZERO), - borrow, - ); + circuit_builder.sel_mixed(s, (*d).into(), MixedCell::Constant(E::BaseField::ZERO), borrow); circuit_builder.assert_const(s, 0); } diff --git a/singer-utils/src/uint/constants.rs b/singer-utils/src/uint/constants.rs index b36405760..15498f167 100644 --- a/singer-utils/src/uint/constants.rs +++ b/singer-utils/src/uint/constants.rs @@ -53,8 +53,7 @@ impl AddSubConstants> { /// The size of the witness assuming carry has no overflow /// |Range_values| + |Carry - 1| - pub const N_WITNESS_CELLS_NO_CARRY_OVERFLOW: usize = - UInt::::N_RANGE_CELLS + Self::N_CARRY_CELLS_NO_OVERFLOW; + pub const N_WITNESS_CELLS_NO_CARRY_OVERFLOW: usize = UInt::::N_RANGE_CELLS + Self::N_CARRY_CELLS_NO_OVERFLOW; pub const N_NO_OVERFLOW_WITNESS_UNSAFE_CELLS: usize = Self::N_CARRY_CELLS_NO_OVERFLOW; diff --git a/singer-utils/src/uint/uint.rs b/singer-utils/src/uint/uint.rs index bfb9ec00b..8d532e7fd 100644 --- a/singer-utils/src/uint/uint.rs +++ b/singer-utils/src/uint/uint.rs @@ -27,12 +27,7 @@ impl UInt { circuit_builder: &mut CircuitBuilder, range_values: &[CellId], ) -> Result { - Self::from_different_sized_cell_values( - circuit_builder, - range_values, - RANGE_CHIP_BIT_WIDTH, - true, - ) + Self::from_different_sized_cell_values(circuit_builder, range_values, RANGE_CHIP_BIT_WIDTH, true) } /// Builds a `UInt` instance from a set of cells that represent big-endian `BYTE_VALUES` @@ -57,12 +52,7 @@ impl UInt { bytes: &[CellId], is_little_endian: bool, ) -> Result { - Self::from_different_sized_cell_values( - circuit_builder, - bytes, - BYTE_BIT_WIDTH, - is_little_endian, - ) + Self::from_different_sized_cell_values(circuit_builder, bytes, BYTE_BIT_WIDTH, is_little_endian) } /// Builds a `UInt` instance from a set of cell values of a certain `CELL_WIDTH` @@ -153,9 +143,7 @@ mod tests { let mut circuit_builder = CircuitBuilder::::new(); let (_, small_values) = circuit_builder.create_witness_in(8); type UInt30 = UInt<30, 6>; - let _ = - UInt30::from_different_sized_cell_values(&mut circuit_builder, &small_values, 2, true) - .unwrap(); + let _ = UInt30::from_different_sized_cell_values(&mut circuit_builder, &small_values, 2, true).unwrap(); circuit_builder.configure(); let circuit = Circuit::new(&circuit_builder); @@ -196,10 +184,7 @@ mod tests { // padding to power of 2 assert_eq!( &output[5..], - vec![0, 0, 0] - .into_iter() - .map(|v| Goldilocks::from(v)) - .collect_vec() + vec![0, 0, 0].into_iter().map(|v| Goldilocks::from(v)).collect_vec() ); } @@ -224,35 +209,15 @@ mod tests { res, vec![ // 0 - vec![ - Goldilocks::from(0), - Goldilocks::from(0), - Goldilocks::from(0) - ], + vec![Goldilocks::from(0), Goldilocks::from(0), Goldilocks::from(0)], // 1 - vec![ - Goldilocks::from(1), - Goldilocks::from(0), - Goldilocks::from(0) - ], + vec![Goldilocks::from(1), Goldilocks::from(0), Goldilocks::from(0)], // 2 - vec![ - Goldilocks::from(0), - Goldilocks::from(1), - Goldilocks::from(0) - ], + vec![Goldilocks::from(0), Goldilocks::from(1), Goldilocks::from(0)], // 3 - vec![ - Goldilocks::from(1), - Goldilocks::from(1), - Goldilocks::from(0) - ], + vec![Goldilocks::from(1), Goldilocks::from(1), Goldilocks::from(0)], // 4 - vec![ - Goldilocks::from(0), - Goldilocks::from(0), - Goldilocks::from(1) - ], + vec![Goldilocks::from(0), Goldilocks::from(0), Goldilocks::from(1)], ] ); } diff --git a/singer-utils/src/uint/util.rs b/singer-utils/src/uint/util.rs index 775419791..b486fdae1 100644 --- a/singer-utils/src/uint/util.rs +++ b/singer-utils/src/uint/util.rs @@ -49,11 +49,7 @@ pub fn convert_decomp( let big_cell = circuit_builder.create_cell(); for (small_chunk_index, small_bit_cell) in values.iter().enumerate() { let shift_size = small_chunk_index * small_cell_bit_width; - circuit_builder.add( - big_cell, - *small_bit_cell, - E::BaseField::from(1 << shift_size), - ); + circuit_builder.add(big_cell, *small_bit_cell, E::BaseField::from(1 << shift_size)); } new_cell_ids.push(big_cell); } @@ -62,11 +58,7 @@ pub fn convert_decomp( } /// Pads a `Vec` with new cells to reach some given size n -pub fn pad_cells( - circuit_builder: &mut CircuitBuilder, - cells: &mut Vec, - size: usize, -) { +pub fn pad_cells(circuit_builder: &mut CircuitBuilder, cells: &mut Vec, size: usize) { if cells.len() < size { cells.extend(circuit_builder.create_cells(size - cells.len())) } @@ -119,14 +111,7 @@ mod tests { let (_, big_values) = circuit_builder.create_witness_in(5); let big_bit_width = 5; let small_bit_width = 2; - let _ = convert_decomp( - &mut circuit_builder, - &big_values, - big_bit_width, - small_bit_width, - true, - ) - .unwrap(); + let _ = convert_decomp(&mut circuit_builder, &big_values, big_bit_width, small_bit_width, true).unwrap(); } #[test] @@ -203,10 +188,7 @@ mod tests { // padding to power of 2 assert_eq!( &output[3..], - vec![0] - .into_iter() - .map(|v| Goldilocks::from(v)) - .collect_vec() + vec![0].into_iter().map(|v| Goldilocks::from(v)).collect_vec() ); } @@ -240,77 +222,49 @@ mod tests { let updated_limbs = add_one_to_big_num(limb_modulo, &initial_limbs); assert_eq!( updated_limbs, - vec![ - Goldilocks::from(1), - Goldilocks::from(0), - Goldilocks::from(0) - ] + vec![Goldilocks::from(1), Goldilocks::from(0), Goldilocks::from(0)] ); // 010 let updated_limbs = add_one_to_big_num(limb_modulo, &updated_limbs); assert_eq!( updated_limbs, - vec![ - Goldilocks::from(0), - Goldilocks::from(1), - Goldilocks::from(0) - ] + vec![Goldilocks::from(0), Goldilocks::from(1), Goldilocks::from(0)] ); // 110 let updated_limbs = add_one_to_big_num(limb_modulo, &updated_limbs); assert_eq!( updated_limbs, - vec![ - Goldilocks::from(1), - Goldilocks::from(1), - Goldilocks::from(0) - ] + vec![Goldilocks::from(1), Goldilocks::from(1), Goldilocks::from(0)] ); // 001 let updated_limbs = add_one_to_big_num(limb_modulo, &updated_limbs); assert_eq!( updated_limbs, - vec![ - Goldilocks::from(0), - Goldilocks::from(0), - Goldilocks::from(1) - ] + vec![Goldilocks::from(0), Goldilocks::from(0), Goldilocks::from(1)] ); // 101 let updated_limbs = add_one_to_big_num(limb_modulo, &updated_limbs); assert_eq!( updated_limbs, - vec![ - Goldilocks::from(1), - Goldilocks::from(0), - Goldilocks::from(1) - ] + vec![Goldilocks::from(1), Goldilocks::from(0), Goldilocks::from(1)] ); // 011 let updated_limbs = add_one_to_big_num(limb_modulo, &updated_limbs); assert_eq!( updated_limbs, - vec![ - Goldilocks::from(0), - Goldilocks::from(1), - Goldilocks::from(1) - ] + vec![Goldilocks::from(0), Goldilocks::from(1), Goldilocks::from(1)] ); // 111 let updated_limbs = add_one_to_big_num(limb_modulo, &updated_limbs); assert_eq!( updated_limbs, - vec![ - Goldilocks::from(1), - Goldilocks::from(1), - Goldilocks::from(1) - ] + vec![Goldilocks::from(1), Goldilocks::from(1), Goldilocks::from(1)] ); // restart cycle @@ -318,11 +272,7 @@ mod tests { let updated_limbs = add_one_to_big_num(limb_modulo, &updated_limbs); assert_eq!( updated_limbs, - vec![ - Goldilocks::from(0), - Goldilocks::from(0), - Goldilocks::from(0) - ] + vec![Goldilocks::from(0), Goldilocks::from(0), Goldilocks::from(0)] ); } } diff --git a/singer-utils/src/uint/witness_extractors.rs b/singer-utils/src/uint/witness_extractors.rs index 32bf8c409..e70b2e8cb 100644 --- a/singer-utils/src/uint/witness_extractors.rs +++ b/singer-utils/src/uint/witness_extractors.rs @@ -1,5 +1,4 @@ -use crate::uint::constants::AddSubConstants; -use crate::uint::uint::UInt; +use crate::uint::{constants::AddSubConstants, uint::UInt}; use simple_frontend::structs::CellId; // TODO: split this into different impls, constrained by specific contexts diff --git a/singer/Cargo.toml b/singer/Cargo.toml index e81d0bb27..d956ed167 100644 --- a/singer/Cargo.toml +++ b/singer/Cargo.toml @@ -39,11 +39,21 @@ const_env = "0.1.2" [features] witness-count = [] test-dbg = [] -dbg-add-opcode = [] +dbg-opcode = [] [[bench]] name = "add" harness = false +[[bench]] +name = "rv_add" +harness = false +path = "benches/riscv/add.rs" + +[[example]] +name = "rv_add" +harness = false +path = "examples/riscv/add.rs" + [profile.bench] opt-level = 0 \ No newline at end of file diff --git a/singer/benches/add.rs b/singer/benches/add.rs index d674f9795..b4477f831 100644 --- a/singer/benches/add.rs +++ b/singer/benches/add.rs @@ -66,8 +66,7 @@ fn bench_add(c: &mut Criterion) { } }; let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); for instance_num_vars in 10..14 { // expand more input size once runtime is acceptable @@ -83,32 +82,27 @@ fn bench_add(c: &mut Criterion) { let mut rng = test_rng(); let singer_builder = SingerGraphBuilder::::new(); let real_challenges = vec![E::random(&mut rng), E::random(&mut rng)]; - (rng, singer_builder, real_challenges) + (rng, singer_builder, real_challenges) }, - |(mut rng,mut singer_builder, real_challenges)| { + |(mut rng, mut singer_builder, real_challenges)| { let size = AddInstruction::phase0_size(); - let phase0: CircuitWiresIn<::BaseField> = vec![LayerWitness { - instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| { - ::BaseField::random( - &mut rng, - ) - }) - .collect_vec() - }) - .collect_vec(), - }]; - + let phase0: CircuitWiresIn<::BaseField> = + vec![LayerWitness { + instances: (0..(1 << instance_num_vars)) + .map(|_| { + (0..size) + .map(|_| ::BaseField::random(&mut rng)) + .collect_vec() + }) + .collect_vec(), + }]; let timer = Instant::now(); let _ = AddInstruction::construct_graph_and_witness( &mut singer_builder.graph_builder, &mut singer_builder.chip_builder, - &circuit_builder.insts_circuits - [>::OPCODE as usize], + &circuit_builder.insts_circuits[>::OPCODE as usize], vec![phase0], &real_challenges, 1 << instance_num_vars, @@ -143,7 +137,8 @@ fn bench_add(c: &mut Criterion) { instance_num_vars, timer.elapsed().as_secs_f64() ); - }); + }, + ); }, ); diff --git a/singer/benches/riscv/add.rs b/singer/benches/riscv/add.rs new file mode 100644 index 000000000..3696024ab --- /dev/null +++ b/singer/benches/riscv/add.rs @@ -0,0 +1,151 @@ +#![allow(clippy::manual_memcpy)] +#![allow(clippy::needless_range_loop)] + +use std::time::{Duration, Instant}; + +use ark_std::test_rng; +use const_env::from_env; +use criterion::*; + +use ff_ext::{ff::Field, ExtensionField}; +use gkr::structs::LayerWitness; +use goldilocks::GoldilocksExt2; +use itertools::Itertools; + +cfg_if::cfg_if! { + if #[cfg(feature = "flamegraph")] { + criterion_group! { + name = op_add; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)).with_profiler(pprof::criterion::PProfProfiler::new(100, pprof::criterion::Output::Flamegraph(None))); + targets = bench_add + } + } else { + criterion_group! { + name = op_add; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_add + } + } +} + +criterion_main!(op_add); + +const NUM_SAMPLES: usize = 10; +#[from_env] +const RAYON_NUM_THREADS: usize = 8; + +use singer::{ + instructions::{self, riscv::add::AddInstruction, Instruction, InstructionGraph, SingerCircuitBuilder}, + scheme::GKRGraphProverState, + CircuitWiresIn, SingerGraphBuilder, SingerParams, +}; +use singer_utils::structs::ChipChallenges; +use transcript::Transcript; + +pub fn is_power_of_2(x: usize) -> bool { + (x != 0) && ((x & (x - 1)) == 0) +} + +fn bench_add(c: &mut Criterion) { + let max_thread_id = { + if !is_power_of_2(RAYON_NUM_THREADS) { + #[cfg(not(feature = "non_pow2_rayon_thread"))] + { + panic!("add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool"); + } + + #[cfg(feature = "non_pow2_rayon_thread")] + { + use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; + let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); + create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); + max_thread_id + } + } else { + RAYON_NUM_THREADS + } + }; + let chip_challenges = ChipChallenges::default(); + let circuit_builder = SingerCircuitBuilder::::new_riscv(chip_challenges).expect("circuit builder failed"); + + for instance_num_vars in 11..12 { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("add_op_{}", instance_num_vars)); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function( + BenchmarkId::new("prove_keccak256", format!("keccak256_log2_{}", instance_num_vars)), + |b| { + b.iter_with_setup( + || { + let mut rng = test_rng(); + let singer_builder = SingerGraphBuilder::::new(); + + let real_challenges = vec![E::random(&mut rng), E::random(&mut rng)]; + (rng, singer_builder, real_challenges) + }, + |(mut rng, mut singer_builder, real_challenges)| { + let size = AddInstruction::phase0_size(); + + let phase0: CircuitWiresIn<::BaseField> = + vec![LayerWitness { + instances: (0..(1 << instance_num_vars)) + .map(|_| { + (0..size) + .map(|_| ::BaseField::random(&mut rng)) + .collect_vec() + }) + .collect_vec(), + }]; + + let timer = Instant::now(); + + let _ = AddInstruction::construct_graph_and_witness( + &mut singer_builder.graph_builder, + &mut singer_builder.chip_builder, + &circuit_builder.insts_circuits[instructions::riscv::add::RV_INSTRUCTION as usize], + vec![phase0], + &real_challenges, + 1 << instance_num_vars, + &SingerParams::default(), + ) + .expect("gkr graph construction failed"); + + let (graph, wit) = singer_builder.graph_builder.finalize_graph_and_witness(); + + println!( + "AddInstruction::construct_graph_and_witness, instance_num_vars = {}, time = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); + + let point = vec![E::random(&mut rng), E::random(&mut rng)]; + let target_evals = graph.target_evals(&wit, &point); + + let mut prover_transcript = &mut Transcript::new(b"Singer"); + + let timer = Instant::now(); + let _ = GKRGraphProverState::prove( + &graph, + &wit, + &target_evals, + &mut prover_transcript, + (1 << instance_num_vars).min(max_thread_id), + ) + .expect("prove failed"); + println!( + "AddInstruction::prove, instance_num_vars = {}, time = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); + }, + ); + }, + ); + + group.finish(); + } + + type E = GoldilocksExt2; +} diff --git a/singer/examples/add.rs b/singer/examples/add.rs index 552d0e331..35062c55f 100644 --- a/singer/examples/add.rs +++ b/singer/examples/add.rs @@ -23,26 +23,11 @@ use transcript::Transcript; fn get_single_instance_values_map() -> BTreeMap<&'static str, Vec> { let mut phase0_values_map = BTreeMap::<&'static str, Vec>::new(); - phase0_values_map.insert( - AddInstruction::phase0_pc_str(), - vec![Goldilocks::from(1u64)], - ); - phase0_values_map.insert( - AddInstruction::phase0_stack_ts_str(), - vec![Goldilocks::from(3u64)], - ); - phase0_values_map.insert( - AddInstruction::phase0_memory_ts_str(), - vec![Goldilocks::from(1u64)], - ); - phase0_values_map.insert( - AddInstruction::phase0_stack_top_str(), - vec![Goldilocks::from(100u64)], - ); - phase0_values_map.insert( - AddInstruction::phase0_clk_str(), - vec![Goldilocks::from(1u64)], - ); + phase0_values_map.insert(AddInstruction::phase0_pc_str(), vec![Goldilocks::from(1u64)]); + phase0_values_map.insert(AddInstruction::phase0_stack_ts_str(), vec![Goldilocks::from(3u64)]); + phase0_values_map.insert(AddInstruction::phase0_memory_ts_str(), vec![Goldilocks::from(1u64)]); + phase0_values_map.insert(AddInstruction::phase0_stack_top_str(), vec![Goldilocks::from(100u64)]); + phase0_values_map.insert(AddInstruction::phase0_clk_str(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( AddInstruction::phase0_pc_add_str(), vec![], // carry is 0, may test carry using larger values in PCUInt @@ -54,14 +39,10 @@ fn get_single_instance_values_map() -> BTreeMap<&'static str, Vec> { * cells are range values, stack_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), - Goldilocks::from(0u64), // no place for carry ], ); - phase0_values_map.insert( - AddInstruction::phase0_old_stack_ts0_str(), - vec![Goldilocks::from(2u64)], - ); + phase0_values_map.insert(AddInstruction::phase0_old_stack_ts0_str(), vec![Goldilocks::from(2u64)]); let m: u64 = (1 << TSUInt::C) - 1; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -73,10 +54,7 @@ fn get_single_instance_values_map() -> BTreeMap<&'static str, Vec> { Goldilocks::from(1u64), // borrow ], ); - phase0_values_map.insert( - AddInstruction::phase0_old_stack_ts1_str(), - vec![Goldilocks::from(1u64)], - ); + phase0_values_map.insert(AddInstruction::phase0_old_stack_ts1_str(), vec![Goldilocks::from(1u64)]); let m: u64 = (1 << TSUInt::C) - 2; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -88,25 +66,16 @@ fn get_single_instance_values_map() -> BTreeMap<&'static str, Vec> { Goldilocks::from(1u64), // borrow ], ); - let m: u64 = (1 << StackUInt::C) - 1; - phase0_values_map.insert( - AddInstruction::phase0_addend_0_str(), - vec![Goldilocks::from(m)], - ); - phase0_values_map.insert( - AddInstruction::phase0_addend_1_str(), - vec![Goldilocks::from(1u64)], - ); + let m: u64 = (1 << StackUInt::MAX_CELL_BIT_WIDTH) - 1; + phase0_values_map.insert(AddInstruction::phase0_addend_0_str(), vec![Goldilocks::from(m)]); + phase0_values_map.insert(AddInstruction::phase0_addend_1_str(), vec![Goldilocks::from(1u64)]); let range_values = u64vec::<{ StackUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m + 1); let mut wit_phase0_instruction_add: Vec = vec![]; for i in 0..16 { wit_phase0_instruction_add.push(Goldilocks::from(range_values[i])) } wit_phase0_instruction_add.push(Goldilocks::from(1u64)); // carry is [1, 0, ...] - phase0_values_map.insert( - AddInstruction::phase0_instruction_add_str(), - wit_phase0_instruction_add, - ); + phase0_values_map.insert(AddInstruction::phase0_instruction_add_str(), wit_phase0_instruction_add); phase0_values_map } fn main() { @@ -114,8 +83,7 @@ fn main() { let instance_num_vars = 11; type E = GoldilocksExt2; let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let mut singer_builder = SingerGraphBuilder::::new(); let mut rng = test_rng(); @@ -126,14 +94,8 @@ fn main() { let mut single_witness_in = vec![::BaseField::ZERO; size]; for key in phase0_idx_map.keys() { - let range = phase0_idx_map - .get(key) - .unwrap() - .clone() - .collect::>(); - let values = phase0_values_map - .get(key) - .expect(&("unknown key ".to_owned() + key)); + let range = phase0_idx_map.get(key).unwrap().clone().collect::>(); + let values = phase0_values_map.get(key).expect(&("unknown key ".to_owned() + key)); for (value_idx, cell_idx) in range.into_iter().enumerate() { if value_idx < values.len() { single_witness_in[cell_idx] = values[value_idx]; @@ -141,12 +103,11 @@ fn main() { } } - let phase0: CircuitWiresIn<::BaseField> = - vec![LayerWitness { - instances: (0..(1 << instance_num_vars)) - .map(|_| single_witness_in.clone()) - .collect_vec(), - }]; + let phase0: CircuitWiresIn<::BaseField> = vec![LayerWitness { + instances: (0..(1 << instance_num_vars)) + .map(|_| single_witness_in.clone()) + .collect_vec(), + }]; let real_challenges = vec![E::random(&mut rng), E::random(&mut rng)]; @@ -173,12 +134,7 @@ fn main() { let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let subscriber = Registry::default() - .with( - fmt::layer() - .compact() - .with_thread_ids(false) - .with_thread_names(false), - ) + .with(fmt::layer().compact().with_thread_ids(false).with_thread_names(false)) .with(EnvFilter::from_default_env()) .with(flame_layer.with_threads_collapsed(true)); tracing::subscriber::set_global_default(subscriber).unwrap(); diff --git a/singer/examples/push_and_pop.rs b/singer/examples/push_and_pop.rs index 9e1c7963f..70bb2dc4d 100644 --- a/singer/examples/push_and_pop.rs +++ b/singer/examples/push_and_pop.rs @@ -10,8 +10,7 @@ use transcript::Transcript; fn main() { let chip_challenges = ChipChallenges::default(); - let circuit_builder = SingerCircuitBuilder::::new(chip_challenges) - .expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let singer_builder = SingerGraphBuilder::::new(); let bytecode = [0x60 as u8, 0x01, 0x50]; diff --git a/singer/examples/riscv/add.rs b/singer/examples/riscv/add.rs new file mode 100644 index 000000000..6849b23e6 --- /dev/null +++ b/singer/examples/riscv/add.rs @@ -0,0 +1,206 @@ +use std::{collections::BTreeMap, time::Instant}; + +use ark_std::test_rng; +use ff_ext::{ff::Field, ExtensionField}; +use gkr::structs::LayerWitness; +use gkr_graph::structs::CircuitGraphAuxInfo; +use goldilocks::{Goldilocks, GoldilocksExt2}; +use itertools::Itertools; + +use simple_frontend::structs::CellId; +use singer::{ + instructions::{ + riscv::add::{AddInstruction, RV_INSTRUCTION}, + InstructionGraph, SingerCircuitBuilder, + }, + scheme::{GKRGraphProverState, GKRGraphVerifierState}, + u64vec, CircuitWiresIn, SingerGraphBuilder, SingerParams, +}; +use singer_utils::{ + constants::RANGE_CHIP_BIT_WIDTH, + structs::{ChipChallenges, TSUInt, UInt64}, +}; +use tracing_flame::FlameLayer; +use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry}; +use transcript::Transcript; + +fn get_single_instance_values_map() -> BTreeMap<&'static str, Vec> { + let mut phase0_values_map = BTreeMap::<&'static str, Vec>::new(); + phase0_values_map.insert(AddInstruction::phase0_pc_str(), vec![Goldilocks::from(1u64)]); + phase0_values_map.insert(AddInstruction::phase0_memory_ts_str(), vec![Goldilocks::from(3u64)]); + phase0_values_map.insert( + AddInstruction::phase0_next_memory_ts_str(), + vec![ + // first TSUInt::N_RANGE_CELLS = 1*(48/16) = 3 cells are range values. + // memory_ts + 1 = 4 + Goldilocks::from(4u64), + Goldilocks::from(0u64), + Goldilocks::from(0u64), + ], + ); + phase0_values_map.insert(AddInstruction::phase0_clk_str(), vec![Goldilocks::from(1u64)]); + phase0_values_map.insert( + AddInstruction::phase0_next_pc_str(), + vec![], // carry is 0, may test carry using larger values in PCUInt + ); + + // register id assigned + phase0_values_map.insert(AddInstruction::phase0_rs1_str(), vec![Goldilocks::from(1u64)]); + phase0_values_map.insert(AddInstruction::phase0_rs2_str(), vec![Goldilocks::from(2u64)]); + phase0_values_map.insert(AddInstruction::phase0_rd_str(), vec![Goldilocks::from(3u64)]); + + let m: u64 = (1 << UInt64::MAX_CELL_BIT_WIDTH) - 1; + phase0_values_map.insert(AddInstruction::phase0_addend_0_str(), vec![Goldilocks::from(m)]); + phase0_values_map.insert(AddInstruction::phase0_addend_1_str(), vec![Goldilocks::from(1u64)]); + let range_values = u64vec::<{ UInt64::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m + 1); + let mut wit_phase0_outcome: Vec = vec![]; + for i in 0..4 { + wit_phase0_outcome.push(Goldilocks::from(range_values[i])) + } + wit_phase0_outcome.push(Goldilocks::from(1u64)); // carry is [1, 0, ...] + phase0_values_map.insert(AddInstruction::phase0_outcome_str(), wit_phase0_outcome); + + phase0_values_map.insert( + AddInstruction::phase0_prev_rd_value_str(), + vec![Goldilocks::from(33u64)], + ); + + phase0_values_map.insert(AddInstruction::phase0_prev_rs1_ts_str(), vec![Goldilocks::from(2u64)]); + let m: u64 = (1 << TSUInt::MAX_CELL_BIT_WIDTH) - 1; + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + phase0_values_map.insert( + AddInstruction::phase0_prev_rs1_ts_lt_str(), + vec![ + Goldilocks::from(range_values[0]), + Goldilocks::from(range_values[1]), + Goldilocks::from(range_values[2]), + Goldilocks::from(1u64), // borrow + ], + ); + phase0_values_map.insert(AddInstruction::phase0_prev_rs2_ts_str(), vec![Goldilocks::from(1u64)]); + let m: u64 = (1 << TSUInt::MAX_CELL_BIT_WIDTH) - 2; + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + phase0_values_map.insert( + AddInstruction::phase0_prev_rs2_ts_lt_str(), + vec![ + Goldilocks::from(range_values[0]), + Goldilocks::from(range_values[1]), + Goldilocks::from(range_values[2]), + Goldilocks::from(1u64), // borrow + ], + ); + phase0_values_map.insert(AddInstruction::phase0_prev_rd_ts_str(), vec![Goldilocks::from(2u64)]); + let m: u64 = (1 << TSUInt::MAX_CELL_BIT_WIDTH) - 1; + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + phase0_values_map.insert( + AddInstruction::phase0_prev_rd_ts_lt_str(), + vec![ + Goldilocks::from(range_values[0]), + Goldilocks::from(range_values[1]), + Goldilocks::from(range_values[2]), + Goldilocks::from(1u64), // borrow + ], + ); + + phase0_values_map +} + +fn main() { + let max_thread_id = 8; + let instance_num_vars = 11; + type E = GoldilocksExt2; + let chip_challenges = ChipChallenges::default(); + let circuit_builder = SingerCircuitBuilder::::new_riscv(chip_challenges).expect("circuit builder failed"); + let mut singer_builder = SingerGraphBuilder::::new(); + + let mut rng = test_rng(); + let size = AddInstruction::phase0_size(); + let phase0_values_map = get_single_instance_values_map(); + let phase0_idx_map = AddInstruction::phase0_idxes_map(); + + let mut single_witness_in = vec![::BaseField::ZERO; size]; + + for key in phase0_idx_map.keys() { + let range = phase0_idx_map.get(key).unwrap().clone().collect::>(); + let values = phase0_values_map.get(key).expect(&("unknown key ".to_owned() + key)); + for (value_idx, cell_idx) in range.into_iter().enumerate() { + if value_idx < values.len() { + single_witness_in[cell_idx] = values[value_idx]; + } + } + } + + let phase0: CircuitWiresIn<::BaseField> = vec![LayerWitness { + instances: (0..(1 << instance_num_vars)) + .map(|_| single_witness_in.clone()) + .collect_vec(), + }]; + + let real_challenges = vec![E::random(&mut rng), E::random(&mut rng)]; + + let timer = Instant::now(); + + let circuit = AddInstruction::construct_circuits(chip_challenges).unwrap(); + let _ = AddInstruction::construct_graph_and_witness( + &mut singer_builder.graph_builder, + &mut singer_builder.chip_builder, + &circuit, // circuit_builder.insts_circuits[RV_INSTRUCTION as usize], + vec![phase0], + &real_challenges, + 1 << instance_num_vars, + &SingerParams::default(), + ) + .expect("gkr graph construction failed"); + + let (graph, wit) = singer_builder.graph_builder.finalize_graph_and_witness(); + + println!( + "AddInstruction::construct_graph_and_witness, instance_num_vars = {}, time = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); + + let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); + let subscriber = Registry::default() + .with(fmt::layer().compact().with_thread_ids(false).with_thread_names(false)) + .with(EnvFilter::from_default_env()) + .with(flame_layer.with_threads_collapsed(true)); + tracing::subscriber::set_global_default(subscriber).unwrap(); + + let point = vec![E::random(&mut rng), E::random(&mut rng)]; + let target_evals = graph.target_evals(&wit, &point); + + for _ in 0..5 { + let mut prover_transcript = &mut Transcript::new(b"Singer"); + let timer = Instant::now(); + let proof = GKRGraphProverState::prove( + &graph, + &wit, + &target_evals, + &mut prover_transcript, + (1 << instance_num_vars).min(max_thread_id), + ) + .expect("prove failed"); + println!( + "AddInstruction::prove, instance_num_vars = {}, time = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); + let mut verifier_transcript = Transcript::new(b"Singer"); + let _ = GKRGraphVerifierState::verify( + &graph, + &real_challenges, + &target_evals, + proof, + &CircuitGraphAuxInfo { + instance_num_vars: wit + .node_witnesses + .iter() + .map(|witness| witness.instance_num_vars()) + .collect(), + }, + &mut verifier_transcript, + ) + .expect("verify failed"); + } +} diff --git a/singer/src/instructions.rs b/singer/src/instructions.rs index 772c233f0..e9e06b5b1 100644 --- a/singer/src/instructions.rs +++ b/singer/src/instructions.rs @@ -11,10 +11,9 @@ use strum_macros::EnumIter; use crate::{error::ZKVMError, CircuitWiresIn, SingerParams}; use self::{ - add::AddInstruction, calldataload::CalldataloadInstruction, dup::DupInstruction, - gt::GtInstruction, jump::JumpInstruction, jumpdest::JumpdestInstruction, - jumpi::JumpiInstruction, mstore::MstoreInstruction, pop::PopInstruction, push::PushInstruction, - ret::ReturnInstruction, swap::SwapInstruction, + add::AddInstruction, calldataload::CalldataloadInstruction, dup::DupInstruction, gt::GtInstruction, + jump::JumpInstruction, jumpdest::JumpdestInstruction, jumpi::JumpiInstruction, mstore::MstoreInstruction, + pop::PopInstruction, push::PushInstruction, ret::ReturnInstruction, swap::SwapInstruction, }; // arithmetic @@ -41,6 +40,9 @@ pub mod mstore; // system pub mod calldataload; +// risc-v +pub mod riscv; + #[derive(Clone, Debug)] pub struct SingerCircuitBuilder { /// Opcode circuits @@ -54,9 +56,8 @@ impl SingerCircuitBuilder { for opcode in 0..=255 { insts_circuits.push(construct_instruction_circuits(opcode, challenges)?); } - let insts_circuits: [Vec>; 256] = insts_circuits - .try_into() - .map_err(|_| ZKVMError::CircuitError)?; + let insts_circuits: [Vec>; 256] = + insts_circuits.try_into().map_err(|_| ZKVMError::CircuitError)?; Ok(Self { insts_circuits, challenges, @@ -84,6 +85,7 @@ pub(crate) fn construct_instruction_circuits( 0x91 => SwapInstruction::<2>::construct_circuits(challenges), 0x93 => SwapInstruction::<4>::construct_circuits(challenges), 0xF3 => ReturnInstruction::construct_circuits(challenges), + _ => Ok(vec![]), // TODO: Add more instructions. } } @@ -153,13 +155,7 @@ pub(crate) fn construct_inst_graph( _ => unimplemented!(), }; - construct_graph( - graph_builder, - chip_builder, - inst_circuits, - real_n_instances, - params, - ) + construct_graph(graph_builder, chip_builder, inst_circuits, real_n_instances, params) } #[derive(Clone, Copy, Debug, EnumIter)] diff --git a/singer/src/instructions/add.rs b/singer/src/instructions/add.rs index a78b69a59..3b81a4c0e 100644 --- a/singer/src/instructions/add.rs +++ b/singer/src/instructions/add.rs @@ -5,8 +5,8 @@ use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; use singer_utils::{ chip_handler::{ - BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, - RangeChipOperations, StackChipOperations, + BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, + StackChipOperations, }, constants::OpcodeType, register_witness, @@ -75,14 +75,9 @@ impl Instruction for AddInstruction { clk, ); - let next_pc = - ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; - let next_stack_ts = rom_handler.add_ts_with_const( - &mut circuit_builder, - &stack_ts, - 1, - &phase0[Self::phase0_stack_ts_add()], - )?; + let next_pc = ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; + let next_stack_ts = + rom_handler.add_ts_with_const(&mut circuit_builder, &stack_ts, 1, &phase0[Self::phase0_stack_ts_add()])?; ram_handler.state_out( &mut circuit_builder, @@ -110,10 +105,7 @@ impl Instruction for AddInstruction { )?; // Check the range of stack_top - 2 is within [0, 1 << STACK_TOP_BIT_WIDTH). - rom_handler.range_check_stack_top( - &mut circuit_builder, - stack_top_expr.sub(E::BaseField::from(2)), - )?; + rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_expr.sub(E::BaseField::from(2)))?; // Pop two values from stack let old_stack_ts0 = (&phase0[Self::phase0_old_stack_ts0()]).try_into()?; @@ -156,11 +148,7 @@ impl Instruction for AddInstruction { ); // Bytecode check for (pc, add) - rom_handler.bytecode_with_pc_opcode( - &mut circuit_builder, - pc.values(), - >::OPCODE, - ); + rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, pc.values(), >::OPCODE); let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); let rom_id = rom_handler.finalize(&mut circuit_builder); @@ -195,9 +183,7 @@ mod test { use transcript::Transcript; use crate::{ - instructions::{ - AddInstruction, ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder, - }, + instructions::{AddInstruction, ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder}, scheme::GKRGraphProverState, test::{get_uint_params, test_opcode_circuit_v2}, utils::u64vec, @@ -224,26 +210,11 @@ mod test { println!("{:?}", inst_circuit); let mut phase0_values_map = BTreeMap::<&'static str, Vec>::new(); - phase0_values_map.insert( - AddInstruction::phase0_pc_str(), - vec![Goldilocks::from(1u64)], - ); - phase0_values_map.insert( - AddInstruction::phase0_stack_ts_str(), - vec![Goldilocks::from(3u64)], - ); - phase0_values_map.insert( - AddInstruction::phase0_memory_ts_str(), - vec![Goldilocks::from(1u64)], - ); - phase0_values_map.insert( - AddInstruction::phase0_stack_top_str(), - vec![Goldilocks::from(100u64)], - ); - phase0_values_map.insert( - AddInstruction::phase0_clk_str(), - vec![Goldilocks::from(1u64)], - ); + phase0_values_map.insert(AddInstruction::phase0_pc_str(), vec![Goldilocks::from(1u64)]); + phase0_values_map.insert(AddInstruction::phase0_stack_ts_str(), vec![Goldilocks::from(3u64)]); + phase0_values_map.insert(AddInstruction::phase0_memory_ts_str(), vec![Goldilocks::from(1u64)]); + phase0_values_map.insert(AddInstruction::phase0_stack_top_str(), vec![Goldilocks::from(100u64)]); + phase0_values_map.insert(AddInstruction::phase0_clk_str(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( AddInstruction::phase0_pc_add_str(), vec![], // carry is 0, may test carry using larger values in PCUInt @@ -258,10 +229,7 @@ mod test { // no place for carry ], ); - phase0_values_map.insert( - AddInstruction::phase0_old_stack_ts0_str(), - vec![Goldilocks::from(2u64)], - ); + phase0_values_map.insert(AddInstruction::phase0_old_stack_ts0_str(), vec![Goldilocks::from(2u64)]); let m: u64 = (1 << get_uint_params::().1) - 1; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -273,10 +241,7 @@ mod test { Goldilocks::from(1u64), ], ); - phase0_values_map.insert( - AddInstruction::phase0_old_stack_ts1_str(), - vec![Goldilocks::from(1u64)], - ); + phase0_values_map.insert(AddInstruction::phase0_old_stack_ts1_str(), vec![Goldilocks::from(1u64)]); let m: u64 = (1 << get_uint_params::().1) - 2; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -289,24 +254,15 @@ mod test { ], ); let m: u64 = (1 << get_uint_params::().1) - 1; - phase0_values_map.insert( - AddInstruction::phase0_addend_0_str(), - vec![Goldilocks::from(m)], - ); - phase0_values_map.insert( - AddInstruction::phase0_addend_1_str(), - vec![Goldilocks::from(1u64)], - ); + phase0_values_map.insert(AddInstruction::phase0_addend_0_str(), vec![Goldilocks::from(m)]); + phase0_values_map.insert(AddInstruction::phase0_addend_1_str(), vec![Goldilocks::from(1u64)]); let range_values = u64vec::<{ StackUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m + 1); let mut wit_phase0_instruction_add: Vec = vec![]; for i in 0..16 { wit_phase0_instruction_add.push(Goldilocks::from(range_values[i])) } wit_phase0_instruction_add.push(Goldilocks::from(1u64)); // carry is [1, 0, ...] - phase0_values_map.insert( - AddInstruction::phase0_instruction_add_str(), - wit_phase0_instruction_add, - ); + phase0_values_map.insert(AddInstruction::phase0_instruction_add_str(), wit_phase0_instruction_add); // The actual challenges used is: // challenges @@ -326,19 +282,14 @@ mod test { #[cfg(not(debug_assertions))] fn bench_add_instruction_helper(instance_num_vars: usize) { let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let mut singer_builder = SingerGraphBuilder::::new(); let mut rng = test_rng(); let size = AddInstruction::phase0_size(); let phase0: CircuitWiresIn = vec![LayerWitness { instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| E::BaseField::random(&mut rng)) - .collect_vec() - }) + .map(|_| (0..size).map(|_| E::BaseField::random(&mut rng)).collect_vec()) .collect_vec(), }]; @@ -371,8 +322,8 @@ mod test { let mut prover_transcript = &mut Transcript::new(b"Singer"); let timer = Instant::now(); - let _ = GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1) - .expect("prove failed"); + let _ = + GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1).expect("prove failed"); println!( "AddInstruction::prove, instance_num_vars = {}, time = {}", instance_num_vars, diff --git a/singer/src/instructions/calldataload.rs b/singer/src/instructions/calldataload.rs index 10ec4f0d7..aaee8edf7 100644 --- a/singer/src/instructions/calldataload.rs +++ b/singer/src/instructions/calldataload.rs @@ -5,8 +5,8 @@ use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; use singer_utils::{ chip_handler::{ - BytecodeChipOperations, CalldataChipOperations, GlobalStateChipOperations, OAMOperations, - ROMOperations, RangeChipOperations, StackChipOperations, + BytecodeChipOperations, CalldataChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, + RangeChipOperations, StackChipOperations, }, constants::OpcodeType, register_witness, @@ -71,14 +71,9 @@ impl Instruction for CalldataloadInstruction { clk, ); - let next_pc = - ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; - let next_stack_ts = rom_handler.add_ts_with_const( - &mut circuit_builder, - &stack_ts, - 1, - &phase0[Self::phase0_stack_ts_add()], - )?; + let next_pc = ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; + let next_stack_ts = + rom_handler.add_ts_with_const(&mut circuit_builder, &stack_ts, 1, &phase0[Self::phase0_stack_ts_add()])?; ram_handler.state_out( &mut circuit_builder, @@ -90,10 +85,7 @@ impl Instruction for CalldataloadInstruction { ); // Range check for stack top - rom_handler.range_check_stack_top( - &mut circuit_builder, - stack_top_expr.sub(E::BaseField::from(1)), - )?; + rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_expr.sub(E::BaseField::from(1)))?; // Stack pop offset from the stack. let old_stack_ts = TSUInt::try_from(&phase0[Self::phase0_old_stack_ts()])?; @@ -125,11 +117,7 @@ impl Instruction for CalldataloadInstruction { ); // Bytecode table (pc, CalldataLoad) - rom_handler.bytecode_with_pc_opcode( - &mut circuit_builder, - pc.values(), - >::OPCODE, - ); + rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, pc.values(), >::OPCODE); let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); let rom_id = rom_handler.finalize(&mut circuit_builder); @@ -161,10 +149,7 @@ mod test { use transcript::Transcript; use crate::{ - instructions::{ - CalldataloadInstruction, ChipChallenges, Instruction, InstructionGraph, - SingerCircuitBuilder, - }, + instructions::{CalldataloadInstruction, ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder}, scheme::GKRGraphProverState, test::{get_uint_params, test_opcode_circuit}, utils::u64vec, @@ -195,10 +180,7 @@ mod test { phase0_values_map.insert("phase0_ts".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert("phase0_stack_ts".to_string(), vec![Goldilocks::from(3u64)]); phase0_values_map.insert("phase0_memory_ts".to_string(), vec![Goldilocks::from(1u64)]); - phase0_values_map.insert( - "phase0_stack_top".to_string(), - vec![Goldilocks::from(100u64)], - ); + phase0_values_map.insert("phase0_stack_top".to_string(), vec![Goldilocks::from(100u64)]); phase0_values_map.insert("phase0_clk".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( "phase0_pc_add".to_string(), @@ -215,10 +197,7 @@ mod test { // no place for carry ], ); - phase0_values_map.insert( - "phase0_old_stack_ts".to_string(), - vec![Goldilocks::from(2u64)], - ); + phase0_values_map.insert("phase0_old_stack_ts".to_string(), vec![Goldilocks::from(2u64)]); let m: u64 = (1 << get_uint_params::().1) - 1; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -264,19 +243,14 @@ mod test { #[cfg(not(debug_assertions))] fn bench_calldataload_instruction_helper(instance_num_vars: usize) { let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let mut singer_builder = SingerGraphBuilder::::new(); let mut rng = test_rng(); let size = CalldataloadInstruction::phase0_size(); let phase0: CircuitWiresIn = vec![LayerWitness { instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| E::BaseField::random(&mut rng)) - .collect_vec() - }) + .map(|_| (0..size).map(|_| E::BaseField::random(&mut rng)).collect_vec()) .collect_vec(), }]; @@ -287,8 +261,7 @@ mod test { let _ = CalldataloadInstruction::construct_graph_and_witness( &mut singer_builder.graph_builder, &mut singer_builder.chip_builder, - &circuit_builder.insts_circuits - [>::OPCODE as usize], + &circuit_builder.insts_circuits[>::OPCODE as usize], vec![phase0], &real_challenges, 1 << instance_num_vars, @@ -310,8 +283,8 @@ mod test { let mut prover_transcript = &mut Transcript::new(b"Singer"); let timer = Instant::now(); - let _ = GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1) - .expect("prove failed"); + let _ = + GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1).expect("prove failed"); println!( "CalldataloadInstruction::prove, instance_num_vars = {}, time = {}", instance_num_vars, diff --git a/singer/src/instructions/dup.rs b/singer/src/instructions/dup.rs index 7feac3e69..ac56d358c 100644 --- a/singer/src/instructions/dup.rs +++ b/singer/src/instructions/dup.rs @@ -5,8 +5,8 @@ use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; use singer_utils::{ chip_handler::{ - BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, - RangeChipOperations, StackChipOperations, + BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, + StackChipOperations, }, constants::OpcodeType, register_witness, @@ -77,14 +77,9 @@ impl Instruction for DupInstruction { clk, ); - let next_pc = - ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; - let next_stack_ts = rom_handler.add_ts_with_const( - &mut circuit_builder, - &stack_ts, - 1, - &phase0[Self::phase0_stack_ts_add()], - )?; + let next_pc = ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; + let next_stack_ts = + rom_handler.add_ts_with_const(&mut circuit_builder, &stack_ts, 1, &phase0[Self::phase0_stack_ts_add()])?; ram_handler.state_out( &mut circuit_builder, @@ -96,10 +91,7 @@ impl Instruction for DupInstruction { ); // Check the range of stack_top - N is within [0, 1 << STACK_TOP_BIT_WIDTH). - rom_handler.range_check_stack_top( - &mut circuit_builder, - stack_top_expr.sub(E::BaseField::from(N as u64)), - )?; + rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_expr.sub(E::BaseField::from(N as u64)))?; // Pop rlc of stack[top - N] from stack let old_stack_ts = (&phase0[Self::phase0_old_stack_ts()]).try_into()?; @@ -127,19 +119,10 @@ impl Instruction for DupInstruction { stack_ts.values(), stack_values, ); - ram_handler.stack_push( - &mut circuit_builder, - stack_top_expr, - stack_ts.values(), - stack_values, - ); + ram_handler.stack_push(&mut circuit_builder, stack_top_expr, stack_ts.values(), stack_values); // Bytecode check for (pc, DUP{N}) - rom_handler.bytecode_with_pc_opcode( - &mut circuit_builder, - pc.values(), - >::OPCODE, - ); + rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, pc.values(), >::OPCODE); let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); let rom_id = rom_handler.finalize(&mut circuit_builder); @@ -171,9 +154,7 @@ mod test { use transcript::Transcript; use crate::{ - instructions::{ - ChipChallenges, DupInstruction, Instruction, InstructionGraph, SingerCircuitBuilder, - }, + instructions::{ChipChallenges, DupInstruction, Instruction, InstructionGraph, SingerCircuitBuilder}, scheme::GKRGraphProverState, test::{get_uint_params, test_opcode_circuit}, utils::u64vec, @@ -203,10 +184,7 @@ mod test { phase0_values_map.insert("phase0_pc".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert("phase0_stack_ts".to_string(), vec![Goldilocks::from(2u64)]); phase0_values_map.insert("phase0_memory_ts".to_string(), vec![Goldilocks::from(1u64)]); - phase0_values_map.insert( - "phase0_stack_top".to_string(), - vec![Goldilocks::from(100u64)], - ); + phase0_values_map.insert("phase0_stack_top".to_string(), vec![Goldilocks::from(100u64)]); phase0_values_map.insert("phase0_clk".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( "phase0_pc_add".to_string(), @@ -236,10 +214,7 @@ mod test { Goldilocks::from(0u64), ], ); - phase0_values_map.insert( - "phase0_old_stack_ts".to_string(), - vec![Goldilocks::from(1u64)], - ); + phase0_values_map.insert("phase0_old_stack_ts".to_string(), vec![Goldilocks::from(1u64)]); let m: u64 = (1 << get_uint_params::().1) - 1; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -271,19 +246,14 @@ mod test { #[cfg(not(debug_assertions))] fn bench_dup_instruction_helper(instance_num_vars: usize) { let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let mut singer_builder = SingerGraphBuilder::::new(); let mut rng = test_rng(); let size = DupInstruction::::phase0_size(); let phase0: CircuitWiresIn = vec![LayerWitness { instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| E::BaseField::random(&mut rng)) - .collect_vec() - }) + .map(|_| (0..size).map(|_| E::BaseField::random(&mut rng)).collect_vec()) .collect_vec(), }]; @@ -317,8 +287,8 @@ mod test { let mut prover_transcript = &mut Transcript::new(b"Singer"); let timer = Instant::now(); - let _ = GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1) - .expect("prove failed"); + let _ = + GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1).expect("prove failed"); println!( "Dup{}Instruction::prove, instance_num_vars = {}, time = {}", N, diff --git a/singer/src/instructions/gt.rs b/singer/src/instructions/gt.rs index 71fc4f983..fa2c43daf 100644 --- a/singer/src/instructions/gt.rs +++ b/singer/src/instructions/gt.rs @@ -5,8 +5,8 @@ use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; use singer_utils::{ chip_handler::{ - BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, - RangeChipOperations, StackChipOperations, + BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, + StackChipOperations, }, constants::OpcodeType, register_witness, @@ -74,14 +74,9 @@ impl Instruction for GtInstruction { clk, ); - let next_pc = - ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; - let next_stack_ts = rom_handler.add_ts_with_const( - &mut circuit_builder, - &stack_ts, - 1, - &phase0[Self::phase0_stack_ts_add()], - )?; + let next_pc = ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; + let next_stack_ts = + rom_handler.add_ts_with_const(&mut circuit_builder, &stack_ts, 1, &phase0[Self::phase0_stack_ts_add()])?; ram_handler.state_out( &mut circuit_builder, @@ -104,10 +99,7 @@ impl Instruction for GtInstruction { )?; // Check the range of stack_top - 2 is within [0, 1 << STACK_TOP_BIT_WIDTH). - rom_handler.range_check_stack_top( - &mut circuit_builder, - stack_top_expr.sub(E::BaseField::from(2)), - )?; + rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_expr.sub(E::BaseField::from(2)))?; // Pop two values from stack let old_stack_ts0 = (&phase0[Self::phase0_old_stack_ts0()]).try_into()?; @@ -150,11 +142,7 @@ impl Instruction for GtInstruction { ); // Bytecode check for (pc, gt) - rom_handler.bytecode_with_pc_opcode( - &mut circuit_builder, - pc.values(), - >::OPCODE, - ); + rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, pc.values(), >::OPCODE); let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); let rom_id = rom_handler.finalize(&mut circuit_builder); @@ -186,9 +174,7 @@ mod test { use transcript::Transcript; use crate::{ - instructions::{ - ChipChallenges, GtInstruction, Instruction, InstructionGraph, SingerCircuitBuilder, - }, + instructions::{ChipChallenges, GtInstruction, Instruction, InstructionGraph, SingerCircuitBuilder}, scheme::GKRGraphProverState, test::{get_uint_params, test_opcode_circuit}, utils::u64vec, @@ -218,10 +204,7 @@ mod test { phase0_values_map.insert("phase0_pc".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert("phase0_stack_ts".to_string(), vec![Goldilocks::from(3u64)]); phase0_values_map.insert("phase0_memory_ts".to_string(), vec![Goldilocks::from(1u64)]); - phase0_values_map.insert( - "phase0_stack_top".to_string(), - vec![Goldilocks::from(100u64)], - ); + phase0_values_map.insert("phase0_stack_top".to_string(), vec![Goldilocks::from(100u64)]); phase0_values_map.insert("phase0_clk".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( "phase0_pc_add".to_string(), @@ -238,10 +221,7 @@ mod test { // no place for carry ], ); - phase0_values_map.insert( - "phase0_old_stack_ts0".to_string(), - vec![Goldilocks::from(2u64)], - ); + phase0_values_map.insert("phase0_old_stack_ts0".to_string(), vec![Goldilocks::from(2u64)]); let m: u64 = (1 << get_uint_params::().1) - 1; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -254,10 +234,7 @@ mod test { Goldilocks::from(1u64), ], ); - phase0_values_map.insert( - "phase0_old_stack_ts1".to_string(), - vec![Goldilocks::from(1u64)], - ); + phase0_values_map.insert("phase0_old_stack_ts1".to_string(), vec![Goldilocks::from(1u64)]); let m: u64 = (1 << get_uint_params::().1) - 2; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -299,19 +276,14 @@ mod test { #[cfg(not(debug_assertions))] fn bench_gt_instruction_helper(instance_num_vars: usize) { let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let mut singer_builder = SingerGraphBuilder::::new(); let mut rng = test_rng(); let size = GtInstruction::phase0_size(); let phase0: CircuitWiresIn = vec![LayerWitness { instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| E::BaseField::random(&mut rng)) - .collect_vec() - }) + .map(|_| (0..size).map(|_| E::BaseField::random(&mut rng)).collect_vec()) .collect_vec(), }]; @@ -344,8 +316,8 @@ mod test { let mut prover_transcript = &mut Transcript::new(b"Singer"); let timer = Instant::now(); - let _ = GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1) - .expect("prove failed"); + let _ = + GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1).expect("prove failed"); println!( "GtInstruction::prove, instance_num_vars = {}, time = {}", instance_num_vars, diff --git a/singer/src/instructions/jump.rs b/singer/src/instructions/jump.rs index 7421542e3..8be5ddd2a 100644 --- a/singer/src/instructions/jump.rs +++ b/singer/src/instructions/jump.rs @@ -7,8 +7,8 @@ use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; use singer_utils::{ chip_handler::{ - BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, - RangeChipOperations, StackChipOperations, + BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, + StackChipOperations, }, constants::OpcodeType, register_witness, @@ -70,8 +70,7 @@ impl Instruction for JumpInstruction { ); // Pop next pc from stack - rom_handler - .range_check_stack_top(&mut circuit_builder, stack_top_expr.sub(E::BaseField::ONE))?; + rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_expr.sub(E::BaseField::ONE))?; let next_pc = &phase0[Self::phase0_next_pc()]; let old_stack_ts = (&phase0[Self::phase0_old_stack_ts()]).try_into()?; @@ -99,11 +98,7 @@ impl Instruction for JumpInstruction { ); // Bytecode check for (pc, jump) - rom_handler.bytecode_with_pc_opcode( - &mut circuit_builder, - pc.values(), - >::OPCODE, - ); + rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, pc.values(), >::OPCODE); // Bytecode check for (next_pc, jumpdest) rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, &next_pc, OpcodeType::JUMPDEST); @@ -170,19 +165,13 @@ mod test { phase0_values_map.insert("phase0_pc".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert("phase0_stack_ts".to_string(), vec![Goldilocks::from(2u64)]); phase0_values_map.insert("phase0_memory_ts".to_string(), vec![Goldilocks::from(1u64)]); - phase0_values_map.insert( - "phase0_stack_top".to_string(), - vec![Goldilocks::from(100u64)], - ); + phase0_values_map.insert("phase0_stack_top".to_string(), vec![Goldilocks::from(100u64)]); phase0_values_map.insert("phase0_clk".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( "phase0_next_pc".to_string(), vec![Goldilocks::from(127u64), Goldilocks::from(125u64)], ); - phase0_values_map.insert( - "phase0_old_stack_ts".to_string(), - vec![Goldilocks::from(1u64)], - ); + phase0_values_map.insert("phase0_old_stack_ts".to_string(), vec![Goldilocks::from(1u64)]); let m: u64 = (1 << get_uint_params::().1) - 1; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -214,19 +203,14 @@ mod test { #[cfg(not(debug_assertions))] fn bench_jump_instruction_helper(instance_num_vars: usize) { let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let mut singer_builder = SingerGraphBuilder::::new(); let mut rng = test_rng(); let size = JumpInstruction::phase0_size(); let phase0: CircuitWiresIn = vec![LayerWitness { instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| E::BaseField::random(&mut rng)) - .collect_vec() - }) + .map(|_| (0..size).map(|_| E::BaseField::random(&mut rng)).collect_vec()) .collect_vec(), }]; @@ -259,8 +243,8 @@ mod test { let mut prover_transcript = &mut Transcript::new(b"Singer"); let timer = Instant::now(); - let _ = GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1) - .expect("prove failed"); + let _ = + GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1).expect("prove failed"); println!( "JumpInstruction::prove, instance_num_vars = {}, time = {}", instance_num_vars, diff --git a/singer/src/instructions/jumpdest.rs b/singer/src/instructions/jumpdest.rs index e1b49e329..38cf44af5 100644 --- a/singer/src/instructions/jumpdest.rs +++ b/singer/src/instructions/jumpdest.rs @@ -4,9 +4,7 @@ use gkr::structs::Circuit; use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; use singer_utils::{ - chip_handler::{ - BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, - }, + chip_handler::{BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations}, constants::OpcodeType, register_witness, structs::{PCUInt, RAMHandler, ROMHandler, TSUInt}, @@ -62,8 +60,7 @@ impl Instruction for JumpdestInstruction { clk, ); - let next_pc = - ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; + let next_pc = ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; ram_handler.state_out( &mut circuit_builder, next_pc.values(), @@ -74,11 +71,7 @@ impl Instruction for JumpdestInstruction { ); // Bytecode check for (pc, jump) - rom_handler.bytecode_with_pc_opcode( - &mut circuit_builder, - pc.values(), - >::OPCODE, - ); + rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, pc.values(), >::OPCODE); let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); let rom_id = rom_handler.finalize(&mut circuit_builder); @@ -109,10 +102,7 @@ mod test { use transcript::Transcript; use crate::{ - instructions::{ - ChipChallenges, Instruction, InstructionGraph, JumpdestInstruction, - SingerCircuitBuilder, - }, + instructions::{ChipChallenges, Instruction, InstructionGraph, JumpdestInstruction, SingerCircuitBuilder}, scheme::GKRGraphProverState, test::test_opcode_circuit, CircuitWiresIn, SingerGraphBuilder, SingerParams, @@ -141,10 +131,7 @@ mod test { phase0_values_map.insert("phase0_pc".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert("phase0_stack_ts".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert("phase0_memory_ts".to_string(), vec![Goldilocks::from(1u64)]); - phase0_values_map.insert( - "phase0_stack_top".to_string(), - vec![Goldilocks::from(100u64)], - ); + phase0_values_map.insert("phase0_stack_top".to_string(), vec![Goldilocks::from(100u64)]); phase0_values_map.insert("phase0_clk".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( "phase0_pc_add".to_string(), @@ -169,19 +156,14 @@ mod test { #[cfg(not(debug_assertions))] fn bench_jumpdest_instruction_helper(instance_num_vars: usize) { let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let mut singer_builder = SingerGraphBuilder::::new(); let mut rng = test_rng(); let size = JumpdestInstruction::phase0_size(); let phase0: CircuitWiresIn = vec![LayerWitness { instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| E::BaseField::random(&mut rng)) - .collect_vec() - }) + .map(|_| (0..size).map(|_| E::BaseField::random(&mut rng)).collect_vec()) .collect_vec(), }]; @@ -192,8 +174,7 @@ mod test { let _ = JumpdestInstruction::construct_graph_and_witness( &mut singer_builder.graph_builder, &mut singer_builder.chip_builder, - &circuit_builder.insts_circuits - [>::OPCODE as usize], + &circuit_builder.insts_circuits[>::OPCODE as usize], vec![phase0], &real_challenges, 1 << instance_num_vars, @@ -215,8 +196,8 @@ mod test { let mut prover_transcript = &mut Transcript::new(b"Singer"); let timer = Instant::now(); - let _ = GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1) - .expect("prove failed"); + let _ = + GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1).expect("prove failed"); println!( "JumpdestInstruction::prove, instance_num_vars = {}, time = {}", instance_num_vars, diff --git a/singer/src/instructions/jumpi.rs b/singer/src/instructions/jumpi.rs index 62e34d21f..ef7c8c8a2 100644 --- a/singer/src/instructions/jumpi.rs +++ b/singer/src/instructions/jumpi.rs @@ -6,8 +6,8 @@ use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; use singer_utils::{ chip_handler::{ - BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, - RangeChipOperations, StackChipOperations, + BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, + StackChipOperations, }, constants::OpcodeType, register_witness, @@ -77,10 +77,7 @@ impl Instruction for JumpiInstruction { ); // Range check stack_top - 2 - rom_handler.range_check_stack_top( - &mut circuit_builder, - stack_top_expr.sub(E::BaseField::from(2)), - )?; + rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_expr.sub(E::BaseField::from(2)))?; // Pop the destination pc from stack. let dest_values = &phase0[Self::phase0_dest_values()]; @@ -131,8 +128,7 @@ impl Instruction for JumpiInstruction { .iter() .for_each(|x| circuit_builder.add(non_zero_or, *x, E::BaseField::ONE)); let cond_non_zero_or_inv = phase0[Self::phase0_cond_non_zero_or_inv().start]; - let cond_non_zero = - rom_handler.non_zero(&mut circuit_builder, non_zero_or, cond_non_zero_or_inv)?; + let cond_non_zero = rom_handler.non_zero(&mut circuit_builder, non_zero_or, cond_non_zero_or_inv)?; // If cond_non_zero, next_pc = dest, otherwise, pc = pc + 1 let pc_add_1 = &phase0[Self::phase0_pc_add()]; @@ -154,11 +150,7 @@ impl Instruction for JumpiInstruction { ); // Bytecode check for (pc, jumpi) - rom_handler.bytecode_with_pc_opcode( - &mut circuit_builder, - pc.values(), - >::OPCODE, - ); + rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, pc.values(), >::OPCODE); // If cond_non_zero, next_opcode = JUMPDEST, otherwise, opcode = pc + 1 opcode let pc_plus_1_opcode = phase0[Self::phase0_pc_plus_1_opcode().start]; diff --git a/singer/src/instructions/mstore.rs b/singer/src/instructions/mstore.rs index 485203d8b..aa0a55596 100644 --- a/singer/src/instructions/mstore.rs +++ b/singer/src/instructions/mstore.rs @@ -6,8 +6,8 @@ use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; use singer_utils::{ chip_handler::{ - BytecodeChipOperations, GlobalStateChipOperations, MemoryChipOperations, OAMOperations, - ROMOperations, RangeChipOperations, StackChipOperations, + BytecodeChipOperations, GlobalStateChipOperations, MemoryChipOperations, OAMOperations, ROMOperations, + RangeChipOperations, StackChipOperations, }, chips::SingerChipBuilder, constants::{OpcodeType, EVM_STACK_BYTE_WIDTH}, @@ -122,11 +122,8 @@ impl InstructionGraph for MstoreInstruction { preds[mstore_acc_circuit.layout.pred_ooo_wire_id.unwrap() as usize] = PredType::PredWire( NodeOutputType::WireOut(inst_node_id, inst_circuit.layout.succ_ooo_wires_id[0]), ); - let mstore_acc_node_id = graph_builder.add_node( - stringify!(MstoreAccessory), - &mstore_acc_circuit.circuit, - preds, - )?; + let mstore_acc_node_id = + graph_builder.add_node(stringify!(MstoreAccessory), &mstore_acc_circuit.circuit, preds)?; chip_builder.construct_chip_check_graph( graph_builder, mstore_acc_node_id, @@ -184,8 +181,7 @@ impl Instruction for MstoreInstruction { clk, ); - let next_pc = - ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; + let next_pc = ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; let next_memory_ts = rom_handler.add_ts_with_const( &mut circuit_builder, &memory_ts, @@ -201,10 +197,7 @@ impl Instruction for MstoreInstruction { clk_expr.add(E::BaseField::ONE), ); - rom_handler.range_check_stack_top( - &mut circuit_builder, - stack_top_expr.sub(E::BaseField::from(2)), - )?; + rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_expr.sub(E::BaseField::from(2)))?; // Pop offset from stack let offset = StackUInt::try_from(&phase0[Self::phase0_offset()])?; @@ -244,15 +237,10 @@ impl Instruction for MstoreInstruction { ); // Bytecode check for (pc, mstore) - rom_handler.bytecode_with_pc_opcode( - &mut circuit_builder, - pc.values(), - >::OPCODE, - ); + rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, pc.values(), >::OPCODE); // To accessory - let (to_acc_dup_id, to_acc_dup) = - circuit_builder.create_witness_out(MstoreAccessory::pred_dup_size()); + let (to_acc_dup_id, to_acc_dup) = circuit_builder.create_witness_out(MstoreAccessory::pred_dup_size()); add_assign_each_cell( &mut circuit_builder, &to_acc_dup[MstoreAccessory::pred_dup_memory_ts()], @@ -264,8 +252,8 @@ impl Instruction for MstoreInstruction { offset.values(), ); - let (to_acc_ooo_id, to_acc_ooo) = circuit_builder - .create_witness_out(MstoreAccessory::pred_ooo_size() * EVM_STACK_BYTE_WIDTH); + let (to_acc_ooo_id, to_acc_ooo) = + circuit_builder.create_witness_out(MstoreAccessory::pred_ooo_size() * EVM_STACK_BYTE_WIDTH); add_assign_each_cell(&mut circuit_builder, &to_acc_ooo, mem_bytes); let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); @@ -308,9 +296,7 @@ register_witness!( ); impl MstoreAccessory { - fn construct_circuit( - challenges: ChipChallenges, - ) -> Result, ZKVMError> { + fn construct_circuit(challenges: ChipChallenges) -> Result, ZKVMError> { let mut circuit_builder = CircuitBuilder::new(); // From predesessor circuit. @@ -331,13 +317,8 @@ impl MstoreAccessory { let offset = StackUInt::try_from(&pred_dup[Self::pred_dup_offset()])?; let offset_add_delta = &phase0[Self::phase0_offset_add_delta()]; let delta = circuit_builder.create_counter_in(0)[0]; - let offset_plus_delta = StackUInt::add_cell( - &mut circuit_builder, - &mut rom_handler, - &offset, - delta, - offset_add_delta, - )?; + let offset_plus_delta = + StackUInt::add_cell(&mut circuit_builder, &mut rom_handler, &offset, delta, offset_add_delta)?; TSUInt::assert_lt( &mut circuit_builder, &mut rom_handler, @@ -378,9 +359,7 @@ impl MstoreAccessory { #[cfg(test)] mod test { - use crate::{ - instructions::InstructionGraph, scheme::GKRGraphProverState, utils::u64vec, SingerParams, - }; + use crate::{instructions::InstructionGraph, scheme::GKRGraphProverState, utils::u64vec, SingerParams}; use ark_std::test_rng; use ff::Field; use ff_ext::ExtensionField; @@ -427,10 +406,7 @@ mod test { phase0_values_map.insert("phase0_pc".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert("phase0_stack_ts".to_string(), vec![Goldilocks::from(3u64)]); phase0_values_map.insert("phase0_memory_ts".to_string(), vec![Goldilocks::from(3u64)]); - phase0_values_map.insert( - "phase0_stack_top".to_string(), - vec![Goldilocks::from(100u64)], - ); + phase0_values_map.insert("phase0_stack_top".to_string(), vec![Goldilocks::from(100u64)]); phase0_values_map.insert("phase0_clk".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( "phase0_pc_add".to_string(), @@ -448,10 +424,7 @@ mod test { ], ); phase0_values_map.insert("phase0_offset".to_string(), vec![Goldilocks::from(1u64)]); - phase0_values_map.insert( - "phase0_old_stack_ts_offset".to_string(), - vec![Goldilocks::from(2u64)], - ); + phase0_values_map.insert("phase0_old_stack_ts_offset".to_string(), vec![Goldilocks::from(2u64)]); let m: u64 = (1 << get_uint_params::().1) - 1; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -468,10 +441,7 @@ mod test { "phase0_mem_bytes".to_string(), vec![], // use 32-byte 0 for mem_bytes ); - phase0_values_map.insert( - "phase0_old_stack_ts_value".to_string(), - vec![Goldilocks::from(1u64)], - ); + phase0_values_map.insert("phase0_old_stack_ts_value".to_string(), vec![Goldilocks::from(1u64)]); let m: u64 = (1 << get_uint_params::().1) - 2; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -503,8 +473,7 @@ mod test { #[cfg(not(debug_assertions))] fn bench_mstore_instruction_helper(instance_num_vars: usize) { let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let mut singer_builder = SingerGraphBuilder::::new(); let mut rng = test_rng(); @@ -562,8 +531,8 @@ mod test { let mut prover_transcript = &mut Transcript::new(b"Singer"); let timer = Instant::now(); - let _ = GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1) - .expect("prove failed"); + let _ = + GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1).expect("prove failed"); println!( "MstoreInstruction::prove, instance_num_vars = {}, time = {}", instance_num_vars, diff --git a/singer/src/instructions/pop.rs b/singer/src/instructions/pop.rs index e17928afc..613e0f802 100644 --- a/singer/src/instructions/pop.rs +++ b/singer/src/instructions/pop.rs @@ -5,8 +5,8 @@ use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; use singer_utils::{ chip_handler::{ - BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, - RangeChipOperations, StackChipOperations, + BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, + StackChipOperations, }, constants::OpcodeType, register_witness, @@ -67,8 +67,7 @@ impl Instruction for PopInstruction { clk, ); - let next_pc = - ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; + let next_pc = ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; ram_handler.state_out( &mut circuit_builder, next_pc.values(), @@ -79,8 +78,7 @@ impl Instruction for PopInstruction { ); // Check the range of stack_top - 1 is within [0, 1 << STACK_TOP_BIT_WIDTH). - rom_handler - .range_check_stack_top(&mut circuit_builder, stack_top_expr.sub(E::BaseField::ONE))?; + rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_expr.sub(E::BaseField::ONE))?; // Pop rlc from stack let old_stack_ts = (&phase0[Self::phase0_old_stack_ts()]).try_into()?; @@ -100,11 +98,7 @@ impl Instruction for PopInstruction { ); // Bytecode check for (pc, POP) - rom_handler.bytecode_with_pc_opcode( - &mut circuit_builder, - pc.values(), - >::OPCODE, - ); + rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, pc.values(), >::OPCODE); let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); let rom_id = rom_handler.finalize(&mut circuit_builder); @@ -171,19 +165,13 @@ mod test { phase0_values_map.insert("phase0_pc".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert("phase0_stack_ts".to_string(), vec![Goldilocks::from(2u64)]); phase0_values_map.insert("phase0_memory_ts".to_string(), vec![Goldilocks::from(1u64)]); - phase0_values_map.insert( - "phase0_stack_top".to_string(), - vec![Goldilocks::from(100u64)], - ); + phase0_values_map.insert("phase0_stack_top".to_string(), vec![Goldilocks::from(100u64)]); phase0_values_map.insert("phase0_clk".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( "phase0_pc_add".to_string(), vec![], // carry is 0, may test carry using larger values in PCUInt ); - phase0_values_map.insert( - "phase0_old_stack_ts".to_string(), - vec![Goldilocks::from(1u64)], - ); + phase0_values_map.insert("phase0_old_stack_ts".to_string(), vec![Goldilocks::from(1u64)]); let m: u64 = (1 << get_uint_params::().1) - 1; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -228,19 +216,14 @@ mod test { #[cfg(not(debug_assertions))] fn bench_pop_instruction_helper(instance_num_vars: usize) { let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let mut singer_builder = SingerGraphBuilder::::new(); let mut rng = test_rng(); let size = PopInstruction::phase0_size(); let phase0: CircuitWiresIn = vec![LayerWitness { instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| E::BaseField::random(&mut rng)) - .collect_vec() - }) + .map(|_| (0..size).map(|_| E::BaseField::random(&mut rng)).collect_vec()) .collect_vec(), }]; @@ -273,8 +256,8 @@ mod test { let mut prover_transcript = &mut Transcript::new(b"Singer"); let timer = Instant::now(); - let _ = GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1) - .expect("prove failed"); + let _ = + GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1).expect("prove failed"); println!( "PopInstruction::prove, instance_num_vars = {}, time = {}", instance_num_vars, diff --git a/singer/src/instructions/push.rs b/singer/src/instructions/push.rs index a8ba4186a..d94a630bc 100644 --- a/singer/src/instructions/push.rs +++ b/singer/src/instructions/push.rs @@ -5,8 +5,8 @@ use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; use singer_utils::{ chip_handler::{ - BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, - RangeChipOperations, StackChipOperations, + BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, + StackChipOperations, }, constants::OpcodeType, register_witness, @@ -78,12 +78,8 @@ impl Instruction for PushInstruction { N as i64 + 1, &phase0[Self::phase0_pc_add_i_plus_1()], )?; - let next_stack_ts = rom_handler.add_ts_with_const( - &mut circuit_builder, - &stack_ts, - 1, - &phase0[Self::phase0_stack_ts_add()], - )?; + let next_stack_ts = + rom_handler.add_ts_with_const(&mut circuit_builder, &stack_ts, 1, &phase0[Self::phase0_stack_ts_add()])?; ram_handler.state_out( &mut circuit_builder, @@ -108,22 +104,13 @@ impl Instruction for PushInstruction { ); // Bytecode check for (pc, PUSH{N}), (pc + 1, byte[0]), ..., (pc + N, byte[N - 1]) - rom_handler.bytecode_with_pc_opcode( - &mut circuit_builder, - pc.values(), - >::OPCODE, - ); + rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, pc.values(), >::OPCODE); for (i, pc_add_i_plus_1) in phase0[Self::phase0_pc_add_i_plus_1()] .chunks(AddSubConstants::::N_NO_OVERFLOW_WITNESS_UNSAFE_CELLS) .enumerate() { - let next_pc = - ROMHandler::add_pc_const(&mut circuit_builder, &pc, i as i64 + 1, pc_add_i_plus_1)?; - rom_handler.bytecode_with_pc_byte( - &mut circuit_builder, - next_pc.values(), - stack_bytes[i], - ); + let next_pc = ROMHandler::add_pc_const(&mut circuit_builder, &pc, i as i64 + 1, pc_add_i_plus_1)?; + rom_handler.bytecode_with_pc_byte(&mut circuit_builder, next_pc.values(), stack_bytes[i]); } let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); @@ -155,9 +142,7 @@ mod test { use transcript::Transcript; use crate::{ - instructions::{ - ChipChallenges, Instruction, InstructionGraph, PushInstruction, SingerCircuitBuilder, - }, + instructions::{ChipChallenges, Instruction, InstructionGraph, PushInstruction, SingerCircuitBuilder}, scheme::GKRGraphProverState, test::test_opcode_circuit, CircuitWiresIn, SingerGraphBuilder, SingerParams, @@ -185,10 +170,7 @@ mod test { phase0_values_map.insert("phase0_pc".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert("phase0_stack_ts".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert("phase0_memory_ts".to_string(), vec![Goldilocks::from(1u64)]); - phase0_values_map.insert( - "phase0_stack_top".to_string(), - vec![Goldilocks::from(100u64)], - ); + phase0_values_map.insert("phase0_stack_top".to_string(), vec![Goldilocks::from(100u64)]); phase0_values_map.insert("phase0_clk".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( "phase0_pc_add_i_plus_1".to_string(), @@ -237,19 +219,14 @@ mod test { #[cfg(not(debug_assertions))] fn bench_push_instruction_helper(instance_num_vars: usize) { let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let mut singer_builder = SingerGraphBuilder::::new(); let mut rng = test_rng(); let size = PushInstruction::::phase0_size(); let phase0: CircuitWiresIn = vec![LayerWitness { instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| E::BaseField::random(&mut rng)) - .collect_vec() - }) + .map(|_| (0..size).map(|_| E::BaseField::random(&mut rng)).collect_vec()) .collect_vec(), }]; @@ -260,8 +237,7 @@ mod test { let _ = PushInstruction::::construct_graph_and_witness( &mut singer_builder.graph_builder, &mut singer_builder.chip_builder, - &circuit_builder.insts_circuits - [ as Instruction>::OPCODE as usize], + &circuit_builder.insts_circuits[ as Instruction>::OPCODE as usize], vec![phase0], &real_challenges, 1 << instance_num_vars, @@ -284,8 +260,8 @@ mod test { let mut prover_transcript = &mut Transcript::new(b"Singer"); let timer = Instant::now(); - let _ = GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1) - .expect("prove failed"); + let _ = + GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1).expect("prove failed"); println!( "Push{}Instruction::prove, instance_num_vars = {}, time = {}", N, diff --git a/singer/src/instructions/ret.rs b/singer/src/instructions/ret.rs index 7ba6c1abe..da75a514d 100644 --- a/singer/src/instructions/ret.rs +++ b/singer/src/instructions/ret.rs @@ -4,16 +4,16 @@ use gkr::structs::Circuit; use gkr_graph::structs::{CircuitGraphBuilder, NodeOutputType, PredType}; use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ chip_handler::{ - BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, - RangeChipOperations, StackChipOperations, + BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, + StackChipOperations, }, chips::SingerChipBuilder, constants::OpcodeType, register_witness, structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, + uint::constants::AddSubConstants, }; use std::{collections::BTreeMap, mem, sync::Arc}; @@ -80,11 +80,9 @@ impl InstructionGraph for ReturnInstruction { let pub_out_load_circuit = &inst_circuits[1]; let n_witness_in = pub_out_load_circuit.circuit.n_witness_in; let mut preds = vec![PredType::Source; n_witness_in]; - preds[pub_out_load_circuit.layout.pred_dup_wire_id.unwrap() as usize] = - PredType::PredWireDup(NodeOutputType::WireOut( - inst_node_id, - inst_circuit.layout.succ_dup_wires_id[0], - )); + preds[pub_out_load_circuit.layout.pred_dup_wire_id.unwrap() as usize] = PredType::PredWireDup( + NodeOutputType::WireOut(inst_node_id, inst_circuit.layout.succ_dup_wires_id[0]), + ); let pub_out_load_node_id = graph_builder.add_node_with_witness( stringify!(ReturnPublicOutLoad), &pub_out_load_circuit.circuit, @@ -190,16 +188,11 @@ impl InstructionGraph for ReturnInstruction { let pub_out_load_circuit = &inst_circuits[1]; let n_witness_in = pub_out_load_circuit.circuit.n_witness_in; let mut preds = vec![PredType::Source; n_witness_in]; - preds[pub_out_load_circuit.layout.pred_dup_wire_id.unwrap() as usize] = - PredType::PredWireDup(NodeOutputType::WireOut( - inst_node_id, - inst_circuit.layout.succ_dup_wires_id[0], - )); - let pub_out_load_node_id = graph_builder.add_node( - stringify!(ReturnPublicOutLoad), - &pub_out_load_circuit.circuit, - preds, - )?; + preds[pub_out_load_circuit.layout.pred_dup_wire_id.unwrap() as usize] = PredType::PredWireDup( + NodeOutputType::WireOut(inst_node_id, inst_circuit.layout.succ_dup_wires_id[0]), + ); + let pub_out_load_node_id = + graph_builder.add_node(stringify!(ReturnPublicOutLoad), &pub_out_load_circuit.circuit, preds)?; chip_builder.construct_chip_check_graph( graph_builder, pub_out_load_node_id, @@ -302,10 +295,7 @@ impl Instruction for ReturnInstruction { ); // Check the range of stack_top - 2 is within [0, 1 << STACK_TOP_BIT_WIDTH). - rom_handler.range_check_stack_top( - &mut circuit_builder, - stack_top_expr.sub(E::BaseField::from(2)), - )?; + rom_handler.range_check_stack_top(&mut circuit_builder, stack_top_expr.sub(E::BaseField::from(2)))?; // Pop offset and mem_size from stack let old_stack_ts0 = TSUInt::try_from(&phase0[Self::phase0_old_stack_ts0()])?; @@ -327,11 +317,7 @@ impl Instruction for ReturnInstruction { ); // Bytecode check for (pc, ret) - rom_handler.bytecode_with_pc_opcode( - &mut circuit_builder, - pc.values(), - >::OPCODE, - ); + rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, pc.values(), >::OPCODE); let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); let rom_id = rom_handler.finalize(&mut circuit_builder); @@ -340,8 +326,7 @@ impl Instruction for ReturnInstruction { let outputs_wire_id = [ram_load_id, ram_store_id, rom_id]; // Copy length to the target wire. - let (target_wire_id, target) = - circuit_builder.create_witness_out(StackUInt::N_OPERAND_CELLS); + let (target_wire_id, target) = circuit_builder.create_witness_out(StackUInt::N_OPERAND_CELLS); let length = length.values(); for i in 1..length.len() { circuit_builder.assert_const(length[i], 0); @@ -351,8 +336,7 @@ impl Instruction for ReturnInstruction { // println!("target: {:?}", target); // Copy offset to wires of public output load circuit. - let (pub_out_wire_id, pub_out) = - circuit_builder.create_witness_out(ReturnPublicOutLoad::pred_size()); + let (pub_out_wire_id, pub_out) = circuit_builder.create_witness_out(ReturnPublicOutLoad::pred_size()); let pub_out_offset = &pub_out[ReturnPublicOutLoad::pred_offset()]; let offset = offset.values(); add_assign_each_cell(&mut circuit_builder, pub_out_offset, offset); @@ -390,9 +374,7 @@ register_witness!( ); impl ReturnPublicOutLoad { - fn construct_circuit( - challenges: ChipChallenges, - ) -> Result, ZKVMError> { + fn construct_circuit(challenges: ChipChallenges) -> Result, ZKVMError> { let mut circuit_builder = CircuitBuilder::new(); let (pred_wire_id, pred) = circuit_builder.create_witness_in(Self::pred_size()); let (phase0_wire_id, phase0) = circuit_builder.create_witness_in(Self::phase0_size()); @@ -453,9 +435,7 @@ register_witness!( ); impl ReturnRestMemLoad { - fn construct_circuit( - challenges: ChipChallenges, - ) -> Result, ZKVMError> { + fn construct_circuit(challenges: ChipChallenges) -> Result, ZKVMError> { let mut circuit_builder = CircuitBuilder::new(); let (phase0_wire_id, phase0) = circuit_builder.create_witness_in(Self::phase0_size()); let mut ram_handler = RAMHandler::new(&challenges); @@ -464,12 +444,7 @@ impl ReturnRestMemLoad { let offset = &phase0[Self::phase0_offset()]; let mem_byte = phase0[Self::phase0_mem_byte().start]; let old_memory_ts = TSUInt::try_from(&phase0[Self::phase0_old_memory_ts()])?; - ram_handler.oam_load( - &mut circuit_builder, - &offset, - old_memory_ts.values(), - &[mem_byte], - ); + ram_handler.oam_load(&mut circuit_builder, &offset, old_memory_ts.values(), &[mem_byte]); let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); circuit_builder.configure(); @@ -500,9 +475,7 @@ register_witness!( ); impl ReturnRestMemStore { - fn construct_circuit( - challenges: ChipChallenges, - ) -> Result, ZKVMError> { + fn construct_circuit(challenges: ChipChallenges) -> Result, ZKVMError> { let mut circuit_builder = CircuitBuilder::new(); let (phase0_wire_id, phase0) = circuit_builder.create_witness_in(Self::phase0_size()); let mut ram_handler = RAMHandler::new(&challenges); @@ -544,9 +517,7 @@ register_witness!( ); impl ReturnRestStackPop { - fn construct_circuit( - challenges: ChipChallenges, - ) -> Result, ZKVMError> { + fn construct_circuit(challenges: ChipChallenges) -> Result, ZKVMError> { let mut circuit_builder = CircuitBuilder::new(); let (phase0_wire_id, phase0) = circuit_builder.create_witness_in(Self::phase0_size()); let mut ram_handler = RAMHandler::new(&challenges); diff --git a/singer/src/instructions/riscv.rs b/singer/src/instructions/riscv.rs new file mode 100644 index 000000000..cced7b48f --- /dev/null +++ b/singer/src/instructions/riscv.rs @@ -0,0 +1 @@ +pub mod add; diff --git a/singer/src/instructions/riscv/add.rs b/singer/src/instructions/riscv/add.rs new file mode 100644 index 000000000..fe551c478 --- /dev/null +++ b/singer/src/instructions/riscv/add.rs @@ -0,0 +1,384 @@ +use crate::error::ZKVMError; +use ff::Field; +use ff_ext::ExtensionField; +use gkr::structs::Circuit; +use paste::paste; +use simple_frontend::structs::{CircuitBuilder, MixedCell}; +use singer_utils::{ + chip_handler::{GlobalStateChipOperations, OAMOperations, ROMOperations, RegisterChipOperations}, + constants::OpcodeType, + register_witness, + riscv_constant::RvInstructions, + structs::{PCUInt, RAMHandler, ROMHandler, RegisterUInt, TSUInt, UInt64}, + uint::constants::AddSubConstants, +}; +use std::{collections::BTreeMap, sync::Arc}; + +use super::super::{ChipChallenges, InstCircuit, InstCircuitLayout, Instruction, InstructionGraph}; + +pub struct AddInstruction; + +impl InstructionGraph for AddInstruction { + type InstType = Self; +} + +register_witness!( + AddInstruction, + phase0 { + pc => PCUInt::N_OPERAND_CELLS, + memory_ts => TSUInt::N_OPERAND_CELLS, + clk => 1, + + rs1 => RegisterUInt::N_OPERAND_CELLS, + rs2 => RegisterUInt::N_OPERAND_CELLS, + rd => RegisterUInt::N_OPERAND_CELLS, + + next_pc => AddSubConstants::::N_NO_OVERFLOW_WITNESS_UNSAFE_CELLS, + next_memory_ts => AddSubConstants::::N_WITNESS_CELLS_NO_CARRY_OVERFLOW, + + // instruction operation + addend_0 => UInt64::N_OPERAND_CELLS, + addend_1 => UInt64::N_OPERAND_CELLS, + outcome => AddSubConstants::::N_WITNESS_CELLS, + + // the value pointed by `rd` before being written with `outcome` + prev_rd_value => UInt64::N_OPERAND_CELLS, + + // register timestamps and comparison gadgets + prev_rs1_ts => TSUInt::N_OPERAND_CELLS, + prev_rs2_ts => TSUInt::N_OPERAND_CELLS, + prev_rd_ts => TSUInt::N_OPERAND_CELLS, + prev_rs1_ts_lt => AddSubConstants::::N_WITNESS_CELLS, + prev_rs2_ts_lt => AddSubConstants::::N_WITNESS_CELLS, + prev_rd_ts_lt => AddSubConstants::::N_WITNESS_CELLS + } +); + +// TODO a workaround to keep the risc-v instruction +pub const RV_INSTRUCTION: RvInstructions = RvInstructions::ADD; +impl Instruction for AddInstruction { + // OPCODE is not used in RISC-V case, just for compatibility + const OPCODE: OpcodeType = OpcodeType::RISCV; + const NAME: &'static str = "ADD"; + fn construct_circuit(challenges: ChipChallenges) -> Result, ZKVMError> { + let mut circuit_builder = CircuitBuilder::new(); + let (phase0_wire_id, phase0) = circuit_builder.create_witness_in(Self::phase0_size()); + let mut ram_handler = RAMHandler::new(&challenges); + let mut rom_handler = ROMHandler::new(&challenges); + + let pc = PCUInt::try_from(&phase0[Self::phase0_pc()])?; + let memory_ts = &phase0[Self::phase0_memory_ts()]; + let clk = phase0[Self::phase0_clk().start]; + let clk_expr = MixedCell::Cell(clk); + let zero_cell_ids = [0]; + + // Bytecode check for (pc, add) + rom_handler.bytecode_with_pc(&mut circuit_builder, pc.values(), RV_INSTRUCTION.into()); + + // State update + ram_handler.state_in( + &mut circuit_builder, + pc.values(), + &zero_cell_ids, // we don't have stack info here + &memory_ts, + 0, + clk, + ); + + let next_pc = ROMHandler::increase_pc(&mut circuit_builder, &pc, &phase0[Self::phase0_next_pc()])?; + let next_memory_ts = rom_handler.increase_ts( + &mut circuit_builder, + &memory_ts.try_into()?, + &phase0[Self::phase0_next_memory_ts()], + )?; + + ram_handler.state_out( + &mut circuit_builder, + next_pc.values(), + &zero_cell_ids, + &next_memory_ts.values(), + MixedCell::Cell(0), + clk_expr.add(E::BaseField::ONE), + ); + + // Register timestamp range check + let prev_rs1_ts = (&phase0[Self::phase0_prev_rs1_ts()]).try_into()?; + let prev_rs2_ts = (&phase0[Self::phase0_prev_rs2_ts()]).try_into()?; + let prev_rd_ts = (&phase0[Self::phase0_prev_rd_ts()]).try_into()?; + let memory_ts = (&phase0[Self::phase0_memory_ts()]).try_into()?; + TSUInt::assert_lt( + &mut circuit_builder, + &mut rom_handler, + &prev_rs1_ts, + &memory_ts, + &phase0[Self::phase0_prev_rs1_ts_lt()], + )?; + TSUInt::assert_lt( + &mut circuit_builder, + &mut rom_handler, + &prev_rs2_ts, + &memory_ts, + &phase0[Self::phase0_prev_rs2_ts_lt()], + )?; + TSUInt::assert_lt( + &mut circuit_builder, + &mut rom_handler, + &prev_rd_ts, + &memory_ts, + &phase0[Self::phase0_prev_rd_ts_lt()], + )?; + if cfg!(feature = "dbg-opcode") { + println!("addInstCircuit::phase0_outcome: {:?}", Self::phase0_outcome()); + } + + // Execution result = addend0 + addend1, with carry. + let addend_0 = (&phase0[Self::phase0_addend_0()]).try_into()?; + let addend_1 = (&phase0[Self::phase0_addend_1()]).try_into()?; + let result = UInt64::add( + &mut circuit_builder, + &mut rom_handler, + &addend_0, + &addend_1, + &phase0[Self::phase0_outcome()], + )?; + + // Read/Write from registers + let rs1 = &phase0[Self::phase0_rs1()]; + let rs2 = &phase0[Self::phase0_rs2()]; + let rd = &phase0[Self::phase0_rd()]; + let prev_rd_value = &phase0[Self::phase0_prev_rd_value()]; + ram_handler.register_load( + &mut circuit_builder, + rs1, + prev_rs1_ts.values(), + memory_ts.values(), + &phase0[Self::phase0_addend_0()], + ); + ram_handler.register_load( + &mut circuit_builder, + rs2, + prev_rs2_ts.values(), + memory_ts.values(), + &phase0[Self::phase0_addend_1()], + ); + ram_handler.register_store( + &mut circuit_builder, + rd, + prev_rd_ts.values(), + memory_ts.values(), + &prev_rd_value, + result.values(), + ); + + // Ram/Rom finalization + let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); + let rom_id = rom_handler.finalize(&mut circuit_builder); + + circuit_builder.configure(); + + Ok(InstCircuit { + circuit: Arc::new(Circuit::new(&circuit_builder)), + layout: InstCircuitLayout { + chip_check_wire_id: [ram_load_id, ram_store_id, rom_id], + phases_wire_id: vec![phase0_wire_id], + ..Default::default() + }, + }) + } +} + +#[cfg(test)] +mod test { + use ark_std::test_rng; + use ff::Field; + use ff_ext::ExtensionField; + use gkr::structs::LayerWitness; + use goldilocks::{Goldilocks, GoldilocksExt2}; + use itertools::Itertools; + use singer_utils::{ + constants::RANGE_CHIP_BIT_WIDTH, + structs::{TSUInt, UInt64}, + }; + use std::{collections::BTreeMap, time::Instant}; + use transcript::Transcript; + + use crate::{ + instructions::{ + riscv::add::{AddInstruction, RV_INSTRUCTION}, + ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder, + }, + scheme::GKRGraphProverState, + test::{get_uint_params, test_opcode_circuit_v2}, + utils::u64vec, + CircuitWiresIn, SingerGraphBuilder, SingerParams, + }; + + #[test] + fn test_add_construct_circuit() { + let challenges = ChipChallenges::default(); + + let phase0_idx_map = AddInstruction::phase0_idxes_map(); + let phase0_witness_size = AddInstruction::phase0_size(); + + if cfg!(feature = "dbg-opcode") { + println!("ADD: {:?}", &phase0_idx_map); + println!("ADD witness_size: {:?}", phase0_witness_size); + } + + // initialize general test inputs associated with push1 + let inst_circuit = AddInstruction::construct_circuit(challenges).unwrap(); + + if cfg!(feature = "dbg-opcode") { + println!("{:?}", inst_circuit.circuit.assert_consts); + } + + let mut phase0_values_map = BTreeMap::<&'static str, Vec>::new(); + phase0_values_map.insert(AddInstruction::phase0_pc_str(), vec![Goldilocks::from(1u64)]); + phase0_values_map.insert(AddInstruction::phase0_memory_ts_str(), vec![Goldilocks::from(3u64)]); + phase0_values_map.insert( + AddInstruction::phase0_next_memory_ts_str(), + vec![ + // first TSUInt::N_RANGE_CELLS = 1*(48/16) = 3 cells are range values. + // memory_ts + 1 = 4 + Goldilocks::from(4u64), + Goldilocks::from(0u64), + Goldilocks::from(0u64), + ], + ); + phase0_values_map.insert(AddInstruction::phase0_clk_str(), vec![Goldilocks::from(1u64)]); + phase0_values_map.insert( + AddInstruction::phase0_next_pc_str(), + vec![], // carry is 0, may test carry using larger values in PCUInt + ); + + // register id assigned + phase0_values_map.insert(AddInstruction::phase0_rs1_str(), vec![Goldilocks::from(1u64)]); + phase0_values_map.insert(AddInstruction::phase0_rs2_str(), vec![Goldilocks::from(2u64)]); + phase0_values_map.insert(AddInstruction::phase0_rd_str(), vec![Goldilocks::from(3u64)]); + + let m: u64 = (1 << get_uint_params::().1) - 1; + phase0_values_map.insert(AddInstruction::phase0_addend_0_str(), vec![Goldilocks::from(m)]); + phase0_values_map.insert(AddInstruction::phase0_addend_1_str(), vec![Goldilocks::from(1u64)]); + let range_values = u64vec::<{ UInt64::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m + 1); + let mut wit_phase0_outcome: Vec = vec![]; + for i in 0..4 { + wit_phase0_outcome.push(Goldilocks::from(range_values[i])) + } + wit_phase0_outcome.push(Goldilocks::from(1u64)); // carry is [1, 0, ...] + phase0_values_map.insert(AddInstruction::phase0_outcome_str(), wit_phase0_outcome); + + phase0_values_map.insert( + AddInstruction::phase0_prev_rd_value_str(), + vec![Goldilocks::from(33u64)], + ); + + phase0_values_map.insert(AddInstruction::phase0_prev_rs1_ts_str(), vec![Goldilocks::from(2u64)]); + let m: u64 = (1 << get_uint_params::().1) - 1; + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + phase0_values_map.insert( + AddInstruction::phase0_prev_rs1_ts_lt_str(), + vec![ + Goldilocks::from(range_values[0]), + Goldilocks::from(range_values[1]), + Goldilocks::from(range_values[2]), + Goldilocks::from(1u64), // borrow + ], + ); + phase0_values_map.insert(AddInstruction::phase0_prev_rs2_ts_str(), vec![Goldilocks::from(1u64)]); + let m: u64 = (1 << get_uint_params::().1) - 2; + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + phase0_values_map.insert( + AddInstruction::phase0_prev_rs2_ts_lt_str(), + vec![ + Goldilocks::from(range_values[0]), + Goldilocks::from(range_values[1]), + Goldilocks::from(range_values[2]), + Goldilocks::from(1u64), // borrow + ], + ); + phase0_values_map.insert(AddInstruction::phase0_prev_rd_ts_str(), vec![Goldilocks::from(2u64)]); + let m: u64 = (1 << get_uint_params::().1) - 1; + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + phase0_values_map.insert( + AddInstruction::phase0_prev_rd_ts_lt_str(), + vec![ + Goldilocks::from(range_values[0]), + Goldilocks::from(range_values[1]), + Goldilocks::from(range_values[2]), + Goldilocks::from(1u64), // borrow + ], + ); + // The actual challenges used is: + // challenges + // { ChallengeConst { challenge: 1, exp: i }: [Goldilocks(c^i)] } + let c = GoldilocksExt2::from(66u64); + let circuit_witness_challenges = vec![c; 3]; + + test_opcode_circuit_v2( + &inst_circuit, + &phase0_idx_map, + phase0_witness_size, + &phase0_values_map, + circuit_witness_challenges, + ); + } + + #[cfg(not(debug_assertions))] + fn bench_add_instruction_helper(instance_num_vars: usize) { + let chip_challenges = ChipChallenges::default(); + let circuit_builder = SingerCircuitBuilder::::new_riscv(chip_challenges).expect("circuit builder failed"); + let mut singer_builder = SingerGraphBuilder::::new(); + + let mut rng = test_rng(); + let size = AddInstruction::phase0_size(); + let phase0: CircuitWiresIn = vec![LayerWitness { + instances: (0..(1 << instance_num_vars)) + .map(|_| (0..size).map(|_| E::BaseField::random(&mut rng)).collect_vec()) + .collect_vec(), + }]; + + let real_challenges = vec![E::random(&mut rng), E::random(&mut rng)]; + + let timer = Instant::now(); + + let _ = AddInstruction::construct_graph_and_witness( + &mut singer_builder.graph_builder, + &mut singer_builder.chip_builder, + &circuit_builder.insts_circuits[RV_INSTRUCTION as usize], + vec![phase0], + &real_challenges, + 1 << instance_num_vars, + &SingerParams::default(), + ) + .expect("gkr graph construction failed"); + + let (graph, wit) = singer_builder.graph_builder.finalize_graph_and_witness(); + + println!( + "AddInstruction::construct_graph_and_witness, instance_num_vars = {}, time = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); + + let point = vec![E::random(&mut rng), E::random(&mut rng)]; + let target_evals = graph.target_evals(&wit, &point); + + let mut prover_transcript = &mut Transcript::new(b"Singer"); + + let timer = Instant::now(); + let _ = + GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1).expect("prove failed"); + println!( + "AddInstruction::prove, instance_num_vars = {}, time + = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); + } + + #[test] + #[cfg(not(debug_assertions))] + fn bench_add_instruction() { + bench_add_instruction_helper::(10); + } +} diff --git a/singer/src/instructions/swap.rs b/singer/src/instructions/swap.rs index e112e0045..76b30e9ed 100644 --- a/singer/src/instructions/swap.rs +++ b/singer/src/instructions/swap.rs @@ -5,8 +5,8 @@ use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; use singer_utils::{ chip_handler::{ - BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, - RangeChipOperations, StackChipOperations, + BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, RangeChipOperations, + StackChipOperations, }, constants::OpcodeType, register_witness, @@ -81,14 +81,9 @@ impl Instruction for SwapInstruction { clk, ); - let next_pc = - ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; - let next_stack_ts = rom_handler.add_ts_with_const( - &mut circuit_builder, - &stack_ts, - 1, - &phase0[Self::phase0_stack_ts_add()], - )?; + let next_pc = ROMHandler::add_pc_const(&mut circuit_builder, &pc, 1, &phase0[Self::phase0_pc_add()])?; + let next_stack_ts = + rom_handler.add_ts_with_const(&mut circuit_builder, &stack_ts, 1, &phase0[Self::phase0_stack_ts_add()])?; ram_handler.state_out( &mut circuit_builder, @@ -155,11 +150,7 @@ impl Instruction for SwapInstruction { ); // Bytecode check for (pc, SWAP{N}). - rom_handler.bytecode_with_pc_opcode( - &mut circuit_builder, - pc.values(), - >::OPCODE, - ); + rom_handler.bytecode_with_pc_opcode(&mut circuit_builder, pc.values(), >::OPCODE); let (ram_load_id, ram_store_id) = ram_handler.finalize(&mut circuit_builder); let rom_id = rom_handler.finalize(&mut circuit_builder); @@ -191,9 +182,7 @@ mod test { use transcript::Transcript; use crate::{ - instructions::{ - ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder, SwapInstruction, - }, + instructions::{ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder, SwapInstruction}, scheme::GKRGraphProverState, test::{get_uint_params, test_opcode_circuit}, utils::u64vec, @@ -223,10 +212,7 @@ mod test { phase0_values_map.insert("phase0_pc".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert("phase0_stack_ts".to_string(), vec![Goldilocks::from(4u64)]); phase0_values_map.insert("phase0_memory_ts".to_string(), vec![Goldilocks::from(1u64)]); - phase0_values_map.insert( - "phase0_stack_top".to_string(), - vec![Goldilocks::from(100u64)], - ); + phase0_values_map.insert("phase0_stack_top".to_string(), vec![Goldilocks::from(100u64)]); phase0_values_map.insert("phase0_clk".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( "phase0_pc_add".to_string(), @@ -243,10 +229,7 @@ mod test { // no place for carry ], ); - phase0_values_map.insert( - "phase0_old_stack_ts_1".to_string(), - vec![Goldilocks::from(3u64)], - ); + phase0_values_map.insert("phase0_old_stack_ts_1".to_string(), vec![Goldilocks::from(3u64)]); let m: u64 = (1 << get_uint_params::().1) - 1; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -258,10 +241,7 @@ mod test { Goldilocks::from(1u64), // current length has no cells for borrow ], ); - phase0_values_map.insert( - "phase0_old_stack_ts_n_plus_1".to_string(), - vec![Goldilocks::from(1u64)], - ); + phase0_values_map.insert("phase0_old_stack_ts_n_plus_1".to_string(), vec![Goldilocks::from(1u64)]); let m: u64 = (1 << get_uint_params::().1) - 3; let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( @@ -318,19 +298,14 @@ mod test { #[cfg(not(debug_assertions))] fn bench_swap_instruction_helper(instance_num_vars: usize) { let chip_challenges = ChipChallenges::default(); - let circuit_builder = - SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); + let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); let mut singer_builder = SingerGraphBuilder::::new(); let mut rng = test_rng(); let size = SwapInstruction::::phase0_size(); let phase0: CircuitWiresIn = vec![LayerWitness { instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| E::BaseField::random(&mut rng)) - .collect_vec() - }) + .map(|_| (0..size).map(|_| E::BaseField::random(&mut rng)).collect_vec()) .collect_vec(), }]; @@ -341,8 +316,7 @@ mod test { let _ = SwapInstruction::::construct_graph_and_witness( &mut singer_builder.graph_builder, &mut singer_builder.chip_builder, - &circuit_builder.insts_circuits - [ as Instruction>::OPCODE as usize], + &circuit_builder.insts_circuits[ as Instruction>::OPCODE as usize], vec![phase0], &real_challenges, 1 << instance_num_vars, @@ -365,8 +339,8 @@ mod test { let mut prover_transcript = &mut Transcript::new(b"Singer"); let timer = Instant::now(); - let _ = GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1) - .expect("prove failed"); + let _ = + GKRGraphProverState::prove(&graph, &wit, &target_evals, &mut prover_transcript, 1).expect("prove failed"); println!( "Swap{}Instruction::prove, instance_num_vars = {}, time = {}", N, diff --git a/singer/src/instructions_riscv_ext.rs b/singer/src/instructions_riscv_ext.rs new file mode 100644 index 000000000..778120190 --- /dev/null +++ b/singer/src/instructions_riscv_ext.rs @@ -0,0 +1,55 @@ +use ff_ext::ExtensionField; +use singer_utils::{ + chips::IntoEnumIterator, + riscv_constant::{RV64IOpcode, RvInstructions, RvOpcode}, + structs::ChipChallenges, +}; + +use crate::{ + error::ZKVMError, + instructions::{riscv, InstCircuit, InstructionGraph, SingerCircuitBuilder}, +}; + +impl SingerCircuitBuilder { + pub fn new_riscv(challenges: ChipChallenges) -> Result { + let ins_len = RvInstructions::END as usize; + let mut insts_circuits = Vec::with_capacity(256); + for opcode in RvInstructions::iter() { + insts_circuits.push(construct_instruction_circuits(opcode.into(), challenges)?); + } + for _ in ins_len..255 { + insts_circuits.push(construct_instruction_circuits(RvInstructions::END.into(), challenges)?); + } + let insts_circuits: [Vec>; 256] = + insts_circuits.try_into().map_err(|_| ZKVMError::CircuitError)?; + Ok(Self { + insts_circuits, + challenges, + }) + } +} + +fn process_opcode_r( + instruction: RvOpcode, + challenges: ChipChallenges, +) -> Result>, ZKVMError> { + // Find the instruction format here: + // https://fraserinnovations.com/risc-v/risc-v-instruction-set-explanation/ + match instruction.funct3 { + 0b000 => match instruction.funct7 { + 0b000_0000 => riscv::add::AddInstruction::construct_circuits(challenges), + _ => Ok(vec![]), // TODO: Add more operations. + }, + _ => Ok(vec![]), // TODO: Add more instructions. + } +} + +pub(crate) fn construct_instruction_circuits( + instruction: RvOpcode, + challenges: ChipChallenges, +) -> Result>, ZKVMError> { + match instruction.opcode { + RV64IOpcode::R => process_opcode_r(instruction, challenges), + _ => Ok(vec![]), // TODO: Add more instructions. + } +} diff --git a/singer/src/lib.rs b/singer/src/lib.rs index aa829c07f..ad8a6f0ef 100644 --- a/singer/src/lib.rs +++ b/singer/src/lib.rs @@ -3,18 +3,15 @@ use error::ZKVMError; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; -use gkr_graph::structs::{ - CircuitGraph, CircuitGraphAuxInfo, CircuitGraphBuilder, CircuitGraphWitness, NodeOutputType, -}; +use gkr_graph::structs::{CircuitGraph, CircuitGraphAuxInfo, CircuitGraphBuilder, CircuitGraphWitness, NodeOutputType}; use goldilocks::SmallField; -use instructions::{ - construct_inst_graph, construct_inst_graph_and_witness, InstOutputType, SingerCircuitBuilder, -}; +use instructions::{construct_inst_graph, construct_inst_graph_and_witness, InstOutputType, SingerCircuitBuilder}; use singer_utils::chips::SingerChipBuilder; use std::mem; pub mod error; pub mod instructions; +pub mod instructions_riscv_ext; pub mod scheme; #[cfg(test)] pub mod test; @@ -58,14 +55,7 @@ impl SingerGraphBuilder { program_input: &[u8], real_challenges: &[E], params: &SingerParams, - ) -> Result< - ( - SingerCircuit, - SingerWitness, - SingerWiresOutID, - ), - ZKVMError, - > { + ) -> Result<(SingerCircuit, SingerWitness, SingerWiresOutID), ZKVMError> { // Add instruction and its extension (if any) circuits to the graph. for inst_wires_in in singer_wires_in.instructions.into_iter() { let InstWiresIn { @@ -119,11 +109,7 @@ impl SingerGraphBuilder { let (graph, graph_witness) = graph_builder.finalize_graph_and_witness_with_targets(&singer_wire_out_id.to_vec()); - Ok(( - SingerCircuit(graph), - SingerWitness(graph_witness), - singer_wire_out_id, - )) + Ok((SingerCircuit(graph), SingerWitness(graph_witness), singer_wire_out_id)) } pub fn construct_graph( @@ -217,12 +203,7 @@ pub struct SingerWiresOutValues { impl SingerWiresOutID { pub fn to_vec(&self) -> Vec { - let mut res = [ - self.ram_load.clone(), - self.ram_store.clone(), - self.rom_input.clone(), - ] - .concat(); + let mut res = [self.ram_load.clone(), self.ram_store.clone(), self.rom_input.clone()].concat(); if let Some(public_output_size) = self.public_output_size { res.push(public_output_size); } diff --git a/singer/src/scheme/prover.rs b/singer/src/scheme/prover.rs index 82f500aba..11d744481 100644 --- a/singer/src/scheme/prover.rs +++ b/singer/src/scheme/prover.rs @@ -5,9 +5,7 @@ use gkr_graph::structs::{CircuitGraphAuxInfo, NodeOutputType}; use itertools::Itertools; use transcript::Transcript; -use crate::{ - error::ZKVMError, SingerCircuit, SingerWiresOutID, SingerWiresOutValues, SingerWitness, -}; +use crate::{error::ZKVMError, SingerCircuit, SingerWiresOutID, SingerWiresOutValues, SingerWitness}; use super::{GKRGraphProverState, SingerProof}; @@ -19,11 +17,7 @@ pub fn prove( ) -> Result<(SingerProof, CircuitGraphAuxInfo), ZKVMError> { // TODO: Add PCS. let point = (0..2 * ::DEGREE) - .map(|_| { - transcript - .get_and_append_challenge(b"output point") - .elements - }) + .map(|_| transcript.get_and_append_challenge(b"output point").elements) .collect_vec(); let singer_out_evals = { @@ -32,15 +26,13 @@ pub fn prove( .iter() .map(|node| { match node { - NodeOutputType::OutputLayer(node_id) => vm_witness.0.node_witnesses - [*node_id as usize] + NodeOutputType::OutputLayer(node_id) => vm_witness.0.node_witnesses[*node_id as usize] .output_layer_witness_ref() .instances .iter() .cloned() .flatten(), - NodeOutputType::WireOut(node_id, wit_id) => vm_witness.0.node_witnesses - [*node_id as usize] + NodeOutputType::WireOut(node_id, wit_id) => vm_witness.0.node_witnesses[*node_id as usize] .witness_out_ref()[*wit_id as usize] .instances .iter() @@ -76,8 +68,7 @@ pub fn prove( }; let target_evals = vm_circuit.0.target_evals(&vm_witness.0, &point); - let gkr_phase_proof = - GKRGraphProverState::prove(&vm_circuit.0, &vm_witness.0, &target_evals, transcript, 1)?; + let gkr_phase_proof = GKRGraphProverState::prove(&vm_circuit.0, &vm_witness.0, &target_evals, transcript, 1)?; Ok(( SingerProof { gkr_phase_proof, diff --git a/singer/src/scheme/verifier.rs b/singer/src/scheme/verifier.rs index a949a7598..e9b0dfaf7 100644 --- a/singer/src/scheme/verifier.rs +++ b/singer/src/scheme/verifier.rs @@ -17,11 +17,7 @@ pub fn verify( ) -> Result<(), ZKVMError> { // TODO: Add PCS. let point = (0..2 * ::DEGREE) - .map(|_| { - transcript - .get_and_append_challenge(b"output point") - .elements - }) + .map(|_| transcript.get_and_append_challenge(b"output point").elements) .collect_vec(); let SingerWiresOutValues { @@ -45,9 +41,7 @@ pub fn verify( let (den, num) = x.split_at(l / 2); (E::from_limbs(den), E::from_limbs(num)) }) - .fold((E::ONE, E::ZERO), |acc, x| { - (acc.0 * x.0, acc.0 * x.1 + acc.1 * x.0) - }); + .fold((E::ONE, E::ZERO), |acc, x| (acc.0 * x.0, acc.0 * x.1 + acc.1 * x.0)); let rom_table_sum = rom_table .iter() .map(|x| { @@ -55,9 +49,7 @@ pub fn verify( let (den, num) = x.split_at(l / 2); (E::from_limbs(den), E::from_limbs(num)) }) - .fold((E::ONE, E::ZERO), |acc, x| { - (acc.0 * x.0, acc.0 * x.1 + acc.1 * x.0) - }); + .fold((E::ONE, E::ZERO), |acc, x| (acc.0 * x.0, acc.0 * x.1 + acc.1 * x.0)); if rom_input_sum.0 * rom_table_sum.1 != rom_input_sum.1 * rom_table_sum.0 { return Err(ZKVMError::VerifyError); } @@ -66,10 +58,7 @@ pub fn verify( chain![ram_load, ram_store, rom_input, rom_table,] .map(|x| { let f = vec![x.to_vec()].as_slice().original_mle(); - PointAndEval::new( - point[..f.num_vars].to_vec(), - f.evaluate(&point[..f.num_vars]), - ) + PointAndEval::new(point[..f.num_vars].to_vec(), f.evaluate(&point[..f.num_vars])) }) .collect_vec(), ); @@ -80,10 +69,7 @@ pub fn verify( point[..f.num_vars].to_vec(), f.evaluate(&point[..f.num_vars]), )); - assert_eq!( - output[0], - E::BaseField::from(aux_info.program_output_len as u64) - ) + assert_eq!(output[0], E::BaseField::from(aux_info.program_output_len as u64)) } GKRGraphVerifierState::verify( diff --git a/singer/src/test.rs b/singer/src/test.rs index 563cc2b4d..153b3348f 100644 --- a/singer/src/test.rs +++ b/singer/src/test.rs @@ -46,14 +46,8 @@ pub(crate) fn test_opcode_circuit_v2( witness_in[phase0_input_idx as usize] = vec![Ext::BaseField::ZERO; phase0_witness_size]; for key in phase0_idx_map.keys() { - let range = phase0_idx_map - .get(key) - .unwrap() - .clone() - .collect::>(); - let values = phase0_values_map - .get(key) - .expect(&("unknown key ".to_owned() + key)); + let range = phase0_idx_map.get(key).unwrap().clone().collect::>(); + let values = phase0_values_map.get(key).expect(&("unknown key ".to_owned() + key)); for (value_idx, cell_idx) in range.into_iter().enumerate() { if value_idx < values.len() { witness_in[phase0_input_idx as usize][cell_idx] = values[value_idx]; diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index c4dcfeb3e..155993322 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -29,10 +29,8 @@ fn prepare_input( ) -> (E, VirtualPolynomial, Vec>) { let mut rng = test_rng(); let size_log2 = ceil_log2(max_thread_id); - let f1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - let g1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); + let f1: Arc> = DenseMultilinearExtension::::random(nv, &mut rng).into(); + let g1: Arc> = DenseMultilinearExtension::::random(nv, &mut rng).into(); let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1.clone(), E::BaseField::ONE); virtual_poly_1.mul_by_mle(g1.clone(), ::BaseField::ONE); @@ -40,10 +38,7 @@ fn prepare_input( let mut virtual_poly_f1: Vec> = match &f1.evaluations { multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) + .map(|chunk| DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()).into()) .map(|mle| VirtualPolynomial::new_from_mle(mle, E::BaseField::ONE)) .collect_vec(), _ => unreachable!(), @@ -52,10 +47,7 @@ fn prepare_input( let poly_g1: Vec> = match &g1.evaluations { multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) + .map(|chunk| DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()).into()) .collect_vec(), _ => unreachable!(), }; @@ -71,11 +63,7 @@ fn prepare_input( .iter_mut() .zip(poly_g1.iter()) .for_each(|(f1, g1)| f1.mul_by_mle(g1.clone(), E::BaseField::ONE)); - ( - asserted_sum, - virtual_poly_1, - virtual_poly_f1.try_into().unwrap(), - ) + (asserted_sum, virtual_poly_1, virtual_poly_f1.try_into().unwrap()) } #[from_env] @@ -90,30 +78,19 @@ fn sumcheck_fn(c: &mut Criterion) { group.sample_size(NUM_SAMPLES); // Benchmark the proving time - group.bench_function( - BenchmarkId::new("prove_sumcheck", format!("sumcheck_nv_{}", nv)), - |b| { - b.iter_with_setup( - || { - let prover_transcript = Transcript::::new(b"test"); - let (asserted_sum, virtual_poly, virtual_poly_splitted) = - { prepare_input(RAYON_NUM_THREADS, nv) }; - ( - prover_transcript, - asserted_sum, - virtual_poly, - virtual_poly_splitted, - ) - }, - |(mut prover_transcript, asserted_sum, virtual_poly, virtual_poly_splitted)| { - let (sumcheck_proof_v1, _) = IOPProverState::::prove_parallel( - virtual_poly.clone(), - &mut prover_transcript, - ); - }, - ); - }, - ); + group.bench_function(BenchmarkId::new("prove_sumcheck", format!("sumcheck_nv_{}", nv)), |b| { + b.iter_with_setup( + || { + let prover_transcript = Transcript::::new(b"test"); + let (asserted_sum, virtual_poly, virtual_poly_splitted) = { prepare_input(RAYON_NUM_THREADS, nv) }; + (prover_transcript, asserted_sum, virtual_poly, virtual_poly_splitted) + }, + |(mut prover_transcript, asserted_sum, virtual_poly, virtual_poly_splitted)| { + let (sumcheck_proof_v1, _) = + IOPProverState::::prove_parallel(virtual_poly.clone(), &mut prover_transcript); + }, + ); + }); group.finish(); } @@ -128,31 +105,22 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { group.sample_size(NUM_SAMPLES); // Benchmark the proving time - group.bench_function( - BenchmarkId::new("prove_sumcheck", format!("devirgo_nv_{}", nv)), - |b| { - b.iter_with_setup( - || { - let prover_transcript = Transcript::::new(b"test"); - let (asserted_sum, virtual_poly, virtual_poly_splitted) = - { prepare_input(RAYON_NUM_THREADS, nv) }; - ( - prover_transcript, - asserted_sum, - virtual_poly, - virtual_poly_splitted, - ) - }, - |(mut prover_transcript, asserted_sum, virtual_poly, virtual_poly_splitted)| { - let (sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( - RAYON_NUM_THREADS, - virtual_poly_splitted, - &mut prover_transcript, - ); - }, - ); - }, - ); + group.bench_function(BenchmarkId::new("prove_sumcheck", format!("devirgo_nv_{}", nv)), |b| { + b.iter_with_setup( + || { + let prover_transcript = Transcript::::new(b"test"); + let (asserted_sum, virtual_poly, virtual_poly_splitted) = { prepare_input(RAYON_NUM_THREADS, nv) }; + (prover_transcript, asserted_sum, virtual_poly, virtual_poly_splitted) + }, + |(mut prover_transcript, asserted_sum, virtual_poly, virtual_poly_splitted)| { + let (sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( + RAYON_NUM_THREADS, + virtual_poly_splitted, + &mut prover_transcript, + ); + }, + ); + }); group.finish(); } diff --git a/sumcheck/examples/devirgo_sumcheck.rs b/sumcheck/examples/devirgo_sumcheck.rs index 3cc7be741..abb8fb36c 100644 --- a/sumcheck/examples/devirgo_sumcheck.rs +++ b/sumcheck/examples/devirgo_sumcheck.rs @@ -18,16 +18,12 @@ use transcript::Transcript; type E = GoldilocksExt2; -fn prepare_input( - max_thread_id: usize, -) -> (E, VirtualPolynomial, Vec>) { +fn prepare_input(max_thread_id: usize) -> (E, VirtualPolynomial, Vec>) { let nv = 10; let mut rng = test_rng(); let size_log2 = ceil_log2(max_thread_id); - let f1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - let g1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); + let f1: Arc> = DenseMultilinearExtension::::random(nv, &mut rng).into(); + let g1: Arc> = DenseMultilinearExtension::::random(nv, &mut rng).into(); let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1.clone(), E::BaseField::ONE); virtual_poly_1.mul_by_mle(g1.clone(), ::BaseField::ONE); @@ -35,10 +31,7 @@ fn prepare_input( let mut virtual_poly_f1: Vec> = match &f1.evaluations { multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) + .map(|chunk| DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()).into()) .map(|mle| VirtualPolynomial::new_from_mle(mle, E::BaseField::ONE)) .collect_vec(), _ => unreachable!(), @@ -47,10 +40,7 @@ fn prepare_input( let poly_g1: Vec> = match &g1.evaluations { multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) + .map(|chunk| DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()).into()) .collect_vec(), _ => unreachable!(), }; @@ -66,11 +56,7 @@ fn prepare_input( .iter_mut() .zip(poly_g1.iter()) .for_each(|(f1, g1)| f1.mul_by_mle(g1.clone(), E::BaseField::ONE)); - ( - asserted_sum, - virtual_poly_1, - virtual_poly_f1.try_into().unwrap(), - ) + (asserted_sum, virtual_poly_1, virtual_poly_f1.try_into().unwrap()) } #[from_env] @@ -90,26 +76,14 @@ fn main() { let mut transcript = Transcript::new(b"test"); let poly_info = virtual_poly.aux_info.clone(); - let subclaim = IOPVerifierState::::verify( - asserted_sum, - &sumcheck_proof_v2, - &poly_info, - &mut transcript, - ); + let subclaim = IOPVerifierState::::verify(asserted_sum, &sumcheck_proof_v2, &poly_info, &mut transcript); assert!( - virtual_poly.evaluate( - &subclaim - .point - .iter() - .map(|c| c.elements) - .collect::>() - .as_ref() - ) == subclaim.expected_evaluation, + virtual_poly.evaluate(&subclaim.point.iter().map(|c| c.elements).collect::>().as_ref()) + == subclaim.expected_evaluation, "wrong subclaim" ); - let (sumcheck_proof_v1, _) = - IOPProverState::::prove_parallel(virtual_poly.clone(), &mut prover_transcript_v1); + let (sumcheck_proof_v1, _) = IOPProverState::::prove_parallel(virtual_poly.clone(), &mut prover_transcript_v1); println!("v1 finish"); assert!(sumcheck_proof_v2 == sumcheck_proof_v1); diff --git a/sumcheck/src/local_thread_pool.rs b/sumcheck/src/local_thread_pool.rs index 8df2f620e..1dbb513ed 100644 --- a/sumcheck/src/local_thread_pool.rs +++ b/sumcheck/src/local_thread_pool.rs @@ -8,10 +8,7 @@ static LOCAL_THREAD_POOL_SET: Once = Once::new(); pub fn create_local_pool_once(size: usize, in_place: bool) { unsafe { let size = if in_place { size - 1 } else { size }; - let pool_size = LOCAL_THREAD_POOL - .as_ref() - .map(|a| a.current_num_threads()) - .unwrap_or(0); + let pool_size = LOCAL_THREAD_POOL.as_ref().map(|a| a.current_num_threads()).unwrap_or(0); if pool_size > 0 && pool_size != size { panic!( "calling prove_batch_polys with different polys size. prev size {} vs now size {}", @@ -19,14 +16,10 @@ pub fn create_local_pool_once(size: usize, in_place: bool) { ); } LOCAL_THREAD_POOL_SET.call_once(|| { - let _ = Some(&*LOCAL_THREAD_POOL.get_or_insert_with(|| { - Arc::new( - rayon::ThreadPoolBuilder::new() - .num_threads(size) - .build() - .unwrap(), - ) - })); + let _ = + Some(&*LOCAL_THREAD_POOL.get_or_insert_with(|| { + Arc::new(rayon::ThreadPoolBuilder::new().num_threads(size).build().unwrap()) + })); }); } } diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index d32423ee7..7bde8a89e 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -16,10 +16,7 @@ use crate::local_thread_pool::{create_local_pool_once, LOCAL_THREAD_POOL}; use crate::{ entered_span, exit_span, structs::{IOPProof, IOPProverMessage, IOPProverState}, - util::{ - barycentric_weights, ceil_log2, extrapolate, merge_sumcheck_polys, AdditiveArray, - AdditiveVec, - }, + util::{barycentric_weights, ceil_log2, extrapolate, merge_sumcheck_polys, AdditiveArray, AdditiveVec}, }; impl IOPProverState { @@ -37,10 +34,7 @@ impl IOPProverState { assert_eq!(polys.len(), max_thread_id); let log2_max_thread_id = ceil_log2(max_thread_id); // do not support SIZE not power of 2 - let (num_variables, max_degree) = ( - polys[0].aux_info.num_variables, - polys[0].aux_info.max_degree, - ); + let (num_variables, max_degree) = (polys[0].aux_info.num_variables, polys[0].aux_info.max_degree); for poly in polys[1..].iter() { assert!(poly.aux_info.num_variables == num_variables); assert!(poly.aux_info.max_degree == max_degree); @@ -75,10 +69,8 @@ impl IOPProverState { // spawn extra #(max_thread_id - 1) work threads, whereas the main-thread be the last work // thread for thread_id in 0..(max_thread_id - 1) { - let mut prover_state = Self::prover_init_with_extrapolation_aux( - mem::take(&mut polys[thread_id]), - extrapolation_aux.clone(), - ); + let mut prover_state = + Self::prover_init_with_extrapolation_aux(mem::take(&mut polys[thread_id]), extrapolation_aux.clone()); let tx_prover_state = tx_prover_state.clone(); let mut thread_based_transcript = thread_based_transcript.clone(); @@ -86,12 +78,10 @@ impl IOPProverState { let mut challenge = None; let span = entered_span!("prove_rounds"); for _ in 0..num_variables { - let prover_msg = - IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); + let prover_msg = IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); thread_based_transcript.append_field_element_exts(&prover_msg.evaluations); - challenge = - Some(thread_based_transcript.get_and_append_challenge(b"Internal round")); + challenge = Some(thread_based_transcript.get_and_append_challenge(b"Internal round")); thread_based_transcript.commit_rolling(); } exit_span!(span); @@ -99,17 +89,11 @@ impl IOPProverState { if let Some(p) = challenge { prover_state.challenges.push(p); // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .iter_mut() - .for_each(|mle| { - let mle = Arc::make_mut(mle); - mle.fix_variables_in_place(&[p.elements]); - }); - tx_prover_state - .send(Some((thread_id, prover_state))) - .unwrap(); + prover_state.poly.flattened_ml_extensions.iter_mut().for_each(|mle| { + let mle = Arc::make_mut(mle); + mle.fix_variables_in_place(&[p.elements]); + }); + tx_prover_state.send(Some((thread_id, prover_state))).unwrap(); } else { tx_prover_state.send(None).unwrap(); } @@ -140,10 +124,8 @@ impl IOPProverState { let mut prover_msgs = Vec::with_capacity(num_variables); let thread_id = max_thread_id - 1; - let mut prover_state = Self::prover_init_with_extrapolation_aux( - mem::take(&mut polys[thread_id]), - extrapolation_aux.clone(), - ); + let mut prover_state = + Self::prover_init_with_extrapolation_aux(mem::take(&mut polys[thread_id]), extrapolation_aux.clone()); let tx_prover_state = tx_prover_state.clone(); let mut thread_based_transcript = thread_based_transcript.clone(); @@ -153,8 +135,7 @@ impl IOPProverState { // refactor to shared closure cause to 5% throuput drop let mut challenge = None; for _ in 0..num_variables { - let prover_msg = - IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); + let prover_msg = IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); thread_based_transcript.append_field_element_exts(&prover_msg.evaluations); // for each round, we must collect #SIZE prover message @@ -188,17 +169,11 @@ impl IOPProverState { if let Some(p) = challenge { prover_state.challenges.push(p); // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .iter_mut() - .for_each(|mle| { - let mle = Arc::make_mut(mle); - mle.fix_variables_in_place(&[p.elements]); - }); - tx_prover_state - .send(Some((thread_id, prover_state))) - .unwrap(); + prover_state.poly.flattened_ml_extensions.iter_mut().for_each(|mle| { + let mle = Arc::make_mut(mle); + mle.fix_variables_in_place(&[p.elements]); + }); + tx_prover_state.send(Some((thread_id, prover_state))).unwrap(); } else { tx_prover_state.send(None).unwrap(); } @@ -232,14 +207,12 @@ impl IOPProverState { // second stage sumcheck let poly = merge_sumcheck_polys(&prover_states, max_thread_id); - let mut prover_state = - Self::prover_init_with_extrapolation_aux(poly, extrapolation_aux.clone()); + let mut prover_state = Self::prover_init_with_extrapolation_aux(poly, extrapolation_aux.clone()); let mut challenge = None; let span = entered_span!("prove_rounds_stage2"); for _ in 0..log2_max_thread_id { - let prover_msg = - IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); + let prover_msg = IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); prover_msg .evaluations @@ -290,10 +263,7 @@ impl IOPProverState { extrapolation_aux: Vec<(Vec, Vec)>, ) -> Self { let start = start_timer!(|| "sum check prover init"); - assert_ne!( - polynomial.aux_info.num_variables, 0, - "Attempt to prove a constant." - ); + assert_ne!(polynomial.aux_info.num_variables, 0, "Attempt to prove a constant."); end_timer!(start); let max_degree = polynomial.aux_info.max_degree; @@ -311,17 +281,10 @@ impl IOPProverState { /// /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2). #[tracing::instrument(skip_all, name = "sumcheck::prove_round_and_update_state")] - pub(crate) fn prove_round_and_update_state( - &mut self, - challenge: &Option>, - ) -> IOPProverMessage { - let start = - start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); + pub(crate) fn prove_round_and_update_state(&mut self, challenge: &Option>) -> IOPProverMessage { + let start = start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); - assert!( - self.round < self.poly.aux_info.num_variables, - "Prover is not active" - ); + assert!(self.round < self.poly.aux_info.num_variables, "Prover is not active"); // let fix_argument = start_timer!(|| "fix argument"); @@ -340,11 +303,7 @@ impl IOPProverState { if self.round == 0 { assert!(challenge.is_none(), "first round should be prover first."); } else { - assert!( - challenge.is_some(), - "verifier message is empty in round {}", - self.round - ); + assert!(challenge.is_some(), "verifier message is empty in round {}", self.round); let chal = challenge.unwrap(); self.challenges.push(chal); let r = self.challenges[self.round - 1]; @@ -409,8 +368,7 @@ impl IOPProverState { |mut acc, b| { acc.0[0] += f[b] * g[b]; acc.0[1] += f[b + 1] * g[b + 1]; - acc.0[2] += - (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]); + acc.0[2] += (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]); acc } ), @@ -501,10 +459,7 @@ impl IOPProverState { let mut prover_msgs = Vec::with_capacity(num_variables); let span = entered_span!("prove_rounds"); for _ in 0..num_variables { - let prover_msg = IOPProverState::prove_round_and_update_state_parallel( - &mut prover_state, - &challenge, - ); + let prover_msg = IOPProverState::prove_round_and_update_state_parallel(&mut prover_state, &challenge); prover_msg .evaluations @@ -553,10 +508,7 @@ impl IOPProverState { /// over {0,1}^`num_vars`. pub(crate) fn prover_init_parallel(polynomial: VirtualPolynomial) -> Self { let start = start_timer!(|| "sum check prover init"); - assert_ne!( - polynomial.aux_info.num_variables, 0, - "Attempt to prove a constant." - ); + assert_ne!(polynomial.aux_info.num_variables, 0, "Attempt to prove a constant."); let max_degree = polynomial.aux_info.max_degree; let prover_state = Self { @@ -585,13 +537,9 @@ impl IOPProverState { &mut self, challenge: &Option>, ) -> IOPProverMessage { - let start = - start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); + let start = start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); - assert!( - self.round < self.poly.aux_info.num_variables, - "Prover is not active" - ); + assert!(self.round < self.poly.aux_info.num_variables, "Prover is not active"); // let fix_argument = start_timer!(|| "fix argument"); @@ -683,8 +631,7 @@ impl IOPProverState { AdditiveArray([ f[b] * g[b], f[b + 1] * g[b + 1], - (f[b + 1] + f[b + 1] - f[b]) - * (g[b + 1] + g[b + 1] - g[b]), + (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]), ]) }) .sum::>(), diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index 89ad863b2..88635eca0 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -13,42 +13,26 @@ use crate::{ util::interpolate_uni_poly, }; -fn test_sumcheck( - nv: usize, - num_multiplicands_range: (usize, usize), - num_products: usize, -) { +fn test_sumcheck(nv: usize, num_multiplicands_range: (usize, usize), num_products: usize) { let mut rng = test_rng(); let mut transcript = Transcript::new(b"test"); - let (poly, asserted_sum) = - VirtualPolynomial::::random(nv, num_multiplicands_range, num_products, &mut rng); + let (poly, asserted_sum) = VirtualPolynomial::::random(nv, num_multiplicands_range, num_products, &mut rng); let poly_info = poly.aux_info.clone(); let (proof, _) = IOPProverState::::prove_parallel(poly.clone(), &mut transcript); let mut transcript = Transcript::new(b"test"); let subclaim = IOPVerifierState::::verify(asserted_sum, &proof, &poly_info, &mut transcript); assert!( - poly.evaluate( - &subclaim - .point - .iter() - .map(|c| c.elements) - .collect::>() - .as_ref() - ) == subclaim.expected_evaluation, + poly.evaluate(&subclaim.point.iter().map(|c| c.elements).collect::>().as_ref()) + == subclaim.expected_evaluation, "wrong subclaim" ); } -fn test_sumcheck_internal( - nv: usize, - num_multiplicands_range: (usize, usize), - num_products: usize, -) { +fn test_sumcheck_internal(nv: usize, num_multiplicands_range: (usize, usize), num_products: usize) { let mut rng = test_rng(); - let (poly, asserted_sum) = - VirtualPolynomial::::random(nv, num_multiplicands_range, num_products, &mut rng); + let (poly, asserted_sum) = VirtualPolynomial::::random(nv, num_multiplicands_range, num_products, &mut rng); let (poly_info, num_variables) = (poly.aux_info.clone(), poly.aux_info.num_variables); let mut prover_state = IOPProverState::prover_init_parallel(poly.clone()); let mut verifier_state = IOPVerifierState::verifier_init(&poly_info); @@ -59,8 +43,7 @@ fn test_sumcheck_internal( transcript.append_message(b"initializing transcript for testing"); for _ in 0..num_variables { - let prover_message = - IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); + let prover_message = IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); challenge = Some(IOPVerifierState::verify_round_and_update_state( &mut verifier_state, @@ -82,14 +65,8 @@ fn test_sumcheck_internal( }; let subclaim = IOPVerifierState::check_and_generate_subclaim(&verifier_state, &asserted_sum); assert!( - poly.evaluate( - &subclaim - .point - .iter() - .map(|c| c.elements) - .collect::>() - .as_ref() - ) == subclaim.expected_evaluation, + poly.evaluate(&subclaim.point.iter().map(|c| c.elements).collect::>().as_ref()) + == subclaim.expected_evaluation, "wrong subclaim" ); } @@ -152,11 +129,7 @@ struct DensePolynomial(Vec); impl DensePolynomial { fn rand(degree: usize, mut rng: &mut impl RngCore) -> Self { - Self( - (0..degree) - .map(|_| GoldilocksExt2::random(&mut rng)) - .collect(), - ) + Self((0..degree).map(|_| GoldilocksExt2::random(&mut rng)).collect()) } fn evaluate(&self, p: &GoldilocksExt2) -> GoldilocksExt2 { diff --git a/sumcheck/src/util.rs b/sumcheck/src/util.rs index 1098fd240..2135e40f3 100644 --- a/sumcheck/src/util.rs +++ b/sumcheck/src/util.rs @@ -103,12 +103,7 @@ pub(crate) fn extrapolate(points: &[F], weights: &[F], evals: &[F let sum_inv = sum.invert().unwrap_or(F::ZERO); (coeffs, sum_inv) }; - coeffs - .iter() - .zip(evals) - .map(|(coeff, eval)| *coeff * eval) - .sum::() - * sum_inv + coeffs.iter().zip(evals).map(|(coeff, eval)| *coeff * eval).sum::() * sum_inv } /// Interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this @@ -206,9 +201,7 @@ pub(crate) fn merge_sumcheck_polys( .iter() .enumerate() .map(|(_, prover_state)| { - if let FieldType::Ext(evaluations) = - &prover_state.poly.flattened_ml_extensions[i].evaluations - { + if let FieldType::Ext(evaluations) = &prover_state.poly.flattened_ml_extensions[i].evaluations { assert!(evaluations.len() == 1); evaluations[0] } else { diff --git a/sumcheck/src/verifier.rs b/sumcheck/src/verifier.rs index b4e119d3d..3485bca3f 100644 --- a/sumcheck/src/verifier.rs +++ b/sumcheck/src/verifier.rs @@ -70,8 +70,7 @@ impl IOPVerifierState { prover_msg: &IOPProverMessage, transcript: &mut Transcript, ) -> Challenge { - let start = - start_timer!(|| format!("sum check verify {}-th round and update state", self.round)); + let start = start_timer!(|| format!("sum check verify {}-th round and update state", self.round)); assert!( !self.finished, @@ -88,8 +87,7 @@ impl IOPVerifierState { let challenge = transcript.get_and_append_challenge(b"Internal round"); self.challenges.push(challenge); - self.polynomials_received - .push(prover_msg.evaluations.to_vec()); + self.polynomials_received.push(prover_msg.evaluations.to_vec()); if self.round == self.num_vars { // accept and close diff --git a/transcript/src/basic.rs b/transcript/src/basic.rs index 5a71602c2..eb26d7b2c 100644 --- a/transcript/src/basic.rs +++ b/transcript/src/basic.rs @@ -20,9 +20,7 @@ impl Transcript { let mut hasher = Poseidon::::new(8, 22); let label_f = E::BaseField::bytes_to_field_elements(label); hasher.update(label_f.as_slice()); - Self { - sponge_hasher: hasher, - } + Self { sponge_hasher: hasher } } } diff --git a/transcript/src/syncronized.rs b/transcript/src/syncronized.rs index 7a6e753e4..d56dc727f 100644 --- a/transcript/src/syncronized.rs +++ b/transcript/src/syncronized.rs @@ -50,21 +50,15 @@ impl TranscriptSyncronized { } pub fn append_field_element_ext(&mut self, element: &E) { - self.ef_append_tx[self.rolling_index] - .send(vec![*element]) - .unwrap(); + self.ef_append_tx[self.rolling_index].send(vec![*element]).unwrap(); } pub fn append_field_element_exts(&mut self, element: &[E]) { - self.ef_append_tx[self.rolling_index] - .send(element.to_vec()) - .unwrap(); + self.ef_append_tx[self.rolling_index].send(element.to_vec()).unwrap(); } pub fn append_field_element(&mut self, element: &E::BaseField) { - self.bf_append_tx[self.rolling_index] - .send(vec![*element]) - .unwrap(); + self.bf_append_tx[self.rolling_index].send(vec![*element]).unwrap(); } pub fn append_challenge(&mut self, _challenge: Challenge) { @@ -108,9 +102,7 @@ impl TranscriptSyncronized { } pub fn send_challenge(&self, challenge: E) { - self.challenge_tx[self.rolling_index] - .send(challenge) - .unwrap(); + self.challenge_tx[self.rolling_index].send(challenge).unwrap(); } pub fn commit_rolling(&mut self) {