diff --git a/arithmetic/src/crypto/modmul/Montgomery.scala b/arithmetic/src/crypto/modmul/Montgomery.scala index beb3794..354a283 100644 --- a/arithmetic/src/crypto/modmul/Montgomery.scala +++ b/arithmetic/src/crypto/modmul/Montgomery.scala @@ -3,20 +3,24 @@ package crypto.modmul import chisel3._ import chisel3.experimental.ChiselEnum import chisel3.util.experimental.decode.TruthTable -import chisel3.util.{Counter, Mux1H} +import chisel3.util.{Cat, Counter, Mux1H} class Montgomery(pWidth: Int = 4096, addPipe: Int) extends Module { val p = IO(Input(UInt(pWidth.W))) val pPrime = IO(Input(Bool())) val a = IO(Input(UInt(pWidth.W))) val b = IO(Input(UInt(pWidth.W))) - val b_add_p = IO(Input(UInt((pWidth + 1).W))) // b + p + val indexCountBit = 16 + val input_width = IO(Input(UInt(indexCountBit.W))) val valid = IO(Input(Bool())) // input valid val out = IO(Output(UInt(pWidth.W))) val out_valid = IO(Output(Bool())) // output valid - val u = Reg(Bool()) - val i = Reg(UInt((pWidth).W)) + val b_add_p = Reg(UInt((pWidth + 1).W)) + val invP = Reg(UInt((pWidth).W)) + val negP = Reg(UInt((pWidth + 2).W)) + val loop_u = Reg(Bool()) + val index = Reg(UInt(indexCountBit.W)) val nextT = Reg(UInt((pWidth + 2).W)) // multicycle prefixadder @@ -24,29 +28,31 @@ class Montgomery(pWidth: Int = 4096, addPipe: Int) extends Module { val add_stable = RegInit(0.U((pWidth + 2).W)) // Control Path object StateType extends ChiselEnum { - val s0 = Value("b0000001".U) // nextT = 0, u = a(0)b(0)pPrime + val s0 = Value("b0000001".U) // nextT = 0, loop_u = a(0)b(0)pPrime, b_add_p = b + p // loop val s1 = Value("b0000010".U) // nextT + b val s2 = Value("b0000100".U) // nextT + p val s3 = Value("b0001000".U) // nextT + b_add_p // loop done - val s4 = Value("b0010000".U) // i << 1, u = (nextT(0) + a(i)b(0))pPrime, nextT / 2 - val s5 = Value("b0100000".U) // if-then + val s4 = Value("b0010000".U) // index += 1, loop_u = (nextT(0) + a(index)b(0))pPrime, nextT / 2 + val s5 = Value("b0100000".U) // nextT - p val s6 = Value("b1000000".U) // done val s7 = Value("b10000000".U) // nextT + 0 + val s8 = Value("b100000000".U) // calculate ~p + val s9 = Value("b1000000000".U) // calculate ~p + 1 } val state = RegInit(StateType.s0) - val isAdd = (state.asUInt & "b10101110".U).orR + val isAdd = (state.asUInt & "b1010101111".U).orR adder.valid := isAdd val addDoneNext = RegInit(false.B) addDoneNext := addDone - lazy val addDone = if (addPipe != 0) Counter(isAdd && (~addDoneNext), addPipe + 1)._2 else true.B - val addSign = ((add_stable >> 1) < p.asUInt) + lazy val addDone = if (addPipe != 0) Counter(valid && isAdd && (~addDoneNext), addPipe + 1)._2 else true.B val a_i = Reg(Bool()) + val iBreak = (index.asUInt >= input_width.asUInt) state := chisel3.util.experimental.decode .decoder( - state.asUInt() ## addDoneNext ## valid ## i.head(1) ## addSign ## u ## a_i, { + state.asUInt() ## addDoneNext ## valid ## iBreak ## loop_u ## a_i, { val Y = "1" val N = "0" val DC = "?" @@ -54,27 +60,32 @@ class Montgomery(pWidth: Int = 4096, addPipe: Int) extends Module { stateI: String, addDone: String = DC, valid: String = DC, - iHead: String = DC, - addSign: String = DC, - u: String = DC, + iBreak: String = DC, + loop_u: String = DC, a_i: String = DC )(stateO: String - ) = s"$stateI$addDone$valid$iHead$addSign$u$a_i->$stateO" - val s0 = "00000001" - val s1 = "00000010" - val s2 = "00000100" - val s3 = "00001000" - val s4 = "00010000" - val s5 = "00100000" - val s6 = "01000000" - val s7 = "10000000" + ) = s"$stateI$addDone$valid$iBreak$loop_u$a_i->$stateO" + val s0 = "0000000001" + val s1 = "0000000010" + val s2 = "0000000100" + val s3 = "0000001000" + val s4 = "0000010000" + val s5 = "0000100000" + val s6 = "0001000000" + val s7 = "0010000000" + val s8 = "0100000000" + val s9 = "1000000000" TruthTable.fromString( Seq( to(s0, valid = N)(s0), - to(s0, valid = Y, a_i = Y, u = N)(s1), - to(s0, valid = Y, a_i = N, u = Y)(s2), - to(s0, valid = Y, a_i = Y, u = Y)(s3), - to(s0, valid = Y, a_i = N, u = N)(s7), + to(s0, valid = Y, addDone = N)(s0), + to(s0, valid = Y, addDone = Y)(s8), + to(s8)(s9), + to(s9, addDone = N)(s9), + to(s9, addDone = Y, a_i = Y, loop_u = N)(s1), + to(s9, addDone = Y, a_i = N, loop_u = Y)(s2), + to(s9, addDone = Y, a_i = Y, loop_u = Y)(s3), + to(s9, addDone = Y, a_i = N, loop_u = N)(s7), to(s1, addDone = Y)(s4), to(s1, addDone = N)(s1), to(s2, addDone = Y)(s4), @@ -83,42 +94,45 @@ class Montgomery(pWidth: Int = 4096, addPipe: Int) extends Module { to(s3, addDone = N)(s3), to(s7, addDone = Y)(s4), to(s7, addDone = N)(s7), - to(s4, iHead = Y, addSign = N)(s5), - to(s4, iHead = Y, addSign = Y)(s6), - to(s4, iHead = N, a_i = Y, u = N)(s1), - to(s4, iHead = N, a_i = N, u = Y)(s2), - to(s4, iHead = N, a_i = Y, u = Y)(s3), - to(s4, iHead = N, a_i = N, u = N)(s7), + to(s4, iBreak = Y)(s5), + to(s4, iBreak = N, a_i = Y, loop_u = N)(s1), + to(s4, iBreak = N, a_i = N, loop_u = Y)(s2), + to(s4, iBreak = N, a_i = Y, loop_u = Y)(s3), + to(s4, iBreak = N, a_i = N, loop_u = N)(s7), to(s5, addDone = Y)(s6), to(s5, addDone = N)(s5), - "????????" + to(s6, valid = N)(s0), + to(s6, valid = Y)(s6), + "??????????" ).mkString("\n") ) } ) .asTypeOf(StateType.Type()) - i := Mux1H( + index := Mux1H( Map( - state.asUInt()(0) -> 1.U, - state.asUInt()(4) -> i.rotateLeft(1), - (state.asUInt & "b11101110".U).orR -> i + state.asUInt()(0) -> 0.U, + state.asUInt()(4) -> (index + 1.U), + (state.asUInt & "b1111101110".U).orR -> index ) ) - u := Mux1H( + b_add_p := Mux(addDone & state.asUInt()(0), debounceAdd, b_add_p) + + loop_u := Mux1H( Map( state.asUInt()(0) -> (a(0).asUInt & b(0).asUInt & pPrime.asUInt), - (state.asUInt & "b10001110".U).orR -> ((add_stable(1) + (((a & (i.rotateLeft(1))).orR) & b(0))) & pPrime.asUInt), - (state.asUInt & "b01110000".U).orR -> u + (state.asUInt & "b0010001110".U).orR -> ((add_stable(1) + (a(index + 1.U) & b(0))) & pPrime.asUInt), + (state.asUInt & "b1101110000".U).orR -> loop_u ) ) a_i := Mux1H( Map( state.asUInt()(0) -> a(0), - (state.asUInt & "b10001110".U).orR -> (a & (i.rotateLeft(1))).orR, - (state.asUInt & "b01110000".U).orR -> a_i + (state.asUInt & "b0010001110".U).orR -> a(index + 1.U), + (state.asUInt & "b1101110000".U).orR -> a_i ) ) @@ -127,24 +141,37 @@ class Montgomery(pWidth: Int = 4096, addPipe: Int) extends Module { state.asUInt()(0) -> 0.U, state.asUInt()(4) -> (add_stable >> 1), state.asUInt()(5) -> add_stable, - (state.asUInt & "b11001110".U).orR -> nextT + (state.asUInt & "b1111001110".U).orR -> nextT ) ) + val TWithoutSubControl = Reg(UInt(1.W)) + val TWithoutSub = Reg(UInt((pWidth + 2).W)) + TWithoutSubControl := Mux(state.asUInt()(5), 0.U, 1.U) + TWithoutSub := Mux(state.asUInt()(5) && (TWithoutSubControl === 1.U), nextT, TWithoutSub) + invP := Mux(state.asUInt()(8), ~p, invP) + negP := Mux(state.asUInt()(9), add_stable, negP) - adder.a := nextT + adder.a := Mux1H( + Map( + state.asUInt()(0) -> p, + state.asUInt()(9) -> 1.U, + (state.asUInt & "b0111111110".U).orR -> nextT + ) + ) adder.b := Mux1H( Map( - state.asUInt()(1) -> b, + (state.asUInt & "b0100000011".U).orR -> b, + state.asUInt()(9) -> Cat(3.U, invP), state.asUInt()(2) -> p, state.asUInt()(3) -> b_add_p, state.asUInt()(7) -> 0.U, - state.asUInt()(5) -> -p + state.asUInt()(5) -> negP ) ) - val debounceAdd = Mux(addDone, adder.z, 0.U) - when(addDone)(add_stable := debounceAdd) + lazy val debounceAdd = Mux(addDone, adder.z, 0.U) + add_stable := Mux(addDone, debounceAdd, add_stable) // output - out := nextT + out := Mux(nextT.head(1).asBool, TWithoutSub, nextT) out_valid := state.asUInt()(6) } diff --git a/arithmetic/tests/src/crypto/MontgomerySpec.scala b/arithmetic/tests/src/crypto/MontgomerySpec.scala index c28632b..e95e78e 100644 --- a/arithmetic/tests/src/crypto/MontgomerySpec.scala +++ b/arithmetic/tests/src/crypto/MontgomerySpec.scala @@ -7,40 +7,77 @@ import utest._ object MontgomerySpec extends TestSuite with ChiselUtestTester { def tests: Tests = Tests { test("Montgomery should pass") { - val u = new Utility() - val length = scala.util.Random.nextInt(20) + 10 // (10, 30) - var p = u.randPrime(length) - - var width = p.toBinaryString.length - var R_inv = u.modinv((scala.math.pow(2, width)).toInt, p) - var addPipe = scala.util.Random.nextInt(10) + 1 - var a = scala.util.Random.nextInt(p) - var b = scala.util.Random.nextInt(p) - val res = BigInt(a) * BigInt(b) * BigInt(R_inv) % BigInt(p) - - testCircuit(new Montgomery(width, addPipe), Seq(chiseltest.internal.NoThreadingAnnotation, chiseltest.simulator.WriteVcdAnnotation)){dut: Montgomery => - dut.clock.setTimeout(0) - dut.p.poke(p.U) - dut.pPrime.poke(true.B) - dut.a.poke(a.U) - dut.b.poke(b.U) - dut.b_add_p.poke((p+b).U) - dut.clock.step() - dut.clock.step() - // delay two cycles then set valid = true - dut.valid.poke(true.B) - dut.clock.step() - var flag = false - for(a <- 1 to 1000) { + var specfic_test = false + if (specfic_test) { + var a = BigInt("643") + var b = BigInt("3249") + var p = BigInt("3323") + var width = 64 + var R = BigInt("4096") + var R_inv = BigInt("374") + var addPipe = 10 + var res = (a) * (b) * (R_inv) % (p) + var random_width = 0 + testCircuit(new Montgomery(64, addPipe), Seq(chiseltest.internal.NoThreadingAnnotation, chiseltest.simulator.WriteVcdAnnotation)){dut: Montgomery => + dut.clock.setTimeout(0) + dut.p.poke(p.U) + dut.pPrime.poke(true.B) + dut.a.poke(a.U) + dut.b.poke(b.U) + dut.input_width.poke((11).U) + dut.clock.step() + dut.clock.step() + // delay two cycles then set valid = true + dut.valid.poke(true.B) + dut.clock.step() + var flag = false + for(a <- 1 to 1000) { + dut.clock.step() + if(dut.out_valid.peek().litValue == 1) { + flag = true + // need to wait a cycle because there is s5 -> s6 or s4 -> s6 in the state machine + dut.clock.step() + utest.assert(dut.out.peek().litValue == res) + } + } + utest.assert(flag) + } + } else { + var u = new Utility() + var length = scala.util.Random.nextInt(20) + 10 // (10, 30) + var p = u.randPrime(length) + var a = scala.util.Random.nextInt(p) + var b = scala.util.Random.nextInt(p) + var width = p.toBinaryString.length + var R = (scala.math.pow(2, width)).toInt + var R_inv = u.modinv(R, p) + var addPipe = scala.util.Random.nextInt(10) + 1 + var res = BigInt(a) * BigInt(b) * BigInt(R_inv) % BigInt(p) + var random_width = scala.util.Random.nextInt(20) + 10 // to test if bigger length hardware can support smaller length number + testCircuit(new Montgomery(width, addPipe), Seq(chiseltest.internal.NoThreadingAnnotation, chiseltest.simulator.WriteVcdAnnotation)){dut: Montgomery => + dut.clock.setTimeout(0) + dut.p.poke(p.U) + dut.pPrime.poke(true.B) + dut.a.poke(a.U) + dut.b.poke(b.U) + dut.input_width.poke((width-1).U) + dut.clock.step() + dut.clock.step() + // delay two cycles then set valid = true + dut.valid.poke(true.B) dut.clock.step() - if(dut.out_valid.peek().litValue == 1) { - flag = true - // need to wait a cycle because there is s5 -> s6 or s4 -> s6 in the state machine + var flag = false + for(a <- 1 to 1000) { dut.clock.step() - utest.assert(dut.out.peek().litValue == res) + if(dut.out_valid.peek().litValue == 1) { + flag = true + // need to wait a cycle because there is s5 -> s6 or s4 -> s6 in the state machine + dut.clock.step() + utest.assert(dut.out.peek().litValue == res) + } } + utest.assert(flag) } - utest.assert(flag) } } }