Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 77 additions & 50 deletions arithmetic/src/crypto/modmul/Montgomery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,78 +3,89 @@ package crypto.modmul
import chisel3._
import chisel3.experimental.ChiselEnum
import chisel3.util.experimental.decode.TruthTable
import chisel3.util.{Counter, Mux1H}
import chisel3.util.{Cat, Counter, Mux1H}

class Montgomery(pWidth: Int = 4096, addPipe: Int) extends Module {
val p = IO(Input(UInt(pWidth.W)))
val pPrime = IO(Input(Bool()))
val a = IO(Input(UInt(pWidth.W)))
val b = IO(Input(UInt(pWidth.W)))
val b_add_p = IO(Input(UInt((pWidth + 1).W))) // b + p
val indexCountBit = 16
val input_width = IO(Input(UInt(indexCountBit.W)))
val valid = IO(Input(Bool())) // input valid
val out = IO(Output(UInt(pWidth.W)))
val out_valid = IO(Output(Bool())) // output valid

val u = Reg(Bool())
val i = Reg(UInt((pWidth).W))
val b_add_p = Reg(UInt((pWidth + 1).W))
val invP = Reg(UInt((pWidth).W))
val negP = Reg(UInt((pWidth + 2).W))
val loop_u = Reg(Bool())
val index = Reg(UInt(indexCountBit.W))
val nextT = Reg(UInt((pWidth + 2).W))

// multicycle prefixadder
val adder = Module(new DummyAdd(pWidth + 2, addPipe))
val add_stable = RegInit(0.U((pWidth + 2).W))
// Control Path
object StateType extends ChiselEnum {
val s0 = Value("b0000001".U) // nextT = 0, u = a(0)b(0)pPrime
val s0 = Value("b0000001".U) // nextT = 0, loop_u = a(0)b(0)pPrime, b_add_p = b + p
// loop
val s1 = Value("b0000010".U) // nextT + b
val s2 = Value("b0000100".U) // nextT + p
val s3 = Value("b0001000".U) // nextT + b_add_p
// loop done
val s4 = Value("b0010000".U) // i << 1, u = (nextT(0) + a(i)b(0))pPrime, nextT / 2
val s5 = Value("b0100000".U) // if-then
val s4 = Value("b0010000".U) // index += 1, loop_u = (nextT(0) + a(index)b(0))pPrime, nextT / 2
val s5 = Value("b0100000".U) // nextT - p
val s6 = Value("b1000000".U) // done
val s7 = Value("b10000000".U) // nextT + 0
val s8 = Value("b100000000".U) // calculate ~p
val s9 = Value("b1000000000".U) // calculate ~p + 1
}

val state = RegInit(StateType.s0)
val isAdd = (state.asUInt & "b10101110".U).orR
val isAdd = (state.asUInt & "b1010101111".U).orR
adder.valid := isAdd
val addDoneNext = RegInit(false.B)
addDoneNext := addDone
lazy val addDone = if (addPipe != 0) Counter(isAdd && (~addDoneNext), addPipe + 1)._2 else true.B
val addSign = ((add_stable >> 1) < p.asUInt)
lazy val addDone = if (addPipe != 0) Counter(valid && isAdd && (~addDoneNext), addPipe + 1)._2 else true.B
val a_i = Reg(Bool())
val iBreak = (index.asUInt >= input_width.asUInt)
state := chisel3.util.experimental.decode
.decoder(
state.asUInt() ## addDoneNext ## valid ## i.head(1) ## addSign ## u ## a_i, {
state.asUInt() ## addDoneNext ## valid ## iBreak ## loop_u ## a_i, {
val Y = "1"
val N = "0"
val DC = "?"
def to(
stateI: String,
addDone: String = DC,
valid: String = DC,
iHead: String = DC,
addSign: String = DC,
u: String = DC,
iBreak: String = DC,
loop_u: String = DC,
a_i: String = DC
)(stateO: String
) = s"$stateI$addDone$valid$iHead$addSign$u$a_i->$stateO"
val s0 = "00000001"
val s1 = "00000010"
val s2 = "00000100"
val s3 = "00001000"
val s4 = "00010000"
val s5 = "00100000"
val s6 = "01000000"
val s7 = "10000000"
) = s"$stateI$addDone$valid$iBreak$loop_u$a_i->$stateO"
val s0 = "0000000001"
val s1 = "0000000010"
val s2 = "0000000100"
val s3 = "0000001000"
val s4 = "0000010000"
val s5 = "0000100000"
val s6 = "0001000000"
val s7 = "0010000000"
val s8 = "0100000000"
val s9 = "1000000000"
TruthTable.fromString(
Seq(
to(s0, valid = N)(s0),
to(s0, valid = Y, a_i = Y, u = N)(s1),
to(s0, valid = Y, a_i = N, u = Y)(s2),
to(s0, valid = Y, a_i = Y, u = Y)(s3),
to(s0, valid = Y, a_i = N, u = N)(s7),
to(s0, valid = Y, addDone = N)(s0),
to(s0, valid = Y, addDone = Y)(s8),
to(s8)(s9),
to(s9, addDone = N)(s9),
to(s9, addDone = Y, a_i = Y, loop_u = N)(s1),
to(s9, addDone = Y, a_i = N, loop_u = Y)(s2),
to(s9, addDone = Y, a_i = Y, loop_u = Y)(s3),
to(s9, addDone = Y, a_i = N, loop_u = N)(s7),
to(s1, addDone = Y)(s4),
to(s1, addDone = N)(s1),
to(s2, addDone = Y)(s4),
Expand All @@ -83,42 +94,45 @@ class Montgomery(pWidth: Int = 4096, addPipe: Int) extends Module {
to(s3, addDone = N)(s3),
to(s7, addDone = Y)(s4),
to(s7, addDone = N)(s7),
to(s4, iHead = Y, addSign = N)(s5),
to(s4, iHead = Y, addSign = Y)(s6),
to(s4, iHead = N, a_i = Y, u = N)(s1),
to(s4, iHead = N, a_i = N, u = Y)(s2),
to(s4, iHead = N, a_i = Y, u = Y)(s3),
to(s4, iHead = N, a_i = N, u = N)(s7),
to(s4, iBreak = Y)(s5),
to(s4, iBreak = N, a_i = Y, loop_u = N)(s1),
to(s4, iBreak = N, a_i = N, loop_u = Y)(s2),
to(s4, iBreak = N, a_i = Y, loop_u = Y)(s3),
to(s4, iBreak = N, a_i = N, loop_u = N)(s7),
to(s5, addDone = Y)(s6),
to(s5, addDone = N)(s5),
"????????"
to(s6, valid = N)(s0),
to(s6, valid = Y)(s6),
"??????????"
).mkString("\n")
)
}
)
.asTypeOf(StateType.Type())

i := Mux1H(
index := Mux1H(
Map(
state.asUInt()(0) -> 1.U,
state.asUInt()(4) -> i.rotateLeft(1),
(state.asUInt & "b11101110".U).orR -> i
state.asUInt()(0) -> 0.U,
state.asUInt()(4) -> (index + 1.U),
(state.asUInt & "b1111101110".U).orR -> index
)
)

u := Mux1H(
b_add_p := Mux(addDone & state.asUInt()(0), debounceAdd, b_add_p)

loop_u := Mux1H(
Map(
state.asUInt()(0) -> (a(0).asUInt & b(0).asUInt & pPrime.asUInt),
(state.asUInt & "b10001110".U).orR -> ((add_stable(1) + (((a & (i.rotateLeft(1))).orR) & b(0))) & pPrime.asUInt),
(state.asUInt & "b01110000".U).orR -> u
(state.asUInt & "b0010001110".U).orR -> ((add_stable(1) + (a(index + 1.U) & b(0))) & pPrime.asUInt),
(state.asUInt & "b1101110000".U).orR -> loop_u
)
)

a_i := Mux1H(
Map(
state.asUInt()(0) -> a(0),
(state.asUInt & "b10001110".U).orR -> (a & (i.rotateLeft(1))).orR,
(state.asUInt & "b01110000".U).orR -> a_i
(state.asUInt & "b0010001110".U).orR -> a(index + 1.U),
(state.asUInt & "b1101110000".U).orR -> a_i
)
)

Expand All @@ -127,24 +141,37 @@ class Montgomery(pWidth: Int = 4096, addPipe: Int) extends Module {
state.asUInt()(0) -> 0.U,
state.asUInt()(4) -> (add_stable >> 1),
state.asUInt()(5) -> add_stable,
(state.asUInt & "b11001110".U).orR -> nextT
(state.asUInt & "b1111001110".U).orR -> nextT
)
)
val TWithoutSubControl = Reg(UInt(1.W))
val TWithoutSub = Reg(UInt((pWidth + 2).W))
TWithoutSubControl := Mux(state.asUInt()(5), 0.U, 1.U)
TWithoutSub := Mux(state.asUInt()(5) && (TWithoutSubControl === 1.U), nextT, TWithoutSub)
invP := Mux(state.asUInt()(8), ~p, invP)
negP := Mux(state.asUInt()(9), add_stable, negP)

adder.a := nextT
adder.a := Mux1H(
Map(
state.asUInt()(0) -> p,
state.asUInt()(9) -> 1.U,
(state.asUInt & "b0111111110".U).orR -> nextT
)
)
adder.b := Mux1H(
Map(
state.asUInt()(1) -> b,
(state.asUInt & "b0100000011".U).orR -> b,
state.asUInt()(9) -> Cat(3.U, invP),
state.asUInt()(2) -> p,
state.asUInt()(3) -> b_add_p,
state.asUInt()(7) -> 0.U,
state.asUInt()(5) -> -p
state.asUInt()(5) -> negP
)
)
val debounceAdd = Mux(addDone, adder.z, 0.U)
when(addDone)(add_stable := debounceAdd)
lazy val debounceAdd = Mux(addDone, adder.z, 0.U)
add_stable := Mux(addDone, debounceAdd, add_stable)

// output
out := nextT
out := Mux(nextT.head(1).asBool, TWithoutSub, nextT)
out_valid := state.asUInt()(6)
}
97 changes: 67 additions & 30 deletions arithmetic/tests/src/crypto/MontgomerySpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,77 @@ import utest._
object MontgomerySpec extends TestSuite with ChiselUtestTester {
def tests: Tests = Tests {
test("Montgomery should pass") {
val u = new Utility()
val length = scala.util.Random.nextInt(20) + 10 // (10, 30)
var p = u.randPrime(length)

var width = p.toBinaryString.length
var R_inv = u.modinv((scala.math.pow(2, width)).toInt, p)
var addPipe = scala.util.Random.nextInt(10) + 1
var a = scala.util.Random.nextInt(p)
var b = scala.util.Random.nextInt(p)
val res = BigInt(a) * BigInt(b) * BigInt(R_inv) % BigInt(p)

testCircuit(new Montgomery(width, addPipe), Seq(chiseltest.internal.NoThreadingAnnotation, chiseltest.simulator.WriteVcdAnnotation)){dut: Montgomery =>
dut.clock.setTimeout(0)
dut.p.poke(p.U)
dut.pPrime.poke(true.B)
dut.a.poke(a.U)
dut.b.poke(b.U)
dut.b_add_p.poke((p+b).U)
dut.clock.step()
dut.clock.step()
// delay two cycles then set valid = true
dut.valid.poke(true.B)
dut.clock.step()
var flag = false
for(a <- 1 to 1000) {
var specfic_test = false
if (specfic_test) {
var a = BigInt("643")
var b = BigInt("3249")
var p = BigInt("3323")
var width = 64
var R = BigInt("4096")
var R_inv = BigInt("374")
var addPipe = 10
var res = (a) * (b) * (R_inv) % (p)
var random_width = 0
testCircuit(new Montgomery(64, addPipe), Seq(chiseltest.internal.NoThreadingAnnotation, chiseltest.simulator.WriteVcdAnnotation)){dut: Montgomery =>
dut.clock.setTimeout(0)
dut.p.poke(p.U)
dut.pPrime.poke(true.B)
dut.a.poke(a.U)
dut.b.poke(b.U)
dut.input_width.poke((11).U)
dut.clock.step()
dut.clock.step()
// delay two cycles then set valid = true
dut.valid.poke(true.B)
dut.clock.step()
var flag = false
for(a <- 1 to 1000) {
dut.clock.step()
if(dut.out_valid.peek().litValue == 1) {
flag = true
// need to wait a cycle because there is s5 -> s6 or s4 -> s6 in the state machine
dut.clock.step()
utest.assert(dut.out.peek().litValue == res)
}
}
utest.assert(flag)
}
} else {
var u = new Utility()
var length = scala.util.Random.nextInt(20) + 10 // (10, 30)
var p = u.randPrime(length)
var a = scala.util.Random.nextInt(p)
var b = scala.util.Random.nextInt(p)
var width = p.toBinaryString.length
var R = (scala.math.pow(2, width)).toInt
var R_inv = u.modinv(R, p)
var addPipe = scala.util.Random.nextInt(10) + 1
var res = BigInt(a) * BigInt(b) * BigInt(R_inv) % BigInt(p)
var random_width = scala.util.Random.nextInt(20) + 10 // to test if bigger length hardware can support smaller length number
testCircuit(new Montgomery(width, addPipe), Seq(chiseltest.internal.NoThreadingAnnotation, chiseltest.simulator.WriteVcdAnnotation)){dut: Montgomery =>
dut.clock.setTimeout(0)
dut.p.poke(p.U)
dut.pPrime.poke(true.B)
dut.a.poke(a.U)
dut.b.poke(b.U)
dut.input_width.poke((width-1).U)
dut.clock.step()
dut.clock.step()
// delay two cycles then set valid = true
dut.valid.poke(true.B)
dut.clock.step()
if(dut.out_valid.peek().litValue == 1) {
flag = true
// need to wait a cycle because there is s5 -> s6 or s4 -> s6 in the state machine
var flag = false
for(a <- 1 to 1000) {
dut.clock.step()
utest.assert(dut.out.peek().litValue == res)
if(dut.out_valid.peek().litValue == 1) {
flag = true
// need to wait a cycle because there is s5 -> s6 or s4 -> s6 in the state machine
dut.clock.step()
utest.assert(dut.out.peek().litValue == res)
}
}
utest.assert(flag)
}
utest.assert(flag)
}
}
}
Expand Down