From c83d329d43b77799a36dbe6935113e477b99636a Mon Sep 17 00:00:00 2001 From: Jiuyang Liu Date: Fri, 16 Apr 2021 08:15:42 +0000 Subject: [PATCH 01/31] Implement SRT class 0. Start to implement SRT table generator. 1. Can draw P-D and Robertson graph now. --- arithmetic/src/division/srt/SRT.scala | 99 +++++++++++++++++++++++++++ build.sc | 8 ++- 2 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 arithmetic/src/division/srt/SRT.scala diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala new file mode 100644 index 0000000..d336c38 --- /dev/null +++ b/arithmetic/src/division/srt/SRT.scala @@ -0,0 +1,99 @@ +package division.srt + +import breeze.linalg._ +import breeze.plot._ + +/** Base SRT class. + * + * @param radix is the radix of SRT. + * @note 5.2 + * @param quotientSet is the min and max of quotient-digit set + * @note 5.6 + * @param ulpN ulp is the unit in the last position, defined by `pow(r, -uplN)` + * @note 5.2 + * @param normD normalized divider range from `pow(2, _._1)` to `pow(2, _._2)` + * @param normX normalized dividend range from `pow(2, _._1)` to `pow(2, _._2)` + */ +case class SRT( + radix: Int, + quotientSet: (Int, Int), + ulpN: Int = 0, + normD: (Int, Int) = (-1, 0), + normX: (Int, Int) = (-1, 0)) { + val a: Int = max(math.abs(quotientSet._1), math.abs(quotientSet._2)) + require(quotientSet._1 < 0, quotientSet._2 > 0) + // @note 5.7 + require(a >= (radix + 1) / 2) + // @note 5.3 + require(normD._1 < normD._2) + require(normX._1 < normX._2) + val xMin: Double = math.pow(2, normX._1) + val xMax: Double = math.pow(2, normX._2) + val dMin: Double = math.pow(2, normD._1) + val dMax: Double = math.pow(2, normD._2) + + /** redundancy factor + * @note 5.8 + */ + val rou: Double = a.toDouble / (radix - 1) + + override def toString: String = s"SRT$radix with quotient set: ${quotientSet._1 to quotientSet._2}" + // @note 5.8s + assert((rou > 1.0 / 2) && (rou <= 1)) + + /** P-D Diagram + * @note Graph 5.17(b) + */ + def pdDiagram(): Unit = { + val fig: Figure = Figure() + val p: Plot = fig.subplot(0) + val x: DenseVector[Double] = linspace(dMin, dMax) + + val (uk, lk) = (quotientSet._1 to quotientSet._2).map { k: Int => + (plot(x, x * uRate(k), name = s"U($k)"), plot(x, x * lRate(k), name = s"L($k)")) + }.unzip + + p ++= uk ++= lk + + p.xlabel = "d" + p.ylabel = "rω[j]" + val scale = 1.1 + p.xlim(0, xMax * scale) + p.ylim((quotientSet._1 - rou) * xMax * scale, (quotientSet._2 + rou) * xMax * scale) + p.title = s"P-D Graph of $this" + p.legend = true + fig.saveas("pd.pdf") + } + + /** slope factor of U_k + * @note 5.56 + */ + def uRate(k: Int): Double = k + rou + + /** slope factor of L_k + * @note 5.56 + */ + def lRate(k: Int): Double = k - rou + + /** Robertson Diagram + * @note Graph 5.17(a) + */ + def robertsonDiagram(d: Double): Unit = { + require(d > dMin && d < dMax) + val fig: Figure = Figure() + val p: Plot = fig.subplot(0) + + p ++= (quotientSet._1 to quotientSet._2).map { k: Int => + val xrange: DenseVector[Double] = linspace((k - rou) * d, (k + rou) * d) + plot(xrange, xrange - k * d, name = s"$k") + } + + p.xlabel = "rω[j]" + p.ylabel = "ω[j+1]" + p.xlim(-radix * rou * dMax, radix * rou * dMax) + p.ylim(-rou * d, rou * d) + p.title = s"Robertson Graph of $this divisor: $d" + p.legend = true + fig.saveas("robertson.pdf") + } +} diff --git a/build.sc b/build.sc index 202124d..529f9a9 100644 --- a/build.sc +++ b/build.sc @@ -15,7 +15,10 @@ object v { val utest = ivy"com.lihaoyi::utest:latest.integration" val upickle = ivy"com.lihaoyi::upickle:latest.integration" val osLib = ivy"com.lihaoyi::os-lib:latest.integration" -// val prime = ivy"org.apache.commons:commons-math3:3.6.1" + val breeze = ivy"com.github.ktakagaki.breeze::breeze:2.0" + val breezeNatives = ivy"com.github.ktakagaki.breeze::breeze-natives:2.0" + val breezeViz = ivy"org.scalanlp::breeze-viz:2.0" + // val prime = ivy"org.apache.commons:commons-math3:3.6.1" } object arithmetic extends arithmetic @@ -38,6 +41,9 @@ class arithmetic extends ScalaModule with ScalafmtModule with PublishModule { m v.chiseltest, v.upickle, v.osLib, + v.breeze, + v.breezeViz, + v.breezeNatives ) object tests extends Tests with Utest { From d570098a4a9254db82220a07694f16e06fe4de6a Mon Sep 17 00:00:00 2001 From: Jiuyang Liu Date: Fri, 16 Apr 2021 12:18:56 +0000 Subject: [PATCH 02/31] Use spire to increase precision to arbitrary precision. --- arithmetic/src/division/srt/SRT.scala | 61 ++++++++++++++------------- build.sc | 4 +- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index d336c38..59a6f0d 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -2,6 +2,8 @@ package division.srt import breeze.linalg._ import breeze.plot._ +import spire.implicits._ +import spire.math._ /** Base SRT class. * @@ -11,35 +13,34 @@ import breeze.plot._ * @note 5.6 * @param ulpN ulp is the unit in the last position, defined by `pow(r, -uplN)` * @note 5.2 - * @param normD normalized divider range from `pow(2, _._1)` to `pow(2, _._2)` - * @param normX normalized dividend range from `pow(2, _._1)` to `pow(2, _._2)` + * @param normD normalized divider range + * @param normX normalized dividend range */ case class SRT( - radix: Int, - quotientSet: (Int, Int), - ulpN: Int = 0, - normD: (Int, Int) = (-1, 0), - normX: (Int, Int) = (-1, 0)) { - val a: Int = max(math.abs(quotientSet._1), math.abs(quotientSet._2)) - require(quotientSet._1 < 0, quotientSet._2 > 0) + radix: Algebraic, + quotientSet: (Algebraic, Algebraic), + ulpN: Algebraic = 0, + normD: (Algebraic, Algebraic) = (-1, 0), + normX: (Algebraic, Algebraic) = (-1, 0)) { + val a: Algebraic = (-quotientSet._1).max(quotientSet._2) // @note 5.7 - require(a >= (radix + 1) / 2) + require(a >= radix / 2) // @note 5.3 require(normD._1 < normD._2) require(normX._1 < normX._2) - val xMin: Double = math.pow(2, normX._1) - val xMax: Double = math.pow(2, normX._2) - val dMin: Double = math.pow(2, normD._1) - val dMax: Double = math.pow(2, normD._2) + val xMin: Algebraic = Algebraic(2).pow(normX._1.toInt) + val xMax: Algebraic = Algebraic(2).pow(normX._2.toInt) + val dMin: Algebraic = Algebraic(2).pow(normD._1.toInt) + val dMax: Algebraic = Algebraic(2).pow(normD._2.toInt) /** redundancy factor * @note 5.8 */ - val rou: Double = a.toDouble / (radix - 1) + val rou: Algebraic = a / (radix - 1) - override def toString: String = s"SRT$radix with quotient set: ${quotientSet._1 to quotientSet._2}" + override def toString: String = s"SRT$radix with quotient set: from ${-quotientSet._1} to ${quotientSet._2}" // @note 5.8s - assert((rou > 1.0 / 2) && (rou <= 1)) + assert((rou > 1 / 2) && (rou <= 1)) /** P-D Diagram * @note Graph 5.17(b) @@ -47,10 +48,10 @@ case class SRT( def pdDiagram(): Unit = { val fig: Figure = Figure() val p: Plot = fig.subplot(0) - val x: DenseVector[Double] = linspace(dMin, dMax) + val x: DenseVector[Double] = linspace(dMin.toDouble, dMax.toDouble) - val (uk, lk) = (quotientSet._1 to quotientSet._2).map { k: Int => - (plot(x, x * uRate(k), name = s"U($k)"), plot(x, x * lRate(k), name = s"L($k)")) + val (uk, lk) = (quotientSet._1.toBigInt to quotientSet._2.toBigInt).map { k: BigInt => + (plot(x, x * uRate(k.toInt).toDouble, name = s"U($k)"), plot(x, x * lRate(k.toInt).toDouble, name = s"L($k)")) }.unzip p ++= uk ++= lk @@ -58,8 +59,8 @@ case class SRT( p.xlabel = "d" p.ylabel = "rω[j]" val scale = 1.1 - p.xlim(0, xMax * scale) - p.ylim((quotientSet._1 - rou) * xMax * scale, (quotientSet._2 + rou) * xMax * scale) + p.xlim(0, (xMax * scale).toDouble) + p.ylim(((quotientSet._1 - rou) * xMax * scale).toDouble, ((quotientSet._2 + rou) * xMax * scale).toDouble) p.title = s"P-D Graph of $this" p.legend = true fig.saveas("pd.pdf") @@ -68,30 +69,30 @@ case class SRT( /** slope factor of U_k * @note 5.56 */ - def uRate(k: Int): Double = k + rou + def uRate(k: Algebraic): Algebraic = k + rou /** slope factor of L_k * @note 5.56 */ - def lRate(k: Int): Double = k - rou + def lRate(k: Algebraic): Algebraic = k - rou /** Robertson Diagram * @note Graph 5.17(a) */ - def robertsonDiagram(d: Double): Unit = { + def robertsonDiagram(d: Algebraic): Unit = { require(d > dMin && d < dMax) val fig: Figure = Figure() val p: Plot = fig.subplot(0) - p ++= (quotientSet._1 to quotientSet._2).map { k: Int => - val xrange: DenseVector[Double] = linspace((k - rou) * d, (k + rou) * d) - plot(xrange, xrange - k * d, name = s"$k") + p ++= (quotientSet._1.toInt to quotientSet._2.toInt).map { k: Int => + val xrange: DenseVector[Double] = linspace(((Algebraic(k) - rou) * d).toDouble, ((Algebraic(k) + rou) * d).toDouble) + plot(xrange, xrange - k * d.toDouble, name = s"$k") } p.xlabel = "rω[j]" p.ylabel = "ω[j+1]" - p.xlim(-radix * rou * dMax, radix * rou * dMax) - p.ylim(-rou * d, rou * d) + p.xlim((-radix * rou * dMax).toDouble, (radix * rou * dMax).toDouble) + p.ylim((-rou * d).toDouble, (rou * d).toDouble) p.title = s"Robertson Graph of $this divisor: $d" p.legend = true fig.saveas("robertson.pdf") diff --git a/build.sc b/build.sc index 529f9a9..0ccd94e 100644 --- a/build.sc +++ b/build.sc @@ -18,6 +18,7 @@ object v { val breeze = ivy"com.github.ktakagaki.breeze::breeze:2.0" val breezeNatives = ivy"com.github.ktakagaki.breeze::breeze-natives:2.0" val breezeViz = ivy"org.scalanlp::breeze-viz:2.0" + val spire = ivy"org.typelevel::spire:0.17.0" // val prime = ivy"org.apache.commons:commons-math3:3.6.1" } @@ -43,7 +44,8 @@ class arithmetic extends ScalaModule with ScalafmtModule with PublishModule { m v.osLib, v.breeze, v.breezeViz, - v.breezeNatives + v.breezeNatives, + v.spire ) object tests extends Tests with Utest { From a9960a96294fe49a3838df01df101fd5eca1c7f4 Mon Sep 17 00:00:00 2001 From: Jiuyang Liu Date: Thu, 22 Apr 2021 18:21:37 +0000 Subject: [PATCH 03/31] rewrite SRT. --- .../prefixadder/graph/PrefixGraph.scala | 2 +- arithmetic/src/division/srt/SRT.scala | 213 ++++++++++++------ build.sc | 5 +- 3 files changed, 146 insertions(+), 74 deletions(-) diff --git a/arithmetic/src/addition/prefixadder/graph/PrefixGraph.scala b/arithmetic/src/addition/prefixadder/graph/PrefixGraph.scala index 38e63bd..ae72d72 100644 --- a/arithmetic/src/addition/prefixadder/graph/PrefixGraph.scala +++ b/arithmetic/src/addition/prefixadder/graph/PrefixGraph.scala @@ -53,7 +53,7 @@ object PrefixGraph { } object CommonSumByConsole extends HasPrefixSumWithGraphImp with CommonPrefixSum { - val filePath = Path(io.StdIn.readLine("Import your graph generated by `dot -Txdot_json`: "), pwd) + val filePath = Path(scala.io.StdIn.readLine("Import your graph generated by `dot -Txdot_json`: "), pwd) val fileName = filePath.baseName val prefixGraph: PrefixGraph = PrefixGraph(filePath) } diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 59a6f0d..48bbe0a 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -1,100 +1,169 @@ package division.srt -import breeze.linalg._ -import breeze.plot._ +import com.cibo.evilplot.colors.HTMLNamedColors +import com.cibo.evilplot.numeric.Bounds +import com.cibo.evilplot.plot._ +import com.cibo.evilplot.plot.aesthetics.DefaultTheme._ +import com.cibo.evilplot.plot.renderers.PointRenderer +import os.Path import spire.implicits._ import spire.math._ /** Base SRT class. * * @param radix is the radix of SRT. + * It defined how many rounds can be calculate in one cycle. * @note 5.2 - * @param quotientSet is the min and max of quotient-digit set - * @note 5.6 - * @param ulpN ulp is the unit in the last position, defined by `pow(r, -uplN)` - * @note 5.2 - * @param normD normalized divider range - * @param normX normalized dividend range */ case class SRT( - radix: Algebraic, - quotientSet: (Algebraic, Algebraic), - ulpN: Algebraic = 0, - normD: (Algebraic, Algebraic) = (-1, 0), - normX: (Algebraic, Algebraic) = (-1, 0)) { - val a: Algebraic = (-quotientSet._1).max(quotientSet._2) - // @note 5.7 - require(a >= radix / 2) - // @note 5.3 - require(normD._1 < normD._2) - require(normX._1 < normX._2) - val xMin: Algebraic = Algebraic(2).pow(normX._1.toInt) - val xMax: Algebraic = Algebraic(2).pow(normX._2.toInt) - val dMin: Algebraic = Algebraic(2).pow(normD._1.toInt) - val dMax: Algebraic = Algebraic(2).pow(normD._2.toInt) + radix: Algebraic, + a: Algebraic, + dTruncateWidth: Algebraic, + xTruncateWidth: Algebraic, + dMin: Algebraic = 0.5, + dMax: Algebraic = 1) { + require(a > 0) + lazy val xMin: Algebraic = -rho * dMax + lazy val xMax: Algebraic = rho * dMax + + /** P-D Diagram + * + * @note Graph 5.17(b) + */ + lazy val pd: Plot = Overlay((aMin.toBigInt to aMax.toBigInt).flatMap { k: BigInt => + Seq( + FunctionPlot.series( + _ * uRate(k.toInt).toDouble, + s"U($k)", + HTMLNamedColors.blue, + Some(Bounds(dMin.toDouble, dMax.toDouble)), + strokeWidth = Some(1) + ), + FunctionPlot.series( + _ * lRate(k.toInt).toDouble, + s"L($k)", + HTMLNamedColors.red, + Some(Bounds(dMin.toDouble, dMax.toDouble)), + strokeWidth = Some(1) + ) + ) ++ qdsPoints :+ mesh + }: _*) + .title(s"P-D Graph of $this") + .xLabel("d") + .yLabel("rω[j]") + .rightLegend() + .standard() + lazy val aMax: Algebraic = a + lazy val aMin: Algebraic = -a + lazy val deltaD: Algebraic = pow(2, -dTruncateWidth.toDouble) + lazy val deltaX: Algebraic = pow(2, -xTruncateWidth.toDouble) /** redundancy factor * @note 5.8 */ - val rou: Algebraic = a / (radix - 1) - - override def toString: String = s"SRT$radix with quotient set: from ${-quotientSet._1} to ${quotientSet._2}" - // @note 5.8s - assert((rou > 1 / 2) && (rou <= 1)) + lazy val rho: Algebraic = a / (radix - 1) + lazy val tables: Seq[(Int, Seq[(Algebraic, Seq[Algebraic])])] = { + (aMin.toInt to aMax.toInt).drop(1).map { k => + k -> dSet.dropRight(1).map { d => + val (floor, ceil) = xRange(k, d, d + deltaD) + val m: Seq[Algebraic] = xSet.filter { x: Algebraic => x <= ceil && x >= floor } + (d, m) + } + } + } + lazy val qdsPoints: Seq[Plot] = { + tables.map { + case (i, ps) => + ScatterPlot( + ps.flatMap { case (d, xs) => xs.map(x => com.cibo.evilplot.numeric.Point(d.toDouble, x.toDouble)) }, + Some( + PointRenderer + .default[com.cibo.evilplot.numeric.Point](pointSize = Some(1), color = Some(HTMLNamedColors.gold)) + ) + ) + } + } - /** P-D Diagram - * @note Graph 5.17(b) - */ - def pdDiagram(): Unit = { - val fig: Figure = Figure() - val p: Plot = fig.subplot(0) - val x: DenseVector[Double] = linspace(dMin.toDouble, dMax.toDouble) + private val xStep = (xMax - xMin) / deltaX + // @note 5.7 + require(a >= radix / 2) + private val xSet = Seq.tabulate((xStep + 1).toInt) { n => xMin + deltaX * n } + private val dStep: Algebraic = (dMax - dMin) / deltaD + assert((rho > 1 / 2) && (rho <= 1)) + private val dSet = Seq.tabulate((dStep + 1).toInt) { n => dMin + deltaD * n } + private val mesh = + ScatterPlot( + xSet.flatMap { y => + dSet.map { x => + com.cibo.evilplot.numeric.Point(x.toDouble, y.toDouble) + } + }, + Some( + PointRenderer + .default[com.cibo.evilplot.numeric.Point](pointSize = Some(0.5), color = Some(HTMLNamedColors.gray)) + ) + ) - val (uk, lk) = (quotientSet._1.toBigInt to quotientSet._2.toBigInt).map { k: BigInt => - (plot(x, x * uRate(k.toInt).toDouble, name = s"U($k)"), plot(x, x * lRate(k.toInt).toDouble, name = s"L($k)")) - }.unzip + override def toString: String = + s"SRT${radix.toInt} with quotient set: from ${aMin.toInt} to ${aMax.toInt}" - p ++= uk ++= lk + /** Robertson Diagram + * + * @note Graph 5.17(a) + */ + def robertson(d: Algebraic): Plot = { + require(d > dMin && d < dMax) + Overlay((aMin.toBigInt to aMax.toBigInt).map { k: BigInt => + FunctionPlot.series( + _ - (Algebraic(k) * d).toDouble, + s"$k", + HTMLNamedColors.black, + xbounds = Some(Bounds(((Algebraic(k) - rho) * d).toDouble, ((Algebraic(k) + rho) * d).toDouble)) + ) + }: _*) + .title(s"Robertson Graph of $this divisor: $d") + .xLabel("rω[j]") + .yLabel("ω[j+1]") + .xbounds((-radix * rho * dMax).toDouble, (radix * rho * dMax).toDouble) + .ybounds((-rho * d).toDouble, (rho * d).toDouble) + .rightLegend() + .standard() + } - p.xlabel = "d" - p.ylabel = "rω[j]" - val scale = 1.1 - p.xlim(0, (xMax * scale).toDouble) - p.ylim(((quotientSet._1 - rou) * xMax * scale).toDouble, ((quotientSet._2 + rou) * xMax * scale).toDouble) - p.title = s"P-D Graph of $this" - p.legend = true - fig.saveas("pd.pdf") + def dumpGraph(plot: Plot, path: Path) = { + javax.imageio.ImageIO.write( + plot.render().asBufferedImage, + "png", + path.wrapped.toFile + ) } - /** slope factor of U_k - * @note 5.56 + /** for range `dLeft` to `dRight`, return the `rOmegaCeil` and `rOmegaFloor` + * this is used for constructing the rectangle where m_k(i) is located. */ - def uRate(k: Algebraic): Algebraic = k + rou + private def xRange(k: Algebraic, dLeft: Algebraic, dRight: Algebraic): (Algebraic, Algebraic) = { + Seq(L(k, dLeft), L(k, dRight), U(k - 1, dLeft), U(k - 1, dRight)) + // not safe + .sortBy(_.toDouble) + .drop(1) + .dropRight(1) match { case Seq(l, r) => (l, r) } + } + + /** find the intersection point between L`k` and `d` */ + private def L(k: Algebraic, d: Algebraic): Algebraic = lRate(k) * d /** slope factor of L_k + * * @note 5.56 */ - def lRate(k: Algebraic): Algebraic = k - rou + private def lRate(k: Algebraic): Algebraic = k - rho - /** Robertson Diagram - * @note Graph 5.17(a) - */ - def robertsonDiagram(d: Algebraic): Unit = { - require(d > dMin && d < dMax) - val fig: Figure = Figure() - val p: Plot = fig.subplot(0) - - p ++= (quotientSet._1.toInt to quotientSet._2.toInt).map { k: Int => - val xrange: DenseVector[Double] = linspace(((Algebraic(k) - rou) * d).toDouble, ((Algebraic(k) + rou) * d).toDouble) - plot(xrange, xrange - k * d.toDouble, name = s"$k") - } + /** find the intersection point between U`k` and `d` */ + private def U(k: Algebraic, d: Algebraic): Algebraic = uRate(k) * d - p.xlabel = "rω[j]" - p.ylabel = "ω[j+1]" - p.xlim((-radix * rou * dMax).toDouble, (radix * rou * dMax).toDouble) - p.ylim((-rou * d).toDouble, (rou * d).toDouble) - p.title = s"Robertson Graph of $this divisor: $d" - p.legend = true - fig.saveas("robertson.pdf") - } -} + /** slope factor of U_k + * + * @note 5.56 + */ + private def uRate(k: Algebraic): Algebraic = k + rho +} \ No newline at end of file diff --git a/build.sc b/build.sc index 0ccd94e..df4fcba 100644 --- a/build.sc +++ b/build.sc @@ -19,6 +19,7 @@ object v { val breezeNatives = ivy"com.github.ktakagaki.breeze::breeze-natives:2.0" val breezeViz = ivy"org.scalanlp::breeze-viz:2.0" val spire = ivy"org.typelevel::spire:0.17.0" + val evilplot = ivy"io.github.cibotech::evilplot:0.8.1" // val prime = ivy"org.apache.commons:commons-math3:3.6.1" } @@ -38,6 +39,7 @@ class arithmetic extends ScalaModule with ScalafmtModule with PublishModule { m override def scalacPluginIvyDeps = Agg(v.chisel3Plugin) override def ivyDeps = super.ivyDeps() ++ Agg( +<<<<<<< HEAD v.chisel3, v.chiseltest, v.upickle, @@ -45,7 +47,8 @@ class arithmetic extends ScalaModule with ScalafmtModule with PublishModule { m v.breeze, v.breezeViz, v.breezeNatives, - v.spire + v.spire, + v.evilplot ) object tests extends Tests with Utest { From cb2caed9cdc49662f8fa68e8e8c12ffeb971ace8 Mon Sep 17 00:00:00 2001 From: Jiuyang Liu Date: Fri, 23 Apr 2021 15:50:38 +0000 Subject: [PATCH 04/31] bug fix and add test. --- arithmetic/src/division/srt/SRT.scala | 6 +++--- arithmetic/tests/src/division/srt/SRTSpec.scala | 13 +++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 arithmetic/tests/src/division/srt/SRTSpec.scala diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 48bbe0a..cca0f88 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -23,8 +23,8 @@ case class SRT( dMin: Algebraic = 0.5, dMax: Algebraic = 1) { require(a > 0) - lazy val xMin: Algebraic = -rho * dMax - lazy val xMax: Algebraic = rho * dMax + lazy val xMin: Algebraic = -rho * dMax * radix + lazy val xMax: Algebraic = rho * dMax * radix /** P-D Diagram * @@ -50,7 +50,7 @@ case class SRT( }: _*) .title(s"P-D Graph of $this") .xLabel("d") - .yLabel("rω[j]") + .yLabel(s"${radix.toInt}ω[j]") .rightLegend() .standard() lazy val aMax: Algebraic = a diff --git a/arithmetic/tests/src/division/srt/SRTSpec.scala b/arithmetic/tests/src/division/srt/SRTSpec.scala new file mode 100644 index 0000000..b2831a9 --- /dev/null +++ b/arithmetic/tests/src/division/srt/SRTSpec.scala @@ -0,0 +1,13 @@ +package division.srt + +import utest._ + + +object SRTSpec extends TestSuite{ + override def tests: Tests = Tests { + test("SRT should draw PD") { + val srt = SRT(4, 2, 5, 5) + srt.dumpGraph(srt.pd, os.root / "tmp" / "srt4-2-5-5.png") + } + } +} From 0ef1d57777772a2403f05df14d54598c14cd413e Mon Sep 17 00:00:00 2001 From: Jiuyang Liu Date: Fri, 1 Apr 2022 16:09:37 +0800 Subject: [PATCH 05/31] fix --- build.sc | 1 - 1 file changed, 1 deletion(-) diff --git a/build.sc b/build.sc index df4fcba..0cde056 100644 --- a/build.sc +++ b/build.sc @@ -39,7 +39,6 @@ class arithmetic extends ScalaModule with ScalafmtModule with PublishModule { m override def scalacPluginIvyDeps = Agg(v.chisel3Plugin) override def ivyDeps = super.ivyDeps() ++ Agg( -<<<<<<< HEAD v.chisel3, v.chiseltest, v.upickle, From 79edb608abd0a2e71174626e0cd20d63f21c5303 Mon Sep 17 00:00:00 2001 From: Jiuyang Liu Date: Fri, 1 Apr 2022 18:18:35 +0800 Subject: [PATCH 06/31] wip SRT --- arithmetic/src/division/srt/QDS.scala | 27 +++ arithmetic/src/division/srt/SRT.scala | 211 +++++------------- arithmetic/src/division/srt/SRTTable.scala | 169 ++++++++++++++ .../tests/src/division/srt/SRTSpec.scala | 2 +- build.sc | 12 +- 5 files changed, 257 insertions(+), 164 deletions(-) create mode 100644 arithmetic/src/division/srt/QDS.scala create mode 100644 arithmetic/src/division/srt/SRTTable.scala diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala new file mode 100644 index 0000000..9336ebe --- /dev/null +++ b/arithmetic/src/division/srt/QDS.scala @@ -0,0 +1,27 @@ +package division.srt +import chisel3._ +import chisel3.util.{RegEnable, Valid} + +class QDSInput extends Bundle { + val partialReminderCarry: UInt = ??? + val partialReminderSum: UInt = ??? +} + +class QDSOutput extends Bundle { + val selectedQuotient: UInt = ??? +} + +class QDS extends Module { + val input = IO(Input(new QDSInput)) + val output = IO(Output(new QDSOutput)) + // used to select a column of SRT Table + val partialDivider = IO(Flipped(Valid(UInt()))) + val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) + // for the first cycle: use partialDivider on the IO + // for the reset of cycles: use partialDividerReg + // for synthesis: the constraint should be IO -> Output is a multi-cycle design + // Reg -> Output is single-cycle + // to avoid glitch, valid should be larger than raise time of partialDividerReg + val partialDividerLatch = Mux(partialDivider.valid, partialDivider.bits, partialDividerReg) + +} diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index cca0f88..0b2a7cb 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -1,169 +1,66 @@ package division.srt -import com.cibo.evilplot.colors.HTMLNamedColors -import com.cibo.evilplot.numeric.Bounds -import com.cibo.evilplot.plot._ -import com.cibo.evilplot.plot.aesthetics.DefaultTheme._ -import com.cibo.evilplot.plot.renderers.PointRenderer -import os.Path -import spire.implicits._ -import spire.math._ +import addition.csa.CarrySaveAdder +import addition.csa.common.CSACompressor3_2 +import chisel3._ +import chisel3.util.{Decoupled, DecoupledIO, Mux1H, log2Ceil} -/** Base SRT class. - * - * @param radix is the radix of SRT. - * It defined how many rounds can be calculate in one cycle. - * @note 5.2 - */ -case class SRT( - radix: Algebraic, - a: Algebraic, - dTruncateWidth: Algebraic, - xTruncateWidth: Algebraic, - dMin: Algebraic = 0.5, - dMax: Algebraic = 1) { - require(a > 0) - lazy val xMin: Algebraic = -rho * dMax * radix - lazy val xMax: Algebraic = rho * dMax * radix +class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { + val dividend = UInt(dividendWidth.W) + val divider = UInt(dividerWidth.W) + val counter = UInt(log2Ceil(???).W) +} - /** P-D Diagram - * - * @note Graph 5.17(b) - */ - lazy val pd: Plot = Overlay((aMin.toBigInt to aMax.toBigInt).flatMap { k: BigInt => - Seq( - FunctionPlot.series( - _ * uRate(k.toInt).toDouble, - s"U($k)", - HTMLNamedColors.blue, - Some(Bounds(dMin.toDouble, dMax.toDouble)), - strokeWidth = Some(1) - ), - FunctionPlot.series( - _ * lRate(k.toInt).toDouble, - s"L($k)", - HTMLNamedColors.red, - Some(Bounds(dMin.toDouble, dMax.toDouble)), - strokeWidth = Some(1) - ) - ) ++ qdsPoints :+ mesh - }: _*) - .title(s"P-D Graph of $this") - .xLabel("d") - .yLabel(s"${radix.toInt}ω[j]") - .rightLegend() - .standard() - lazy val aMax: Algebraic = a - lazy val aMin: Algebraic = -a - lazy val deltaD: Algebraic = pow(2, -dTruncateWidth.toDouble) - lazy val deltaX: Algebraic = pow(2, -xTruncateWidth.toDouble) +class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { + val reminder = UInt(reminderWidth.W) + val quotient = UInt(quotientWidth.W) +} - /** redundancy factor - * @note 5.8 - */ - lazy val rho: Algebraic = a / (radix - 1) - lazy val tables: Seq[(Int, Seq[(Algebraic, Seq[Algebraic])])] = { - (aMin.toInt to aMax.toInt).drop(1).map { k => - k -> dSet.dropRight(1).map { d => - val (floor, ceil) = xRange(k, d, d + deltaD) - val m: Seq[Algebraic] = xSet.filter { x: Algebraic => x <= ceil && x >= floor } - (d, m) - } - } - } - lazy val qdsPoints: Seq[Plot] = { - tables.map { - case (i, ps) => - ScatterPlot( - ps.flatMap { case (d, xs) => xs.map(x => com.cibo.evilplot.numeric.Point(d.toDouble, x.toDouble)) }, - Some( - PointRenderer - .default[com.cibo.evilplot.numeric.Point](pointSize = Some(1), color = Some(HTMLNamedColors.gold)) - ) - ) - } - } +// only SRT4 currently +class SRT( + dividendWidth: Int, + dividerWidth: Int, + n: Int) + extends Module { + // IO + val input: DecoupledIO[SRTInput] = Flipped(Decoupled(new SRTInput(dividendWidth, dividerWidth, n))) + val output: DecoupledIO[SRTOutput] = Decoupled(new SRTOutput(dividerWidth, dividendWidth)) - private val xStep = (xMax - xMin) / deltaX - // @note 5.7 - require(a >= radix / 2) - private val xSet = Seq.tabulate((xStep + 1).toInt) { n => xMin + deltaX * n } - private val dStep: Algebraic = (dMax - dMin) / deltaD - assert((rho > 1 / 2) && (rho <= 1)) - private val dSet = Seq.tabulate((dStep + 1).toInt) { n => dMin + deltaD * n } - private val mesh = - ScatterPlot( - xSet.flatMap { y => - dSet.map { x => - com.cibo.evilplot.numeric.Point(x.toDouble, y.toDouble) - } - }, - Some( - PointRenderer - .default[com.cibo.evilplot.numeric.Point](pointSize = Some(0.5), color = Some(HTMLNamedColors.gray)) - ) - ) - override def toString: String = - s"SRT${radix.toInt} with quotient set: from ${aMin.toInt} to ${aMax.toInt}" + // State + // because we need a CSA to minimize the critical path + val partialReminderCarry = Reg(UInt()) + val partialReminderSum = Reg(UInt()) + val divider = Reg(UInt()) - /** Robertson Diagram - * - * @note Graph 5.17(a) - */ - def robertson(d: Algebraic): Plot = { - require(d > dMin && d < dMax) - Overlay((aMin.toBigInt to aMax.toBigInt).map { k: BigInt => - FunctionPlot.series( - _ - (Algebraic(k) * d).toDouble, - s"$k", - HTMLNamedColors.black, - xbounds = Some(Bounds(((Algebraic(k) - rho) * d).toDouble, ((Algebraic(k) + rho) * d).toDouble)) - ) - }: _*) - .title(s"Robertson Graph of $this divisor: $d") - .xLabel("rω[j]") - .yLabel("ω[j+1]") - .xbounds((-radix * rho * dMax).toDouble, (radix * rho * dMax).toDouble) - .ybounds((-rho * d).toDouble, (rho * d).toDouble) - .rightLegend() - .standard() - } + val quotient = Reg(UInt()) + val quotientMinusOne = Reg(UInt()) - def dumpGraph(plot: Plot, path: Path) = { - javax.imageio.ImageIO.write( - plot.render().asBufferedImage, - "png", - path.wrapped.toFile - ) - } + val state = Reg(UInt()) + val counter = Reg(UInt()) - /** for range `dLeft` to `dRight`, return the `rOmegaCeil` and `rOmegaFloor` - * this is used for constructing the rectangle where m_k(i) is located. - */ - private def xRange(k: Algebraic, dLeft: Algebraic, dRight: Algebraic): (Algebraic, Algebraic) = { - Seq(L(k, dLeft), L(k, dRight), U(k - 1, dLeft), U(k - 1, dRight)) - // not safe - .sortBy(_.toDouble) - .drop(1) - .dropRight(1) match { case Seq(l, r) => (l, r) } - } + // Control + // sign of select quotient, true -> negative, false -> positive + val qdsSign: Bool = Wire(Bool()) - /** find the intersection point between L`k` and `d` */ - private def L(k: Algebraic, d: Algebraic): Algebraic = lRate(k) * d + // Datapath + val qds = new QDS() - /** slope factor of L_k - * - * @note 5.56 - */ - private def lRate(k: Algebraic): Algebraic = k - rho - - /** find the intersection point between U`k` and `d` */ - private def U(k: Algebraic, d: Algebraic): Algebraic = uRate(k) * d - - /** slope factor of U_k - * - * @note 5.56 - */ - private def uRate(k: Algebraic): Algebraic = k + rho -} \ No newline at end of file + val csa = new CarrySaveAdder(CSACompressor3_2, ???) + csa.in(0) := partialReminderSum + csa.in(1) := (partialReminderCarry ## !qdsSign) + csa.in(2) := Mux1H(Map( + ??? -> , + ??? -> + )) + partialReminderSum := Mux1H(Map( + ??? -> input.bits.dividend, + ??? -> (csa.out(0) << log2Ceil(n)), + ??? -> partialReminderSum + )) + partialReminderCarry := Mux1H(Map( + ??? -> 0.U, + ??? -> (csa.out(1) << log2Ceil(n)), + ??? -> partialReminderCarry + )) +} diff --git a/arithmetic/src/division/srt/SRTTable.scala b/arithmetic/src/division/srt/SRTTable.scala new file mode 100644 index 0000000..6567e38 --- /dev/null +++ b/arithmetic/src/division/srt/SRTTable.scala @@ -0,0 +1,169 @@ +package division.srt + +import com.cibo.evilplot.colors.HTMLNamedColors +import com.cibo.evilplot.numeric.Bounds +import com.cibo.evilplot.plot._ +import com.cibo.evilplot.plot.aesthetics.DefaultTheme._ +import com.cibo.evilplot.plot.renderers.PointRenderer +import os.Path +import spire.implicits._ +import spire.math._ + +/** Base SRT class. + * + * @param radix is the radix of SRT. + * It defined how many rounds can be calculate in one cycle. + * @note 5.2 + */ +case class SRTTable( + radix: Algebraic, + a: Algebraic, + dTruncateWidth: Algebraic, + xTruncateWidth: Algebraic, + dMin: Algebraic = 0.5, + dMax: Algebraic = 1) { + require(a > 0) + lazy val xMin: Algebraic = -rho * dMax * radix + lazy val xMax: Algebraic = rho * dMax * radix + + /** P-D Diagram + * + * @note Graph 5.17(b) + */ + lazy val pd: Plot = Overlay((aMin.toBigInt to aMax.toBigInt).flatMap { k: BigInt => + Seq( + FunctionPlot.series( + _ * uRate(k.toInt).toDouble, + s"U($k)", + HTMLNamedColors.blue, + Some(Bounds(dMin.toDouble, dMax.toDouble)), + strokeWidth = Some(1) + ), + FunctionPlot.series( + _ * lRate(k.toInt).toDouble, + s"L($k)", + HTMLNamedColors.red, + Some(Bounds(dMin.toDouble, dMax.toDouble)), + strokeWidth = Some(1) + ) + ) ++ qdsPoints :+ mesh + }: _*) + .title(s"P-D Graph of $this") + .xLabel("d") + .yLabel(s"${radix.toInt}ω[j]") + .rightLegend() + .standard() + lazy val aMax: Algebraic = a + lazy val aMin: Algebraic = -a + lazy val deltaD: Algebraic = pow(2, -dTruncateWidth.toDouble) + lazy val deltaX: Algebraic = pow(2, -xTruncateWidth.toDouble) + + /** redundancy factor + * @note 5.8 + */ + lazy val rho: Algebraic = a / (radix - 1) + lazy val tables: Seq[(Int, Seq[(Algebraic, Seq[Algebraic])])] = { + (aMin.toInt to aMax.toInt).drop(1).map { k => + k -> dSet.dropRight(1).map { d => + val (floor, ceil) = xRange(k, d, d + deltaD) + val m: Seq[Algebraic] = xSet.filter { x: Algebraic => x <= ceil && x >= floor } + (d, m) + } + } + } + lazy val qdsPoints: Seq[Plot] = { + tables.map { + case (i, ps) => + ScatterPlot( + ps.flatMap { case (d, xs) => xs.map(x => com.cibo.evilplot.numeric.Point(d.toDouble, x.toDouble)) }, + Some( + PointRenderer + .default[com.cibo.evilplot.numeric.Point](pointSize = Some(1), color = Some(HTMLNamedColors.gold)) + ) + ) + } + } + + private val xStep = (xMax - xMin) / deltaX + // @note 5.7 + require(a >= radix / 2) + private val xSet = Seq.tabulate((xStep + 1).toInt) { n => xMin + deltaX * n } + private val dStep: Algebraic = (dMax - dMin) / deltaD + assert((rho > 1 / 2) && (rho <= 1)) + private val dSet = Seq.tabulate((dStep + 1).toInt) { n => dMin + deltaD * n } + private val mesh = + ScatterPlot( + xSet.flatMap { y => + dSet.map { x => + com.cibo.evilplot.numeric.Point(x.toDouble, y.toDouble) + } + }, + Some( + PointRenderer + .default[com.cibo.evilplot.numeric.Point](pointSize = Some(0.5), color = Some(HTMLNamedColors.gray)) + ) + ) + + override def toString: String = + s"SRT${radix.toInt} with quotient set: from ${aMin.toInt} to ${aMax.toInt}" + + /** Robertson Diagram + * + * @note Graph 5.17(a) + */ + def robertson(d: Algebraic): Plot = { + require(d > dMin && d < dMax) + Overlay((aMin.toBigInt to aMax.toBigInt).map { k: BigInt => + FunctionPlot.series( + _ - (Algebraic(k) * d).toDouble, + s"$k", + HTMLNamedColors.black, + xbounds = Some(Bounds(((Algebraic(k) - rho) * d).toDouble, ((Algebraic(k) + rho) * d).toDouble)) + ) + }: _*) + .title(s"Robertson Graph of $this divisor: $d") + .xLabel("rω[j]") + .yLabel("ω[j+1]") + .xbounds((-radix * rho * dMax).toDouble, (radix * rho * dMax).toDouble) + .ybounds((-rho * d).toDouble, (rho * d).toDouble) + .rightLegend() + .standard() + } + + def dumpGraph(plot: Plot, path: Path) = { + javax.imageio.ImageIO.write( + plot.render().asBufferedImage, + "png", + path.wrapped.toFile + ) + } + + /** for range `dLeft` to `dRight`, return the `rOmegaCeil` and `rOmegaFloor` + * this is used for constructing the rectangle where m_k(i) is located. + */ + private def xRange(k: Algebraic, dLeft: Algebraic, dRight: Algebraic): (Algebraic, Algebraic) = { + Seq(L(k, dLeft), L(k, dRight), U(k - 1, dLeft), U(k - 1, dRight)) + // not safe + .sortBy(_.toDouble) + .drop(1) + .dropRight(1) match { case Seq(l, r) => (l, r) } + } + + /** find the intersection point between L`k` and `d` */ + private def L(k: Algebraic, d: Algebraic): Algebraic = lRate(k) * d + + /** slope factor of L_k + * + * @note 5.56 + */ + private def lRate(k: Algebraic): Algebraic = k - rho + + /** find the intersection point between U`k` and `d` */ + private def U(k: Algebraic, d: Algebraic): Algebraic = uRate(k) * d + + /** slope factor of U_k + * + * @note 5.56 + */ + private def uRate(k: Algebraic): Algebraic = k + rho +} \ No newline at end of file diff --git a/arithmetic/tests/src/division/srt/SRTSpec.scala b/arithmetic/tests/src/division/srt/SRTSpec.scala index b2831a9..85d143d 100644 --- a/arithmetic/tests/src/division/srt/SRTSpec.scala +++ b/arithmetic/tests/src/division/srt/SRTSpec.scala @@ -6,7 +6,7 @@ import utest._ object SRTSpec extends TestSuite{ override def tests: Tests = Tests { test("SRT should draw PD") { - val srt = SRT(4, 2, 5, 5) + val srt = SRTTable(4, 2, 5, 5) srt.dumpGraph(srt.pd, os.root / "tmp" / "srt4-2-5-5.png") } } diff --git a/build.sc b/build.sc index 0cde056..5dfe61e 100644 --- a/build.sc +++ b/build.sc @@ -15,9 +15,9 @@ object v { val utest = ivy"com.lihaoyi::utest:latest.integration" val upickle = ivy"com.lihaoyi::upickle:latest.integration" val osLib = ivy"com.lihaoyi::os-lib:latest.integration" - val breeze = ivy"com.github.ktakagaki.breeze::breeze:2.0" - val breezeNatives = ivy"com.github.ktakagaki.breeze::breeze-natives:2.0" - val breezeViz = ivy"org.scalanlp::breeze-viz:2.0" +// val breeze = ivy"com.github.ktakagaki.breeze::breeze:2.0" +// val breezeNatives = ivy"com.github.ktakagaki.breeze::breeze-natives:2.0" +// val breezeViz = ivy"org.scalanlp::breeze-viz:2.0" val spire = ivy"org.typelevel::spire:0.17.0" val evilplot = ivy"io.github.cibotech::evilplot:0.8.1" // val prime = ivy"org.apache.commons:commons-math3:3.6.1" @@ -43,9 +43,9 @@ class arithmetic extends ScalaModule with ScalafmtModule with PublishModule { m v.chiseltest, v.upickle, v.osLib, - v.breeze, - v.breezeViz, - v.breezeNatives, +// v.breeze, +// v.breezeViz, +// v.breezeNatives, v.spire, v.evilplot ) From 11088452274269802254edc40a8d5945f02861fc Mon Sep 17 00:00:00 2001 From: GH Cheng <1536771081@qq.com> Date: Mon, 4 Apr 2022 21:00:17 +0800 Subject: [PATCH 07/31] srt fix --- arithmetic/src/division/srt/QDS.scala | 8 ++++---- arithmetic/src/division/srt/SRT.scala | 6 ++++-- build.sc | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index 9336ebe..47a4c7b 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -1,6 +1,6 @@ package division.srt import chisel3._ -import chisel3.util.{RegEnable, Valid} +import chisel3.util.{RegEnable, Valid, log2Ceil} class QDSInput extends Bundle { val partialReminderCarry: UInt = ??? @@ -8,15 +8,15 @@ class QDSInput extends Bundle { } class QDSOutput extends Bundle { - val selectedQuotient: UInt = ??? + val selectedQuotient: UInt = UInt((log2Ceil(n)+1).W) } class QDS extends Module { val input = IO(Input(new QDSInput)) val output = IO(Output(new QDSOutput)) // used to select a column of SRT Table - val partialDivider = IO(Flipped(Valid(UInt()))) - val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) + val partialDivider = IO(Flipped(Valid(UInt()))) //它表达的是什么意思? + val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) // for the first cycle: use partialDivider on the IO // for the reset of cycles: use partialDividerReg // for synthesis: the constraint should be IO -> Output is a multi-cycle design diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 0b2a7cb..0e38375 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -31,6 +31,8 @@ class SRT( // because we need a CSA to minimize the critical path val partialReminderCarry = Reg(UInt()) val partialReminderSum = Reg(UInt()) + + // dMultiplier val divider = Reg(UInt()) val quotient = Reg(UInt()) @@ -48,7 +50,7 @@ class SRT( val csa = new CarrySaveAdder(CSACompressor3_2, ???) csa.in(0) := partialReminderSum - csa.in(1) := (partialReminderCarry ## !qdsSign) + csa.in(1) := (partialReminderCarry ## !qdsSign) //?这里有点问题 csa.in(2) := Mux1H(Map( ??? -> , ??? -> @@ -59,7 +61,7 @@ class SRT( ??? -> partialReminderSum )) partialReminderCarry := Mux1H(Map( - ??? -> 0.U, + ??? -> 0.U, ??? -> (csa.out(1) << log2Ceil(n)), ??? -> partialReminderCarry )) diff --git a/build.sc b/build.sc index 5dfe61e..46a78e1 100644 --- a/build.sc +++ b/build.sc @@ -20,7 +20,7 @@ object v { // val breezeViz = ivy"org.scalanlp::breeze-viz:2.0" val spire = ivy"org.typelevel::spire:0.17.0" val evilplot = ivy"io.github.cibotech::evilplot:0.8.1" - // val prime = ivy"org.apache.commons:commons-math3:3.6.1" +// val prime = ivy"org.apache.commons:commons-math3:3.6.1" } object arithmetic extends arithmetic From c502ab52a359529a461e03a5f227e604d54ad1e4 Mon Sep 17 00:00:00 2001 From: Jiuyang Liu Date: Tue, 5 Apr 2022 15:15:33 +0800 Subject: [PATCH 08/31] wip --- arithmetic/src/division/srt/QDS.scala | 22 ++++++++-- arithmetic/src/division/srt/SRT.scala | 61 ++++++++++++++++++--------- 2 files changed, 59 insertions(+), 24 deletions(-) diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index 9336ebe..a503e27 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -4,18 +4,23 @@ import chisel3.util.{RegEnable, Valid} class QDSInput extends Bundle { val partialReminderCarry: UInt = ??? - val partialReminderSum: UInt = ??? + val partialReminderSum: UInt = ??? } class QDSOutput extends Bundle { - val selectedQuotient: UInt = ??? + val selectedQuotientOH: UInt = ??? } -class QDS extends Module { +/** + */ +class QDS(table: String) extends Module { + // IO val input = IO(Input(new QDSInput)) val output = IO(Output(new QDSOutput)) // used to select a column of SRT Table val partialDivider = IO(Flipped(Valid(UInt()))) + + // State val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) // for the first cycle: use partialDivider on the IO // for the reset of cycles: use partialDividerReg @@ -24,4 +29,15 @@ class QDS extends Module { // to avoid glitch, valid should be larger than raise time of partialDividerReg val partialDividerLatch = Mux(partialDivider.valid, partialDivider.bits, partialDividerReg) + // Datapath + val columnSelect = partialDividerLatch + val rowSelect = input.partialReminderCarry + input.partialReminderSum + val selectRom: Vec[Vec[UInt]] = ??? + val mkVec = selectRom(columnSelect) + val selectPoints = mkVec.map{mk => + // get the select point + input.partialReminderCarry + input.partialReminderSum - mk + } + // decoder or findFirstOne here, prefer decoder + } diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 0b2a7cb..68c140c 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -3,7 +3,7 @@ package division.srt import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 import chisel3._ -import chisel3.util.{Decoupled, DecoupledIO, Mux1H, log2Ceil} +import chisel3.util.{log2Ceil, Decoupled, DecoupledIO, Mux1H} class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { val dividend = UInt(dividendWidth.W) @@ -20,13 +20,12 @@ class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { class SRT( dividendWidth: Int, dividerWidth: Int, - n: Int) + n: Int) extends Module { // IO val input: DecoupledIO[SRTInput] = Flipped(Decoupled(new SRTInput(dividendWidth, dividerWidth, n))) val output: DecoupledIO[SRTOutput] = Decoupled(new SRTOutput(dividerWidth, dividendWidth)) - // State // because we need a CSA to minimize the critical path val partialReminderCarry = Reg(UInt()) @@ -44,23 +43,43 @@ class SRT( val qdsSign: Bool = Wire(Bool()) // Datapath - val qds = new QDS() + val qds = new QDS("???") + // TODO: bit select here + qds.input.partialReminderSum := partialReminderSum + qds.input.partialReminderCarry := partialReminderCarry - val csa = new CarrySaveAdder(CSACompressor3_2, ???) - csa.in(0) := partialReminderSum - csa.in(1) := (partialReminderCarry ## !qdsSign) - csa.in(2) := Mux1H(Map( - ??? -> , - ??? -> - )) - partialReminderSum := Mux1H(Map( - ??? -> input.bits.dividend, - ??? -> (csa.out(0) << log2Ceil(n)), - ??? -> partialReminderSum - )) - partialReminderCarry := Mux1H(Map( - ??? -> 0.U, - ??? -> (csa.out(1) << log2Ceil(n)), - ??? -> partialReminderCarry - )) + // for SRT4 -> CSA32 + // for SRT8 -> CSA32+CSA32 + // for SRT16 -> CSA53+CSA32 + // SRT16 <- SRT4 + SRT4*5 + // { + val csa = new CarrySaveAdder(CSACompressor3_2, ???) + csa.in(0) := partialReminderSum + csa.in(1) := (partialReminderCarry ## !qdsSign) + csa.in(2) := Mux1H( + qds.output.selectedQuotientOH, + // TODO: this is for SRT4, for SRT8 or SRT16, this should be changed + VecInit((-2 to 2).map { + case -2 => divider << 1 + case -1 => divider + case 0 => 0.U + case 1 => ~divider + case 2 => (~divider) << 1 + }) + ) + // } + partialReminderSum := Mux1H( + Map( + ??? -> input.bits.dividend, + ??? -> (csa.out(0) << log2Ceil(n)), + ??? -> partialReminderSum + ) + ) + partialReminderCarry := Mux1H( + Map( + ??? -> 0.U, + ??? -> (csa.out(1) << log2Ceil(n)), + ??? -> partialReminderCarry + ) + ) } From c3c3ed1c3e17df129da9bcdef965b8ae86ccc1cb Mon Sep 17 00:00:00 2001 From: Jiuyang Liu Date: Thu, 7 Apr 2022 21:33:25 +0800 Subject: [PATCH 09/31] wip --- arithmetic/src/division/srt/OTF.scala | 23 +++++++++++++++++++++++ arithmetic/src/division/srt/SRT.scala | 23 +++++++++++++++++------ 2 files changed, 40 insertions(+), 6 deletions(-) create mode 100644 arithmetic/src/division/srt/OTF.scala diff --git a/arithmetic/src/division/srt/OTF.scala b/arithmetic/src/division/srt/OTF.scala new file mode 100644 index 0000000..61d0ded --- /dev/null +++ b/arithmetic/src/division/srt/OTF.scala @@ -0,0 +1,23 @@ +package division.srt + +import chisel3._ + +class OTFInput extends Bundle { + val quotient = UInt() + val quotientMinusOne = UInt() + val selectedQuotientOH = UInt() +} + +class OTFOutput extends Bundle { + val quotient = UInt() + val quotientMinusOne = UInt() +} + +class OTF extends Module { + val input = IO(Input(new OTFInput)) + val output = IO(Output(new OTFOutput)) + // control + + // datapath + +} diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 68c140c..bc0e3af 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -3,7 +3,7 @@ package division.srt import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 import chisel3._ -import chisel3.util.{log2Ceil, Decoupled, DecoupledIO, Mux1H} +import chisel3.util.{DecoupledIO, Mux1H, ValidIO, log2Ceil} class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { val dividend = UInt(dividendWidth.W) @@ -23,8 +23,8 @@ class SRT( n: Int) extends Module { // IO - val input: DecoupledIO[SRTInput] = Flipped(Decoupled(new SRTInput(dividendWidth, dividerWidth, n))) - val output: DecoupledIO[SRTOutput] = Decoupled(new SRTOutput(dividerWidth, dividendWidth)) + val input = Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n))) + val output = ValidIO(new SRTOutput(dividerWidth, dividendWidth)) // State // because we need a CSA to minimize the critical path @@ -35,7 +35,7 @@ class SRT( val quotient = Reg(UInt()) val quotientMinusOne = Reg(UInt()) - val state = Reg(UInt()) + // counter = 0 quotientToFix = 0 -> val counter = Reg(UInt()) // Control @@ -43,17 +43,19 @@ class SRT( val qdsSign: Bool = Wire(Bool()) // Datapath - val qds = new QDS("???") + val qds = Module(new QDS("???")) // TODO: bit select here qds.input.partialReminderSum := partialReminderSum qds.input.partialReminderCarry := partialReminderCarry + counter := counter - 1.U + // for SRT4 -> CSA32 // for SRT8 -> CSA32+CSA32 // for SRT16 -> CSA53+CSA32 // SRT16 <- SRT4 + SRT4*5 // { - val csa = new CarrySaveAdder(CSACompressor3_2, ???) + val csa = Module(new CarrySaveAdder(CSACompressor3_2, ???)) csa.in(0) := partialReminderSum csa.in(1) := (partialReminderCarry ## !qdsSign) csa.in(2) := Mux1H( @@ -82,4 +84,13 @@ class SRT( ??? -> partialReminderCarry ) ) + + // On-The-Fly conversion + val otf = Module(new OTF) + otf.input.quotient := quotient + otf.input.quotientMinusOne := quotientMinusOne + otf.input.selectedQuotientOH := qds.output.selectedQuotientOH + quotient := otf.output.quotient + quotientMinusOne := otf.output.quotientMinusOne + output.bits.quotient := quotient } From 65808e24d2653c43a6aa3c26ac23f5f590e979af Mon Sep 17 00:00:00 2001 From: GH Cheng <1536771081@qq.com> Date: Fri, 8 Apr 2022 14:57:20 +0800 Subject: [PATCH 10/31] srt fetch --- arithmetic/src/division/srt/OTF.scala | 23 ++++++++++ arithmetic/src/division/srt/QDS.scala | 27 +++++++++-- arithmetic/src/division/srt/SRT.scala | 65 +++++++++++++++++++++++---- 3 files changed, 104 insertions(+), 11 deletions(-) create mode 100644 arithmetic/src/division/srt/OTF.scala diff --git a/arithmetic/src/division/srt/OTF.scala b/arithmetic/src/division/srt/OTF.scala new file mode 100644 index 0000000..61d0ded --- /dev/null +++ b/arithmetic/src/division/srt/OTF.scala @@ -0,0 +1,23 @@ +package division.srt + +import chisel3._ + +class OTFInput extends Bundle { + val quotient = UInt() + val quotientMinusOne = UInt() + val selectedQuotientOH = UInt() +} + +class OTFOutput extends Bundle { + val quotient = UInt() + val quotientMinusOne = UInt() +} + +class OTF extends Module { + val input = IO(Input(new OTFInput)) + val output = IO(Output(new OTFOutput)) + // control + + // datapath + +} diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index 47a4c7b..eee9eb6 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -4,19 +4,29 @@ import chisel3.util.{RegEnable, Valid, log2Ceil} class QDSInput extends Bundle { val partialReminderCarry: UInt = ??? - val partialReminderSum: UInt = ??? + val partialReminderSum: UInt = ??? } class QDSOutput extends Bundle { val selectedQuotient: UInt = UInt((log2Ceil(n)+1).W) + val selectedQuotientOH: UInt = ??? } -class QDS extends Module { +/** + */ +class QDS(table: String) extends Module { + // IO val input = IO(Input(new QDSInput)) val output = IO(Output(new QDSOutput)) // used to select a column of SRT Table - val partialDivider = IO(Flipped(Valid(UInt()))) //它表达的是什么意思? + + val partialDivider = IO(Flipped(Valid(UInt()))) val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) + + val partialDivider = IO(Flipped(Valid(UInt()))) + + // State + val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) // for the first cycle: use partialDivider on the IO // for the reset of cycles: use partialDividerReg // for synthesis: the constraint should be IO -> Output is a multi-cycle design @@ -24,4 +34,15 @@ class QDS extends Module { // to avoid glitch, valid should be larger than raise time of partialDividerReg val partialDividerLatch = Mux(partialDivider.valid, partialDivider.bits, partialDividerReg) + // Datapath + val columnSelect = partialDividerLatch + val rowSelect = input.partialReminderCarry + input.partialReminderSum + val selectRom: Vec[Vec[UInt]] = ??? + val mkVec = selectRom(columnSelect) + val selectPoints = mkVec.map{mk => + // get the select point + input.partialReminderCarry + input.partialReminderSum - mk + } + // decoder or findFirstOne here, prefer decoder + } diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 0e38375..513c746 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -3,12 +3,12 @@ package division.srt import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 import chisel3._ -import chisel3.util.{Decoupled, DecoupledIO, Mux1H, log2Ceil} +import chisel3.util.{DecoupledIO, Mux1H, ValidIO, log2Ceil} class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { val dividend = UInt(dividendWidth.W) val divider = UInt(dividerWidth.W) - val counter = UInt(log2Ceil(???).W) + val counter = UInt(log2Ceil(n).W) } class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { @@ -20,12 +20,11 @@ class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { class SRT( dividendWidth: Int, dividerWidth: Int, - n: Int) + n: Int) extends Module { // IO - val input: DecoupledIO[SRTInput] = Flipped(Decoupled(new SRTInput(dividendWidth, dividerWidth, n))) - val output: DecoupledIO[SRTOutput] = Decoupled(new SRTOutput(dividerWidth, dividendWidth)) - + val input = Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n))) + val output = ValidIO(new SRTOutput(dividerWidth, dividendWidth)) // State // because we need a CSA to minimize the critical path @@ -38,7 +37,7 @@ class SRT( val quotient = Reg(UInt()) val quotientMinusOne = Reg(UInt()) - val state = Reg(UInt()) + // counter = 0 quotientToFix = 0 -> val counter = Reg(UInt()) // Control @@ -46,7 +45,47 @@ class SRT( val qdsSign: Bool = Wire(Bool()) // Datapath - val qds = new QDS() + val qds = Module(new QDS("???")) + // TODO: bit select here + qds.input.partialReminderSum := partialReminderSum + qds.input.partialReminderCarry := partialReminderCarry + + counter := counter - 1.U + + // for SRT4 -> CSA32 + // for SRT8 -> CSA32+CSA32 + // for SRT16 -> CSA53+CSA32 + // SRT16 <- SRT4 + SRT4*5 + // { + val csa = Module(new CarrySaveAdder(CSACompressor3_2, ???)) + csa.in(0) := partialReminderSum + csa.in(1) := (partialReminderCarry ## !qdsSign) + csa.in(2) := Mux1H( + qds.output.selectedQuotientOH, + // TODO: this is for SRT4, for SRT8 or SRT16, this should be changed + VecInit((-2 to 2).map { + case -2 => divider << 1 + case -1 => divider + case 0 => 0.U + case 1 => ~divider + case 2 => (~divider) << 1 + }) + ) + // } + partialReminderSum := Mux1H( + Map( + ??? -> input.bits.dividend, + ??? -> (csa.out(0) << log2Ceil(n)), + ??? -> partialReminderSum + ) + ) + partialReminderCarry := Mux1H( + Map( + ??? -> 0.U, + ??? -> (csa.out(1) << log2Ceil(n)), + ??? -> partialReminderCarry + ) + ) val csa = new CarrySaveAdder(CSACompressor3_2, ???) csa.in(0) := partialReminderSum @@ -65,4 +104,14 @@ class SRT( ??? -> (csa.out(1) << log2Ceil(n)), ??? -> partialReminderCarry )) + + // On-The-Fly conversion + val otf = Module(new OTF) + otf.input.quotient := quotient + otf.input.quotientMinusOne := quotientMinusOne + otf.input.selectedQuotientOH := qds.output.selectedQuotientOH + quotient := otf.output.quotient + quotientMinusOne := otf.output.quotientMinusOne + output.bits.quotient := quotient + } From 26888fe05659d0fa07513d912b2812a6561aeea2 Mon Sep 17 00:00:00 2001 From: GH Cheng <1536771081@qq.com> Date: Sun, 10 Apr 2022 15:46:06 +0800 Subject: [PATCH 11/31] Coding OTF --- arithmetic/src/division/srt/OTF.scala | 38 ++++++++++---- arithmetic/src/division/srt/QDS.scala | 18 +++---- arithmetic/src/division/srt/SRT.scala | 76 +++++++++++++-------------- arithmetic/src/division/srt/SZ.scala | 34 ++++++++++++ 4 files changed, 107 insertions(+), 59 deletions(-) create mode 100644 arithmetic/src/division/srt/SZ.scala diff --git a/arithmetic/src/division/srt/OTF.scala b/arithmetic/src/division/srt/OTF.scala index 61d0ded..fdf507a 100644 --- a/arithmetic/src/division/srt/OTF.scala +++ b/arithmetic/src/division/srt/OTF.scala @@ -2,22 +2,40 @@ package division.srt import chisel3._ -class OTFInput extends Bundle { - val quotient = UInt() - val quotientMinusOne = UInt() - val selectedQuotientOH = UInt() +class OTFInput(qWidth: Int, ohWidth: Int) extends Bundle { + val quotient = UInt(qWidth.W) + val quotientMinusOne = UInt(qWidth.W) + val selectedQuotientOH = UInt(ohWidth.W) } -class OTFOutput extends Bundle { - val quotient = UInt() - val quotientMinusOne = UInt() +class OTFOutput(qWidth: Int) extends Bundle { + val quotient = UInt(qWidth.W) + val quotientMinusOne = UInt(qWidth.W) } -class OTF extends Module { - val input = IO(Input(new OTFInput)) - val output = IO(Output(new OTFOutput)) +class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { + val input = IO(Input(new OTFInput(qWidth, ohWidth))) + val output = IO(Output(new OTFOutput(qWidth))) // control // datapath + // q_j+1 in this circle + val qNext: UInt = Mux1H(Seq( + input.selectedQuotient(0) -> -2.S, + input.selectedQuotient(1) -> -1.S, + input.selectedQuotient(2) -> 0.S, + input.selectedQuotient(3) -> 1.S, + input.selectedQuotient(4) -> 2.S + )) + // val cShiftQ: Bool = qNext >= 0.U + // val cShiftQM: Bool = qNext <= 0.U + val cShiftQ: Bool = input.selectedQuotient(ohWidth/2, 0).orR + val cShiftQM: Bool = input.selectedQuotient(ohWidth-1, ohWidth/2).orR + + val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext) + val qmIn: UInt = Mux(~cShiftQM, qNext -1.U, (radix-1).U + qNext) + + output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne) ## qIn + output.quotientMinusOne := Mux(cShiftQM, input.quotientMinusOne, input.quotient) ## qmIn } diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index eee9eb6..e101325 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -3,28 +3,26 @@ import chisel3._ import chisel3.util.{RegEnable, Valid, log2Ceil} class QDSInput extends Bundle { - val partialReminderCarry: UInt = ??? - val partialReminderSum: UInt = ??? + val partialReminderCarry: UInt = UInt(rWidth.W) + val partialReminderSum: UInt = UInt(rWidth.W) } class QDSOutput extends Bundle { - val selectedQuotient: UInt = UInt((log2Ceil(n)+1).W) - val selectedQuotientOH: UInt = ??? + // val selectedQuotient: UInt = UInt((log2Ceil(n)+1).W) + val selectedQuotientOH: UInt = UInt(ohWidth.W) } /** */ -class QDS(table: String) extends Module { +class QDS(table: String, rWidth: Int, ohWidth: Int) extends Module { // IO - val input = IO(Input(new QDSInput)) - val output = IO(Output(new QDSOutput)) + val input = IO(Input(new QDSInput(rWidth))) + val output = IO(Output(new QDSOutput(ohWidth))) + // used to select a column of SRT Table - val partialDivider = IO(Flipped(Valid(UInt()))) val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) - val partialDivider = IO(Flipped(Valid(UInt()))) - // State val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) // for the first cycle: use partialDivider on the IO diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 513c746..e788c3f 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -5,10 +5,10 @@ import addition.csa.common.CSACompressor3_2 import chisel3._ import chisel3.util.{DecoupledIO, Mux1H, ValidIO, log2Ceil} -class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { - val dividend = UInt(dividendWidth.W) - val divider = UInt(dividerWidth.W) - val counter = UInt(log2Ceil(n).W) +class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int, radix: Int) extends Bundle { + val dividend = UInt(dividendWidth.W) //0.1********** + val divider = UInt(dividerWidth.W) //0.1********** + val counter = UInt((log2Ceil(n/log2Ceil(radix))).W) // n为需要计算的二进制位数 } class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { @@ -17,10 +17,11 @@ class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { } // only SRT4 currently -class SRT( +class SRT( dividendWidth: Int, dividerWidth: Int, - n: Int) + n: Int, + radix: Int = 4) extends Module { // IO val input = Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n))) @@ -28,29 +29,44 @@ class SRT( // State // because we need a CSA to minimize the critical path - val partialReminderCarry = Reg(UInt()) - val partialReminderSum = Reg(UInt()) + val partialReminderCarry = Reg(UInt((dividendWidth + log2Ceil(radix)).W)) + val partialReminderSum = Reg(UInt((dividendWidth + log2Ceil(radix)).W)) // dMultiplier - val divider = Reg(UInt()) + val divider = RegInit(input.divider) - val quotient = Reg(UInt()) - val quotientMinusOne = Reg(UInt()) + val quotient = Reg(UInt(n.W)) //? + val quotientMinusOne = Reg(UInt(n.W)) //? // counter = 0 quotientToFix = 0 -> - val counter = Reg(UInt()) + val counter = Reg(UInt((log2Ceil(n/log2Ceil(radix))).W)) // Control // sign of select quotient, true -> negative, false -> positive val qdsSign: Bool = Wire(Bool()) // Datapath - val qds = Module(new QDS("???")) + val qds = Module(new QDS()) // TODO: bit select here qds.input.partialReminderSum := partialReminderSum qds.input.partialReminderCarry := partialReminderCarry counter := counter - 1.U + //整个srt的最终输出 + when(counter == 0.U){ + val sz = Module(new SZ(dividendWidth)) + sz.input.partialReminderSum := partialReminderSum + sz.input.partialReminderCarry := partialReminderCarry + when(sz.output.sign){ + ??? //修正,多减的给加回去,上的商给还原 + } + //拉高valid,输出商和余数 + output.valid := true.B + output.remainder := remainder + output.quotient := quotient + + } + // for SRT4 -> CSA32 // for SRT8 -> CSA32+CSA32 @@ -74,44 +90,26 @@ class SRT( // } partialReminderSum := Mux1H( Map( - ??? -> input.bits.dividend, - ??? -> (csa.out(0) << log2Ceil(n)), - ??? -> partialReminderSum + (counter === n/log2Ceil(radix)) -> input.bits.dividend, + (counter > 0.U) -> (csa.out(0) << log2Ceil(n)), + (counter === 0.U) -> partialReminderSum ) ) partialReminderCarry := Mux1H( Map( - ??? -> 0.U, - ??? -> (csa.out(1) << log2Ceil(n)), - ??? -> partialReminderCarry + (counter === n/log2Ceil(radix)) -> 0.U, + (counter > 0.U) -> (csa.out(1) << log2Ceil(n)-1), + (counter === 0.U) -> partialReminderCarry ) ) - val csa = new CarrySaveAdder(CSACompressor3_2, ???) - csa.in(0) := partialReminderSum - csa.in(1) := (partialReminderCarry ## !qdsSign) //?这里有点问题 - csa.in(2) := Mux1H(Map( - ??? -> , - ??? -> - )) - partialReminderSum := Mux1H(Map( - ??? -> input.bits.dividend, - ??? -> (csa.out(0) << log2Ceil(n)), - ??? -> partialReminderSum - )) - partialReminderCarry := Mux1H(Map( - ??? -> 0.U, - ??? -> (csa.out(1) << log2Ceil(n)), - ??? -> partialReminderCarry - )) - // On-The-Fly conversion - val otf = Module(new OTF) + val otf = Module(new OTF(radix, quotient.getWidth, qds.output.selectedQuotientOH.getWidth)) otf.input.quotient := quotient otf.input.quotientMinusOne := quotientMinusOne otf.input.selectedQuotientOH := qds.output.selectedQuotientOH + quotient := otf.output.quotient quotientMinusOne := otf.output.quotientMinusOne output.bits.quotient := quotient - } diff --git a/arithmetic/src/division/srt/SZ.scala b/arithmetic/src/division/srt/SZ.scala new file mode 100644 index 0000000..3380cba --- /dev/null +++ b/arithmetic/src/division/srt/SZ.scala @@ -0,0 +1,34 @@ +package division.srt + +import chisel3._ + +class SZInput extends Bundle { + val partialReminderCarry: UInt = UInt(rWidth.W) + val partialReminderSum: UInt = UInt(rWidth.W) +} + +class SZOutput extends Bundle { + // val selectedQuotient: UInt = UInt((log2Ceil(n)+1).W) + val sign: Bool = Bool() + val zero: Bool = Bool() +} + +class SZ(rWidth: Int) extends Module{ + val input = IO(Input(new SZInput(rWidth))) + val output= IO(Output(new SZOutput())) + + //controlpath + + //datapath + val ws = input.partialReminderCarry.asBools + val wc = input.partialReminderSum.asBools + + val psc: Seq[(Bool, Bool)]= ws.zip(wc).map{case(s,c) =>(~(s ^ c), (s | c))} + val ps: Seq[Bool] = psc.map(_._1) +: false.B + val pc: Seq[Bool] = false.B +: psc.map(_._2) + val p: Seq[Bool] = ps.zip(pc){case(s, c) => s ^ c} + + output.zero := p.andR + output.sign := (p.asUInt.head(1) ^ ???) & (~output.zero) + +} \ No newline at end of file From 0b2a14e66585437d96aa074cb240fd2309e7e2a8 Mon Sep 17 00:00:00 2001 From: GH Cheng <1536771081@qq.com> Date: Mon, 11 Apr 2022 21:31:55 +0800 Subject: [PATCH 12/31] using table from SRTTable --- arithmetic/src/division/srt/OTF.scala | 12 +-- arithmetic/src/division/srt/QDS.scala | 54 +++++++++--- arithmetic/src/division/srt/SRT.scala | 99 ++++++++++++---------- arithmetic/src/division/srt/SRTTable.scala | 20 ++++- arithmetic/src/division/srt/SZ.scala | 34 ++++---- 5 files changed, 138 insertions(+), 81 deletions(-) diff --git a/arithmetic/src/division/srt/OTF.scala b/arithmetic/src/division/srt/OTF.scala index fdf507a..2c8881d 100644 --- a/arithmetic/src/division/srt/OTF.scala +++ b/arithmetic/src/division/srt/OTF.scala @@ -21,11 +21,11 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { // datapath // q_j+1 in this circle val qNext: UInt = Mux1H(Seq( - input.selectedQuotient(0) -> -2.S, - input.selectedQuotient(1) -> -1.S, - input.selectedQuotient(2) -> 0.S, - input.selectedQuotient(3) -> 1.S, - input.selectedQuotient(4) -> 2.S + input.selectedQuotient(0) -> "b110", + input.selectedQuotient(1) -> "b101", + input.selectedQuotient(2) -> "b000", + input.selectedQuotient(3) -> "b001", + input.selectedQuotient(4) -> "b010" )) // val cShiftQ: Bool = qNext >= 0.U @@ -33,7 +33,7 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { val cShiftQ: Bool = input.selectedQuotient(ohWidth/2, 0).orR val cShiftQM: Bool = input.selectedQuotient(ohWidth-1, ohWidth/2).orR - val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext) + val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext) val qmIn: UInt = Mux(~cShiftQM, qNext -1.U, (radix-1).U + qNext) output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne) ## qIn diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index e101325..cb33a5e 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -1,6 +1,7 @@ package division.srt import chisel3._ -import chisel3.util.{RegEnable, Valid, log2Ceil} +import chisel3.util.{log2Ceil, RegEnable, Valid} +import chisle3.util.experimental.decode class QDSInput extends Bundle { val partialReminderCarry: UInt = UInt(rWidth.W) @@ -14,16 +15,15 @@ class QDSOutput extends Bundle { /** */ -class QDS(table: String, rWidth: Int, ohWidth: Int) extends Module { +class QDS(table: Seq[(Int, Int)], rWidth: Int, ohWidth: Int) extends Module { // IO val input = IO(Input(new QDSInput(rWidth))) val output = IO(Output(new QDSOutput(ohWidth))) - + // used to select a column of SRT Table - val partialDivider = IO(Flipped(Valid(UInt()))) - val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) + val partialDivider = IO(Flipped(Valid(UInt(???)))) //这个从哪里来 - // State + // State, in order to keep divider's value val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) // for the first cycle: use partialDivider on the IO // for the reset of cycles: use partialDividerReg @@ -35,12 +35,42 @@ class QDS(table: String, rWidth: Int, ohWidth: Int) extends Module { // Datapath val columnSelect = partialDividerLatch val rowSelect = input.partialReminderCarry + input.partialReminderSum - val selectRom: Vec[Vec[UInt]] = ??? - val mkVec = selectRom(columnSelect) - val selectPoints = mkVec.map{mk => + + // // use the table from XiangShan + // // from XiangShan: /16 + // val qSelTable = Array( + // Array(12, 4, -4, -13), + // Array(14, 4, -6, -15), + // Array(15, 4, -6, -16), + // Array(16, 4, -6, -18), + // Array(18, 6, -8, -20), + // Array(20, 6, -8, -20), + // Array(20, 8, -8, -22), + // Array(24, 8, -8, -24) + // ) + + // TODO: complete select_table algorithm + val selectRom: Vec[(UInt, SInt)] = table.map{case (d, x) => (d.U, x.S)} + val mkVec = selectRom.filter{ case (d, x) => d === columnSelect }.map{ case (d, x) => x } + + val selectPoints = mkVec.map { mk => // get the select point - input.partialReminderCarry + input.partialReminderSum - mk - } - // decoder or findFirstOne here, prefer decoder + // TODO: find the sign + (input.partialReminderCarry + input.partialReminderSum - mk).head(1) + }.flatten.asUInt + // decoder or findFirstOne here, prefer decoder + // the decoder only for srt4 + io.output := chisel3.util.experimental.decode.decoder( + selectPoints, + TruthTable( + Seq( + BitPat("b1???") -> BitPat("b10000"), + BitPat("b01??") -> BitPat("b01000"), + BitPat("b001?") -> BitPat("b00100"), + BitPat("b0001") -> BitPat("b00010"), + ), + BitPat("b00001") + ) + ) } diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index e788c3f..04ea8f0 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -3,12 +3,14 @@ package division.srt import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 import chisel3._ -import chisel3.util.{DecoupledIO, Mux1H, ValidIO, log2Ceil} +import chisel3.util.{log2Ceil, DecoupledIO, Mux1H, ValidIO} + +// TODO: width class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int, radix: Int) extends Bundle { val dividend = UInt(dividendWidth.W) //0.1********** - val divider = UInt(dividerWidth.W) //0.1********** - val counter = UInt((log2Ceil(n/log2Ceil(radix))).W) // n为需要计算的二进制位数 + val divider = UInt(dividerWidth.W) //0.1********** + val counter = UInt((log2Ceil(n / log2Ceil(radix))).W) // n为需要计算的二进制位数 } class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { @@ -17,11 +19,15 @@ class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { } // only SRT4 currently -class SRT( - dividendWidth: Int, - dividerWidth: Int, - n: Int, - radix: Int = 4) +class SRT( + dividendWidth: Int, + dividerWidth: Int, + n: Int, + radix: Int = 4, + a: Int = 2, + dTruncateWidth: Int = 4, + xTruncateWidth: Int = 3 + ) extends Module { // IO val input = Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n))) @@ -30,75 +36,78 @@ class SRT( // State // because we need a CSA to minimize the critical path val partialReminderCarry = Reg(UInt((dividendWidth + log2Ceil(radix)).W)) - val partialReminderSum = Reg(UInt((dividendWidth + log2Ceil(radix)).W)) + val partialReminderSum = Reg(UInt((dividendWidth + log2Ceil(radix)).W)) // dMultiplier val divider = RegInit(input.divider) - val quotient = Reg(UInt(n.W)) //? + val quotient = Reg(UInt(n.W)) //? val quotientMinusOne = Reg(UInt(n.W)) //? // counter = 0 quotientToFix = 0 -> - val counter = Reg(UInt((log2Ceil(n/log2Ceil(radix))).W)) + val counter = Reg(UInt((log2Ceil(n / log2Ceil(radix))).W)) // Control // sign of select quotient, true -> negative, false -> positive val qdsSign: Bool = Wire(Bool()) // Datapath - val qds = Module(new QDS()) + // from software get quotient select Constant tables,how to convert double to UInt and store in ROM + val table: Seq[(Int, Int)] = SRTTable(radix, a, dTruncateWidth, xTruncateWidth).qdsTables + val rTruncateWidth: Int = ??? + val selectedQuotientOHWidth: Int = ??? + val qds = Module(new QDS(table, rTruncateWidth, selectedQuotientOHWidth)) // TODO: bit select here - qds.input.partialReminderSum := partialReminderSum - qds.input.partialReminderCarry := partialReminderCarry + qds.input.partialReminderSum := partialReminderSum.head(rTruncateWidth) + qds.input.partialReminderCarry := partialReminderCarry.head(rTruncateWidth) counter := counter - 1.U - //整个srt的最终输出 - when(counter == 0.U){ - val sz = Module(new SZ(dividendWidth)) - sz.input.partialReminderSum := partialReminderSum - sz.input.partialReminderCarry := partialReminderCarry - when(sz.output.sign){ - ??? //修正,多减的给加回去,上的商给还原 - } - //拉高valid,输出商和余数 - output.valid := true.B - output.remainder := remainder - output.quotient := quotient - - } + val sz = Module(new SZ(dividendWidth)) + sz.input.partialReminderSum := partialReminderSum + sz.input.partialReminderCarry := partialReminderCarry + + // the output of srt + output.remainder := remainder + output.quotient := quotient + + // if counter === 0.U && sz.output.sign, correct the quotient and remainder. valid = 1 + // TODO: correct + quotient := Mux(counter === 0.U && sz.output.sign, ???, quotient) + remainder := Mux(counter === 0.U && sz.output.sign, ???, remainder) + output.valid := Mux(counter === 0.U, true.B, false.B) // for SRT4 -> CSA32 // for SRT8 -> CSA32+CSA32 // for SRT16 -> CSA53+CSA32 // SRT16 <- SRT4 + SRT4*5 // { - val csa = Module(new CarrySaveAdder(CSACompressor3_2, ???)) - csa.in(0) := partialReminderSum - csa.in(1) := (partialReminderCarry ## !qdsSign) - csa.in(2) := Mux1H( - qds.output.selectedQuotientOH, - // TODO: this is for SRT4, for SRT8 or SRT16, this should be changed - VecInit((-2 to 2).map { - case -2 => divider << 1 - case -1 => divider - case 0 => 0.U - case 1 => ~divider - case 2 => (~divider) << 1 - }) - ) + val csa = Module(new CarrySaveAdder(CSACompressor3_2, ???)) + csa.in(0) := partialReminderSum + csa.in(1) := (partialReminderCarry ## !qdsSign) + csa.in(2) := Mux1H( + qds.output.selectedQuotientOH, + // TODO: this is for SRT4, for SRT8 or SRT16, this should be changed + VecInit((-2 to 2).map { + case -2 => divider << 1 + case -1 => divider + case 0 => 0.U + case 1 => ~divider + case 2 => (~divider) << 1 + }) + ) // } partialReminderSum := Mux1H( Map( - (counter === n/log2Ceil(radix)) -> input.bits.dividend, + (counter === n / log2Ceil(radix)) -> input.bits.dividend, (counter > 0.U) -> (csa.out(0) << log2Ceil(n)), (counter === 0.U) -> partialReminderSum ) ) partialReminderCarry := Mux1H( Map( - (counter === n/log2Ceil(radix)) -> 0.U, - (counter > 0.U) -> (csa.out(1) << log2Ceil(n)-1), + (counter === n / log2Ceil(radix)) -> 0.U, + (counter > 0.U) -> (csa.out(1) << log2Ceil(n) - 1), (counter === 0.U) -> partialReminderCarry ) ) diff --git a/arithmetic/src/division/srt/SRTTable.scala b/arithmetic/src/division/srt/SRTTable.scala index 6567e38..de9a4f0 100644 --- a/arithmetic/src/division/srt/SRTTable.scala +++ b/arithmetic/src/division/srt/SRTTable.scala @@ -53,15 +53,17 @@ case class SRTTable( .yLabel(s"${radix.toInt}ω[j]") .rightLegend() .standard() + lazy val aMax: Algebraic = a lazy val aMin: Algebraic = -a - lazy val deltaD: Algebraic = pow(2, -dTruncateWidth.toDouble) - lazy val deltaX: Algebraic = pow(2, -xTruncateWidth.toDouble) + lazy val deltaD: Algebraic = pow(2, -dTruncateWidth.toDouble) // length of dStep + lazy val deltaX: Algebraic = pow(2, -xTruncateWidth.toDouble) // length of romegeStep /** redundancy factor * @note 5.8 */ lazy val rho: Algebraic = a / (radix - 1) + // k d m xSet lazy val tables: Seq[(Int, Seq[(Algebraic, Seq[Algebraic])])] = { (aMin.toInt to aMax.toInt).drop(1).map { k => k -> dSet.dropRight(1).map { d => @@ -84,13 +86,23 @@ case class SRTTable( } } + // // from each m select a Constant, select rule: symmetry, how define the rule + // lazy val qdsTables: Seq[(Algebraic, Algebraic)] = { + // tables.map { + // case (i, ps) => + // ps.flatMap { case (d, xs) => xs.filter{ x: Algebraic => ??? }.map(x => ((d<= radix / 2) private val xSet = Seq.tabulate((xStep + 1).toInt) { n => xMin + deltaX * n } + private val dStep: Algebraic = (dMax - dMin) / deltaD assert((rho > 1 / 2) && (rho <= 1)) private val dSet = Seq.tabulate((dStep + 1).toInt) { n => dMin + deltaD * n } + private val mesh = ScatterPlot( xSet.flatMap { y => @@ -138,6 +150,7 @@ case class SRTTable( ) } + //select four points, then drop first and last points /** for range `dLeft` to `dRight`, return the `rOmegaCeil` and `rOmegaFloor` * this is used for constructing the rectangle where m_k(i) is located. */ @@ -149,6 +162,7 @@ case class SRTTable( .dropRight(1) match { case Seq(l, r) => (l, r) } } + // U_k = (k + rho) * d, L_k = (k - rho) * d /** find the intersection point between L`k` and `d` */ private def L(k: Algebraic, d: Algebraic): Algebraic = lRate(k) * d @@ -166,4 +180,4 @@ case class SRTTable( * @note 5.56 */ private def uRate(k: Algebraic): Algebraic = k + rho -} \ No newline at end of file +} diff --git a/arithmetic/src/division/srt/SZ.scala b/arithmetic/src/division/srt/SZ.scala index 3380cba..de38ab7 100644 --- a/arithmetic/src/division/srt/SZ.scala +++ b/arithmetic/src/division/srt/SZ.scala @@ -1,6 +1,7 @@ package division.srt import chisel3._ +import addition.prefixadder._ class SZInput extends Bundle { val partialReminderCarry: UInt = UInt(rWidth.W) @@ -13,22 +14,25 @@ class SZOutput extends Bundle { val zero: Bool = Bool() } -class SZ(rWidth: Int) extends Module{ - val input = IO(Input(new SZInput(rWidth))) - val output= IO(Output(new SZOutput())) +class SZ(rWidth: Int, prefixSum: PrefixSum = BrentKungSum) extends Module { + val input = IO(Input(new SZInput(rWidth))) + val output = IO(Output(new SZOutput())) - //controlpath + //controlpath - //datapath - val ws = input.partialReminderCarry.asBools - val wc = input.partialReminderSum.asBools + //datapath + // csa(ws,wc,-2^-b) => Seq[(Bool,Bool)] + val ws = input.partialReminderCarry.asBools + val wc = input.partialReminderSum.asBools + val psc: Seq[(Bool, Bool)] = ws.zip(wc).map { case (s, c) => (~(s ^ c), (s | c)) } - val psc: Seq[(Bool, Bool)]= ws.zip(wc).map{case(s,c) =>(~(s ^ c), (s | c))} - val ps: Seq[Bool] = psc.map(_._1) +: false.B - val pc: Seq[Bool] = false.B +: psc.map(_._2) - val p: Seq[Bool] = ps.zip(pc){case(s, c) => s ^ c} + // call the prefixtree to associativeOp + val pairs: Seq[(Bool, Bool)] = prefixSum.zeroLayer(psc.map(_._1) +: false.B, false.B +: psc.map(_._2)) + val pgs: Vector[(Bool, Bool)] = prefixSum(pairs) + val gs: Vector[Bool] = pgs.map(_._2) + val ps: Vector[Bool] = pgs.map(_._1) - output.zero := p.andR - output.sign := (p.asUInt.head(1) ^ ???) & (~output.zero) - -} \ No newline at end of file + // maybe have a problem. + output.zero := ps.asUInt.head(1) + output.sign := (ps.asUInt.head(1) ^ gs.asUInt.head(1)) & (~output.zero) +} From 70c19e7c1488ef81dc2dcf745e96057fe07123cb Mon Sep 17 00:00:00 2001 From: GH Cheng <1536771081@qq.com> Date: Wed, 13 Apr 2022 20:35:56 +0800 Subject: [PATCH 13/31] using table from XS --- arithmetic/src/division/srt/OTF.scala | 21 ++-- arithmetic/src/division/srt/QDS.scala | 54 +++++----- arithmetic/src/division/srt/SRT.scala | 100 +++++++++--------- arithmetic/src/division/srt/SRTTable.scala | 9 +- arithmetic/src/division/srt/SZ.scala | 9 +- .../src/division/srt/SRT4SpecTester.scala | 17 +++ 6 files changed, 114 insertions(+), 96 deletions(-) create mode 100644 arithmetic/tests/src/division/srt/SRT4SpecTester.scala diff --git a/arithmetic/src/division/srt/OTF.scala b/arithmetic/src/division/srt/OTF.scala index 2c8881d..f3a8d62 100644 --- a/arithmetic/src/division/srt/OTF.scala +++ b/arithmetic/src/division/srt/OTF.scala @@ -1,6 +1,7 @@ package division.srt import chisel3._ +import chisel3.util.{Mux1H} class OTFInput(qWidth: Int, ohWidth: Int) extends Bundle { val quotient = UInt(qWidth.W) @@ -19,22 +20,22 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { // control // datapath - // q_j+1 in this circle + // q_j+1 in this circle, only for srt4 val qNext: UInt = Mux1H(Seq( - input.selectedQuotient(0) -> "b110", - input.selectedQuotient(1) -> "b101", - input.selectedQuotient(2) -> "b000", - input.selectedQuotient(3) -> "b001", - input.selectedQuotient(4) -> "b010" + input.selectedQuotientOH(0) -> "b110".U, + input.selectedQuotientOH(1) -> "b101".U, + input.selectedQuotientOH(2) -> "b000".U, + input.selectedQuotientOH(3) -> "b001".U, + input.selectedQuotientOH(4) -> "b010".U )) // val cShiftQ: Bool = qNext >= 0.U // val cShiftQM: Bool = qNext <= 0.U - val cShiftQ: Bool = input.selectedQuotient(ohWidth/2, 0).orR - val cShiftQM: Bool = input.selectedQuotient(ohWidth-1, ohWidth/2).orR + val cShiftQ: Bool = input.selectedQuotientOH(ohWidth/2, 0).orR + val cShiftQM: Bool = input.selectedQuotientOH(ohWidth-1, ohWidth/2).orR - val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext) - val qmIn: UInt = Mux(~cShiftQM, qNext -1.U, (radix-1).U + qNext) + val qIn: UInt = (Mux(cShiftQ, qNext, radix.U + qNext))(1, 0) + val qmIn: UInt = (Mux(~cShiftQM, qNext - 1.U, (radix-1).U + qNext))(1, 0) output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne) ## qIn output.quotientMinusOne := Mux(cShiftQM, input.quotientMinusOne, input.quotient) ## qmIn diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index cb33a5e..7003a90 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -1,27 +1,24 @@ package division.srt import chisel3._ -import chisel3.util.{log2Ceil, RegEnable, Valid} -import chisle3.util.experimental.decode +import chisel3.util.{log2Ceil, RegEnable, Valid, BitPat} +import chisel3.util.experimental.decode._ -class QDSInput extends Bundle { +class QDSInput(rWidth: Int) extends Bundle { val partialReminderCarry: UInt = UInt(rWidth.W) val partialReminderSum: UInt = UInt(rWidth.W) } -class QDSOutput extends Bundle { - // val selectedQuotient: UInt = UInt((log2Ceil(n)+1).W) +class QDSOutput(ohWidth: Int) extends Bundle { val selectedQuotientOH: UInt = UInt(ohWidth.W) } -/** - */ -class QDS(table: Seq[(Int, Int)], rWidth: Int, ohWidth: Int) extends Module { +class QDS(rWidth: Int, ohWidth: Int) extends Module { // IO val input = IO(Input(new QDSInput(rWidth))) val output = IO(Output(new QDSOutput(ohWidth))) - // used to select a column of SRT Table - val partialDivider = IO(Flipped(Valid(UInt(???)))) //这个从哪里来 + // used to select a column of SRT Table, 这个需要手动连接吗 + val partialDivider = IO(Flipped(Valid(UInt(3.W)))) //这个从哪里来 // State, in order to keep divider's value val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) @@ -33,11 +30,7 @@ class QDS(table: Seq[(Int, Int)], rWidth: Int, ohWidth: Int) extends Module { val partialDividerLatch = Mux(partialDivider.valid, partialDivider.bits, partialDividerReg) // Datapath - val columnSelect = partialDividerLatch - val rowSelect = input.partialReminderCarry + input.partialReminderSum - - // // use the table from XiangShan - // // from XiangShan: /16 + // from XiangShan/P269 in : /16, should have got from SRTTable. // val qSelTable = Array( // Array(12, 4, -4, -13), // Array(14, 4, -6, -15), @@ -48,20 +41,25 @@ class QDS(table: Seq[(Int, Int)], rWidth: Int, ohWidth: Int) extends Module { // Array(20, 8, -8, -22), // Array(24, 8, -8, -24) // ) + val columnSelect = partialDividerLatch + val selectRom: Vec[Vec[UInt]] = VecInit( + VecInit("b111_0100".U, "b111_1100".U, "b000_0100".U, "b000_1101".U), + VecInit("b111_0010".U, "b111_1100".U, "b000_0110".U, "b000_1111".U), + VecInit("b111_0001".U, "b111_1100".U, "b000_0110".U, "b001_0000".U), + VecInit("b111_0000".U, "b111_1100".U, "b000_0110".U, "b001_0010".U), + VecInit("b110_1101".U, "b111_1010".U, "b000_1000".U, "b001_0100".U), + VecInit("b110_1100".U, "b111_1010".U, "b000_1000".U, "b001_0100".U), + VecInit("b110_1100".U, "b111_1000".U, "b000_1000".U, "b001_0110".U), + VecInit("b110_1000".U, "b111_1000".U, "b000_1000".U, "b001_1000".U) + ) + val mkVec = selectRom(columnSelect) + val selectPoints = VecInit(mkVec.map{ mk => + // maybe have a problem. + (input.partialReminderCarry + input.partialReminderSum + mk).head(1) + }).asUInt() - // TODO: complete select_table algorithm - val selectRom: Vec[(UInt, SInt)] = table.map{case (d, x) => (d.U, x.S)} - val mkVec = selectRom.filter{ case (d, x) => d === columnSelect }.map{ case (d, x) => x } - - val selectPoints = mkVec.map { mk => - // get the select point - // TODO: find the sign - (input.partialReminderCarry + input.partialReminderSum - mk).head(1) - }.flatten.asUInt - - // decoder or findFirstOne here, prefer decoder - // the decoder only for srt4 - io.output := chisel3.util.experimental.decode.decoder( + // decoder or findFirstOne here, prefer decoder, the decoder only for srt4 + output.selectedQuotientOH := chisel3.util.experimental.decode.decoder( selectPoints, TruthTable( Seq( diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 04ea8f0..035fb62 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -2,15 +2,24 @@ package division.srt import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 +import utils.{extend} import chisel3._ import chisel3.util.{log2Ceil, DecoupledIO, Mux1H, ValidIO} +import scala.math.{ceil} +// 带csa的srt除法器 1/2<= d < 1, 1/2 < rho <=1, 0 < q < 2 +// 0, radix = 4 +// 1,商数范围 :a = 2, {-2, -1, 0, 1, -2}, +// 2, 冗余因子rho = 2/(4-1) =2/3 +// 3,估值(截断位宽):3位整数,4位小数 t = 4 +// 4,选商函数 通过输入的y^(xxx.xxxx)和截断的d(0.1xxx)来进行选商,计算出选商查找表,来进行查找选商[d_i,d_i+1) +// 5, -44/16 < y^ < 42/16 // TODO: width -class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int, radix: Int) extends Bundle { +class SRTInput(dividendWidth: Int, dividerWidth: Int) extends Bundle { val dividend = UInt(dividendWidth.W) //0.1********** - val divider = UInt(dividerWidth.W) //0.1********** - val counter = UInt((log2Ceil(n / log2Ceil(radix))).W) // n为需要计算的二进制位数 + val divider = UInt(dividerWidth.W) //0.1********** + val counter = UInt(log2Ceil(cycle).W) } class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { @@ -22,67 +31,57 @@ class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { class SRT( dividendWidth: Int, dividerWidth: Int, - n: Int, - radix: Int = 4, + n: Int, // the number of quotient + radixLog2: Int = 2, a: Int = 2, - dTruncateWidth: Int = 4, - xTruncateWidth: Int = 3 + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4 ) extends Module { + // the numbers of cycle + val cycle: UInt = ceil(n.toDouble/radixLog2).toInt + 1 + val ohWidth: Int = 2 * a + 1 + // IO - val input = Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n))) + val input = Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth))) val output = ValidIO(new SRTOutput(dividerWidth, dividendWidth)) // State // because we need a CSA to minimize the critical path - val partialReminderCarry = Reg(UInt((dividendWidth + log2Ceil(radix)).W)) - val partialReminderSum = Reg(UInt((dividendWidth + log2Ceil(radix)).W)) - - // dMultiplier - val divider = RegInit(input.divider) - - val quotient = Reg(UInt(n.W)) //? - val quotientMinusOne = Reg(UInt(n.W)) //? - - // counter = 0 quotientToFix = 0 -> - val counter = Reg(UInt((log2Ceil(n / log2Ceil(radix))).W)) - + val partialReminderCarry = Reg(UInt((dividendWidth + radixLog2).W)) + val partialReminderSum = Reg(UInt((dividendWidth + radixLog2).W)) + val divider = RegInit(input.bits.divider) + val quotient = Reg(UInt(n.W)) + val quotientMinusOne = Reg(UInt(n.W)) + val counter = RegInit(input.bits.counter) // Control // sign of select quotient, true -> negative, false -> positive val qdsSign: Bool = Wire(Bool()) + qdsSign := qds.output.selectedQuotientOH(ohWidth-1, ohWidth/2).orR // Datapath - // from software get quotient select Constant tables,how to convert double to UInt and store in ROM - val table: Seq[(Int, Int)] = SRTTable(radix, a, dTruncateWidth, xTruncateWidth).qdsTables - val rTruncateWidth: Int = ??? - val selectedQuotientOHWidth: Int = ??? - val qds = Module(new QDS(table, rTruncateWidth, selectedQuotientOHWidth)) - // TODO: bit select here - qds.input.partialReminderSum := partialReminderSum.head(rTruncateWidth) - qds.input.partialReminderCarry := partialReminderCarry.head(rTruncateWidth) - + val qds = Module(new QDS(rTruncateWidth, ohWidth)) + qds.input.partialReminderSum := partialReminderSum.head(1 + radixLog2 + rTruncateWidth) + qds.input.partialReminderCarry := partialReminderCarry.head(1 + radixLog2 + rTruncateWidth) + qds.partialDivider.bits := (input.bits.divider.head(1 + radixLog2 + rTruncateWidth))(dTruncateWidth-2, 0) + counter := counter - 1.U - + // if counter === 0.U && sz.output.sign, correct the quotient and reminder. valid = 1 + // the output of srt val sz = Module(new SZ(dividendWidth)) sz.input.partialReminderSum := partialReminderSum sz.input.partialReminderCarry := partialReminderCarry - - // the output of srt - output.remainder := remainder - output.quotient := quotient - - // if counter === 0.U && sz.output.sign, correct the quotient and remainder. valid = 1 - // TODO: correct - quotient := Mux(counter === 0.U && sz.output.sign, ???, quotient) - remainder := Mux(counter === 0.U && sz.output.sign, ???, remainder) - output.valid := Mux(counter === 0.U, true.B, false.B) + output.valid := Mux(counter === 0.U, true.B, false.B) + // correcting maybe have problem + quotient := Mux(counter === 0.U && sz.output.sign, quotient - 1.U, quotient) + output.bits.reminder := partialReminderCarry + partialReminderSum + Mux(counter === 0.U && sz.output.sign, divider, 0.U) + output.bits.quotient := quotient // for SRT4 -> CSA32 // for SRT8 -> CSA32+CSA32 // for SRT16 -> CSA53+CSA32 // SRT16 <- SRT4 + SRT4*5 - // { - val csa = Module(new CarrySaveAdder(CSACompressor3_2, ???)) + val csa = Module(new CarrySaveAdder(CSACompressor3_2, dividendWidth + radixLog2)) csa.in(0) := partialReminderSum csa.in(1) := (partialReminderCarry ## !qdsSign) csa.in(2) := Mux1H( @@ -92,28 +91,29 @@ class SRT( case -2 => divider << 1 case -1 => divider case 0 => 0.U - case 1 => ~divider - case 2 => (~divider) << 1 + case 1 => extend(~divider, dividendWidth + radixLog2) + case 2 => extend((~divider) << 1, dividendWidth + radixLog2) }) ) - // } + partialReminderSum := Mux1H( Map( - (counter === n / log2Ceil(radix)) -> input.bits.dividend, - (counter > 0.U) -> (csa.out(0) << log2Ceil(n)), + (counter === cycle.U) -> input.bits.dividend, + (counter > 0.U) -> (csa.out(0) << radixLog2), (counter === 0.U) -> partialReminderSum ) ) + partialReminderCarry := Mux1H( Map( - (counter === n / log2Ceil(radix)) -> 0.U, - (counter > 0.U) -> (csa.out(1) << log2Ceil(n) - 1), + (counter === cycle.U) -> 0.U, + (counter > 0.U) -> (csa.out(1) << (radixLog2 - 1)), (counter === 0.U) -> partialReminderCarry ) ) // On-The-Fly conversion - val otf = Module(new OTF(radix, quotient.getWidth, qds.output.selectedQuotientOH.getWidth)) + val otf = Module(new OTF(1< @@ -150,7 +151,7 @@ case class SRTTable( ) } - //select four points, then drop first and last points + // select four points, then drop the first and the last one. /** for range `dLeft` to `dRight`, return the `rOmegaCeil` and `rOmegaFloor` * this is used for constructing the rectangle where m_k(i) is located. */ diff --git a/arithmetic/src/division/srt/SZ.scala b/arithmetic/src/division/srt/SZ.scala index de38ab7..adee3b0 100644 --- a/arithmetic/src/division/srt/SZ.scala +++ b/arithmetic/src/division/srt/SZ.scala @@ -2,8 +2,9 @@ package division.srt import chisel3._ import addition.prefixadder._ +import addition.prefixadder.common.{BrentKungSum} -class SZInput extends Bundle { +class SZInput(rWidth: Int) extends Bundle { val partialReminderCarry: UInt = UInt(rWidth.W) val partialReminderSum: UInt = UInt(rWidth.W) } @@ -27,12 +28,12 @@ class SZ(rWidth: Int, prefixSum: PrefixSum = BrentKungSum) extends Module { val psc: Seq[(Bool, Bool)] = ws.zip(wc).map { case (s, c) => (~(s ^ c), (s | c)) } // call the prefixtree to associativeOp - val pairs: Seq[(Bool, Bool)] = prefixSum.zeroLayer(psc.map(_._1) +: false.B, false.B +: psc.map(_._2)) + val pairs: Seq[(Bool, Bool)] = prefixSum.zeroLayer(psc.map(_._1) :+ false.B, false.B +: psc.map(_._2)) val pgs: Vector[(Bool, Bool)] = prefixSum(pairs) val gs: Vector[Bool] = pgs.map(_._2) val ps: Vector[Bool] = pgs.map(_._1) // maybe have a problem. - output.zero := ps.asUInt.head(1) - output.sign := (ps.asUInt.head(1) ^ gs.asUInt.head(1)) & (~output.zero) + output.zero := VecInit(ps).asUInt.head(1) + output.sign := (output.zero ^ VecInit(gs).asUInt.head(1)) & (~output.zero) } diff --git a/arithmetic/tests/src/division/srt/SRT4SpecTester.scala b/arithmetic/tests/src/division/srt/SRT4SpecTester.scala new file mode 100644 index 0000000..5fe86ec --- /dev/null +++ b/arithmetic/tests/src/division/srt/SRT4SpecTester.scala @@ -0,0 +1,17 @@ +package division.srt + +import chisel3._ +import chisel3.tester.{ChiselUtestTester} + +object SRT4SpecTester extends TestSuite with ChiselUtestTester{ + def tests: Tests = Tests{ + test("SRT4 should pass"){ + val u = ??? + + testCircuit(new SRT(32, 32, 32), Seq(chiseltest.internal.NoThreadingAnnotation, chiseltest.simulator.WriteVcdAnnotation)){ dut: SRT => + + } + } + } + +} \ No newline at end of file From a3930e2e7bcd5f06761a7f409085175d9ece9b52 Mon Sep 17 00:00:00 2001 From: GH Cheng <1536771081@qq.com> Date: Thu, 14 Apr 2022 22:13:05 +0800 Subject: [PATCH 14/31] SZ fix --- arithmetic/src/division/srt/OTF.scala | 24 ++++--- arithmetic/src/division/srt/QDS.scala | 18 +++-- arithmetic/src/division/srt/SRT.scala | 68 ++++++++++--------- arithmetic/src/division/srt/SRTTable.scala | 6 +- arithmetic/src/division/srt/SZ.scala | 20 ++++-- .../src/multiplier/WallaceMultiplier.scala | 1 - 6 files changed, 74 insertions(+), 63 deletions(-) diff --git a/arithmetic/src/division/srt/OTF.scala b/arithmetic/src/division/srt/OTF.scala index f3a8d62..5236ee7 100644 --- a/arithmetic/src/division/srt/OTF.scala +++ b/arithmetic/src/division/srt/OTF.scala @@ -21,21 +21,23 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { // datapath // q_j+1 in this circle, only for srt4 - val qNext: UInt = Mux1H(Seq( - input.selectedQuotientOH(0) -> "b110".U, - input.selectedQuotientOH(1) -> "b101".U, - input.selectedQuotientOH(2) -> "b000".U, - input.selectedQuotientOH(3) -> "b001".U, - input.selectedQuotientOH(4) -> "b010".U - )) + val qNext: UInt = Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b110".U, + input.selectedQuotientOH(1) -> "b101".U, + input.selectedQuotientOH(2) -> "b000".U, + input.selectedQuotientOH(3) -> "b001".U, + input.selectedQuotientOH(4) -> "b010".U + ) + ) // val cShiftQ: Bool = qNext >= 0.U // val cShiftQM: Bool = qNext <= 0.U - val cShiftQ: Bool = input.selectedQuotientOH(ohWidth/2, 0).orR - val cShiftQM: Bool = input.selectedQuotientOH(ohWidth-1, ohWidth/2).orR + val cShiftQ: Bool = input.selectedQuotientOH(ohWidth / 2, 0).orR + val cShiftQM: Bool = input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR - val qIn: UInt = (Mux(cShiftQ, qNext, radix.U + qNext))(1, 0) - val qmIn: UInt = (Mux(~cShiftQM, qNext - 1.U, (radix-1).U + qNext))(1, 0) + val qIn: UInt = (Mux(cShiftQ, qNext, radix.U + qNext))(1, 0) + val qmIn: UInt = (Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext))(1, 0) output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne) ## qIn output.quotientMinusOne := Mux(cShiftQM, input.quotientMinusOne, input.quotient) ## qmIn diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index 7003a90..9454c94 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -1,6 +1,6 @@ package division.srt import chisel3._ -import chisel3.util.{log2Ceil, RegEnable, Valid, BitPat} +import chisel3.util.{log2Ceil, BitPat, RegEnable, Valid} import chisel3.util.experimental.decode._ class QDSInput(rWidth: Int) extends Bundle { @@ -16,9 +16,7 @@ class QDS(rWidth: Int, ohWidth: Int) extends Module { // IO val input = IO(Input(new QDSInput(rWidth))) val output = IO(Output(new QDSOutput(ohWidth))) - - // used to select a column of SRT Table, 这个需要手动连接吗 - val partialDivider = IO(Flipped(Valid(UInt(3.W)))) //这个从哪里来 + val partialDivider = IO(Flipped(Valid(UInt(3.W)))) // State, in order to keep divider's value val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) @@ -52,11 +50,11 @@ class QDS(rWidth: Int, ohWidth: Int) extends Module { VecInit("b110_1100".U, "b111_1000".U, "b000_1000".U, "b001_0110".U), VecInit("b110_1000".U, "b111_1000".U, "b000_1000".U, "b001_1000".U) ) - val mkVec = selectRom(columnSelect) - val selectPoints = VecInit(mkVec.map{ mk => - // maybe have a problem. - (input.partialReminderCarry + input.partialReminderSum + mk).head(1) - }).asUInt() + val mkVec = selectRom(columnSelect) + val selectPoints = VecInit(mkVec.map { mk => + // maybe have a problem."+&" extend signed to avoid overflow. only for srt4, because -44/16 < y^ < 42/16. + (input.partialReminderCarry +& input.partialReminderSum + mk).head(1) + }).asUInt // decoder or findFirstOne here, prefer decoder, the decoder only for srt4 output.selectedQuotientOH := chisel3.util.experimental.decode.decoder( @@ -66,7 +64,7 @@ class QDS(rWidth: Int, ohWidth: Int) extends Module { BitPat("b1???") -> BitPat("b10000"), BitPat("b01??") -> BitPat("b01000"), BitPat("b001?") -> BitPat("b00100"), - BitPat("b0001") -> BitPat("b00010"), + BitPat("b0001") -> BitPat("b00010") ), BitPat("b00001") ) diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 035fb62..b60a55a 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -2,10 +2,11 @@ package division.srt import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 -import utils.{extend} +import utils.extend import chisel3._ -import chisel3.util.{log2Ceil, DecoupledIO, Mux1H, ValidIO} -import scala.math.{ceil} +import chisel3.util.{log2Ceil, Counter, DecoupledIO, Mux1H, ValidIO} + +import scala.math.ceil // 带csa的srt除法器 1/2<= d < 1, 1/2 < rho <=1, 0 < q < 2 // 0, radix = 4 @@ -16,10 +17,10 @@ import scala.math.{ceil} // 5, -44/16 < y^ < 42/16 // TODO: width -class SRTInput(dividendWidth: Int, dividerWidth: Int) extends Bundle { +class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { val dividend = UInt(dividendWidth.W) //0.1********** - val divider = UInt(dividerWidth.W) //0.1********** - val counter = UInt(log2Ceil(cycle).W) + val divider = UInt(dividerWidth.W) //0.1********** + val counter = UInt(log2Ceil(n).W) //the width of quotient. } class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { @@ -31,50 +32,55 @@ class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { class SRT( dividendWidth: Int, dividerWidth: Int, - n: Int, // the number of quotient + n: Int, // the longest width radixLog2: Int = 2, - a: Int = 2, + a: Int = 2, dTruncateWidth: Int = 4, - rTruncateWidth: Int = 4 - ) + rTruncateWidth: Int = 4) extends Module { // the numbers of cycle - val cycle: UInt = ceil(n.toDouble/radixLog2).toInt + 1 - val ohWidth: Int = 2 * a + 1 + val ohWidth: Int = 2 * a + 1 // IO - val input = Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth))) + val input = Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n))) val output = ValidIO(new SRTOutput(dividerWidth, dividendWidth)) // State // because we need a CSA to minimize the critical path - val partialReminderCarry = Reg(UInt((dividendWidth + radixLog2).W)) - val partialReminderSum = Reg(UInt((dividendWidth + radixLog2).W)) + val partialReminderCarry = Reg(UInt((dividendWidth + radixLog2).W)) + val partialReminderSum = Reg(UInt((dividendWidth + radixLog2).W)) val divider = RegInit(input.bits.divider) - val quotient = Reg(UInt(n.W)) + val quotient = Reg(UInt(n.W)) val quotientMinusOne = Reg(UInt(n.W)) val counter = RegInit(input.bits.counter) // Control // sign of select quotient, true -> negative, false -> positive val qdsSign: Bool = Wire(Bool()) - qdsSign := qds.output.selectedQuotientOH(ohWidth-1, ohWidth/2).orR + qdsSign := qds.output.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR // Datapath val qds = Module(new QDS(rTruncateWidth, ohWidth)) - qds.input.partialReminderSum := partialReminderSum.head(1 + radixLog2 + rTruncateWidth) + qds.input.partialReminderSum := partialReminderSum.head(1 + radixLog2 + rTruncateWidth) qds.input.partialReminderCarry := partialReminderCarry.head(1 + radixLog2 + rTruncateWidth) - qds.partialDivider.bits := (input.bits.divider.head(1 + radixLog2 + rTruncateWidth))(dTruncateWidth-2, 0) - - counter := counter - 1.U + qds.partialDivider.bits := input.bits.divider.head(1 + radixLog2 + rTruncateWidth)(dTruncateWidth - 2, 0) + + counter := counter - radixLog2.U // if counter === 0.U && sz.output.sign, correct the quotient and reminder. valid = 1 // the output of srt - val sz = Module(new SZ(dividendWidth)) - sz.input.partialReminderSum := partialReminderSum - sz.input.partialReminderCarry := partialReminderCarry - output.valid := Mux(counter === 0.U, true.B, false.B) + val sz = Module(new SZ(dividendWidth - 2)) + sz.input.partialReminderSum := partialReminderSum(partialReminderSum.getWidth-3, 0) + sz.input.partialReminderCarry := partialReminderCarry(partialReminderSum.getWidth-3, 0) + output.valid := Mux(counter === 0.U, true.B, false.B) + // correcting maybe have problem - quotient := Mux(counter === 0.U && sz.output.sign, quotient - 1.U, quotient) - output.bits.reminder := partialReminderCarry + partialReminderSum + Mux(counter === 0.U && sz.output.sign, divider, 0.U) + quotient := Mux(counter === 0.U && sz.output.sign, quotient - 1.U, quotient) + output.bits.reminder := Mux1H( + Map( + (counter === 0.U && sz.output.zero) -> 0.U, + (counter === 0.U && sz.output.sign) -> (sz.output.remainder + 1.U + divider), + (counter === 0.U && !sz.output.sign) -> (sz.output.remainder + 1.U) + ) + ) output.bits.quotient := quotient // for SRT4 -> CSA32 @@ -98,22 +104,22 @@ class SRT( partialReminderSum := Mux1H( Map( - (counter === cycle.U) -> input.bits.dividend, + (counter === input.bits.counter) -> input.bits.dividend, (counter > 0.U) -> (csa.out(0) << radixLog2), (counter === 0.U) -> partialReminderSum ) ) - + partialReminderCarry := Mux1H( Map( - (counter === cycle.U) -> 0.U, + (counter === input.bits.counter) -> 0.U, (counter > 0.U) -> (csa.out(1) << (radixLog2 - 1)), (counter === 0.U) -> partialReminderCarry ) ) // On-The-Fly conversion - val otf = Module(new OTF(1< Seq[(Bool,Bool)] + // drop signed bits val ws = input.partialReminderCarry.asBools val wc = input.partialReminderSum.asBools - val psc: Seq[(Bool, Bool)] = ws.zip(wc).map { case (s, c) => (~(s ^ c), (s | c)) } + val psc: Seq[(Bool, Bool)] = ws.zip(wc).map { case (s, c) => (!(s ^ c), (s | c)) } - // call the prefixtree to associativeOp + // call the prefixtree to associativeOp and compute last remainder val pairs: Seq[(Bool, Bool)] = prefixSum.zeroLayer(psc.map(_._1) :+ false.B, false.B +: psc.map(_._2)) val pgs: Vector[(Bool, Bool)] = prefixSum(pairs) - val gs: Vector[Bool] = pgs.map(_._2) val ps: Vector[Bool] = pgs.map(_._1) + val gs: Vector[Bool] = pgs.map(_._2) + + val a: Vector[Bool] = false.B +: gs + val b: Seq[Bool] = pairs.map(_._1) :+ false.B + val sum: Seq[Bool] = a.zip(b).map { case (p, c) => p ^ c } // maybe have a problem. output.zero := VecInit(ps).asUInt.head(1) - output.sign := (output.zero ^ VecInit(gs).asUInt.head(1)) & (~output.zero) + output.sign := (pairs(pairs.length - 1)._1 ^ gs(gs.length - 2)) & (!output.zero) + output.remainder := VecInit(sum).asUInt } diff --git a/arithmetic/src/multiplier/WallaceMultiplier.scala b/arithmetic/src/multiplier/WallaceMultiplier.scala index 84a47c6..4bc6881 100644 --- a/arithmetic/src/multiplier/WallaceMultiplier.scala +++ b/arithmetic/src/multiplier/WallaceMultiplier.scala @@ -72,7 +72,6 @@ class WallaceMultiplierImpl( addAll(toNextLayer, depth + 1) } } - // produce Seq(b, 2 * b, ..., 2^digits * b), output width = width + radixLog2 - 1 val bMultipleWidth = (width + radixLog2 - 1).W def prepareBMultiples(digits: Int): Seq[SInt] = { From 0cd81d8803be7c6817773a84f4ec598ebed54b3e Mon Sep 17 00:00:00 2001 From: GH Cheng <1536771081@qq.com> Date: Mon, 18 Apr 2022 14:43:36 +0800 Subject: [PATCH 15/31] rm srt4test --- arithmetic/src/division/srt/SRT.scala | 16 ++++++++-------- .../src/multiplier/WallaceMultiplier.scala | 1 + .../tests/src/division/srt/SRT4SpecTester.scala | 17 ----------------- 3 files changed, 9 insertions(+), 25 deletions(-) delete mode 100644 arithmetic/tests/src/division/srt/SRT4SpecTester.scala diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index b60a55a..ce6c603 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -8,13 +8,14 @@ import chisel3.util.{log2Ceil, Counter, DecoupledIO, Mux1H, ValidIO} import scala.math.ceil -// 带csa的srt除法器 1/2<= d < 1, 1/2 < rho <=1, 0 < q < 2 -// 0, radix = 4 -// 1,商数范围 :a = 2, {-2, -1, 0, 1, -2}, -// 2, 冗余因子rho = 2/(4-1) =2/3 -// 3,估值(截断位宽):3位整数,4位小数 t = 4 -// 4,选商函数 通过输入的y^(xxx.xxxx)和截断的d(0.1xxx)来进行选商,计算出选商查找表,来进行查找选商[d_i,d_i+1) -// 5, -44/16 < y^ < 42/16 +/** SRT4 + * 1/2<= d < 1, 1/2 < rho <=1, 0 < q < 2 + * 0, radix = 4 + * a = 2, {-2, -1, 0, 1, -2}, + * t = 4 + * y^(xxx.xxxx), d^(0.1xxx) + * -44/16 < y^ < 42/16 + */ // TODO: width class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { @@ -38,7 +39,6 @@ class SRT( dTruncateWidth: Int = 4, rTruncateWidth: Int = 4) extends Module { - // the numbers of cycle val ohWidth: Int = 2 * a + 1 // IO diff --git a/arithmetic/src/multiplier/WallaceMultiplier.scala b/arithmetic/src/multiplier/WallaceMultiplier.scala index 4bc6881..84a47c6 100644 --- a/arithmetic/src/multiplier/WallaceMultiplier.scala +++ b/arithmetic/src/multiplier/WallaceMultiplier.scala @@ -72,6 +72,7 @@ class WallaceMultiplierImpl( addAll(toNextLayer, depth + 1) } } + // produce Seq(b, 2 * b, ..., 2^digits * b), output width = width + radixLog2 - 1 val bMultipleWidth = (width + radixLog2 - 1).W def prepareBMultiples(digits: Int): Seq[SInt] = { diff --git a/arithmetic/tests/src/division/srt/SRT4SpecTester.scala b/arithmetic/tests/src/division/srt/SRT4SpecTester.scala deleted file mode 100644 index 5fe86ec..0000000 --- a/arithmetic/tests/src/division/srt/SRT4SpecTester.scala +++ /dev/null @@ -1,17 +0,0 @@ -package division.srt - -import chisel3._ -import chisel3.tester.{ChiselUtestTester} - -object SRT4SpecTester extends TestSuite with ChiselUtestTester{ - def tests: Tests = Tests{ - test("SRT4 should pass"){ - val u = ??? - - testCircuit(new SRT(32, 32, 32), Seq(chiseltest.internal.NoThreadingAnnotation, chiseltest.simulator.WriteVcdAnnotation)){ dut: SRT => - - } - } - } - -} \ No newline at end of file From c9c87caa87ee7e1537b1f2191a189485fc4dc464 Mon Sep 17 00:00:00 2001 From: GH Cheng <1536771081@qq.com> Date: Mon, 25 Apr 2022 19:47:48 +0800 Subject: [PATCH 16/31] add SRT4Test & fix test --- arithmetic/src/division/srt/OTF.scala | 10 +- arithmetic/src/division/srt/QDS.scala | 17 ++-- arithmetic/src/division/srt/SRT.scala | 93 ++++++++++--------- arithmetic/src/division/srt/SZ.scala | 12 +-- .../tests/src/division/srt/SRT4Test.scala | 35 +++++++ 5 files changed, 107 insertions(+), 60 deletions(-) create mode 100644 arithmetic/tests/src/division/srt/SRT4Test.scala diff --git a/arithmetic/src/division/srt/OTF.scala b/arithmetic/src/division/srt/OTF.scala index 5236ee7..bedb574 100644 --- a/arithmetic/src/division/srt/OTF.scala +++ b/arithmetic/src/division/srt/OTF.scala @@ -24,7 +24,7 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { val qNext: UInt = Mux1H( Seq( input.selectedQuotientOH(0) -> "b110".U, - input.selectedQuotientOH(1) -> "b101".U, + input.selectedQuotientOH(1) -> "b111".U, input.selectedQuotientOH(2) -> "b000".U, input.selectedQuotientOH(3) -> "b001".U, input.selectedQuotientOH(4) -> "b010".U @@ -36,9 +36,9 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { val cShiftQ: Bool = input.selectedQuotientOH(ohWidth / 2, 0).orR val cShiftQM: Bool = input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR - val qIn: UInt = (Mux(cShiftQ, qNext, radix.U + qNext))(1, 0) - val qmIn: UInt = (Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext))(1, 0) + val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(1, 0) + val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(1, 0) - output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne) ## qIn - output.quotientMinusOne := Mux(cShiftQM, input.quotientMinusOne, input.quotient) ## qmIn + output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - 2, 0) ## qIn + output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - 2, 0) ## qmIn } diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index 9454c94..321d60f 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -1,7 +1,8 @@ package division.srt import chisel3._ -import chisel3.util.{log2Ceil, BitPat, RegEnable, Valid} +import chisel3.util.{BitPat, RegEnable, Valid} import chisel3.util.experimental.decode._ +import utils.extend class QDSInput(rWidth: Int) extends Bundle { val partialReminderCarry: UInt = UInt(rWidth.W) @@ -12,11 +13,11 @@ class QDSOutput(ohWidth: Int) extends Bundle { val selectedQuotientOH: UInt = UInt(ohWidth.W) } -class QDS(rWidth: Int, ohWidth: Int) extends Module { +class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { // IO val input = IO(Input(new QDSInput(rWidth))) val output = IO(Output(new QDSOutput(ohWidth))) - val partialDivider = IO(Flipped(Valid(UInt(3.W)))) + val partialDivider = IO(Flipped(Valid(UInt(partialDividerWidth.W)))) // State, in order to keep divider's value val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) @@ -37,7 +38,7 @@ class QDS(rWidth: Int, ohWidth: Int) extends Module { // Array(18, 6, -8, -20), // Array(20, 6, -8, -20), // Array(20, 8, -8, -22), - // Array(24, 8, -8, -24) + // Array(24, 8, -8, -24)/16 // ) val columnSelect = partialDividerLatch val selectRom: Vec[Vec[UInt]] = VecInit( @@ -50,10 +51,14 @@ class QDS(rWidth: Int, ohWidth: Int) extends Module { VecInit("b110_1100".U, "b111_1000".U, "b000_1000".U, "b001_0110".U), VecInit("b110_1000".U, "b111_1000".U, "b000_1000".U, "b001_1000".U) ) + val mkVec = selectRom(columnSelect) + val adderWidth = rWidth + 1 val selectPoints = VecInit(mkVec.map { mk => - // maybe have a problem."+&" extend signed to avoid overflow. only for srt4, because -44/16 < y^ < 42/16. - (input.partialReminderCarry +& input.partialReminderSum + mk).head(1) + // extend signed to avoid overflow. only for srt4, because -44/16 < y^ < 42/16. + (extend(input.partialReminderCarry, adderWidth).asUInt + + extend(input.partialReminderSum, adderWidth).asUInt + + extend(mk, adderWidth).asUInt).head(1) }).asUInt // decoder or findFirstOne here, prefer decoder, the decoder only for srt4 diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index ce6c603..81da38f 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -9,15 +9,15 @@ import chisel3.util.{log2Ceil, Counter, DecoupledIO, Mux1H, ValidIO} import scala.math.ceil /** SRT4 - * 1/2<= d < 1, 1/2 < rho <=1, 0 < q < 2 - * 0, radix = 4 - * a = 2, {-2, -1, 0, 1, -2}, - * t = 4 - * y^(xxx.xxxx), d^(0.1xxx) - * -44/16 < y^ < 42/16 - */ - -// TODO: width + * 1/2<= d < 1, 1/2 < rho <=1, 0 < q < 2 + * 0, radix = 4 + * a = 2, {-2, -1, 0, 1, -2}, + * t = 4 + * y^(xxx.xxxx), d^(0.1xxx) + * -44/16 < y^ < 42/16 + */ + +// TODO: counter & n class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { val dividend = UInt(dividendWidth.W) //0.1********** val divider = UInt(dividerWidth.W) //0.1********** @@ -33,12 +33,14 @@ class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { class SRT( dividendWidth: Int, dividerWidth: Int, - n: Int, // the longest width + n: Int, // the longest width, radixLog2: Int = 2, a: Int = 2, dTruncateWidth: Int = 4, rTruncateWidth: Int = 4) extends Module { + + val xLen: Int = dividendWidth + radixLog2 val ohWidth: Int = 2 * a + 1 // IO @@ -47,8 +49,8 @@ class SRT( // State // because we need a CSA to minimize the critical path - val partialReminderCarry = Reg(UInt((dividendWidth + radixLog2).W)) - val partialReminderSum = Reg(UInt((dividendWidth + radixLog2).W)) + val partialReminderCarry = Reg(UInt(xLen.W)) + val partialReminderSum = Reg(UInt(xLen.W)) val divider = RegInit(input.bits.divider) val quotient = Reg(UInt(n.W)) val quotientMinusOne = Reg(UInt(n.W)) @@ -59,67 +61,72 @@ class SRT( qdsSign := qds.output.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR // Datapath - val qds = Module(new QDS(rTruncateWidth, ohWidth)) - qds.input.partialReminderSum := partialReminderSum.head(1 + radixLog2 + rTruncateWidth) - qds.input.partialReminderCarry := partialReminderCarry.head(1 + radixLog2 + rTruncateWidth) - qds.partialDivider.bits := input.bits.divider.head(1 + radixLog2 + rTruncateWidth)(dTruncateWidth - 2, 0) - - counter := counter - radixLog2.U + val rWidth: Int = 1 + radixLog2 + rTruncateWidth + val qds = Module(new QDS(rWidth, ohWidth, dTruncateWidth - 1)) + qds.input.partialReminderSum := partialReminderSum.head(rWidth) + qds.input.partialReminderCarry := partialReminderCarry.head(rWidth) + qds.partialDivider.bits := input.bits.divider + .head(dTruncateWidth + 1)(dTruncateWidth - 2, 0) //0.1********** -> 0.1*** -> *** + + counter := counter - 1.U // if counter === 0.U && sz.output.sign, correct the quotient and reminder. valid = 1 // the output of srt - val sz = Module(new SZ(dividendWidth - 2)) - sz.input.partialReminderSum := partialReminderSum(partialReminderSum.getWidth-3, 0) - sz.input.partialReminderCarry := partialReminderCarry(partialReminderSum.getWidth-3, 0) - output.valid := Mux(counter === 0.U, true.B, false.B) - - // correcting maybe have problem - quotient := Mux(counter === 0.U && sz.output.sign, quotient - 1.U, quotient) - output.bits.reminder := Mux1H( - Map( - (counter === 0.U && sz.output.zero) -> 0.U, - (counter === 0.U && sz.output.sign) -> (sz.output.remainder + 1.U + divider), - (counter === 0.U && !sz.output.sign) -> (sz.output.remainder + 1.U) - ) - ) +// val sz = Module(new SZ(dividendWidth - 2)) +// sz.input.partialReminderSum := partialReminderSum(partialReminderSum.getWidth-3, 0) +// sz.input.partialReminderCarry := partialReminderCarry(partialReminderSum.getWidth-3, 0) +// // correcting maybe have problem +// quotient := quotient - Mux(sz.output.sign, 1.U, 0.U) +// output.bits.reminder := sz.output.remainder + Mux(sz.output.sign, divider, 0.U) +// output.bits.quotient := quotient + + // according two adders + val isLastCycle: Bool = !counter.orR + output.valid := Mux(isLastCycle, true.B, false.B) + val remainderNoCorrect: UInt = partialReminderSum(xLen - 3, 0) + partialReminderCarry(xLen - 3, 0) + val needCorrect: Bool = Mux(isLastCycle, remainderNoCorrect.head(1).asBool, false.B) + val remainderCorrect: UInt = partialReminderSum(xLen - 3, 0) + partialReminderCarry(xLen - 3, 0) + divider + + quotient := quotient - needCorrect.asUInt + output.bits.reminder := Mux(needCorrect, remainderNoCorrect, remainderCorrect) output.bits.quotient := quotient // for SRT4 -> CSA32 // for SRT8 -> CSA32+CSA32 // for SRT16 -> CSA53+CSA32 // SRT16 <- SRT4 + SRT4*5 - val csa = Module(new CarrySaveAdder(CSACompressor3_2, dividendWidth + radixLog2)) + val csa = Module(new CarrySaveAdder(CSACompressor3_2, xLen)) csa.in(0) := partialReminderSum - csa.in(1) := (partialReminderCarry ## !qdsSign) + csa.in(1) := (partialReminderCarry(xLen, 1) ## !qdsSign) csa.in(2) := Mux1H( qds.output.selectedQuotientOH, - // TODO: this is for SRT4, for SRT8 or SRT16, this should be changed + //this is for SRT4, for SRT8 or SRT16, this should be changed VecInit((-2 to 2).map { case -2 => divider << 1 case -1 => divider case 0 => 0.U - case 1 => extend(~divider, dividendWidth + radixLog2) - case 2 => extend((~divider) << 1, dividendWidth + radixLog2) + case 1 => extend(~divider, xLen) + case 2 => extend((~divider) << 1, xLen) }) ) + // TODO: sel maybe have a problem partialReminderSum := Mux1H( Map( (counter === input.bits.counter) -> input.bits.dividend, - (counter > 0.U) -> (csa.out(0) << radixLog2), - (counter === 0.U) -> partialReminderSum + counter.orR -> (csa.out(0) << radixLog2)(xLen - 1, 0), + isLastCycle -> partialReminderSum ) ) partialReminderCarry := Mux1H( Map( (counter === input.bits.counter) -> 0.U, - (counter > 0.U) -> (csa.out(1) << (radixLog2 - 1)), - (counter === 0.U) -> partialReminderCarry + counter.orR -> (csa.out(1) << radixLog2)(xLen - 1, 0), + isLastCycle -> partialReminderCarry ) ) - // On-The-Fly conversion - val otf = Module(new OTF((1 << radixLog2), n, ohWidth)) + val otf = Module(new OTF(1 << radixLog2, n, ohWidth)) otf.input.quotient := quotient otf.input.quotientMinusOne := quotientMinusOne otf.input.selectedQuotientOH := qds.output.selectedQuotientOH diff --git a/arithmetic/src/division/srt/SZ.scala b/arithmetic/src/division/srt/SZ.scala index 3a34489..6afb437 100644 --- a/arithmetic/src/division/srt/SZ.scala +++ b/arithmetic/src/division/srt/SZ.scala @@ -10,20 +10,20 @@ class SZInput(rWidth: Int) extends Bundle { } class SZOutput(rWidth: Int) extends Bundle { - val sign: Bool = Bool() - val zero: Bool = Bool() - val remainder: UInt = UInt((rWidth + 1).W) + val sign: Bool = Bool() + val zero: Bool = Bool() + val remainder: UInt = UInt((rWidth).W) } class SZ(rWidth: Int, prefixSum: PrefixSum = BrentKungSum) extends Module { val input = IO(Input(new SZInput(rWidth))) val output = IO(Output(new SZOutput(rWidth))) - //controlpath //datapath // csa(ws,wc,-2^-b) => Seq[(Bool,Bool)] // drop signed bits + // prefixtree by group val ws = input.partialReminderCarry.asBools val wc = input.partialReminderSum.asBools val psc: Seq[(Bool, Bool)] = ws.zip(wc).map { case (s, c) => (!(s ^ c), (s | c)) } @@ -34,8 +34,8 @@ class SZ(rWidth: Int, prefixSum: PrefixSum = BrentKungSum) extends Module { val ps: Vector[Bool] = pgs.map(_._1) val gs: Vector[Bool] = pgs.map(_._2) - val a: Vector[Bool] = false.B +: gs - val b: Seq[Bool] = pairs.map(_._1) :+ false.B + val a: Vector[Bool] = false.B +: gs + val b: Seq[Bool] = pairs.map(_._1) :+ false.B val sum: Seq[Bool] = a.zip(b).map { case (p, c) => p ^ c } // maybe have a problem. diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala new file mode 100644 index 0000000..0b2928e --- /dev/null +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -0,0 +1,35 @@ +package division.srt + +import chisel3._ +import chisel3.tester.{ChiselUtestTester, testableClock, testableData} +import utest._ + +object SRT4Test extends TestSuite with ChiselUtestTester{ + def tests: Tests = Tests { + test("SRT4 should pass") { + // parameters + val dividendWidth: Int = 4 + val dividerWidth: Int = 3 + val n: Int = 3 +// val dividend: Int = 7 +// val divider: Int = 3 + val countr: Int = 2 + val remainder: Int = dividend / divider + val quotient: Int = dividend % divider + // test + testCircuit(new SRT(dividendWidth, dividerWidth, n), + Seq(chiseltest.internal.NoThreadingAnnotation, + chiseltest.simulator.WriteVcdAnnotation)){ + dut: SRT => + dut.clock.setTimeout(0) + dut.input.valid.poke(true.B) + dut.input.bits.dividend.poke("b0111".U) + dut.input.bits.divider.poke("b011".U) + dut.input.bits.counter.poke(countr.U) + dut.clock.step(countr) + dut.output.bits.quotient.expect(quotient.U) + dut.output.bits.reminder.expect(remainder.U) + } + } + } +} \ No newline at end of file From ea32e50645305baca9012ca5ed8e7da6ba69b041 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Mon, 9 May 2022 16:36:38 +0800 Subject: [PATCH 17/31] SRT4Test fix --- arithmetic/src/division/srt/OTF.scala | 4 +- arithmetic/src/division/srt/QDS.scala | 10 +- arithmetic/src/division/srt/SRT.scala | 119 ++++++++++-------- .../tests/src/division/srt/SRT4Test.scala | 45 ++++--- 4 files changed, 103 insertions(+), 75 deletions(-) diff --git a/arithmetic/src/division/srt/OTF.scala b/arithmetic/src/division/srt/OTF.scala index bedb574..98cd94e 100644 --- a/arithmetic/src/division/srt/OTF.scala +++ b/arithmetic/src/division/srt/OTF.scala @@ -33,8 +33,8 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { // val cShiftQ: Bool = qNext >= 0.U // val cShiftQM: Bool = qNext <= 0.U - val cShiftQ: Bool = input.selectedQuotientOH(ohWidth / 2, 0).orR - val cShiftQM: Bool = input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR + val cShiftQ: Bool = input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR + val cShiftQM: Bool = input.selectedQuotientOH(ohWidth / 2, 0).orR val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(1, 0) val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(1, 0) diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index 321d60f..921d304 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -66,12 +66,12 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { selectPoints, TruthTable( Seq( - BitPat("b1???") -> BitPat("b10000"), - BitPat("b01??") -> BitPat("b01000"), - BitPat("b001?") -> BitPat("b00100"), - BitPat("b0001") -> BitPat("b00010") + BitPat("b1???") -> BitPat("b00001"), //-2 + BitPat("b01??") -> BitPat("b00010"), //-1 + BitPat("b001?") -> BitPat("b00100"), //0 + BitPat("b0001") -> BitPat("b01000") //1 ), - BitPat("b00001") + BitPat("b10000") //2 ) ) } diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 81da38f..1bb87d4 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -2,15 +2,15 @@ package division.srt import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 -import utils.extend +import utils.{extend} import chisel3._ import chisel3.util.{log2Ceil, Counter, DecoupledIO, Mux1H, ValidIO} import scala.math.ceil /** SRT4 - * 1/2<= d < 1, 1/2 < rho <=1, 0 < q < 2 - * 0, radix = 4 + * 1/2 <= d < 1, 1/2 < rho <=1, 0 < q < 2 + * radix = 4 * a = 2, {-2, -1, 0, 1, -2}, * t = 4 * y^(xxx.xxxx), d^(0.1xxx) @@ -44,94 +44,107 @@ class SRT( val ohWidth: Int = 2 * a + 1 // IO - val input = Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n))) - val output = ValidIO(new SRTOutput(dividerWidth, dividendWidth)) + val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) + val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) // State // because we need a CSA to minimize the critical path val partialReminderCarry = Reg(UInt(xLen.W)) val partialReminderSum = Reg(UInt(xLen.W)) - val divider = RegInit(input.bits.divider) + val divider = Reg(UInt(dividerWidth.W)) val quotient = Reg(UInt(n.W)) val quotientMinusOne = Reg(UInt(n.W)) - val counter = RegInit(input.bits.counter) + val counter = Reg(UInt(log2Ceil(n).W)) + // Control // sign of select quotient, true -> negative, false -> positive val qdsSign: Bool = Wire(Bool()) - qdsSign := qds.output.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR // Datapath - val rWidth: Int = 1 + radixLog2 + rTruncateWidth - val qds = Module(new QDS(rWidth, ohWidth, dTruncateWidth - 1)) - qds.input.partialReminderSum := partialReminderSum.head(rWidth) - qds.input.partialReminderCarry := partialReminderCarry.head(rWidth) - qds.partialDivider.bits := input.bits.divider - .head(dTruncateWidth + 1)(dTruncateWidth - 2, 0) //0.1********** -> 0.1*** -> *** - - counter := counter - 1.U - // if counter === 0.U && sz.output.sign, correct the quotient and reminder. valid = 1 - // the output of srt -// val sz = Module(new SZ(dividendWidth - 2)) -// sz.input.partialReminderSum := partialReminderSum(partialReminderSum.getWidth-3, 0) -// sz.input.partialReminderCarry := partialReminderCarry(partialReminderSum.getWidth-3, 0) -// // correcting maybe have problem -// quotient := quotient - Mux(sz.output.sign, 1.U, 0.U) -// output.bits.reminder := sz.output.remainder + Mux(sz.output.sign, divider, 0.U) -// output.bits.quotient := quotient - // according two adders val isLastCycle: Bool = !counter.orR output.valid := Mux(isLastCycle, true.B, false.B) + input.ready := Mux(isLastCycle, true.B, false.B) + val remainderNoCorrect: UInt = partialReminderSum(xLen - 3, 0) + partialReminderCarry(xLen - 3, 0) val needCorrect: Bool = Mux(isLastCycle, remainderNoCorrect.head(1).asBool, false.B) val remainderCorrect: UInt = partialReminderSum(xLen - 3, 0) + partialReminderCarry(xLen - 3, 0) + divider - - quotient := quotient - needCorrect.asUInt - output.bits.reminder := Mux(needCorrect, remainderNoCorrect, remainderCorrect) + // TODO: ">> radixLog2" is not a better op + output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect) >> radixLog2 output.bits.quotient := quotient + // qds + val rWidth: Int = 1 + radixLog2 + rTruncateWidth + val qds = Module(new QDS(rWidth, ohWidth, dTruncateWidth - 1)) + qds.input.partialReminderSum := partialReminderSum.head(rWidth) + qds.input.partialReminderCarry := partialReminderCarry.head(rWidth) + qds.partialDivider.valid := input.valid + qds.partialDivider.bits := input.bits.divider + .head(dTruncateWidth + 1)(dTruncateWidth - 2, 0) //0.1********** -> 0.1*** -> *** + qdsSign := qds.output.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR + // for SRT4 -> CSA32 // for SRT8 -> CSA32+CSA32 // for SRT16 -> CSA53+CSA32 // SRT16 <- SRT4 + SRT4*5 val csa = Module(new CarrySaveAdder(CSACompressor3_2, xLen)) csa.in(0) := partialReminderSum - csa.in(1) := (partialReminderCarry(xLen, 1) ## !qdsSign) - csa.in(2) := Mux1H( - qds.output.selectedQuotientOH, - //this is for SRT4, for SRT8 or SRT16, this should be changed - VecInit((-2 to 2).map { - case -2 => divider << 1 - case -1 => divider - case 0 => 0.U - case 1 => extend(~divider, xLen) - case 2 => extend((~divider) << 1, xLen) - }) - ) + csa.in(1) := (partialReminderCarry(xLen - 1, 1) ## qdsSign) + csa.in(2) := + Mux1H( + qds.output.selectedQuotientOH, + //this is for SRT4, for SRT8 or SRT16, this should be changed + VecInit((-2 to 2).map { + case -2 => divider << 1 + case -1 => divider + case 0 => 0.U + case 1 => extend(~divider, xLen) + case 2 => extend(~divider << 1, xLen) + }) + ) + // On-The-Fly conversion + val otf = Module(new OTF(1 << radixLog2, n, ohWidth)) + otf.input.quotient := quotient + otf.input.quotientMinusOne := quotientMinusOne + otf.input.selectedQuotientOH := qds.output.selectedQuotientOH + + // divider := Mux(input.valid && input.ready, input.bits.divider, divider) + // counter := Mux(input.valid && input.ready, input.bits.counter, counter - 1.U) // TODO: sel maybe have a problem + divider := Mux(input.valid, input.bits.divider, divider) + counter := Mux(input.valid, input.bits.counter, counter - 1.U) + + quotientMinusOne := Mux1H( + Map( + input.valid -> 0.U, + (!input.valid & counter.orR) -> otf.output.quotientMinusOne, + isLastCycle -> quotientMinusOne + ) + ) + + quotient := Mux1H( + Map( + input.valid -> 0.U, + (!input.valid & counter.orR) -> otf.output.quotient, + isLastCycle -> (quotient - needCorrect.asUInt) + ) + ) + partialReminderSum := Mux1H( Map( - (counter === input.bits.counter) -> input.bits.dividend, - counter.orR -> (csa.out(0) << radixLog2)(xLen - 1, 0), + input.valid -> input.bits.dividend, + (!input.valid & counter.orR) -> (csa.out(1) << radixLog2)(xLen - 1, 0), isLastCycle -> partialReminderSum ) ) partialReminderCarry := Mux1H( Map( - (counter === input.bits.counter) -> 0.U, - counter.orR -> (csa.out(1) << radixLog2)(xLen - 1, 0), + input.valid -> 0.U, + (!input.valid & counter.orR) -> (csa.out(0) << radixLog2 + 1)(xLen - 1, 0), isLastCycle -> partialReminderCarry ) ) - // On-The-Fly conversion - val otf = Module(new OTF(1 << radixLog2, n, ohWidth)) - otf.input.quotient := quotient - otf.input.quotientMinusOne := quotientMinusOne - otf.input.selectedQuotientOH := qds.output.selectedQuotientOH - quotient := otf.output.quotient - quotientMinusOne := otf.output.quotientMinusOne - output.bits.quotient := quotient } diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala index 0b2928e..61409d7 100644 --- a/arithmetic/tests/src/division/srt/SRT4Test.scala +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -8,27 +8,42 @@ object SRT4Test extends TestSuite with ChiselUtestTester{ def tests: Tests = Tests { test("SRT4 should pass") { // parameters - val dividendWidth: Int = 4 - val dividerWidth: Int = 3 - val n: Int = 3 -// val dividend: Int = 7 -// val divider: Int = 3 - val countr: Int = 2 - val remainder: Int = dividend / divider - val quotient: Int = dividend % divider + val dividendWidth: Int = 8 + val dividerWidth: Int = 8 + val n: Int = 10 + val dividend: Int = 15 << 3 + val divider: Int = 3 << 5 +// val counter: Int = 2 + val counter: Int = 1 + val quotient: Int = dividend / divider + val remainder: Int = dividend % divider // test testCircuit(new SRT(dividendWidth, dividerWidth, n), Seq(chiseltest.internal.NoThreadingAnnotation, chiseltest.simulator.WriteVcdAnnotation)){ dut: SRT => - dut.clock.setTimeout(0) +// dut.clock.setTimeout(0) dut.input.valid.poke(true.B) - dut.input.bits.dividend.poke("b0111".U) - dut.input.bits.divider.poke("b011".U) - dut.input.bits.counter.poke(countr.U) - dut.clock.step(countr) - dut.output.bits.quotient.expect(quotient.U) - dut.output.bits.reminder.expect(remainder.U) +// dut.input.bits.dividend.poke("b01111000".U) +// dut.input.bits.divider.poke( "b01100000".U) + dut.input.bits.dividend.poke("b01111000".U) + dut.input.bits.divider.poke( "b01100000".U) + dut.input.bits.counter.poke(counter.U) + dut.clock.step() + dut.input.valid.poke(false.B) + var flag = false + for(a <- 1 to 20 if !flag) { + if(dut.output.valid.peek().litValue == 1) { + flag = true + dut.clock.step() +// dut.output.bits.quotient.expect(5.U) +// dut.output.bits.reminder.expect(0.U) + dut.output.bits.quotient.expect(1.U) + dut.output.bits.reminder.expect("b11000".U) + } + dut.clock.step() + } + utest.assert(flag) } } } From db72ed0520a455525ba726f00a98447f95cb426a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Mon, 9 May 2022 21:09:55 +0800 Subject: [PATCH 18/31] SRT fix --- arithmetic/src/division/srt/SRT.scala | 59 +++++++------------ .../tests/src/division/srt/SRT4Test.scala | 20 +++---- 2 files changed, 31 insertions(+), 48 deletions(-) diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 1bb87d4..386d994 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -17,7 +17,6 @@ import scala.math.ceil * -44/16 < y^ < 42/16 */ -// TODO: counter & n class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { val dividend = UInt(dividendWidth.W) //0.1********** val divider = UInt(dividerWidth.W) //0.1********** @@ -33,7 +32,7 @@ class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { class SRT( dividendWidth: Int, dividerWidth: Int, - n: Int, // the longest width, + n: Int, // the longest width radixLog2: Int = 2, a: Int = 2, dTruncateWidth: Int = 4, @@ -54,31 +53,31 @@ class SRT( val divider = Reg(UInt(dividerWidth.W)) val quotient = Reg(UInt(n.W)) val quotientMinusOne = Reg(UInt(n.W)) - val counter = Reg(UInt(log2Ceil(n).W)) + val counter = RegInit(0.U(log2Ceil(n).W)) // Control // sign of select quotient, true -> negative, false -> positive val qdsSign: Bool = Wire(Bool()) - // Datapath + // Datapath // according two adders val isLastCycle: Bool = !counter.orR - output.valid := Mux(isLastCycle, true.B, false.B) - input.ready := Mux(isLastCycle, true.B, false.B) + output.valid := isLastCycle + input.ready := isLastCycle + // lastCycle-> correct-> output val remainderNoCorrect: UInt = partialReminderSum(xLen - 3, 0) + partialReminderCarry(xLen - 3, 0) val needCorrect: Bool = Mux(isLastCycle, remainderNoCorrect.head(1).asBool, false.B) val remainderCorrect: UInt = partialReminderSum(xLen - 3, 0) + partialReminderCarry(xLen - 3, 0) + divider - // TODO: ">> radixLog2" is not a better op - output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect) >> radixLog2 - output.bits.quotient := quotient + output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect) + output.bits.quotient := quotient - needCorrect.asUInt // qds val rWidth: Int = 1 + radixLog2 + rTruncateWidth val qds = Module(new QDS(rWidth, ohWidth, dTruncateWidth - 1)) qds.input.partialReminderSum := partialReminderSum.head(rWidth) qds.input.partialReminderCarry := partialReminderCarry.head(rWidth) - qds.partialDivider.valid := input.valid + qds.partialDivider.valid := input.valid && input.ready qds.partialDivider.bits := input.bits.divider .head(dTruncateWidth + 1)(dTruncateWidth - 2, 0) //0.1********** -> 0.1*** -> *** qdsSign := qds.output.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR @@ -109,42 +108,26 @@ class SRT( otf.input.quotientMinusOne := quotientMinusOne otf.input.selectedQuotientOH := qds.output.selectedQuotientOH - // divider := Mux(input.valid && input.ready, input.bits.divider, divider) - // counter := Mux(input.valid && input.ready, input.bits.counter, counter - 1.U) - // TODO: sel maybe have a problem - divider := Mux(input.valid, input.bits.divider, divider) - counter := Mux(input.valid, input.bits.counter, counter - 1.U) + divider := Mux(input.valid && input.ready, input.bits.divider, divider) + counter := Mux(input.valid && input.ready, input.bits.counter, counter - 1.U) - quotientMinusOne := Mux1H( - Map( - input.valid -> 0.U, - (!input.valid & counter.orR) -> otf.output.quotientMinusOne, - isLastCycle -> quotientMinusOne - ) - ) - - quotient := Mux1H( - Map( - input.valid -> 0.U, - (!input.valid & counter.orR) -> otf.output.quotient, - isLastCycle -> (quotient - needCorrect.asUInt) - ) - ) + quotient:= Mux(isLastCycle, 0.U, otf.output.quotient) + quotientMinusOne:= Mux(isLastCycle, 0.U, otf.output.quotientMinusOne) partialReminderSum := Mux1H( Map( - input.valid -> input.bits.dividend, - (!input.valid & counter.orR) -> (csa.out(1) << radixLog2)(xLen - 1, 0), - isLastCycle -> partialReminderSum + isLastCycle -> input.bits.dividend, + (counter > 1.U) -> (csa.out(1) << radixLog2)(xLen - 1, 0), + (counter === 1.U) -> csa.out(1)(xLen - 1, 0) ) ) - partialReminderCarry := Mux1H( Map( - input.valid -> 0.U, - (!input.valid & counter.orR) -> (csa.out(0) << radixLog2 + 1)(xLen - 1, 0), - isLastCycle -> partialReminderCarry + isLastCycle -> 0.U, + (counter > 1.U) -> (csa.out(0) << radixLog2 + 1)(xLen - 1, 0), + (counter === 1.U) -> (csa.out(0) << 1)(xLen - 1, 0) ) ) - } + + diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala index 61409d7..9a97220 100644 --- a/arithmetic/tests/src/division/srt/SRT4Test.scala +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -13,8 +13,8 @@ object SRT4Test extends TestSuite with ChiselUtestTester{ val n: Int = 10 val dividend: Int = 15 << 3 val divider: Int = 3 << 5 -// val counter: Int = 2 - val counter: Int = 1 + val counter: Int = 2 +// val counter: Int = 1 val quotient: Int = dividend / divider val remainder: Int = dividend % divider // test @@ -24,10 +24,10 @@ object SRT4Test extends TestSuite with ChiselUtestTester{ dut: SRT => // dut.clock.setTimeout(0) dut.input.valid.poke(true.B) -// dut.input.bits.dividend.poke("b01111000".U) -// dut.input.bits.divider.poke( "b01100000".U) - dut.input.bits.dividend.poke("b01111000".U) - dut.input.bits.divider.poke( "b01100000".U) + dut.input.bits.dividend.poke("b01111000".U) + dut.input.bits.divider.poke( "b01100000".U) +// dut.input.bits.dividend.poke("b01111000".U) +// dut.input.bits.divider.poke( "b01100000".U) dut.input.bits.counter.poke(counter.U) dut.clock.step() dut.input.valid.poke(false.B) @@ -36,10 +36,10 @@ object SRT4Test extends TestSuite with ChiselUtestTester{ if(dut.output.valid.peek().litValue == 1) { flag = true dut.clock.step() -// dut.output.bits.quotient.expect(5.U) -// dut.output.bits.reminder.expect(0.U) - dut.output.bits.quotient.expect(1.U) - dut.output.bits.reminder.expect("b11000".U) + dut.output.bits.quotient.expect(5.U) + dut.output.bits.reminder.expect(0.U) +// dut.output.bits.quotient.expect(1.U) +// dut.output.bits.reminder.expect("b11000".U) } dut.clock.step() } From 1660f1462fcfe7bb0a62ebab259b768bef5aa870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Tue, 10 May 2022 12:56:09 +0800 Subject: [PATCH 19/31] fix reformat --- arithmetic/src/division/srt/SRT.scala | 8 +++----- arithmetic/tests/src/division/srt/SRT4Test.scala | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 386d994..81fe723 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -111,13 +111,13 @@ class SRT( divider := Mux(input.valid && input.ready, input.bits.divider, divider) counter := Mux(input.valid && input.ready, input.bits.counter, counter - 1.U) - quotient:= Mux(isLastCycle, 0.U, otf.output.quotient) - quotientMinusOne:= Mux(isLastCycle, 0.U, otf.output.quotientMinusOne) + quotient := Mux(isLastCycle, 0.U, otf.output.quotient) + quotientMinusOne := Mux(isLastCycle, 0.U, otf.output.quotientMinusOne) partialReminderSum := Mux1H( Map( isLastCycle -> input.bits.dividend, - (counter > 1.U) -> (csa.out(1) << radixLog2)(xLen - 1, 0), + (counter > 1.U) -> (csa.out(1) << radixLog2)(xLen - 1, 0), (counter === 1.U) -> csa.out(1)(xLen - 1, 0) ) ) @@ -129,5 +129,3 @@ class SRT( ) ) } - - diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala index 9a97220..80ae2eb 100644 --- a/arithmetic/tests/src/division/srt/SRT4Test.scala +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -36,8 +36,8 @@ object SRT4Test extends TestSuite with ChiselUtestTester{ if(dut.output.valid.peek().litValue == 1) { flag = true dut.clock.step() - dut.output.bits.quotient.expect(5.U) - dut.output.bits.reminder.expect(0.U) + dut.output.bits.quotient.expect(0.U) + dut.output.bits.reminder.expect("b01111000".U) // dut.output.bits.quotient.expect(1.U) // dut.output.bits.reminder.expect("b11000".U) } From b2e89117866633377959f0494e343fb763368be4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Tue, 10 May 2022 13:08:34 +0800 Subject: [PATCH 20/31] reformat --- arithmetic/src/division/srt/SRT.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index d822ccd..81fe723 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -17,7 +17,6 @@ import scala.math.ceil * -44/16 < y^ < 42/16 */ - class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { val dividend = UInt(dividendWidth.W) //0.1********** val divider = UInt(dividerWidth.W) //0.1********** From 3f33c9933f0cdce22ef01ae25063c440fdeedb88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Wed, 11 May 2022 01:03:23 +0800 Subject: [PATCH 21/31] srt4 fix --- arithmetic/src/division/srt/SRT.scala | 36 ++++++++++++++----- .../tests/src/division/srt/SRT4Test.scala | 23 +++++------- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 81fe723..94c66a0 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -4,7 +4,7 @@ import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 import utils.{extend} import chisel3._ -import chisel3.util.{log2Ceil, Counter, DecoupledIO, Mux1H, ValidIO} +import chisel3.util.{log2Ceil, Counter, DecoupledIO, Fill, Mux1H, ValidIO} import scala.math.ceil @@ -18,8 +18,8 @@ import scala.math.ceil */ class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { - val dividend = UInt(dividendWidth.W) //0.1********** - val divider = UInt(dividerWidth.W) //0.1********** + val dividend = UInt(dividendWidth.W) //.1********** + val divider = UInt(dividerWidth.W) //.1********** val counter = UInt(log2Ceil(n).W) //the width of quotient. } @@ -39,7 +39,7 @@ class SRT( rTruncateWidth: Int = 4) extends Module { - val xLen: Int = dividendWidth + radixLog2 + val xLen: Int = dividendWidth + radixLog2 + 1 val ohWidth: Int = 2 * a + 1 // IO @@ -66,9 +66,10 @@ class SRT( input.ready := isLastCycle // lastCycle-> correct-> output + // only mux is in lastCycle, adder is not inlastCycle val remainderNoCorrect: UInt = partialReminderSum(xLen - 3, 0) + partialReminderCarry(xLen - 3, 0) - val needCorrect: Bool = Mux(isLastCycle, remainderNoCorrect.head(1).asBool, false.B) val remainderCorrect: UInt = partialReminderSum(xLen - 3, 0) + partialReminderCarry(xLen - 3, 0) + divider + val needCorrect: Bool = remainderNoCorrect.head(1).asBool output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect) output.bits.quotient := quotient - needCorrect.asUInt @@ -79,7 +80,7 @@ class SRT( qds.input.partialReminderCarry := partialReminderCarry.head(rWidth) qds.partialDivider.valid := input.valid && input.ready qds.partialDivider.bits := input.bits.divider - .head(dTruncateWidth + 1)(dTruncateWidth - 2, 0) //0.1********** -> 0.1*** -> *** + .head(dTruncateWidth)(dTruncateWidth - 1, 0) //.1********** -> .1*** -> *** qdsSign := qds.output.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR // for SRT4 -> CSA32 @@ -97,8 +98,8 @@ class SRT( case -2 => divider << 1 case -1 => divider case 0 => 0.U - case 1 => extend(~divider, xLen) - case 2 => extend(~divider << 1, xLen) + case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider + case 2 => Fill(radixLog2, 1.U(1.W)) ## (~divider << 1) }) ) @@ -113,6 +114,25 @@ class SRT( quotient := Mux(isLastCycle, 0.U, otf.output.quotient) quotientMinusOne := Mux(isLastCycle, 0.U, otf.output.quotientMinusOne) +// //shiftleft before csa +// partialReminderSum := Mux(isLastCycle, input.bits.dividend >> radixLog2, csa.out(1)) +// partialReminderCarry := Mux(isLastCycle, 0.U, csa.out(0) << 1) +// val csa = Module(new CarrySaveAdder(CSACompressor3_2, xLen)) +// //csa.in(0) := Mux(counter === input.bits.counter, input.bits.dividend, partialReminderSum ) +// csa.in(0) := partialReminderSum << radixLog2 +// csa.in(1) := ((partialReminderCarry << radixLog2)(xLen - 1, 1) ## qdsSign) +// csa.in(2) := +// Mux1H( +// qds.output.selectedQuotientOH, +// //this is for SRT4, for SRT8 or SRT16, this should be changed +// VecInit((-2 to 2).map { +// case -2 => divider << 1 +// case -1 => divider +// case 0 => 0.U +// case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider +// case 2 => Fill(radixLog2, 1.U(1.W)) ## (~divider << 1) +// }) +// ) partialReminderSum := Mux1H( Map( diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala index 80ae2eb..abd87b0 100644 --- a/arithmetic/tests/src/division/srt/SRT4Test.scala +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -8,13 +8,12 @@ object SRT4Test extends TestSuite with ChiselUtestTester{ def tests: Tests = Tests { test("SRT4 should pass") { // parameters - val dividendWidth: Int = 8 - val dividerWidth: Int = 8 + val dividendWidth: Int = 7 + val dividerWidth: Int = 7 val n: Int = 10 val dividend: Int = 15 << 3 val divider: Int = 3 << 5 val counter: Int = 2 -// val counter: Int = 1 val quotient: Int = dividend / divider val remainder: Int = dividend % divider // test @@ -22,12 +21,10 @@ object SRT4Test extends TestSuite with ChiselUtestTester{ Seq(chiseltest.internal.NoThreadingAnnotation, chiseltest.simulator.WriteVcdAnnotation)){ dut: SRT => -// dut.clock.setTimeout(0) + dut.clock.setTimeout(0) dut.input.valid.poke(true.B) - dut.input.bits.dividend.poke("b01111000".U) - dut.input.bits.divider.poke( "b01100000".U) -// dut.input.bits.dividend.poke("b01111000".U) -// dut.input.bits.divider.poke( "b01100000".U) + dut.input.bits.dividend.poke("b1111000".U) + dut.input.bits.divider.poke( "b1100000".U) dut.input.bits.counter.poke(counter.U) dut.clock.step() dut.input.valid.poke(false.B) @@ -35,13 +32,11 @@ object SRT4Test extends TestSuite with ChiselUtestTester{ for(a <- 1 to 20 if !flag) { if(dut.output.valid.peek().litValue == 1) { flag = true - dut.clock.step() - dut.output.bits.quotient.expect(0.U) - dut.output.bits.reminder.expect("b01111000".U) -// dut.output.bits.quotient.expect(1.U) -// dut.output.bits.reminder.expect("b11000".U) + dut.output.bits.quotient.expect(5.U) + dut.output.bits.reminder.expect(0.U) } - dut.clock.step() + else + dut.clock.step() } utest.assert(flag) } From 7018ab8ce804cfe25370d20942d93ccdd9a2fc38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Sun, 15 May 2022 14:33:13 +0800 Subject: [PATCH 22/31] using RegEnable --- arithmetic/src/division/srt/SRT.scala | 99 ++++++++----------- .../tests/src/division/srt/SRT4Test.scala | 6 +- 2 files changed, 43 insertions(+), 62 deletions(-) diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index 94c66a0..ef04f80 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -2,9 +2,9 @@ package division.srt import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 -import utils.{extend} +import utils.extend import chisel3._ -import chisel3.util.{log2Ceil, Counter, DecoupledIO, Fill, Mux1H, ValidIO} +import chisel3.util.{log2Ceil, Counter, DecoupledIO, Fill, Mux1H, RegEnable, ValidIO} import scala.math.ceil @@ -40,45 +40,55 @@ class SRT( extends Module { val xLen: Int = dividendWidth + radixLog2 + 1 + val wLen: Int = xLen + radixLog2 val ohWidth: Int = 2 * a + 1 // IO val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) - // State - // because we need a CSA to minimize the critical path - val partialReminderCarry = Reg(UInt(xLen.W)) - val partialReminderSum = Reg(UInt(xLen.W)) - val divider = Reg(UInt(dividerWidth.W)) - val quotient = Reg(UInt(n.W)) - val quotientMinusOne = Reg(UInt(n.W)) - val counter = RegInit(0.U(log2Ceil(n).W)) + val partialReminderCarryNext = Wire(UInt(wLen.W)) + val partialReminderSumNext = Wire(UInt(wLen.W)) + val dividerNext = Wire(UInt(dividerWidth.W)) + val counterNext = Wire(UInt(n.W)) + val quotientNext = Wire(UInt(n.W)) + val quotientMinusOneNext = Wire(UInt(log2Ceil(n).W)) // Control // sign of select quotient, true -> negative, false -> positive val qdsSign: Bool = Wire(Bool()) + // sign of Cycle, true -> (counter === 0.U) + val isLastCycle: Bool = Wire(Bool()) + + // State + // because we need a CSA to minimize the critical path + val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), input.fire || !isLastCycle) + val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), input.fire || !isLastCycle) + val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), input.fire || !isLastCycle) + val quotient = RegEnable(quotientNext, 0.U(n.W), input.fire || !isLastCycle) + val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), input.fire || !isLastCycle) + val counter = RegEnable(counterNext, 0.U(log2Ceil(n).W), input.fire || !isLastCycle) // Datapath // according two adders - val isLastCycle: Bool = !counter.orR + isLastCycle := !counter.orR output.valid := isLastCycle input.ready := isLastCycle - // lastCycle-> correct-> output - // only mux is in lastCycle, adder is not inlastCycle - val remainderNoCorrect: UInt = partialReminderSum(xLen - 3, 0) + partialReminderCarry(xLen - 3, 0) - val remainderCorrect: UInt = partialReminderSum(xLen - 3, 0) + partialReminderCarry(xLen - 3, 0) + divider - val needCorrect: Bool = remainderNoCorrect.head(1).asBool + // only mux is in last Cycle, adder is in every Cycle + val remainderNoCorrect: UInt = partialReminderSum(wLen - 3, radixLog2) + partialReminderCarry(wLen - 3, radixLog2) + val remainderCorrect: UInt = + partialReminderSum(wLen - 3, radixLog2) + partialReminderCarry(wLen - 3, radixLog2) + divider + val needCorrect: Bool = remainderNoCorrect.head(1).asBool output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect) output.bits.quotient := quotient - needCorrect.asUInt // qds val rWidth: Int = 1 + radixLog2 + rTruncateWidth val qds = Module(new QDS(rWidth, ohWidth, dTruncateWidth - 1)) - qds.input.partialReminderSum := partialReminderSum.head(rWidth) - qds.input.partialReminderCarry := partialReminderCarry.head(rWidth) - qds.partialDivider.valid := input.valid && input.ready + qds.input.partialReminderSum := (partialReminderSum << radixLog2)(wLen - 1, wLen - rWidth) + qds.input.partialReminderCarry := (partialReminderCarry << radixLog2)(wLen - 1, wLen - rWidth) + qds.partialDivider.valid := input.fire qds.partialDivider.bits := input.bits.divider .head(dTruncateWidth)(dTruncateWidth - 1, 0) //.1********** -> .1*** -> *** qdsSign := qds.output.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR @@ -88,8 +98,8 @@ class SRT( // for SRT16 -> CSA53+CSA32 // SRT16 <- SRT4 + SRT4*5 val csa = Module(new CarrySaveAdder(CSACompressor3_2, xLen)) - csa.in(0) := partialReminderSum - csa.in(1) := (partialReminderCarry(xLen - 1, 1) ## qdsSign) + csa.in(0) := (partialReminderSum << radixLog2)(wLen - 1, radixLog2) + csa.in(1) := (partialReminderCarry << radixLog2)(wLen - 1, radixLog2 + 1) ## qdsSign csa.in(2) := Mux1H( qds.output.selectedQuotientOH, @@ -109,43 +119,12 @@ class SRT( otf.input.quotientMinusOne := quotientMinusOne otf.input.selectedQuotientOH := qds.output.selectedQuotientOH - divider := Mux(input.valid && input.ready, input.bits.divider, divider) - counter := Mux(input.valid && input.ready, input.bits.counter, counter - 1.U) - - quotient := Mux(isLastCycle, 0.U, otf.output.quotient) - quotientMinusOne := Mux(isLastCycle, 0.U, otf.output.quotientMinusOne) -// //shiftleft before csa -// partialReminderSum := Mux(isLastCycle, input.bits.dividend >> radixLog2, csa.out(1)) -// partialReminderCarry := Mux(isLastCycle, 0.U, csa.out(0) << 1) -// val csa = Module(new CarrySaveAdder(CSACompressor3_2, xLen)) -// //csa.in(0) := Mux(counter === input.bits.counter, input.bits.dividend, partialReminderSum ) -// csa.in(0) := partialReminderSum << radixLog2 -// csa.in(1) := ((partialReminderCarry << radixLog2)(xLen - 1, 1) ## qdsSign) -// csa.in(2) := -// Mux1H( -// qds.output.selectedQuotientOH, -// //this is for SRT4, for SRT8 or SRT16, this should be changed -// VecInit((-2 to 2).map { -// case -2 => divider << 1 -// case -1 => divider -// case 0 => 0.U -// case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider -// case 2 => Fill(radixLog2, 1.U(1.W)) ## (~divider << 1) -// }) -// ) - - partialReminderSum := Mux1H( - Map( - isLastCycle -> input.bits.dividend, - (counter > 1.U) -> (csa.out(1) << radixLog2)(xLen - 1, 0), - (counter === 1.U) -> csa.out(1)(xLen - 1, 0) - ) - ) - partialReminderCarry := Mux1H( - Map( - isLastCycle -> 0.U, - (counter > 1.U) -> (csa.out(0) << radixLog2 + 1)(xLen - 1, 0), - (counter === 1.U) -> (csa.out(0) << 1)(xLen - 1, 0) - ) - ) + dividerNext := Mux(input.fire, input.bits.divider, divider) + counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) + + quotientNext := Mux(input.fire, 0.U, otf.output.quotient) + quotientMinusOneNext := Mux(input.fire, 0.U, otf.output.quotientMinusOne) + + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa.out(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa.out(0) << 1 + radixLog2) } diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala index abd87b0..986df44 100644 --- a/arithmetic/tests/src/division/srt/SRT4Test.scala +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -11,12 +11,13 @@ object SRT4Test extends TestSuite with ChiselUtestTester{ val dividendWidth: Int = 7 val dividerWidth: Int = 7 val n: Int = 10 - val dividend: Int = 15 << 3 - val divider: Int = 3 << 5 + val dividend: Int = scala.util.Random.nextInt(scala.math.pow(2, n).toInt) + val divider: Int = scala.util.Random.nextInt(scala.math.pow(2, n).toInt) val counter: Int = 2 val quotient: Int = dividend / divider val remainder: Int = dividend % divider // test + //println(chisel3.stage.ChiselStage.emitVerilog(new SRT(dividendWidth, dividerWidth, n))) testCircuit(new SRT(dividendWidth, dividerWidth, n), Seq(chiseltest.internal.NoThreadingAnnotation, chiseltest.simulator.WriteVcdAnnotation)){ @@ -39,6 +40,7 @@ object SRT4Test extends TestSuite with ChiselUtestTester{ dut.clock.step() } utest.assert(flag) + dut.clock.step(scala.util.Random.nextInt(5)) } } } From f0fc979a820e81231eeb1140c6b0275fee64b79c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Mon, 23 May 2022 18:57:00 +0800 Subject: [PATCH 23/31] srt4 debug --- arithmetic/src/division/srt/QDS.scala | 4 +- arithmetic/src/division/srt/SRT.scala | 12 +-- arithmetic/src/division/srt/SZ.scala | 45 ----------- .../tests/src/division/srt/SRT4Test.scala | 74 +++++++++++++------ 4 files changed, 59 insertions(+), 76 deletions(-) delete mode 100644 arithmetic/src/division/srt/SZ.scala diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index 757f113..fd7230b 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -47,14 +47,14 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { VecInit("b111_0010".U, "b111_1100".U, "b000_0110".U, "b000_1111".U), VecInit("b111_0001".U, "b111_1100".U, "b000_0110".U, "b001_0000".U), VecInit("b111_0000".U, "b111_1100".U, "b000_0110".U, "b001_0010".U), - VecInit("b110_1101".U, "b111_1010".U, "b000_1000".U, "b001_0100".U), + VecInit("b110_1110".U, "b111_1010".U, "b000_1000".U, "b001_0100".U), VecInit("b110_1100".U, "b111_1010".U, "b000_1000".U, "b001_0100".U), VecInit("b110_1100".U, "b111_1000".U, "b000_1000".U, "b001_0110".U), VecInit("b110_1000".U, "b111_1000".U, "b000_1000".U, "b001_1000".U) ) val mkVec = selectRom(columnSelect) - val adderWidth = rWidth + 1 + val adderWidth = rWidth + 2 val selectPoints = VecInit(mkVec.map { mk => // extend signed to avoid overflow. only for srt4, because -44/16 < y^ < 42/16. (extend(input.partialReminderCarry, adderWidth).asUInt diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index ef04f80..a0e7310 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -2,11 +2,8 @@ package division.srt import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 -import utils.extend import chisel3._ -import chisel3.util.{log2Ceil, Counter, DecoupledIO, Fill, Mux1H, RegEnable, ValidIO} - -import scala.math.ceil +import chisel3.util.{log2Ceil, DecoupledIO, Fill, Mux1H, RegEnable, ValidIO} /** SRT4 * 1/2 <= d < 1, 1/2 < rho <=1, 0 < q < 2 @@ -18,7 +15,7 @@ import scala.math.ceil */ class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { - val dividend = UInt(dividendWidth.W) //.1********** + val dividend = UInt(dividendWidth.W) //.*********** val divider = UInt(dividerWidth.W) //.1********** val counter = UInt(log2Ceil(n).W) //the width of quotient. } @@ -59,7 +56,6 @@ class SRT( val qdsSign: Bool = Wire(Bool()) // sign of Cycle, true -> (counter === 0.U) val isLastCycle: Bool = Wire(Bool()) - // State // because we need a CSA to minimize the critical path val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), input.fire || !isLastCycle) @@ -91,7 +87,7 @@ class SRT( qds.partialDivider.valid := input.fire qds.partialDivider.bits := input.bits.divider .head(dTruncateWidth)(dTruncateWidth - 1, 0) //.1********** -> .1*** -> *** - qdsSign := qds.output.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR + qdsSign := qds.output.selectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR // for SRT4 -> CSA32 // for SRT8 -> CSA32+CSA32 @@ -109,7 +105,7 @@ class SRT( case -1 => divider case 0 => 0.U case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider - case 2 => Fill(radixLog2, 1.U(1.W)) ## (~divider << 1) + case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) }) ) diff --git a/arithmetic/src/division/srt/SZ.scala b/arithmetic/src/division/srt/SZ.scala deleted file mode 100644 index 6afb437..0000000 --- a/arithmetic/src/division/srt/SZ.scala +++ /dev/null @@ -1,45 +0,0 @@ -package division.srt - -import chisel3._ -import addition.prefixadder._ -import addition.prefixadder.common.{BrentKungSum} - -class SZInput(rWidth: Int) extends Bundle { - val partialReminderCarry: UInt = UInt(rWidth.W) - val partialReminderSum: UInt = UInt(rWidth.W) -} - -class SZOutput(rWidth: Int) extends Bundle { - val sign: Bool = Bool() - val zero: Bool = Bool() - val remainder: UInt = UInt((rWidth).W) -} - -class SZ(rWidth: Int, prefixSum: PrefixSum = BrentKungSum) extends Module { - val input = IO(Input(new SZInput(rWidth))) - val output = IO(Output(new SZOutput(rWidth))) - //controlpath - - //datapath - // csa(ws,wc,-2^-b) => Seq[(Bool,Bool)] - // drop signed bits - // prefixtree by group - val ws = input.partialReminderCarry.asBools - val wc = input.partialReminderSum.asBools - val psc: Seq[(Bool, Bool)] = ws.zip(wc).map { case (s, c) => (!(s ^ c), (s | c)) } - - // call the prefixtree to associativeOp and compute last remainder - val pairs: Seq[(Bool, Bool)] = prefixSum.zeroLayer(psc.map(_._1) :+ false.B, false.B +: psc.map(_._2)) - val pgs: Vector[(Bool, Bool)] = prefixSum(pairs) - val ps: Vector[Bool] = pgs.map(_._1) - val gs: Vector[Bool] = pgs.map(_._2) - - val a: Vector[Bool] = false.B +: gs - val b: Seq[Bool] = pairs.map(_._1) :+ false.B - val sum: Seq[Bool] = a.zip(b).map { case (p, c) => p ^ c } - - // maybe have a problem. - output.zero := VecInit(ps).asUInt.head(1) - output.sign := (pairs(pairs.length - 1)._1 ^ gs(gs.length - 2)) & (!output.zero) - output.remainder := VecInit(sum).asUInt -} diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala index 986df44..76c7f11 100644 --- a/arithmetic/tests/src/division/srt/SRT4Test.scala +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -4,43 +4,75 @@ import chisel3._ import chisel3.tester.{ChiselUtestTester, testableClock, testableData} import utest._ -object SRT4Test extends TestSuite with ChiselUtestTester{ +object SRT4Test extends TestSuite with ChiselUtestTester { def tests: Tests = Tests { test("SRT4 should pass") { // parameters - val dividendWidth: Int = 7 - val dividerWidth: Int = 7 - val n: Int = 10 - val dividend: Int = scala.util.Random.nextInt(scala.math.pow(2, n).toInt) - val divider: Int = scala.util.Random.nextInt(scala.math.pow(2, n).toInt) - val counter: Int = 2 + val n: Int = 16 + val m: Int = n - 1 +// val dividend: Int = scala.util.Random.nextInt(scala.math.pow(2,n -2 ).toInt) +// val divider: Int = scala.util.Random.nextInt(scala.math.pow(2, n - 8).toInt) + val dividend: Int = 65 + val divider: Int = 1 + + def zeroCheck(x: Int): Int = { + var flag = false + var a: Int = m + while (!flag && (a >= -1)) { + flag = ((1 << a) & x) != 0 + a = a - 1 + } + a + 1 + } + val zeroHeadDividend: Int = m - zeroCheck(dividend) + val zeroHeadDivider: Int = m - zeroCheck(divider) + val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + 1 + val noguard: Boolean = needComputerWidth % 2 == 0 + + val counter: Int = (needComputerWidth + 1) / 2 val quotient: Int = dividend / divider - val remainder: Int = dividend % divider + val remainder: Int = dividend % divider + val leftShiftWidthDividend: Int = zeroHeadDividend - (if (noguard) 0 else 1) + val leftShiftWidthDivider: Int = zeroHeadDivider + + println("dividend = %8x, dividend = %d ".format(dividend, dividend)) + println("divider = %8x, divider = %d".format(divider, divider)) + println("zeroHeadDividend = %d, dividend << zeroHeadDividend = %d".format(zeroHeadDividend, dividend << leftShiftWidthDividend)) + println("zeroHeadDivider = %d, divider << zeroHeadDivider = %d".format(zeroHeadDivider, divider << leftShiftWidthDivider)) + println("quotient = %d, remainder = %d".format(quotient, remainder)) + println("counter = %d, needComputerWidth = %d".format(counter, needComputerWidth)) // test //println(chisel3.stage.ChiselStage.emitVerilog(new SRT(dividendWidth, dividerWidth, n))) - testCircuit(new SRT(dividendWidth, dividerWidth, n), + testCircuit(new SRT(n, n, n), Seq(chiseltest.internal.NoThreadingAnnotation, - chiseltest.simulator.WriteVcdAnnotation)){ - dut: SRT => + chiseltest.simulator.WriteVcdAnnotation)) { + dut: SRT => dut.clock.setTimeout(0) dut.input.valid.poke(true.B) - dut.input.bits.dividend.poke("b1111000".U) - dut.input.bits.divider.poke( "b1100000".U) + dut.input.bits.dividend.poke((dividend << leftShiftWidthDividend).U) + dut.input.bits.divider.poke((divider << leftShiftWidthDivider).U) dut.input.bits.counter.poke(counter.U) +// dut.input.bits.dividend.poke("b11_1111_1111".U) +// dut.input.bits.divider.poke( "b10_0000_0000".U) +// dut.input.bits.counter.poke(6.U) dut.clock.step() dut.input.valid.poke(false.B) var flag = false - for(a <- 1 to 20 if !flag) { - if(dut.output.valid.peek().litValue == 1) { + for (a <- 1 to 1000 if !flag) { + if (dut.output.valid.peek().litValue == 1) { flag = true - dut.output.bits.quotient.expect(5.U) - dut.output.bits.reminder.expect(0.U) + println(dut.output.bits.quotient.peek().litValue) + println(dut.output.bits.reminder.peek().litValue) +// println(dut.qds.partialDivider.peek().litValue) + utest.assert(dut.output.bits.quotient.peek().litValue == 61) + utest.assert(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider == remainder) +// utest.assert(dut.output.bits.quotient.peek().litValue == 31) +// utest.assert(dut.output.bits.reminder.peek().litValue == 0) } - else - dut.clock.step() + dut.clock.step() } - utest.assert(flag) - dut.clock.step(scala.util.Random.nextInt(5)) + utest.assert(flag) + dut.clock.step(scala.util.Random.nextInt(10)) } } } From ddadec6c9c5a22bf9d612e6be873b363e1371fcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Thu, 2 Jun 2022 23:14:42 +0800 Subject: [PATCH 24/31] srt4test fixed --- arithmetic/src/division/srt/QDS.scala | 20 ++- arithmetic/src/division/srt/SRT.scala | 28 ++-- .../tests/src/division/srt/SRT4Test.scala | 128 +++++++++--------- 3 files changed, 87 insertions(+), 89 deletions(-) diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index fd7230b..ded08d7 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -30,7 +30,7 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { // Datapath - // from XiangShan/P269 in : /16, should have got from SRTTable. + // from P269 in : /16, should have got from SRTTable. // val qSelTable = Array( // Array(12, 4, -4, -13), // Array(14, 4, -6, -15), @@ -53,12 +53,11 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { VecInit("b110_1000".U, "b111_1000".U, "b000_1000".U, "b001_1000".U) ) + val adderWidth = rWidth + 1 + val yTruncate: UInt = input.partialReminderCarry + input.partialReminderSum val mkVec = selectRom(columnSelect) - val adderWidth = rWidth + 2 val selectPoints = VecInit(mkVec.map { mk => - // extend signed to avoid overflow. only for srt4, because -44/16 < y^ < 42/16. - (extend(input.partialReminderCarry, adderWidth).asUInt - + extend(input.partialReminderSum, adderWidth).asUInt + (extend(yTruncate, adderWidth).asUInt + extend(mk, adderWidth).asUInt).head(1) }).asUInt @@ -67,13 +66,12 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { selectPoints, TruthTable( Seq( - BitPat("b1???") -> BitPat("b00001"), //-2 - BitPat("b01??") -> BitPat("b00010"), //-1 - BitPat("b001?") -> BitPat("b00100"), //0 - BitPat("b0001") -> BitPat("b01000") //1 + BitPat("b???0") -> BitPat("b10000"), //2 + BitPat("b??01") -> BitPat("b01000"), //1 + BitPat("b?011") -> BitPat("b00100"), //0 + BitPat("b0111") -> BitPat("b00010") //-1 ), - BitPat("b10000") //2 + BitPat("b00001") //-2 ) ) - } diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index a0e7310..59fb2fe 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -15,7 +15,7 @@ import chisel3.util.{log2Ceil, DecoupledIO, Fill, Mux1H, RegEnable, ValidIO} */ class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { - val dividend = UInt(dividendWidth.W) //.*********** + val dividend = UInt(dividendWidth.W) //000.***********00 val divider = UInt(dividerWidth.W) //.1********** val counter = UInt(log2Ceil(n).W) //the width of quotient. } @@ -47,23 +47,24 @@ class SRT( val partialReminderCarryNext = Wire(UInt(wLen.W)) val partialReminderSumNext = Wire(UInt(wLen.W)) val dividerNext = Wire(UInt(dividerWidth.W)) - val counterNext = Wire(UInt(n.W)) + val counterNext = Wire(UInt(log2Ceil(n).W)) val quotientNext = Wire(UInt(n.W)) - val quotientMinusOneNext = Wire(UInt(log2Ceil(n).W)) + val quotientMinusOneNext = Wire(UInt(n.W)) // Control // sign of select quotient, true -> negative, false -> positive val qdsSign: Bool = Wire(Bool()) // sign of Cycle, true -> (counter === 0.U) val isLastCycle: Bool = Wire(Bool()) + val enable: Bool = input.fire || !isLastCycle // State // because we need a CSA to minimize the critical path - val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), input.fire || !isLastCycle) - val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), input.fire || !isLastCycle) - val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), input.fire || !isLastCycle) - val quotient = RegEnable(quotientNext, 0.U(n.W), input.fire || !isLastCycle) - val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), input.fire || !isLastCycle) - val counter = RegEnable(counterNext, 0.U(log2Ceil(n).W), input.fire || !isLastCycle) + val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable) + val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), enable) + val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), enable) + val quotient = RegEnable(quotientNext, 0.U(n.W), enable) + val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), enable) + val counter = RegEnable(counterNext, 0.U(log2Ceil(n).W), enable) // Datapath // according two adders @@ -71,12 +72,11 @@ class SRT( output.valid := isLastCycle input.ready := isLastCycle - // only mux is in last Cycle, adder is in every Cycle - val remainderNoCorrect: UInt = partialReminderSum(wLen - 3, radixLog2) + partialReminderCarry(wLen - 3, radixLog2) + val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry val remainderCorrect: UInt = - partialReminderSum(wLen - 3, radixLog2) + partialReminderCarry(wLen - 3, radixLog2) + divider - val needCorrect: Bool = remainderNoCorrect.head(1).asBool - output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect) + partialReminderSum + partialReminderCarry + (divider << 2) + val needCorrect: Bool = remainderNoCorrect(wLen - 3).asBool + output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 3, radixLog2) output.bits.quotient := quotient - needCorrect.asUInt // qds diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala index 76c7f11..5379f1c 100644 --- a/arithmetic/tests/src/division/srt/SRT4Test.scala +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -3,77 +3,77 @@ package division.srt import chisel3._ import chisel3.tester.{ChiselUtestTester, testableClock, testableData} import utest._ +import scala.util.{Random} object SRT4Test extends TestSuite with ChiselUtestTester { def tests: Tests = Tests { test("SRT4 should pass") { - // parameters - val n: Int = 16 - val m: Int = n - 1 -// val dividend: Int = scala.util.Random.nextInt(scala.math.pow(2,n -2 ).toInt) -// val divider: Int = scala.util.Random.nextInt(scala.math.pow(2, n - 8).toInt) - val dividend: Int = 65 - val divider: Int = 1 - - def zeroCheck(x: Int): Int = { - var flag = false - var a: Int = m - while (!flag && (a >= -1)) { - flag = ((1 << a) & x) != 0 - a = a - 1 - } - a + 1 - } - val zeroHeadDividend: Int = m - zeroCheck(dividend) - val zeroHeadDivider: Int = m - zeroCheck(divider) - val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + 1 - val noguard: Boolean = needComputerWidth % 2 == 0 - - val counter: Int = (needComputerWidth + 1) / 2 - val quotient: Int = dividend / divider - val remainder: Int = dividend % divider - val leftShiftWidthDividend: Int = zeroHeadDividend - (if (noguard) 0 else 1) - val leftShiftWidthDivider: Int = zeroHeadDivider - - println("dividend = %8x, dividend = %d ".format(dividend, dividend)) - println("divider = %8x, divider = %d".format(divider, divider)) - println("zeroHeadDividend = %d, dividend << zeroHeadDividend = %d".format(zeroHeadDividend, dividend << leftShiftWidthDividend)) - println("zeroHeadDivider = %d, divider << zeroHeadDivider = %d".format(zeroHeadDivider, divider << leftShiftWidthDivider)) - println("quotient = %d, remainder = %d".format(quotient, remainder)) - println("counter = %d, needComputerWidth = %d".format(counter, needComputerWidth)) - // test - //println(chisel3.stage.ChiselStage.emitVerilog(new SRT(dividendWidth, dividerWidth, n))) - testCircuit(new SRT(n, n, n), - Seq(chiseltest.internal.NoThreadingAnnotation, - chiseltest.simulator.WriteVcdAnnotation)) { - dut: SRT => - dut.clock.setTimeout(0) - dut.input.valid.poke(true.B) - dut.input.bits.dividend.poke((dividend << leftShiftWidthDividend).U) - dut.input.bits.divider.poke((divider << leftShiftWidthDivider).U) - dut.input.bits.counter.poke(counter.U) -// dut.input.bits.dividend.poke("b11_1111_1111".U) -// dut.input.bits.divider.poke( "b10_0000_0000".U) -// dut.input.bits.counter.poke(6.U) - dut.clock.step() - dut.input.valid.poke(false.B) + def testcase: Unit ={ + // parameters + val n: Int = 64 + val m: Int = n - 1 + val p: Int = Random.nextInt(m) + val q: Int = Random.nextInt(m) + val dividend: BigInt = BigInt(p, Random) + val divider: BigInt = BigInt(q, Random) + def zeroCheck(x: BigInt): Int = { var flag = false - for (a <- 1 to 1000 if !flag) { - if (dut.output.valid.peek().litValue == 1) { - flag = true - println(dut.output.bits.quotient.peek().litValue) - println(dut.output.bits.reminder.peek().litValue) -// println(dut.qds.partialDivider.peek().litValue) - utest.assert(dut.output.bits.quotient.peek().litValue == 61) - utest.assert(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider == remainder) -// utest.assert(dut.output.bits.quotient.peek().litValue == 31) -// utest.assert(dut.output.bits.reminder.peek().litValue == 0) - } - dut.clock.step() + var a: Int = m + while (!flag && (a >= -1)) { + flag = ((BigInt(1) << a) & x) != 0 + a = a - 1 } - utest.assert(flag) - dut.clock.step(scala.util.Random.nextInt(10)) + a + 1 + } + val zeroHeadDividend: Int = m - zeroCheck(dividend) + val zeroHeadDivider: Int = m - zeroCheck(divider) + val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + 1 + val noguard: Boolean = needComputerWidth % 2 == 0 + val counter: Int = (needComputerWidth + 1) / 2 + if ((divider == 0) || (divider > dividend) || (needComputerWidth <= 0)) + return + val quotient: BigInt = dividend / divider + val remainder: BigInt = dividend % divider + val leftShiftWidthDividend: Int = zeroHeadDividend - (if (noguard) 0 else 1) + val leftShiftWidthDivider: Int = zeroHeadDivider +// println("dividend = %8x, dividend = %d ".format(dividend, dividend)) +// println("divider = %8x, divider = %d".format(divider, divider)) +// println("zeroHeadDividend = %d, dividend << zeroHeadDividend = %d".format(zeroHeadDividend, dividend << leftShiftWidthDividend)) +// println("zeroHeadDivider = %d, divider << zeroHeadDivider = %d".format(zeroHeadDivider, divider << leftShiftWidthDivider)) +// println("quotient = %d, remainder = %d".format(quotient, remainder)) +// println("counter = %d, needComputerWidth = %d".format(counter, needComputerWidth)) + // test + testCircuit(new SRT(n, n, n), + Seq(chiseltest.internal.NoThreadingAnnotation, + chiseltest.simulator.WriteVcdAnnotation)) { + dut: SRT => + dut.clock.setTimeout(0) + dut.input.valid.poke(true.B) + dut.input.bits.dividend.poke((dividend << leftShiftWidthDividend).U) + dut.input.bits.divider.poke((divider << leftShiftWidthDivider).U) + dut.input.bits.counter.poke(counter.U) + dut.clock.step() + dut.input.valid.poke(false.B) + var flag = false + for (a <- 1 to 1000 if !flag) { + if (dut.output.valid.peek().litValue == 1) { + flag = true + println(dut.output.bits.quotient.peek().litValue) + println(dut.output.bits.reminder.peek().litValue) + utest.assert(dut.output.bits.quotient.peek().litValue == quotient) + utest.assert(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider == remainder) + } + dut.clock.step() + } + utest.assert(flag) + dut.clock.step(scala.util.Random.nextInt(10)) + } } + + testcase +// for( i <- 1 to 1000){ +// testcase +// } } } } \ No newline at end of file From 9aa7dc50c5918fe44150419c8fdfa4364094af05 Mon Sep 17 00:00:00 2001 From: GH Cheng <1536771081@qq.com> Date: Fri, 3 Jun 2022 18:44:29 +0800 Subject: [PATCH 25/31] srt4 fix & add srt16 --- .gitignore | 3 +- arithmetic/src/division/srt/OTF.scala | 12 ++ arithmetic/src/division/srt/QDS.scala | 29 ++--- arithmetic/src/division/srt/SRT16.scala | 123 ++++++++++++++++++ .../division/srt/{SRT.scala => SRT4.scala} | 50 ++++--- arithmetic/src/utils/package.scala | 6 + .../tests/src/division/srt/SRT16Test.scala | 83 ++++++++++++ .../tests/src/division/srt/SRT4Test.scala | 4 +- 8 files changed, 266 insertions(+), 44 deletions(-) create mode 100644 arithmetic/src/division/srt/SRT16.scala rename arithmetic/src/division/srt/{SRT.scala => SRT4.scala} (75%) create mode 100644 arithmetic/tests/src/division/srt/SRT16Test.scala diff --git a/.gitignore b/.gitignore index c0b4ce2..beffdbc 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,5 @@ verdiLog *.out *.cmd *.log -*.json \ No newline at end of file +*.json +*.iml \ No newline at end of file diff --git a/arithmetic/src/division/srt/OTF.scala b/arithmetic/src/division/srt/OTF.scala index 98cd94e..bc762ee 100644 --- a/arithmetic/src/division/srt/OTF.scala +++ b/arithmetic/src/division/srt/OTF.scala @@ -42,3 +42,15 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - 2, 0) ## qIn output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - 2, 0) ## qmIn } + +object OTF { + def apply(radix: Int, qWidth: Int, ohWidth: Int)(quotient: UInt, quotientMinusOne: UInt, selectedQuotientOH: UInt): Vec[UInt] = { + val m = new OTF(radix, qWidth, ohWidth) + m.input.quotient := quotient + m.input.quotientMinusOne := quotientMinusOne + m.input.selectedQuotientOH := selectedQuotientOH + val out = VecInit(m.output.quotient, m.output.quotientMinusOne) + out + } +} + diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index ded08d7..12eb04b 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -4,9 +4,10 @@ import chisel3.util.{BitPat, RegEnable, Valid} import chisel3.util.experimental.decode._ import utils.extend -class QDSInput(rWidth: Int) extends Bundle { +class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle { val partialReminderCarry: UInt = UInt(rWidth.W) val partialReminderSum: UInt = UInt(rWidth.W) + val partialDivider: UInt = UInt(partialDividerWidth.W) } class QDSOutput(ohWidth: Int) extends Bundle { @@ -15,20 +16,8 @@ class QDSOutput(ohWidth: Int) extends Bundle { class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { // IO - val input = IO(Input(new QDSInput(rWidth))) + val input = IO(Input(new QDSInput(rWidth, partialDividerWidth))) val output = IO(Output(new QDSOutput(ohWidth))) - val partialDivider = IO(Flipped(Valid(UInt(partialDividerWidth.W)))) - - // State, in order to keep divider's value - val partialDividerReg = RegEnable(partialDivider.bits, partialDivider.valid) - // for the first cycle: use partialDivider on the IO - // for the reset of cycles: use partialDividerReg - // for synthesis: the constraint should be IO -> Output is a multi-cycle design - // Reg -> Output is single-cycle - // to avoid glitch, valid should be larger than raise time of partialDividerReg - val partialDividerLatch = Mux(partialDivider.valid, partialDivider.bits, partialDividerReg) - - // Datapath // from P269 in : /16, should have got from SRTTable. // val qSelTable = Array( @@ -41,7 +30,7 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { // Array(20, 8, -8, -22), // Array(24, 8, -8, -24)/16 // ) - val columnSelect = partialDividerLatch + val columnSelect = input.partialDivider val selectRom: Vec[Vec[UInt]] = VecInit( VecInit("b111_0100".U, "b111_1100".U, "b000_0100".U, "b000_1101".U), VecInit("b111_0010".U, "b111_1100".U, "b000_0110".U, "b000_1111".U), @@ -75,3 +64,13 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { ) ) } + +object QDS{ + def apply(rWidth: Int, ohWidth: Int, partialDividerWidth: Int)(partialReminderSum: UInt, partialReminderCarry: UInt, partialDivider: UInt): UInt = { + val m = new QDS(rWidth, ohWidth, partialDividerWidth) + m.input.partialReminderSum := partialReminderSum + m.input.partialReminderCarry := partialReminderCarry + m.input.partialDivider := partialDivider + m.output.selectedQuotientOH + } +} \ No newline at end of file diff --git a/arithmetic/src/division/srt/SRT16.scala b/arithmetic/src/division/srt/SRT16.scala new file mode 100644 index 0000000..7a943d2 --- /dev/null +++ b/arithmetic/src/division/srt/SRT16.scala @@ -0,0 +1,123 @@ +package division.srt + +import chisel3._ +import chisel3.util.{log2Ceil, DecoupledIO, Fill, Mux1H, RegEnable, ValidIO} +import utils.staticLeftShift + +/** RSRT16 with Two SRT4 Overlapped Stages + */ + +class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { + val dividend = UInt(dividendWidth.W) //.*********** + val divider = UInt(dividerWidth.W) //.1********** + val counter = UInt(log2Ceil(n).W) //the width of quotient. +} + +class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { + val reminder = UInt(reminderWidth.W) + val quotient = UInt(quotientWidth.W) +} + +// only SRT4 currently +class SRT16( + dividendWidth: Int, + dividerWidth: Int, + n: Int, // the longest width + radixLog2: Int = 2, + a: Int = 2, + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4) + extends Module { + + val xLen: Int = dividendWidth + radixLog2 + 1 + val wLen: Int = xLen + radixLog2 + val ohWidth: Int = 2 * a + 1 + + // IO + val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) + val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) + + val partialReminderCarryNext = Wire(UInt(wLen.W)) + val partialReminderSumNext = Wire(UInt(wLen.W)) + val dividerNext = Wire(UInt(dividerWidth.W)) + val counterNext = Wire(UInt(log2Ceil(n).W)) + val quotientNext = Wire(UInt(n.W)) + val quotientMinusOneNext = Wire(UInt(n.W)) + + val ws1, wc1, ws2, wc2: UInt = Wire(UInt(wLen.W)) + + // Control + // sign of select quotient, true -> negative, false -> positive + val qdsSign, isLastCycle, enable: Bool = Wire(Bool()) + // State + // because we need a CSA to minimize the critical path + val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable) + val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), enable) + val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), enable) + val quotient = RegEnable(quotientNext, 0.U(n.W), enable) + val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), enable) + val counter = RegEnable(counterNext, 0.U(log2Ceil(n).W), enable) + + // Datapath + // according two adders + isLastCycle := !counter.orR + output.valid := isLastCycle + input.ready := isLastCycle + enable := input.fire || !isLastCycle + + val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry + val remainderCorrect: UInt = + partialReminderSum + partialReminderCarry + (divider << 2) + val needCorrect: Bool = remainderNoCorrect(wLen - 3).asBool + output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 3, radixLog2) + output.bits.quotient := quotient - needCorrect.asUInt + + // qds + val rWidth: Int = 1 + radixLog2 + rTruncateWidth + val qds1SelectedQuotientOH = QDS(rWidth, ohWidth, dTruncateWidth - 1)( + staticLeftShift(partialReminderSum, radixLog2).head(rWidth), + staticLeftShift(partialReminderCarry, radixLog2).head(rWidth), + dividerNext.head(dTruncateWidth)(dTruncateWidth - 1, 0)) + + val qds2SelectedQuotientOH = QDS(rWidth, ohWidth, dTruncateWidth - 1)( + staticLeftShift(ws1, radixLog2).head(rWidth), + staticLeftShift(wc1, radixLog2).head(rWidth), + dividerNext.head(dTruncateWidth)(dTruncateWidth - 1, 0)) + qdsSign := qds2SelectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR + + // for SRT16 + val dividerMap = VecInit((-2 to 2).map { + case -2 => divider << 1 + case -1 => divider + case 0 => 0.U + case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider + case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) + }) + val csa1In = VecInit( staticLeftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + staticLeftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 -1) ## qdsSign, + Mux1H(qds1SelectedQuotientOH, dividerMap) ) + val csa1 = addition.csa.c32(csa1In) + + val csa2In = VecInit( staticLeftShift(ws1, radixLog2).head(wLen - radixLog2), + staticLeftShift(ws2, radixLog2).head(wLen - radixLog2 -1) ## qdsSign, + Mux1H(qds1SelectedQuotientOH, dividerMap) ) + val csa2 = addition.csa.c32(csa1In) + + ws1 := csa1(1) << radixLog2 + wc1 := csa1(0) << radixLog2 + 1 + ws2 := csa2(1) << radixLog2 + wc2 := csa2(0) << radixLog2 +1 + + // On-The-Fly conversion + val otf1 = OTF(1 << radixLog2, n, ohWidth)(quotient, quotientMinusOne, qds1SelectedQuotientOH) + val otf2 = OTF(1 << radixLog2, n, ohWidth)(otf1(0), otf1(1), qds2SelectedQuotientOH) + + dividerNext := Mux(input.fire, input.bits.divider, divider) + counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) + + quotientNext := Mux(input.fire, 0.U, otf2(0)) + quotientMinusOneNext := Mux(input.fire, 0.U, otf2(1)) + + partialReminderSumNext := Mux(input.fire, input.bits.dividend, ws2) + partialReminderCarryNext := Mux(input.fire, 0.U, wc2) +} diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT4.scala similarity index 75% rename from arithmetic/src/division/srt/SRT.scala rename to arithmetic/src/division/srt/SRT4.scala index 59fb2fe..64786ce 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT4.scala @@ -4,6 +4,7 @@ import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 import chisel3._ import chisel3.util.{log2Ceil, DecoupledIO, Fill, Mux1H, RegEnable, ValidIO} +import utils.staticLeftShift /** SRT4 * 1/2 <= d < 1, 1/2 < rho <=1, 0 < q < 2 @@ -26,15 +27,15 @@ class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { } // only SRT4 currently -class SRT( - dividendWidth: Int, - dividerWidth: Int, - n: Int, // the longest width - radixLog2: Int = 2, - a: Int = 2, - dTruncateWidth: Int = 4, - rTruncateWidth: Int = 4) - extends Module { +class SRT4( + dividendWidth: Int, + dividerWidth: Int, + n: Int, // the longest width + radixLog2: Int = 2, + a: Int = 2, + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4) + extends Module { val xLen: Int = dividendWidth + radixLog2 + 1 val wLen: Int = xLen + radixLog2 @@ -44,19 +45,16 @@ class SRT( val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) - val partialReminderCarryNext = Wire(UInt(wLen.W)) - val partialReminderSumNext = Wire(UInt(wLen.W)) + val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(wLen.W)) + val quotientNext, quotientMinusOneNext = Wire(UInt(n.W)) val dividerNext = Wire(UInt(dividerWidth.W)) val counterNext = Wire(UInt(log2Ceil(n).W)) - val quotientNext = Wire(UInt(n.W)) - val quotientMinusOneNext = Wire(UInt(n.W)) - + // Control // sign of select quotient, true -> negative, false -> positive - val qdsSign: Bool = Wire(Bool()) // sign of Cycle, true -> (counter === 0.U) - val isLastCycle: Bool = Wire(Bool()) - val enable: Bool = input.fire || !isLastCycle + val qdsSign, isLastCycle, enable: Bool = Wire(Bool()) + // State // because we need a CSA to minimize the critical path val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable) @@ -68,9 +66,10 @@ class SRT( // Datapath // according two adders - isLastCycle := !counter.orR + isLastCycle := !counter.orR output.valid := isLastCycle - input.ready := isLastCycle + input.ready := isLastCycle + enable := input.fire || !isLastCycle val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry val remainderCorrect: UInt = @@ -78,15 +77,14 @@ class SRT( val needCorrect: Bool = remainderNoCorrect(wLen - 3).asBool output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 3, radixLog2) output.bits.quotient := quotient - needCorrect.asUInt +// output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) // qds val rWidth: Int = 1 + radixLog2 + rTruncateWidth val qds = Module(new QDS(rWidth, ohWidth, dTruncateWidth - 1)) - qds.input.partialReminderSum := (partialReminderSum << radixLog2)(wLen - 1, wLen - rWidth) - qds.input.partialReminderCarry := (partialReminderCarry << radixLog2)(wLen - 1, wLen - rWidth) - qds.partialDivider.valid := input.fire - qds.partialDivider.bits := input.bits.divider - .head(dTruncateWidth)(dTruncateWidth - 1, 0) //.1********** -> .1*** -> *** + qds.input.partialReminderSum := staticLeftShift(partialReminderSum, radixLog2).head(rWidth) + qds.input.partialReminderCarry := staticLeftShift(partialReminderCarry, radixLog2).head(rWidth) + qds.input.partialDivider := dividerNext.head(dTruncateWidth)(dTruncateWidth - 1, 0) //.1********** -> .1*** -> *** qdsSign := qds.output.selectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR // for SRT4 -> CSA32 @@ -94,8 +92,8 @@ class SRT( // for SRT16 -> CSA53+CSA32 // SRT16 <- SRT4 + SRT4*5 val csa = Module(new CarrySaveAdder(CSACompressor3_2, xLen)) - csa.in(0) := (partialReminderSum << radixLog2)(wLen - 1, radixLog2) - csa.in(1) := (partialReminderCarry << radixLog2)(wLen - 1, radixLog2 + 1) ## qdsSign + csa.in(0) := staticLeftShift(partialReminderSum, radixLog2).head(wLen - radixLog2) + csa.in(1) := staticLeftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 -1) ## qdsSign csa.in(2) := Mux1H( qds.output.selectedQuotientOH, diff --git a/arithmetic/src/utils/package.scala b/arithmetic/src/utils/package.scala index 6b884a1..19ecef0 100644 --- a/arithmetic/src/utils/package.scala +++ b/arithmetic/src/utils/package.scala @@ -55,4 +55,10 @@ package object utils { else BitPat((x + (1 << w)).U(w.W)) } + + def staticLeftShift(x: UInt, n: Int): UInt={ + val length: Int = x.getWidth + (x << n)(length- 1 - n, 0) + } + } diff --git a/arithmetic/tests/src/division/srt/SRT16Test.scala b/arithmetic/tests/src/division/srt/SRT16Test.scala new file mode 100644 index 0000000..3684e3d --- /dev/null +++ b/arithmetic/tests/src/division/srt/SRT16Test.scala @@ -0,0 +1,83 @@ +package division.srt + +import chisel3._ +import chisel3.tester.{ChiselUtestTester, testableClock, testableData} +import utest._ + +import scala.util.Random + +object SRT16Test extends TestSuite with ChiselUtestTester { + def tests: Tests = Tests { + test("SRT4 should pass") { + def testcase: Unit ={ + // parameters + val radixLog2: Int = 4 + val n: Int = 64 + // guard + val m: Int = n - radixLog2 - 1 + val p: Int = Random.nextInt(m) + val q: Int = Random.nextInt(m) + val dividend: BigInt = BigInt(p, Random) + val divider: BigInt = BigInt(q, Random) + def zeroCheck(x: BigInt): Int = { + var flag = false + var a: Int = m + while (!flag && (a >= -1)) { + flag = ((BigInt(1) << a) & x) != 0 + a = a - 1 + } + a + 1 + } + val zeroHeadDividend: Int = m - zeroCheck(dividend) + val zeroHeadDivider: Int = m - zeroCheck(divider) + val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + 1 + val noguard: Boolean = needComputerWidth % radixLog2 == 0 + val guardWidth: Int = if (noguard) 0 else 4 - needComputerWidth % 4 + val counter: Int = (needComputerWidth + guardWidth) / radixLog2 + if ((divider == 0) || (divider > dividend) || (needComputerWidth <= 0)) + return + val quotient: BigInt = dividend / divider + val remainder: BigInt = dividend % divider + val leftShiftWidthDividend: Int = zeroHeadDividend - guardWidth + val leftShiftWidthDivider: Int = zeroHeadDivider +// println("dividend = %8x, dividend = %d ".format(dividend, dividend)) +// println("divider = %8x, divider = %d".format(divider, divider)) +// println("zeroHeadDividend = %d, dividend << zeroHeadDividend = %d".format(zeroHeadDividend, dividend << leftShiftWidthDividend)) +// println("zeroHeadDivider = %d, divider << zeroHeadDivider = %d".format(zeroHeadDivider, divider << leftShiftWidthDivider)) +// println("quotient = %d, remainder = %d".format(quotient, remainder)) +// println("counter = %d, needComputerWidth = %d".format(counter, needComputerWidth)) + // test + testCircuit(new SRT16(n, n, n), + Seq(chiseltest.internal.NoThreadingAnnotation, + chiseltest.simulator.WriteVcdAnnotation)) { + dut: SRT16 => + dut.clock.setTimeout(0) + dut.input.valid.poke(true.B) + dut.input.bits.dividend.poke((dividend << leftShiftWidthDividend).U) + dut.input.bits.divider.poke((divider << leftShiftWidthDivider).U) + dut.input.bits.counter.poke(counter.U) + dut.clock.step() + dut.input.valid.poke(false.B) + var flag = false + for (a <- 1 to 1000 if !flag) { + if (dut.output.valid.peek().litValue == 1) { + flag = true + println(dut.output.bits.quotient.peek().litValue) + println(dut.output.bits.reminder.peek().litValue) + utest.assert(dut.output.bits.quotient.peek().litValue == quotient) + utest.assert(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider == remainder) + } + dut.clock.step() + } + utest.assert(flag) + dut.clock.step(scala.util.Random.nextInt(10)) + } + } + + testcase +// for( i <- 1 to 1000){ +// testcase +// } + } + } +} \ No newline at end of file diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala index 5379f1c..4a7d5e9 100644 --- a/arithmetic/tests/src/division/srt/SRT4Test.scala +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -43,10 +43,10 @@ object SRT4Test extends TestSuite with ChiselUtestTester { // println("quotient = %d, remainder = %d".format(quotient, remainder)) // println("counter = %d, needComputerWidth = %d".format(counter, needComputerWidth)) // test - testCircuit(new SRT(n, n, n), + testCircuit(new SRT4(n, n, n), Seq(chiseltest.internal.NoThreadingAnnotation, chiseltest.simulator.WriteVcdAnnotation)) { - dut: SRT => + dut: SRT4 => dut.clock.setTimeout(0) dut.input.valid.poke(true.B) dut.input.bits.dividend.poke((dividend << leftShiftWidthDividend).U) From bb63956912fcb71135c18fa1af5659bda908db7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Sat, 4 Jun 2022 16:24:53 +0800 Subject: [PATCH 26/31] fix srt4 & naive srt16 implement --- arithmetic/src/division/srt/OTF.scala | 18 ++-- arithmetic/src/division/srt/QDS.scala | 17 +++- arithmetic/src/division/srt/SRT16.scala | 93 ++++++++----------- arithmetic/src/division/srt/SRT4.scala | 69 +++++++------- arithmetic/src/utils/package.scala | 5 +- .../tests/src/division/srt/SRT16Test.scala | 27 ++---- .../tests/src/division/srt/SRT4Test.scala | 15 +-- 7 files changed, 115 insertions(+), 129 deletions(-) diff --git a/arithmetic/src/division/srt/OTF.scala b/arithmetic/src/division/srt/OTF.scala index bc762ee..30ce242 100644 --- a/arithmetic/src/division/srt/OTF.scala +++ b/arithmetic/src/division/srt/OTF.scala @@ -17,7 +17,6 @@ class OTFOutput(qWidth: Int) extends Bundle { class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { val input = IO(Input(new OTFInput(qWidth, ohWidth))) val output = IO(Output(new OTFOutput(qWidth))) - // control // datapath // q_j+1 in this circle, only for srt4 @@ -35,17 +34,23 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { // val cShiftQM: Bool = qNext <= 0.U val cShiftQ: Bool = input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR val cShiftQM: Bool = input.selectedQuotientOH(ohWidth / 2, 0).orR - - val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(1, 0) - val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(1, 0) + val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(1, 0) + val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(1, 0) output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - 2, 0) ## qIn output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - 2, 0) ## qmIn } object OTF { - def apply(radix: Int, qWidth: Int, ohWidth: Int)(quotient: UInt, quotientMinusOne: UInt, selectedQuotientOH: UInt): Vec[UInt] = { - val m = new OTF(radix, qWidth, ohWidth) + def apply( + radix: Int, + qWidth: Int, + ohWidth: Int + )(quotient: UInt, + quotientMinusOne: UInt, + selectedQuotientOH: UInt + ): Vec[UInt] = { + val m = Module(new OTF(radix, qWidth, ohWidth)) m.input.quotient := quotient m.input.quotientMinusOne := quotientMinusOne m.input.selectedQuotientOH := selectedQuotientOH @@ -53,4 +58,3 @@ object OTF { out } } - diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index 12eb04b..32e5301 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -7,7 +7,7 @@ import utils.extend class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle { val partialReminderCarry: UInt = UInt(rWidth.W) val partialReminderSum: UInt = UInt(rWidth.W) - val partialDivider: UInt = UInt(partialDividerWidth.W) + val partialDivider: UInt = UInt(partialDividerWidth.W) } class QDSOutput(ohWidth: Int) extends Bundle { @@ -65,12 +65,19 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { ) } -object QDS{ - def apply(rWidth: Int, ohWidth: Int, partialDividerWidth: Int)(partialReminderSum: UInt, partialReminderCarry: UInt, partialDivider: UInt): UInt = { - val m = new QDS(rWidth, ohWidth, partialDividerWidth) +object QDS { + def apply( + rWidth: Int, + ohWidth: Int, + partialDividerWidth: Int + )(partialReminderSum: UInt, + partialReminderCarry: UInt, + partialDivider: UInt + ): UInt = { + val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth)) m.input.partialReminderSum := partialReminderSum m.input.partialReminderCarry := partialReminderCarry m.input.partialDivider := partialDivider m.output.selectedQuotientOH } -} \ No newline at end of file +} diff --git a/arithmetic/src/division/srt/SRT16.scala b/arithmetic/src/division/srt/SRT16.scala index 7a943d2..448e9d9 100644 --- a/arithmetic/src/division/srt/SRT16.scala +++ b/arithmetic/src/division/srt/SRT16.scala @@ -2,23 +2,10 @@ package division.srt import chisel3._ import chisel3.util.{log2Ceil, DecoupledIO, Fill, Mux1H, RegEnable, ValidIO} -import utils.staticLeftShift +import utils.leftShift /** RSRT16 with Two SRT4 Overlapped Stages */ - -class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { - val dividend = UInt(dividendWidth.W) //.*********** - val divider = UInt(dividerWidth.W) //.1********** - val counter = UInt(log2Ceil(n).W) //the width of quotient. -} - -class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { - val reminder = UInt(reminderWidth.W) - val quotient = UInt(quotientWidth.W) -} - -// only SRT4 currently class SRT16( dividendWidth: Int, dividerWidth: Int, @@ -37,18 +24,14 @@ class SRT16( val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) - val partialReminderCarryNext = Wire(UInt(wLen.W)) - val partialReminderSumNext = Wire(UInt(wLen.W)) + val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(wLen.W)) val dividerNext = Wire(UInt(dividerWidth.W)) val counterNext = Wire(UInt(log2Ceil(n).W)) - val quotientNext = Wire(UInt(n.W)) - val quotientMinusOneNext = Wire(UInt(n.W)) - + val quotientNext, quotientMinusOneNext = Wire(UInt(n.W)) val ws1, wc1, ws2, wc2: UInt = Wire(UInt(wLen.W)) // Control - // sign of select quotient, true -> negative, false -> positive - val qdsSign, isLastCycle, enable: Bool = Wire(Bool()) + val qds1Sign, qds2Sign, isLastCycle, enable: Bool = Wire(Bool()) // State // because we need a CSA to minimize the critical path val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable) @@ -59,54 +42,60 @@ class SRT16( val counter = RegEnable(counterNext, 0.U(log2Ceil(n).W), enable) // Datapath - // according two adders - isLastCycle := !counter.orR + isLastCycle := !counter.orR output.valid := isLastCycle - input.ready := isLastCycle - enable := input.fire || !isLastCycle + input.ready := isLastCycle + enable := input.fire || !isLastCycle val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry val remainderCorrect: UInt = partialReminderSum + partialReminderCarry + (divider << 2) val needCorrect: Bool = remainderNoCorrect(wLen - 3).asBool - output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 3, radixLog2) - output.bits.quotient := quotient - needCorrect.asUInt + output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 4, radixLog2) + output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) // qds - val rWidth: Int = 1 + radixLog2 + rTruncateWidth - val qds1SelectedQuotientOH = QDS(rWidth, ohWidth, dTruncateWidth - 1)( - staticLeftShift(partialReminderSum, radixLog2).head(rWidth), - staticLeftShift(partialReminderCarry, radixLog2).head(rWidth), - dividerNext.head(dTruncateWidth)(dTruncateWidth - 1, 0)) - - val qds2SelectedQuotientOH = QDS(rWidth, ohWidth, dTruncateWidth - 1)( - staticLeftShift(ws1, radixLog2).head(rWidth), - staticLeftShift(wc1, radixLog2).head(rWidth), - dividerNext.head(dTruncateWidth)(dTruncateWidth - 1, 0)) - qdsSign := qds2SelectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR - - // for SRT16 - val dividerMap = VecInit((-2 to 2).map { + val rWidth: Int = 1 + radixLog2 + rTruncateWidth + val partialDivider: UInt = dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) + val qds1SelectedQuotientOH: UInt = + QDS(rWidth, ohWidth, dTruncateWidth - 1)( + leftShift(partialReminderSum, radixLog2).head(rWidth), + leftShift(partialReminderCarry, radixLog2).head(rWidth), + partialDivider + ) + val qds2SelectedQuotientOH: UInt = + QDS(rWidth, ohWidth, dTruncateWidth - 1)( + leftShift(ws1, radixLog2).head(rWidth), + leftShift(wc1, radixLog2).head(rWidth), + partialDivider + ) + qds1Sign := qds1SelectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR + qds2Sign := qds2SelectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR + + // CSA32 -> CSA32 + val dividerMap = VecInit((-2 to 2).map { case -2 => divider << 1 case -1 => divider case 0 => 0.U case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) }) - val csa1In = VecInit( staticLeftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), - staticLeftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 -1) ## qdsSign, - Mux1H(qds1SelectedQuotientOH, dividerMap) ) + val csa1In = VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qds1Sign, + Mux1H(qds1SelectedQuotientOH, dividerMap) + ) + val csa2In = VecInit( + leftShift(ws1, radixLog2).head(wLen - radixLog2), + leftShift(wc1, radixLog2).head(wLen - radixLog2 - 1) ## qds2Sign, + Mux1H(qds2SelectedQuotientOH, dividerMap) + ) val csa1 = addition.csa.c32(csa1In) - - val csa2In = VecInit( staticLeftShift(ws1, radixLog2).head(wLen - radixLog2), - staticLeftShift(ws2, radixLog2).head(wLen - radixLog2 -1) ## qdsSign, - Mux1H(qds1SelectedQuotientOH, dividerMap) ) - val csa2 = addition.csa.c32(csa1In) - + val csa2 = addition.csa.c32(csa2In) ws1 := csa1(1) << radixLog2 wc1 := csa1(0) << radixLog2 + 1 ws2 := csa2(1) << radixLog2 - wc2 := csa2(0) << radixLog2 +1 + wc2 := csa2(0) << radixLog2 + 1 // On-The-Fly conversion val otf1 = OTF(1 << radixLog2, n, ohWidth)(quotient, quotientMinusOne, qds1SelectedQuotientOH) @@ -114,10 +103,8 @@ class SRT16( dividerNext := Mux(input.fire, input.bits.divider, divider) counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) - quotientNext := Mux(input.fire, 0.U, otf2(0)) quotientMinusOneNext := Mux(input.fire, 0.U, otf2(1)) - partialReminderSumNext := Mux(input.fire, input.bits.dividend, ws2) partialReminderCarryNext := Mux(input.fire, 0.U, wc2) } diff --git a/arithmetic/src/division/srt/SRT4.scala b/arithmetic/src/division/srt/SRT4.scala index 64786ce..d012511 100644 --- a/arithmetic/src/division/srt/SRT4.scala +++ b/arithmetic/src/division/srt/SRT4.scala @@ -4,7 +4,7 @@ import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 import chisel3._ import chisel3.util.{log2Ceil, DecoupledIO, Fill, Mux1H, RegEnable, ValidIO} -import utils.staticLeftShift +import utils.leftShift /** SRT4 * 1/2 <= d < 1, 1/2 < rho <=1, 0 < q < 2 @@ -16,7 +16,7 @@ import utils.staticLeftShift */ class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { - val dividend = UInt(dividendWidth.W) //000.***********00 + val dividend = UInt(dividendWidth.W) //.*********** val divider = UInt(dividerWidth.W) //.1********** val counter = UInt(log2Ceil(n).W) //the width of quotient. } @@ -28,14 +28,14 @@ class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { // only SRT4 currently class SRT4( - dividendWidth: Int, - dividerWidth: Int, - n: Int, // the longest width - radixLog2: Int = 2, - a: Int = 2, - dTruncateWidth: Int = 4, - rTruncateWidth: Int = 4) - extends Module { + dividendWidth: Int, + dividerWidth: Int, + n: Int, // the longest width + radixLog2: Int = 2, + a: Int = 2, + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4) + extends Module { val xLen: Int = dividendWidth + radixLog2 + 1 val wLen: Int = xLen + radixLog2 @@ -49,7 +49,7 @@ class SRT4( val quotientNext, quotientMinusOneNext = Wire(UInt(n.W)) val dividerNext = Wire(UInt(dividerWidth.W)) val counterNext = Wire(UInt(log2Ceil(n).W)) - + // Control // sign of select quotient, true -> negative, false -> positive // sign of Cycle, true -> (counter === 0.U) @@ -66,37 +66,35 @@ class SRT4( // Datapath // according two adders - isLastCycle := !counter.orR + isLastCycle := !counter.orR output.valid := isLastCycle - input.ready := isLastCycle - enable := input.fire || !isLastCycle + input.ready := isLastCycle + enable := input.fire || !isLastCycle val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry val remainderCorrect: UInt = partialReminderSum + partialReminderCarry + (divider << 2) val needCorrect: Bool = remainderNoCorrect(wLen - 3).asBool - output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 3, radixLog2) - output.bits.quotient := quotient - needCorrect.asUInt -// output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) + output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 4, radixLog2) + output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) // qds val rWidth: Int = 1 + radixLog2 + rTruncateWidth - val qds = Module(new QDS(rWidth, ohWidth, dTruncateWidth - 1)) - qds.input.partialReminderSum := staticLeftShift(partialReminderSum, radixLog2).head(rWidth) - qds.input.partialReminderCarry := staticLeftShift(partialReminderCarry, radixLog2).head(rWidth) - qds.input.partialDivider := dividerNext.head(dTruncateWidth)(dTruncateWidth - 1, 0) //.1********** -> .1*** -> *** - qdsSign := qds.output.selectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR - - // for SRT4 -> CSA32 - // for SRT8 -> CSA32+CSA32 - // for SRT16 -> CSA53+CSA32 - // SRT16 <- SRT4 + SRT4*5 + val selectedQuotientOH: UInt = + QDS(rWidth, ohWidth, dTruncateWidth - 1)( + leftShift(partialReminderSum, radixLog2).head(rWidth), + leftShift(partialReminderCarry, radixLog2).head(rWidth), + dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> *** + ) + qdsSign := selectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR + + // csa for SRT4 -> CSA32, SRT8 -> CSA32+CSA32, SRT16 -> CSA53+CSA32, SRT16 <- SRT4 + SRT4*5 val csa = Module(new CarrySaveAdder(CSACompressor3_2, xLen)) - csa.in(0) := staticLeftShift(partialReminderSum, radixLog2).head(wLen - radixLog2) - csa.in(1) := staticLeftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 -1) ## qdsSign + csa.in(0) := leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2) + csa.in(1) := leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign csa.in(2) := Mux1H( - qds.output.selectedQuotientOH, + selectedQuotientOH, //this is for SRT4, for SRT8 or SRT16, this should be changed VecInit((-2 to 2).map { case -2 => divider << 1 @@ -108,17 +106,12 @@ class SRT4( ) // On-The-Fly conversion - val otf = Module(new OTF(1 << radixLog2, n, ohWidth)) - otf.input.quotient := quotient - otf.input.quotientMinusOne := quotientMinusOne - otf.input.selectedQuotientOH := qds.output.selectedQuotientOH + val otf = OTF(1 << radixLog2, n, ohWidth)(quotient, quotientMinusOne, selectedQuotientOH) dividerNext := Mux(input.fire, input.bits.divider, divider) counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) - - quotientNext := Mux(input.fire, 0.U, otf.output.quotient) - quotientMinusOneNext := Mux(input.fire, 0.U, otf.output.quotientMinusOne) - + quotientNext := Mux(input.fire, 0.U, otf(0)) + quotientMinusOneNext := Mux(input.fire, 0.U, otf(1)) partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa.out(1) << radixLog2) partialReminderCarryNext := Mux(input.fire, 0.U, csa.out(0) << 1 + radixLog2) } diff --git a/arithmetic/src/utils/package.scala b/arithmetic/src/utils/package.scala index 19ecef0..f2cf29f 100644 --- a/arithmetic/src/utils/package.scala +++ b/arithmetic/src/utils/package.scala @@ -56,9 +56,10 @@ package object utils { BitPat((x + (1 << w)).U(w.W)) } - def staticLeftShift(x: UInt, n: Int): UInt={ + // keep the width of UInt + def leftShift(x: UInt, n: Int): UInt = { val length: Int = x.getWidth - (x << n)(length- 1 - n, 0) + (x << n)(length - 1, 0) } } diff --git a/arithmetic/tests/src/division/srt/SRT16Test.scala b/arithmetic/tests/src/division/srt/SRT16Test.scala index 3684e3d..c9e3074 100644 --- a/arithmetic/tests/src/division/srt/SRT16Test.scala +++ b/arithmetic/tests/src/division/srt/SRT16Test.scala @@ -8,15 +8,14 @@ import scala.util.Random object SRT16Test extends TestSuite with ChiselUtestTester { def tests: Tests = Tests { - test("SRT4 should pass") { - def testcase: Unit ={ + test("SRT16 should pass") { + def testcase(width: Int): Unit ={ // parameters val radixLog2: Int = 4 - val n: Int = 64 - // guard - val m: Int = n - radixLog2 - 1 - val p: Int = Random.nextInt(m) - val q: Int = Random.nextInt(m) + val n: Int = width + val m: Int = n - 1 + val p: Int = Random.nextInt(m - radixLog2 +1) //order to offer guardwidth + val q: Int = Random.nextInt(m - radixLog2 +1) val dividend: BigInt = BigInt(p, Random) val divider: BigInt = BigInt(q, Random) def zeroCheck(x: BigInt): Int = { @@ -40,12 +39,6 @@ object SRT16Test extends TestSuite with ChiselUtestTester { val remainder: BigInt = dividend % divider val leftShiftWidthDividend: Int = zeroHeadDividend - guardWidth val leftShiftWidthDivider: Int = zeroHeadDivider -// println("dividend = %8x, dividend = %d ".format(dividend, dividend)) -// println("divider = %8x, divider = %d".format(divider, divider)) -// println("zeroHeadDividend = %d, dividend << zeroHeadDividend = %d".format(zeroHeadDividend, dividend << leftShiftWidthDividend)) -// println("zeroHeadDivider = %d, divider << zeroHeadDivider = %d".format(zeroHeadDivider, divider << leftShiftWidthDivider)) -// println("quotient = %d, remainder = %d".format(quotient, remainder)) -// println("counter = %d, needComputerWidth = %d".format(counter, needComputerWidth)) // test testCircuit(new SRT16(n, n, n), Seq(chiseltest.internal.NoThreadingAnnotation, @@ -73,10 +66,10 @@ object SRT16Test extends TestSuite with ChiselUtestTester { dut.clock.step(scala.util.Random.nextInt(10)) } } - - testcase -// for( i <- 1 to 1000){ -// testcase + + testcase(64) +// for( i <- 1 to 100){ +// testcase(128) // } } } diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala index 4a7d5e9..9f447c9 100644 --- a/arithmetic/tests/src/division/srt/SRT4Test.scala +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -8,9 +8,10 @@ import scala.util.{Random} object SRT4Test extends TestSuite with ChiselUtestTester { def tests: Tests = Tests { test("SRT4 should pass") { - def testcase: Unit ={ + def testcase(width: Int): Unit ={ // parameters - val n: Int = 64 + val radixLog2: Int = 2 + val n: Int = width val m: Int = n - 1 val p: Int = Random.nextInt(m) val q: Int = Random.nextInt(m) @@ -28,7 +29,7 @@ object SRT4Test extends TestSuite with ChiselUtestTester { val zeroHeadDividend: Int = m - zeroCheck(dividend) val zeroHeadDivider: Int = m - zeroCheck(divider) val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + 1 - val noguard: Boolean = needComputerWidth % 2 == 0 + val noguard: Boolean = needComputerWidth % radixLog2 == 0 val counter: Int = (needComputerWidth + 1) / 2 if ((divider == 0) || (divider > dividend) || (needComputerWidth <= 0)) return @@ -69,10 +70,10 @@ object SRT4Test extends TestSuite with ChiselUtestTester { dut.clock.step(scala.util.Random.nextInt(10)) } } - - testcase -// for( i <- 1 to 1000){ -// testcase + + testcase(64) +// for( i <- 1 to 100){ +// testcase(128) // } } } From 6dc8605a9692e2c2b3f20c1cbe1f19a6b79016f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Sun, 5 Jun 2022 19:46:09 +0800 Subject: [PATCH 27/31] fix SRTTable --- arithmetic/src/division/srt/OTF.scala | 16 ++++++++-------- arithmetic/src/division/srt/QDS.scala | 2 +- arithmetic/src/division/srt/SRT16.scala | 4 ++-- arithmetic/src/division/srt/SRT4.scala | 2 +- arithmetic/src/division/srt/SRTTable.scala | 4 ++-- arithmetic/src/utils/package.scala | 5 ++--- arithmetic/tests/src/division/srt/SRTSpec.scala | 8 ++++++-- 7 files changed, 22 insertions(+), 19 deletions(-) diff --git a/arithmetic/src/division/srt/OTF.scala b/arithmetic/src/division/srt/OTF.scala index 30ce242..73ac49f 100644 --- a/arithmetic/src/division/srt/OTF.scala +++ b/arithmetic/src/division/srt/OTF.scala @@ -14,10 +14,11 @@ class OTFOutput(qWidth: Int) extends Bundle { val quotientMinusOne = UInt(qWidth.W) } -class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { +class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int) extends Module { val input = IO(Input(new OTFInput(qWidth, ohWidth))) val output = IO(Output(new OTFOutput(qWidth))) + val radix: Int = 1 << radixLog2 // datapath // q_j+1 in this circle, only for srt4 val qNext: UInt = Mux1H( @@ -34,8 +35,8 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { // val cShiftQM: Bool = qNext <= 0.U val cShiftQ: Bool = input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR val cShiftQM: Bool = input.selectedQuotientOH(ohWidth / 2, 0).orR - val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(1, 0) - val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(1, 0) + val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0) + val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0) output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - 2, 0) ## qIn output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - 2, 0) ## qmIn @@ -43,18 +44,17 @@ class OTF(radix: Int, qWidth: Int, ohWidth: Int) extends Module { object OTF { def apply( - radix: Int, + radixLog2: Int, qWidth: Int, ohWidth: Int )(quotient: UInt, quotientMinusOne: UInt, selectedQuotientOH: UInt - ): Vec[UInt] = { - val m = Module(new OTF(radix, qWidth, ohWidth)) + ): Seq[UInt] = { + val m = Module(new OTF(radixLog2, qWidth, ohWidth)) m.input.quotient := quotient m.input.quotientMinusOne := quotientMinusOne m.input.selectedQuotientOH := selectedQuotientOH - val out = VecInit(m.output.quotient, m.output.quotientMinusOne) - out + Seq(m.output.quotient, m.output.quotientMinusOne) } } diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/QDS.scala index 32e5301..a86b875 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/QDS.scala @@ -1,6 +1,6 @@ package division.srt import chisel3._ -import chisel3.util.{BitPat, RegEnable, Valid} +import chisel3.util.{BitPat} import chisel3.util.experimental.decode._ import utils.extend diff --git a/arithmetic/src/division/srt/SRT16.scala b/arithmetic/src/division/srt/SRT16.scala index 448e9d9..6dcdcf3 100644 --- a/arithmetic/src/division/srt/SRT16.scala +++ b/arithmetic/src/division/srt/SRT16.scala @@ -98,8 +98,8 @@ class SRT16( wc2 := csa2(0) << radixLog2 + 1 // On-The-Fly conversion - val otf1 = OTF(1 << radixLog2, n, ohWidth)(quotient, quotientMinusOne, qds1SelectedQuotientOH) - val otf2 = OTF(1 << radixLog2, n, ohWidth)(otf1(0), otf1(1), qds2SelectedQuotientOH) + val otf1 = OTF(radixLog2, n, ohWidth)(quotient, quotientMinusOne, qds1SelectedQuotientOH) + val otf2 = OTF(radixLog2, n, ohWidth)(otf1(0), otf1(1), qds2SelectedQuotientOH) dividerNext := Mux(input.fire, input.bits.divider, divider) counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) diff --git a/arithmetic/src/division/srt/SRT4.scala b/arithmetic/src/division/srt/SRT4.scala index d012511..0abc6e4 100644 --- a/arithmetic/src/division/srt/SRT4.scala +++ b/arithmetic/src/division/srt/SRT4.scala @@ -106,7 +106,7 @@ class SRT4( ) // On-The-Fly conversion - val otf = OTF(1 << radixLog2, n, ohWidth)(quotient, quotientMinusOne, selectedQuotientOH) + val otf = OTF(radixLog2, n, ohWidth)(quotient, quotientMinusOne, selectedQuotientOH) dividerNext := Mux(input.fire, input.bits.divider, divider) counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) diff --git a/arithmetic/src/division/srt/SRTTable.scala b/arithmetic/src/division/srt/SRTTable.scala index 2c8fcbf..56dc20b 100644 --- a/arithmetic/src/division/srt/SRTTable.scala +++ b/arithmetic/src/division/srt/SRTTable.scala @@ -68,7 +68,7 @@ case class SRTTable( (aMin.toInt to aMax.toInt).drop(1).map { k => k -> dSet.dropRight(1).map { d => val (floor, ceil) = xRange(k, d, d + deltaD) - val m: Seq[Algebraic] = xSet.filter { x: Algebraic => x <= ceil && x >= floor } + val m: Seq[Algebraic] = xSet.filter { x: Algebraic => x <= (ceil - deltaX) && x >= floor } (d, m) } } @@ -98,7 +98,7 @@ case class SRTTable( private val xStep = (xMax - xMin) / deltaX // @note 5.7 require(a >= radix / 2) - private val xSet = Seq.tabulate((xStep + 1).toInt) { n => xMin + deltaX * n } + private val xSet = Seq.tabulate((xStep/2 + 1).toInt) { n => deltaX * n } ++ Seq.tabulate((xStep/2 + 1).toInt) { n => -deltaX * n } private val dStep: Algebraic = (dMax - dMin) / deltaD assert((rho > 1 / 2) && (rho <= 1)) diff --git a/arithmetic/src/utils/package.scala b/arithmetic/src/utils/package.scala index f2cf29f..2edf1ab 100644 --- a/arithmetic/src/utils/package.scala +++ b/arithmetic/src/utils/package.scala @@ -56,10 +56,9 @@ package object utils { BitPat((x + (1 << w)).U(w.W)) } - // keep the width of UInt - def leftShift(x: UInt, n: Int): UInt = { + // left shift and keep the width of Bits + def leftShift(x: Bits, n: Int): UInt = { val length: Int = x.getWidth (x << n)(length - 1, 0) } - } diff --git a/arithmetic/tests/src/division/srt/SRTSpec.scala b/arithmetic/tests/src/division/srt/SRTSpec.scala index 85d143d..4b01197 100644 --- a/arithmetic/tests/src/division/srt/SRTSpec.scala +++ b/arithmetic/tests/src/division/srt/SRTSpec.scala @@ -3,11 +3,15 @@ package division.srt import utest._ + object SRTSpec extends TestSuite{ override def tests: Tests = Tests { test("SRT should draw PD") { - val srt = SRTTable(4, 2, 5, 5) - srt.dumpGraph(srt.pd, os.root / "tmp" / "srt4-2-5-5.png") + val srt = SRTTable(4, 2, 4, 4) +// val table = srt.tables.flatMap { +// case (i, ps) => ps.flatMap{ case (d, xs) => xs.map(x => (d.toDouble, x.toDouble*16)) }}.groupBy(_._1) +// table.map{case (x, y) => println(y)} + srt.dumpGraph(srt.pd, os.root / "tmp" / "srt4-2-4-4.png") } } } From 4d045c03de3b62e1bb06697321d9148063743b3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Fri, 10 Jun 2022 19:33:55 +0800 Subject: [PATCH 28/31] srt8 fixed & get tables from SRTTable --- arithmetic/src/division/srt/SRT16.scala | 110 -------------- arithmetic/src/division/srt/SRTIO.scala | 15 ++ arithmetic/src/division/srt/SRTTable.scala | 26 +++- arithmetic/src/division/srt/srt16/OTF.scala | 60 ++++++++ arithmetic/src/division/srt/srt16/QDS.scala | 78 ++++++++++ arithmetic/src/division/srt/srt16/SRT16.scala | 137 ++++++++++++++++++ .../src/division/srt/{ => srt4}/OTF.scala | 8 +- .../src/division/srt/{ => srt4}/QDS.scala | 44 ++++-- .../src/division/srt/{ => srt4}/SRT4.scala | 24 +-- arithmetic/src/division/srt/srt8/OTF.scala | 71 +++++++++ arithmetic/src/division/srt/srt8/QDS.scala | 89 ++++++++++++ arithmetic/src/division/srt/srt8/SRT8.scala | 122 ++++++++++++++++ .../tests/src/division/srt/SRT16Test.scala | 8 +- .../tests/src/division/srt/SRT4Test.scala | 6 +- .../tests/src/division/srt/SRT8Test.scala | 77 ++++++++++ .../tests/src/division/srt/SRTSpec.scala | 10 +- 16 files changed, 723 insertions(+), 162 deletions(-) delete mode 100644 arithmetic/src/division/srt/SRT16.scala create mode 100644 arithmetic/src/division/srt/SRTIO.scala create mode 100644 arithmetic/src/division/srt/srt16/OTF.scala create mode 100644 arithmetic/src/division/srt/srt16/QDS.scala create mode 100644 arithmetic/src/division/srt/srt16/SRT16.scala rename arithmetic/src/division/srt/{ => srt4}/OTF.scala (92%) rename arithmetic/src/division/srt/{ => srt4}/QDS.scala (64%) rename arithmetic/src/division/srt/{ => srt4}/SRT4.scala (84%) create mode 100644 arithmetic/src/division/srt/srt8/OTF.scala create mode 100644 arithmetic/src/division/srt/srt8/QDS.scala create mode 100644 arithmetic/src/division/srt/srt8/SRT8.scala create mode 100644 arithmetic/tests/src/division/srt/SRT8Test.scala diff --git a/arithmetic/src/division/srt/SRT16.scala b/arithmetic/src/division/srt/SRT16.scala deleted file mode 100644 index 6dcdcf3..0000000 --- a/arithmetic/src/division/srt/SRT16.scala +++ /dev/null @@ -1,110 +0,0 @@ -package division.srt - -import chisel3._ -import chisel3.util.{log2Ceil, DecoupledIO, Fill, Mux1H, RegEnable, ValidIO} -import utils.leftShift - -/** RSRT16 with Two SRT4 Overlapped Stages - */ -class SRT16( - dividendWidth: Int, - dividerWidth: Int, - n: Int, // the longest width - radixLog2: Int = 2, - a: Int = 2, - dTruncateWidth: Int = 4, - rTruncateWidth: Int = 4) - extends Module { - - val xLen: Int = dividendWidth + radixLog2 + 1 - val wLen: Int = xLen + radixLog2 - val ohWidth: Int = 2 * a + 1 - - // IO - val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) - val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) - - val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(wLen.W)) - val dividerNext = Wire(UInt(dividerWidth.W)) - val counterNext = Wire(UInt(log2Ceil(n).W)) - val quotientNext, quotientMinusOneNext = Wire(UInt(n.W)) - val ws1, wc1, ws2, wc2: UInt = Wire(UInt(wLen.W)) - - // Control - val qds1Sign, qds2Sign, isLastCycle, enable: Bool = Wire(Bool()) - // State - // because we need a CSA to minimize the critical path - val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable) - val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), enable) - val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), enable) - val quotient = RegEnable(quotientNext, 0.U(n.W), enable) - val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), enable) - val counter = RegEnable(counterNext, 0.U(log2Ceil(n).W), enable) - - // Datapath - isLastCycle := !counter.orR - output.valid := isLastCycle - input.ready := isLastCycle - enable := input.fire || !isLastCycle - - val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry - val remainderCorrect: UInt = - partialReminderSum + partialReminderCarry + (divider << 2) - val needCorrect: Bool = remainderNoCorrect(wLen - 3).asBool - output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 4, radixLog2) - output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) - - // qds - val rWidth: Int = 1 + radixLog2 + rTruncateWidth - val partialDivider: UInt = dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) - val qds1SelectedQuotientOH: UInt = - QDS(rWidth, ohWidth, dTruncateWidth - 1)( - leftShift(partialReminderSum, radixLog2).head(rWidth), - leftShift(partialReminderCarry, radixLog2).head(rWidth), - partialDivider - ) - val qds2SelectedQuotientOH: UInt = - QDS(rWidth, ohWidth, dTruncateWidth - 1)( - leftShift(ws1, radixLog2).head(rWidth), - leftShift(wc1, radixLog2).head(rWidth), - partialDivider - ) - qds1Sign := qds1SelectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR - qds2Sign := qds2SelectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR - - // CSA32 -> CSA32 - val dividerMap = VecInit((-2 to 2).map { - case -2 => divider << 1 - case -1 => divider - case 0 => 0.U - case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider - case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) - }) - val csa1In = VecInit( - leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), - leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qds1Sign, - Mux1H(qds1SelectedQuotientOH, dividerMap) - ) - val csa2In = VecInit( - leftShift(ws1, radixLog2).head(wLen - radixLog2), - leftShift(wc1, radixLog2).head(wLen - radixLog2 - 1) ## qds2Sign, - Mux1H(qds2SelectedQuotientOH, dividerMap) - ) - val csa1 = addition.csa.c32(csa1In) - val csa2 = addition.csa.c32(csa2In) - ws1 := csa1(1) << radixLog2 - wc1 := csa1(0) << radixLog2 + 1 - ws2 := csa2(1) << radixLog2 - wc2 := csa2(0) << radixLog2 + 1 - - // On-The-Fly conversion - val otf1 = OTF(radixLog2, n, ohWidth)(quotient, quotientMinusOne, qds1SelectedQuotientOH) - val otf2 = OTF(radixLog2, n, ohWidth)(otf1(0), otf1(1), qds2SelectedQuotientOH) - - dividerNext := Mux(input.fire, input.bits.divider, divider) - counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) - quotientNext := Mux(input.fire, 0.U, otf2(0)) - quotientMinusOneNext := Mux(input.fire, 0.U, otf2(1)) - partialReminderSumNext := Mux(input.fire, input.bits.dividend, ws2) - partialReminderCarryNext := Mux(input.fire, 0.U, wc2) -} diff --git a/arithmetic/src/division/srt/SRTIO.scala b/arithmetic/src/division/srt/SRTIO.scala new file mode 100644 index 0000000..ae8e7b0 --- /dev/null +++ b/arithmetic/src/division/srt/SRTIO.scala @@ -0,0 +1,15 @@ +package division.srt + +import chisel3._ +import chisel3.util.log2Ceil + +class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { + val dividend = UInt(dividendWidth.W) //.*********** + val divider = UInt(dividerWidth.W) //.1********** + val counter = UInt(log2Ceil(n).W) //the width of quotient. +} + +class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { + val reminder = UInt(reminderWidth.W) + val quotient = UInt(quotientWidth.W) +} diff --git a/arithmetic/src/division/srt/SRTTable.scala b/arithmetic/src/division/srt/SRTTable.scala index 56dc20b..8de5dfb 100644 --- a/arithmetic/src/division/srt/SRTTable.scala +++ b/arithmetic/src/division/srt/SRTTable.scala @@ -87,18 +87,28 @@ case class SRTTable( } // TODO: select a Constant from each m, then offer the table to QDS. - // select rule: symmetry and draw a line parallel to the Y-axis, how define the rule - // lazy val qdsTables: Seq[(Algebraic, Algebraic)] = { - // tables.map { - // case (i, ps) => - // ps.flatMap { case (d, xs) => xs.filter{ x: Algebraic => ??? }.map(x => ((d< + k -> dSet.dropRight(1).map { d => + val (floor, ceil) = xRange(k, d, d + deltaD) + val m: Seq[Algebraic] = xSet.filter { x: Algebraic => x <= (ceil - deltaX) && x >= floor } + (d, m.head) + } + } + }.flatMap { + case (i, ps) => + ps.map { + case (x, y) => (x.toDouble, y.toDouble * 16) + } + }.groupBy(_._1).toSeq.sortBy(_._1).map { case (x, y) => y.map { case (x, y) => y.toInt }.reverse } private val xStep = (xMax - xMin) / deltaX // @note 5.7 require(a >= radix / 2) - private val xSet = Seq.tabulate((xStep/2 + 1).toInt) { n => deltaX * n } ++ Seq.tabulate((xStep/2 + 1).toInt) { n => -deltaX * n } + private val xSet = Seq.tabulate((xStep / 2 + 1).toInt) { n => deltaX * n } ++ Seq.tabulate((xStep / 2 + 1).toInt) { + n => -deltaX * n + } private val dStep: Algebraic = (dMax - dMin) / deltaD assert((rho > 1 / 2) && (rho <= 1)) diff --git a/arithmetic/src/division/srt/srt16/OTF.scala b/arithmetic/src/division/srt/srt16/OTF.scala new file mode 100644 index 0000000..88fd12f --- /dev/null +++ b/arithmetic/src/division/srt/srt16/OTF.scala @@ -0,0 +1,60 @@ +package division.srt.srt16 + +import chisel3._ +import chisel3.util.Mux1H + +class OTFInput(qWidth: Int, ohWidth: Int) extends Bundle { + val quotient = UInt(qWidth.W) + val quotientMinusOne = UInt(qWidth.W) + val selectedQuotientOH = UInt(ohWidth.W) +} + +class OTFOutput(qWidth: Int) extends Bundle { + val quotient = UInt(qWidth.W) + val quotientMinusOne = UInt(qWidth.W) +} + +class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int) extends Module { + val input = IO(Input(new OTFInput(qWidth, ohWidth))) + val output = IO(Output(new OTFOutput(qWidth))) + + val radix: Int = 1 << radixLog2 + // datapath + // q_j+1 in this circle, only for srt4 + val qNext: UInt = Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b110".U, + input.selectedQuotientOH(1) -> "b111".U, + input.selectedQuotientOH(2) -> "b000".U, + input.selectedQuotientOH(3) -> "b001".U, + input.selectedQuotientOH(4) -> "b010".U + ) + ) + + // val cShiftQ: Bool = qNext >= 0.U + // val cShiftQM: Bool = qNext <= 0.U + val cShiftQ: Bool = input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR + val cShiftQM: Bool = input.selectedQuotientOH(ohWidth / 2, 0).orR + val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0) + val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0) + + output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qIn + output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qmIn +} + +object OTF { + def apply( + radixLog2: Int, + qWidth: Int, + ohWidth: Int + )(quotient: UInt, + quotientMinusOne: UInt, + selectedQuotientOH: UInt + ): Seq[UInt] = { + val m = Module(new OTF(radixLog2, qWidth, ohWidth)) + m.input.quotient := quotient + m.input.quotientMinusOne := quotientMinusOne + m.input.selectedQuotientOH := selectedQuotientOH + Seq(m.output.quotient, m.output.quotientMinusOne) + } +} diff --git a/arithmetic/src/division/srt/srt16/QDS.scala b/arithmetic/src/division/srt/srt16/QDS.scala new file mode 100644 index 0000000..6fa0337 --- /dev/null +++ b/arithmetic/src/division/srt/srt16/QDS.scala @@ -0,0 +1,78 @@ +package division.srt.srt16 + +import chisel3._ +import chisel3.util.BitPat +import chisel3.util.experimental.decode._ +import utils.extend + +class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle { + val partialReminderCarry: UInt = UInt(rWidth.W) + val partialReminderSum: UInt = UInt(rWidth.W) + val partialDivider: UInt = UInt(partialDividerWidth.W) +} + +class QDSOutput(ohWidth: Int) extends Bundle { + val selectedQuotientOH: UInt = UInt(ohWidth.W) +} + +class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]]) extends Module { + // IO + val input = IO(Input(new QDSInput(rWidth, partialDividerWidth))) + val output = IO(Output(new QDSOutput(ohWidth))) + + // get from SRTTable. + val selectRom = VecInit(tables.map { + case x => + VecInit(x.map { + case x => + new StringBuffer("b") + .append( + if ((-x).toBinaryString.length >= rWidth) (-x).toBinaryString.reverse.substring(0, rWidth).reverse + else (-x).toBinaryString + ) + .toString + .U + }) + }) + + val columnSelect = input.partialDivider + val adderWidth = rWidth + 1 + val yTruncate: UInt = input.partialReminderCarry + input.partialReminderSum + val mkVec = selectRom(columnSelect) + val selectPoints = VecInit(mkVec.map { mk => + (extend(yTruncate, adderWidth).asUInt + + extend(mk, adderWidth).asUInt).head(1) + }).asUInt + + // decoder or findFirstOne here, prefer decoder, the decoder only for srt4 + output.selectedQuotientOH := chisel3.util.experimental.decode.decoder( + selectPoints, + TruthTable( + Seq( + BitPat("b???0") -> BitPat("b10000"), //2 + BitPat("b??01") -> BitPat("b01000"), //1 + BitPat("b?011") -> BitPat("b00100"), //0 + BitPat("b0111") -> BitPat("b00010") //-1 + ), + BitPat("b00001") //-2 + ) + ) +} + +object QDS { + def apply( + rWidth: Int, + ohWidth: Int, + partialDividerWidth: Int, + tables: Seq[Seq[Int]] + )(partialReminderSum: UInt, + partialReminderCarry: UInt, + partialDivider: UInt + ): UInt = { + val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables)) + m.input.partialReminderSum := partialReminderSum + m.input.partialReminderCarry := partialReminderCarry + m.input.partialDivider := partialDivider + m.output.selectedQuotientOH + } +} diff --git a/arithmetic/src/division/srt/srt16/SRT16.scala b/arithmetic/src/division/srt/srt16/SRT16.scala new file mode 100644 index 0000000..2af05f1 --- /dev/null +++ b/arithmetic/src/division/srt/srt16/SRT16.scala @@ -0,0 +1,137 @@ +package division.srt.srt16 + +import division.srt._ +import chisel3._ +import chisel3.util.{log2Ceil, DecoupledIO, Fill, Mux1H, RegEnable, ValidIO} +import utils.leftShift + +/** RSRT16 with Two SRT4 Overlapped Stages + * n>=7 + * Reuse parameters, OTF and QDS of srt4 + */ +class SRT16( + dividendWidth: Int, + dividerWidth: Int, + n: Int, // the longest width + radixLog2: Int = 2, + a: Int = 2, + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4) + extends Module { + + val xLen: Int = dividendWidth + radixLog2 + 1 + val wLen: Int = xLen + radixLog2 + val ohWidth: Int = 2 * a + 1 + + // IO + val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) + val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) + + val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(wLen.W)) + val dividerNext = Wire(UInt(dividerWidth.W)) + val counterNext = Wire(UInt(log2Ceil(n).W)) + val quotientNext, quotientMinusOneNext = Wire(UInt(n.W)) + + // Control + val isLastCycle, enable: Bool = Wire(Bool()) + // State + // because we need a CSA to minimize the critical path + val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable) + val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), enable) + val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), enable) + val quotient = RegEnable(quotientNext, 0.U(n.W), enable) + val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), enable) + val counter = RegEnable(counterNext, 0.U(log2Ceil(n).W), enable) + + // Datapath + isLastCycle := !counter.orR + output.valid := isLastCycle + input.ready := isLastCycle + enable := input.fire || !isLastCycle + + val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry + val remainderCorrect: UInt = + partialReminderSum + partialReminderCarry + (divider << radixLog2) + val needCorrect: Bool = remainderNoCorrect(wLen - 3).asBool + output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 4, radixLog2) + output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) + + // 5*CSA32 SRT16 <- SRT4 + SRT4*5 /SRT16 -> CSA53+CSA32 + val dividerMap = VecInit((-2 to 2).map { + case -2 => divider << 1 + case -1 => divider + case 0 => 0.U + case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider + case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) + }) + val csaIn1 = leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2) + val csaIn2 = leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) + val csa1 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## false.B, dividerMap(0))) // -2 + val csa2 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## false.B, dividerMap(1))) // -1 + val csa3 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## false.B, dividerMap(2))) // 0 + val csa4 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## true.B, dividerMap(3))) // 1 + val csa5 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## true.B, dividerMap(4))) // 2 + + // qds + val rWidth: Int = 1 + radixLog2 + rTruncateWidth + val tables: Seq[Seq[Int]] = division.srt.SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS + val partialDivider: UInt = dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) + val qdsOH0: UInt = + QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)( + leftShift(partialReminderSum, radixLog2).head(rWidth), + leftShift(partialReminderCarry, radixLog2).head(rWidth), + partialDivider + ) // q_j+1 oneHot + + def qds(a: Vec[UInt]): UInt = { + QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)( + leftShift(a(1), radixLog2).head(rWidth), + leftShift(a(0), radixLog2 + 1).head(rWidth), + partialDivider + ) + } + // q_j+2 oneHot precompute + val qds1SelectedQuotientOH: UInt = qds(csa1) // -2 + val qds2SelectedQuotientOH: UInt = qds(csa2) // -1 + val qds3SelectedQuotientOH: UInt = qds(csa3) // 0 + val qds4SelectedQuotientOH: UInt = qds(csa4) // 1 + val qds5SelectedQuotientOH: UInt = qds(csa5) // 2 + + val csa0OutMap = VecInit((-2 to 2).map { + case -2 => csa1 + case -1 => csa2 + case 0 => csa3 + case 1 => csa4 + case 2 => csa5 + }) + val qds1SelectedQuotientOHMap = VecInit((-2 to 2).map { + case -2 => qds1SelectedQuotientOH + case -1 => qds2SelectedQuotientOH + case 0 => qds3SelectedQuotientOH + case 1 => qds4SelectedQuotientOH + case 2 => qds5SelectedQuotientOH + }) + + val qdsOH1 = Mux1H(qdsOH0, qds1SelectedQuotientOHMap) // q_j+2 oneHot + val qds1sign = qdsOH1(ohWidth - 1, ohWidth / 2 + 1).orR + val csa0Out = Mux1H(qdsOH0, csa0OutMap) + val csa1Out = addition.csa.c32( + VecInit( + leftShift(csa0Out(1), radixLog2).head(wLen - radixLog2), + leftShift(csa0Out(0), radixLog2 + 1).head(wLen - radixLog2 - 1) ## qds1sign, + Mux1H(qdsOH1, dividerMap) + ) + ) + + // On-The-Fly conversion + // todo?: OTF input: Q, QM1, (q1 << 2 + q2) output: Q,QM1 + val otf0 = OTF(radixLog2, n, ohWidth)(quotient, quotientMinusOne, qdsOH0) + val otf1 = OTF(radixLog2, n, ohWidth)(otf0(0), otf0(1), qdsOH1) + + dividerNext := Mux(input.fire, input.bits.divider, divider) + counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) + quotientNext := Mux(input.fire, 0.U, otf1(0)) + quotientMinusOneNext := Mux(input.fire, 0.U, otf1(1)) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa1Out(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa1Out(0) << radixLog2 + 1) +} diff --git a/arithmetic/src/division/srt/OTF.scala b/arithmetic/src/division/srt/srt4/OTF.scala similarity index 92% rename from arithmetic/src/division/srt/OTF.scala rename to arithmetic/src/division/srt/srt4/OTF.scala index 73ac49f..3e94a10 100644 --- a/arithmetic/src/division/srt/OTF.scala +++ b/arithmetic/src/division/srt/srt4/OTF.scala @@ -1,7 +1,7 @@ -package division.srt +package division.srt.srt4 import chisel3._ -import chisel3.util.{Mux1H} +import chisel3.util.Mux1H class OTFInput(qWidth: Int, ohWidth: Int) extends Bundle { val quotient = UInt(qWidth.W) @@ -38,8 +38,8 @@ class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int) extends Module { val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0) val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0) - output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - 2, 0) ## qIn - output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - 2, 0) ## qmIn + output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qIn + output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qmIn } object OTF { diff --git a/arithmetic/src/division/srt/QDS.scala b/arithmetic/src/division/srt/srt4/QDS.scala similarity index 64% rename from arithmetic/src/division/srt/QDS.scala rename to arithmetic/src/division/srt/srt4/QDS.scala index a86b875..eff9ab4 100644 --- a/arithmetic/src/division/srt/QDS.scala +++ b/arithmetic/src/division/srt/srt4/QDS.scala @@ -1,7 +1,9 @@ -package division.srt +package division.srt.srt4 + import chisel3._ -import chisel3.util.{BitPat} +import chisel3.util.BitPat import chisel3.util.experimental.decode._ +import division.srt.SRTTable import utils.extend class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle { @@ -30,18 +32,34 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { // Array(20, 8, -8, -22), // Array(24, 8, -8, -24)/16 // ) - val columnSelect = input.partialDivider - val selectRom: Vec[Vec[UInt]] = VecInit( - VecInit("b111_0100".U, "b111_1100".U, "b000_0100".U, "b000_1101".U), - VecInit("b111_0010".U, "b111_1100".U, "b000_0110".U, "b000_1111".U), - VecInit("b111_0001".U, "b111_1100".U, "b000_0110".U, "b001_0000".U), - VecInit("b111_0000".U, "b111_1100".U, "b000_0110".U, "b001_0010".U), - VecInit("b110_1110".U, "b111_1010".U, "b000_1000".U, "b001_0100".U), - VecInit("b110_1100".U, "b111_1010".U, "b000_1000".U, "b001_0100".U), - VecInit("b110_1100".U, "b111_1000".U, "b000_1000".U, "b001_0110".U), - VecInit("b110_1000".U, "b111_1000".U, "b000_1000".U, "b001_1000".U) - ) + // val selectRom: Vec[Vec[UInt]] = VecInit( + // VecInit("b111_0100".U, "b111_1100".U, "b000_0100".U, "b000_1101".U), + // VecInit("b111_0010".U, "b111_1100".U, "b000_0110".U, "b000_1111".U), + // VecInit("b111_0001".U, "b111_1100".U, "b000_0110".U, "b001_0000".U), + // VecInit("b111_0000".U, "b111_1100".U, "b000_0110".U, "b001_0010".U), + // VecInit("b110_1110".U, "b111_1010".U, "b000_1000".U, "b001_0100".U), + // VecInit("b110_1100".U, "b111_1010".U, "b000_1000".U, "b001_0100".U), + // VecInit("b110_1100".U, "b111_1000".U, "b000_1000".U, "b001_0110".U), + // VecInit("b110_1000".U, "b111_1000".U, "b000_1000".U, "b001_1000".U) + // ) + // get from SRTTable. + val tables: Seq[Seq[Int]] = SRTTable(4, 2, 4, 4).tablesToQDS + lazy val selectRom = VecInit(tables.map { + case x => + VecInit(x.map { + case x => + new StringBuffer("b") + .append( + if ((-x).toBinaryString.length >= rWidth) (-x).toBinaryString.reverse.substring(0, rWidth).reverse + else (-x).toBinaryString + ) + .toString + .U + }) + }) + + val columnSelect = input.partialDivider val adderWidth = rWidth + 1 val yTruncate: UInt = input.partialReminderCarry + input.partialReminderSum val mkVec = selectRom(columnSelect) diff --git a/arithmetic/src/division/srt/SRT4.scala b/arithmetic/src/division/srt/srt4/SRT4.scala similarity index 84% rename from arithmetic/src/division/srt/SRT4.scala rename to arithmetic/src/division/srt/srt4/SRT4.scala index 0abc6e4..ca10fe9 100644 --- a/arithmetic/src/division/srt/SRT4.scala +++ b/arithmetic/src/division/srt/srt4/SRT4.scala @@ -1,32 +1,21 @@ -package division.srt +package division.srt.srt4 +import division.srt._ import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 import chisel3._ -import chisel3.util.{log2Ceil, DecoupledIO, Fill, Mux1H, RegEnable, ValidIO} +import chisel3.util._ import utils.leftShift /** SRT4 * 1/2 <= d < 1, 1/2 < rho <=1, 0 < q < 2 * radix = 4 * a = 2, {-2, -1, 0, 1, -2}, - * t = 4 + * dTruncateWidth = 4, rTruncateWidth = 8 * y^(xxx.xxxx), d^(0.1xxx) * -44/16 < y^ < 42/16 */ -class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { - val dividend = UInt(dividendWidth.W) //.*********** - val divider = UInt(dividerWidth.W) //.1********** - val counter = UInt(log2Ceil(n).W) //the width of quotient. -} - -class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { - val reminder = UInt(reminderWidth.W) - val quotient = UInt(quotientWidth.W) -} - -// only SRT4 currently class SRT4( dividendWidth: Int, dividerWidth: Int, @@ -73,13 +62,14 @@ class SRT4( val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry val remainderCorrect: UInt = - partialReminderSum + partialReminderCarry + (divider << 2) + partialReminderSum + partialReminderCarry + (divider << radixLog2) val needCorrect: Bool = remainderNoCorrect(wLen - 3).asBool output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 4, radixLog2) output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) // qds val rWidth: Int = 1 + radixLog2 + rTruncateWidth + val selectedQuotientOH: UInt = QDS(rWidth, ohWidth, dTruncateWidth - 1)( leftShift(partialReminderSum, radixLog2).head(rWidth), @@ -88,7 +78,7 @@ class SRT4( ) qdsSign := selectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR - // csa for SRT4 -> CSA32, SRT8 -> CSA32+CSA32, SRT16 -> CSA53+CSA32, SRT16 <- SRT4 + SRT4*5 + // csa for SRT4 -> CSA32 val csa = Module(new CarrySaveAdder(CSACompressor3_2, xLen)) csa.in(0) := leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2) csa.in(1) := leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign diff --git a/arithmetic/src/division/srt/srt8/OTF.scala b/arithmetic/src/division/srt/srt8/OTF.scala new file mode 100644 index 0000000..88196ef --- /dev/null +++ b/arithmetic/src/division/srt/srt8/OTF.scala @@ -0,0 +1,71 @@ +package division.srt.srt8 + +import chisel3._ +import chisel3.util.Mux1H + +class OTFInput(qWidth: Int, ohWidth: Int) extends Bundle { + val quotient = UInt(qWidth.W) + val quotientMinusOne = UInt(qWidth.W) + val selectedQuotientOH = UInt(ohWidth.W) +} + +class OTFOutput(qWidth: Int) extends Bundle { + val quotient = UInt(qWidth.W) + val quotientMinusOne = UInt(qWidth.W) +} + +class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int) extends Module { + val input = IO(Input(new OTFInput(qWidth, ohWidth))) + val output = IO(Output(new OTFOutput(qWidth))) + + val radix: Int = 1 << radixLog2 + // datapath + // q_j+1 in this circle, only for srt8(a = 7) + val qNext: UInt = Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b11110".U, // -2 + input.selectedQuotientOH(1) -> "b11111".U, // -1 + input.selectedQuotientOH(2) -> "b00000".U, // 0 + input.selectedQuotientOH(3) -> "b00001".U, // 1 + input.selectedQuotientOH(4) -> "b00010".U // 2 + ) + ) + Mux1H( + Seq( + input.selectedQuotientOH(5) -> "b11000".U, // -8 + input.selectedQuotientOH(6) -> "b11100".U, // -4 + input.selectedQuotientOH(7) -> "b00000".U, // 0 + input.selectedQuotientOH(8) -> "b00100".U, // 4 + input.selectedQuotientOH(9) -> "b01000".U // 8 + ) + ) + + // val cShiftQ: Bool = qNext >= 0.U + // val cShiftQM: Bool = qNext <= 0.U + val cShiftQ: Bool = input.selectedQuotientOH(9, 8).orR || + (input.selectedQuotientOH(7) && input.selectedQuotientOH(4, 2).orR) + val cShiftQM: Bool = input.selectedQuotientOH(6, 5).orR || + (input.selectedQuotientOH(7) && input.selectedQuotientOH(2, 0).orR) + + val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0) + val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0) + + output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qIn + output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qmIn +} + +object OTF { + def apply( + radixLog2: Int, + qWidth: Int, + ohWidth: Int + )(quotient: UInt, + quotientMinusOne: UInt, + selectedQuotientOH: UInt + ): Seq[UInt] = { + val m = Module(new OTF(radixLog2, qWidth, ohWidth)) + m.input.quotient := quotient + m.input.quotientMinusOne := quotientMinusOne + m.input.selectedQuotientOH := selectedQuotientOH + Seq(m.output.quotient, m.output.quotientMinusOne) + } +} diff --git a/arithmetic/src/division/srt/srt8/QDS.scala b/arithmetic/src/division/srt/srt8/QDS.scala new file mode 100644 index 0000000..eff243a --- /dev/null +++ b/arithmetic/src/division/srt/srt8/QDS.scala @@ -0,0 +1,89 @@ +package division.srt.srt8 + +import chisel3._ +import chisel3.util.{BitPat, ValidIO} +import chisel3.util.experimental.decode.{TruthTable, _} +import division.srt.SRTTable +import utils.extend + +class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle { + val partialReminderCarry: UInt = UInt(rWidth.W) + val partialReminderSum: UInt = UInt(rWidth.W) + val partialDivider: UInt = UInt(partialDividerWidth.W) +} + +class QDSOutput(ohWidth: Int) extends Bundle { + val selectedQuotientOH: UInt = UInt(ohWidth.W) +} + +class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { + // IO + val input = IO(Input(new QDSInput(rWidth, partialDividerWidth))) + val output = IO(Output(new QDSOutput(ohWidth))) + + val columnSelect = input.partialDivider + // Seq[Seq[Int]] => Vec[Vec[UInt]] + val tables: Seq[Seq[Int]] = SRTTable(8, 7, 4, 4).tablesToQDS + lazy val selectRom = VecInit(tables.map { + case x => + VecInit(x.map { + case x => + new StringBuffer("b") + .append( + if ((-x).toBinaryString.length >= rWidth) (-x).toBinaryString.reverse.substring(0, rWidth).reverse + else (-x).toBinaryString + ) + .toString + .U(rWidth.W) + }) + }) + + val adderWidth = rWidth + 1 + val yTruncate: UInt = input.partialReminderCarry + input.partialReminderSum + val mkVec = selectRom(columnSelect) + val selectPoints = VecInit(mkVec.map { mk => + (extend(yTruncate, adderWidth).asUInt + + extend(mk, adderWidth).asUInt).head(1) + }).asUInt + + // decoder or findFirstOne here, prefer decoder, the decoder only for srt8(a = 7) + output.selectedQuotientOH := chisel3.util.experimental.decode.decoder( + selectPoints, + TruthTable( + Seq( // 8 4 0 -4 -8__2 1 0 -1 -2 + BitPat("b??_????_????_???0") -> BitPat("b10000_00010"), // 7 = +8 + (-1) + BitPat("b??_????_????_??01") -> BitPat("b01000_10000"), // 6 = +4 + (+2) + BitPat("b??_????_????_?011") -> BitPat("b01000_01000"), // 5 = +4 + (+1) + BitPat("b??_????_????_0111") -> BitPat("b01000_00100"), // 4 = +4 + ( 0) + BitPat("b??_????_???0_1111") -> BitPat("b01000_00010"), // 3 = +4 + (-1) + BitPat("b??_????_??01_1111") -> BitPat("b00100_10000"), // 2 = 0 + (+2) + BitPat("b??_????_?011_1111") -> BitPat("b00100_01000"), // 1 = 0 + (+1) + BitPat("b??_????_0111_1111") -> BitPat("b00100_00100"), // 0 = 0 + ( 0) + BitPat("b??_???0_1111_1111") -> BitPat("b00100_00010"), //-1 = 0 + (-1) + BitPat("b??_??01_1111_1111") -> BitPat("b00100_00001"), //-2 = 0 + (-2) + BitPat("b??_?011_1111_1111") -> BitPat("b00010_01000"), //-3 = -4 + ( 1) + BitPat("b??_0111_1111_1111") -> BitPat("b00010_00100"), //-4 = -4 + ( 0) + BitPat("b?0_1111_1111_1111") -> BitPat("b00010_00010"), //-5 = -4 + (-1) + BitPat("b01_1111_1111_1111") -> BitPat("b00010_00001") //-6 = -4 + (-2) + ), + BitPat("b00001_01000") //-7 = -8 + (+1) + ) + ) +} + +object QDS { + def apply( + rWidth: Int, + ohWidth: Int, + partialDividerWidth: Int + )(partialReminderSum: UInt, + partialReminderCarry: UInt, + partialDivider: UInt + ): UInt = { + val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth)) + m.input.partialReminderSum := partialReminderSum + m.input.partialReminderCarry := partialReminderCarry + m.input.partialDivider := partialDivider + m.output.selectedQuotientOH + } +} diff --git a/arithmetic/src/division/srt/srt8/SRT8.scala b/arithmetic/src/division/srt/srt8/SRT8.scala new file mode 100644 index 0000000..b0a889e --- /dev/null +++ b/arithmetic/src/division/srt/srt8/SRT8.scala @@ -0,0 +1,122 @@ +package division.srt.srt8 + +import division.srt._ +import division.srt.SRTTable +import chisel3._ +import chisel3.util._ +import utils.leftShift + +/** SRT8 + * 1/2 <= d < 1, 1/2 < rho <=1, 0 < q < 2 + * radix = 8 + * a = 7, {-7, ... ,-2, -1, 0, 1, 2, ... 7}, + * dTruncateWidth = 4, rTruncateWidth = 4 + * y^(xxxx.xxxx), d^(0.1xxx) + * table from SRTTable + */ + +class SRT8( + dividendWidth: Int, + dividerWidth: Int, + n: Int, // the longest width + radixLog2: Int = 3, + a: Int = 7, + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4) + extends Module { + + val xLen: Int = dividendWidth + radixLog2 + 1 + val wLen: Int = xLen + radixLog2 + val ohWidth: Int = 10 + + // IO + val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) + val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) + + val partialReminderCarryNext, partialReminderSumNext = Wire(UInt(wLen.W)) + val quotientNext, quotientMinusOneNext = Wire(UInt(n.W)) + val dividerNext = Wire(UInt(dividerWidth.W)) + val counterNext = Wire(UInt(log2Ceil(n).W)) + + // Control + // sign of select quotient, true -> negative, false -> positive + // sign of Cycle, true -> (counter === 0.U) + val qdsSign0, qdsSign1, isLastCycle, enable: Bool = Wire(Bool()) + + // State + // because we need a CSA to minimize the critical path + val partialReminderCarry = RegEnable(partialReminderCarryNext, 0.U(wLen.W), enable) + val partialReminderSum = RegEnable(partialReminderSumNext, 0.U(wLen.W), enable) + val divider = RegEnable(dividerNext, 0.U(dividerWidth.W), enable) + val quotient = RegEnable(quotientNext, 0.U(n.W), enable) + val quotientMinusOne = RegEnable(quotientMinusOneNext, 0.U(n.W), enable) + val counter = RegEnable(counterNext, 0.U(log2Ceil(n).W), enable) + + // Datapath + // according two adders + isLastCycle := !counter.orR + output.valid := isLastCycle + input.ready := isLastCycle + enable := input.fire || !isLastCycle + + val remainderNoCorrect: UInt = partialReminderSum + partialReminderCarry + val remainderCorrect: UInt = + partialReminderSum + partialReminderCarry + (divider << radixLog2) + val needCorrect: Bool = remainderNoCorrect(wLen - 4).asBool + output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 5, radixLog2) + output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) + + // qds + val rWidth: Int = 1 + radixLog2 + rTruncateWidth + val selectedQuotientOH: UInt = + QDS(rWidth, ohWidth, dTruncateWidth - 1)( + leftShift(partialReminderSum, radixLog2).head(rWidth), + leftShift(partialReminderCarry, radixLog2).head(rWidth), + dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> *** + ) + + qdsSign0 := selectedQuotientOH(9, 8).orR + qdsSign1 := selectedQuotientOH(4, 3).orR + + val qHigh: UInt = selectedQuotientOH(9, 5) + val qLow: UInt = selectedQuotientOH(4, 0) + // csa for SRT8 -> CSA32+CSA32 + val divideMap0 = VecInit((-2 to 2).map { + case -2 => divider << 3 // -8 + case -1 => divider << 2 // -4 + case 0 => 0.U // 0 + case 1 => Fill(2, 1.U(1.W)) ## ~(divider << 2) // 4 + case 2 => Fill(1, 1.U(1.W)) ## ~(divider << 3) // 8 + }) + val divideMap1 = VecInit((-2 to 2).map { + case -2 => divider << 1 // -2 + case -1 => divider // -1 + case 0 => 0.U // 0 + case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider // 1 + case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) // 2 + }) + val csa0 = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0, + Mux1H(qHigh, divideMap0) + ) + ) + val csa1 = addition.csa.c32( + VecInit( + csa0(1).head(wLen - radixLog2), + leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1, + Mux1H(qLow, divideMap1) + ) + ) + + // On-The-Fly conversion + val otf = OTF(radixLog2, n, ohWidth)(quotient, quotientMinusOne, selectedQuotientOH) + + dividerNext := Mux(input.fire, input.bits.divider, divider) + counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) + quotientNext := Mux(input.fire, 0.U, otf(0)) + quotientMinusOneNext := Mux(input.fire, 0.U, otf(1)) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa1(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa1(0) << 1 + radixLog2) +} diff --git a/arithmetic/tests/src/division/srt/SRT16Test.scala b/arithmetic/tests/src/division/srt/SRT16Test.scala index c9e3074..3bbf36a 100644 --- a/arithmetic/tests/src/division/srt/SRT16Test.scala +++ b/arithmetic/tests/src/division/srt/SRT16Test.scala @@ -1,4 +1,4 @@ -package division.srt +package division.srt.srt16 import chisel3._ import chisel3.tester.{ChiselUtestTester, testableClock, testableData} @@ -18,6 +18,8 @@ object SRT16Test extends TestSuite with ChiselUtestTester { val q: Int = Random.nextInt(m - radixLog2 +1) val dividend: BigInt = BigInt(p, Random) val divider: BigInt = BigInt(q, Random) +// val dividend: BigInt = BigInt("65") +// val divider: BigInt = BigInt("1") def zeroCheck(x: BigInt): Int = { var flag = false var a: Int = m @@ -67,8 +69,8 @@ object SRT16Test extends TestSuite with ChiselUtestTester { } } - testcase(64) -// for( i <- 1 to 100){ + testcase(16) +// for( i <- 1 to 50){ // testcase(128) // } } diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala index 9f447c9..5a4d938 100644 --- a/arithmetic/tests/src/division/srt/SRT4Test.scala +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -1,4 +1,4 @@ -package division.srt +package division.srt.srt4 import chisel3._ import chisel3.tester.{ChiselUtestTester, testableClock, testableData} @@ -17,6 +17,8 @@ object SRT4Test extends TestSuite with ChiselUtestTester { val q: Int = Random.nextInt(m) val dividend: BigInt = BigInt(p, Random) val divider: BigInt = BigInt(q, Random) +// val dividend: BigInt = BigInt("65") +// val divider: BigInt = BigInt("1") def zeroCheck(x: BigInt): Int = { var flag = false var a: Int = m @@ -73,7 +75,7 @@ object SRT4Test extends TestSuite with ChiselUtestTester { testcase(64) // for( i <- 1 to 100){ -// testcase(128) +// testcase(64) // } } } diff --git a/arithmetic/tests/src/division/srt/SRT8Test.scala b/arithmetic/tests/src/division/srt/SRT8Test.scala new file mode 100644 index 0000000..b436aa8 --- /dev/null +++ b/arithmetic/tests/src/division/srt/SRT8Test.scala @@ -0,0 +1,77 @@ +package division.srt.srt8 + +import chisel3._ +import chisel3.tester.{ChiselUtestTester, testableClock, testableData} +import utest._ + +import scala.util.Random + +object SRT8Test extends TestSuite with ChiselUtestTester { + def tests: Tests = Tests { + test("SRT8 should pass") { + def testcase(width: Int): Unit ={ + // parameters + val radixLog2: Int = 3 + val n: Int = width + val m: Int = n - 1 + val p: Int = Random.nextInt(m - radixLog2 +1) //order to offer guardwidth + val q: Int = Random.nextInt(m - radixLog2 +1) + val dividend: BigInt = BigInt(p, Random) + val divider: BigInt = BigInt(q, Random) +// val dividend: BigInt = BigInt("65") +// val divider: BigInt = BigInt("1") + def zeroCheck(x: BigInt): Int = { + var flag = false + var a: Int = m + while (!flag && (a >= -1)) { + flag = ((BigInt(1) << a) & x) != 0 + a = a - 1 + } + a + 1 + } + val zeroHeadDividend: Int = m - zeroCheck(dividend) + val zeroHeadDivider: Int = m - zeroCheck(divider) + val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + radixLog2 -1 + val noguard: Boolean = needComputerWidth % radixLog2 == 0 + val guardWidth: Int = if (noguard) 0 else 3 - needComputerWidth % 3 + val counter: Int = (needComputerWidth + guardWidth) / radixLog2 + if ((divider == 0) || (divider > dividend) || (needComputerWidth <= 0)) + return + val quotient: BigInt = dividend / divider + val remainder: BigInt = dividend % divider + val leftShiftWidthDividend: Int = zeroHeadDividend - guardWidth + val leftShiftWidthDivider: Int = zeroHeadDivider + testCircuit(new SRT8(n, n, n), + Seq(chiseltest.internal.NoThreadingAnnotation, + chiseltest.simulator.WriteVcdAnnotation)) { + dut: SRT8 => + dut.clock.setTimeout(0) + dut.input.valid.poke(true.B) + dut.input.bits.dividend.poke((dividend << leftShiftWidthDividend).U) + dut.input.bits.divider.poke((divider << leftShiftWidthDivider).U) + dut.input.bits.counter.poke(counter.U) + dut.clock.step() + dut.input.valid.poke(false.B) + var flag = false + for (a <- 1 to 1000 if !flag) { + if (dut.output.valid.peek().litValue == 1) { + flag = true + println(dut.output.bits.quotient.peek().litValue) + println(dut.output.bits.reminder.peek().litValue) + utest.assert(dut.output.bits.quotient.peek().litValue == quotient) + utest.assert(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider == remainder) + } + dut.clock.step() + } + utest.assert(flag) + dut.clock.step(scala.util.Random.nextInt(10)) + } + } + + testcase(64) +// for( i <- 1 to 50){ +// testcase(128) +// } + } + } +} \ No newline at end of file diff --git a/arithmetic/tests/src/division/srt/SRTSpec.scala b/arithmetic/tests/src/division/srt/SRTSpec.scala index 4b01197..cedf006 100644 --- a/arithmetic/tests/src/division/srt/SRTSpec.scala +++ b/arithmetic/tests/src/division/srt/SRTSpec.scala @@ -1,17 +1,17 @@ package division.srt import utest._ +import chisel3._ +import utils.extend object SRTSpec extends TestSuite{ override def tests: Tests = Tests { test("SRT should draw PD") { - val srt = SRTTable(4, 2, 4, 4) -// val table = srt.tables.flatMap { -// case (i, ps) => ps.flatMap{ case (d, xs) => xs.map(x => (d.toDouble, x.toDouble*16)) }}.groupBy(_._1) -// table.map{case (x, y) => println(y)} - srt.dumpGraph(srt.pd, os.root / "tmp" / "srt4-2-4-4.png") + val srt = SRTTable(8, 7, 4, 4) +// println(srt.tablesToQDS) + srt.dumpGraph(srt.pd, os.root / "tmp" / "srt8-7-4-4.png") } } } From bf3eec4937408fc59a9c5937c4def75a8ff40ee0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Sat, 11 Jun 2022 21:16:02 +0800 Subject: [PATCH 29/31] fix srt16 & SRTTable fixed --- arithmetic/src/division/srt/SRTTable.scala | 2 +- arithmetic/src/division/srt/srt16/QDS.scala | 4 +- arithmetic/src/division/srt/srt16/SRT16.scala | 38 ++++++++++--------- arithmetic/src/division/srt/srt4/QDS.scala | 13 +++---- arithmetic/src/division/srt/srt4/SRT4.scala | 5 ++- arithmetic/src/division/srt/srt8/QDS.scala | 13 +++---- arithmetic/src/division/srt/srt8/SRT8.scala | 15 +++++--- .../tests/src/division/srt/SRT16Test.scala | 4 +- .../tests/src/division/srt/SRT4Test.scala | 4 +- .../tests/src/division/srt/SRT8Test.scala | 2 +- .../tests/src/division/srt/SRTSpec.scala | 5 ++- 11 files changed, 56 insertions(+), 49 deletions(-) diff --git a/arithmetic/src/division/srt/SRTTable.scala b/arithmetic/src/division/srt/SRTTable.scala index 8de5dfb..b76c3a0 100644 --- a/arithmetic/src/division/srt/SRTTable.scala +++ b/arithmetic/src/division/srt/SRTTable.scala @@ -99,7 +99,7 @@ case class SRTTable( }.flatMap { case (i, ps) => ps.map { - case (x, y) => (x.toDouble, y.toDouble * 16) + case (x, y) => (x.toDouble, y.toDouble * (1 << xTruncateWidth.toInt)) } }.groupBy(_._1).toSeq.sortBy(_._1).map { case (x, y) => y.map { case (x, y) => y.toInt }.reverse } diff --git a/arithmetic/src/division/srt/srt16/QDS.scala b/arithmetic/src/division/srt/srt16/QDS.scala index 6fa0337..f615d93 100644 --- a/arithmetic/src/division/srt/srt16/QDS.scala +++ b/arithmetic/src/division/srt/srt16/QDS.scala @@ -21,7 +21,7 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[I val output = IO(Output(new QDSOutput(ohWidth))) // get from SRTTable. - val selectRom = VecInit(tables.map { + lazy val selectRom = VecInit(tables.map { case x => VecInit(x.map { case x => @@ -31,7 +31,7 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[I else (-x).toBinaryString ) .toString - .U + .U(rWidth.W) }) }) diff --git a/arithmetic/src/division/srt/srt16/SRT16.scala b/arithmetic/src/division/srt/srt16/SRT16.scala index 2af05f1..210ba49 100644 --- a/arithmetic/src/division/srt/srt16/SRT16.scala +++ b/arithmetic/src/division/srt/srt16/SRT16.scala @@ -22,6 +22,7 @@ class SRT16( val xLen: Int = dividendWidth + radixLog2 + 1 val wLen: Int = xLen + radixLog2 val ohWidth: Int = 2 * a + 1 + val rWidth: Int = 1 + radixLog2 + rTruncateWidth // IO val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) @@ -64,17 +65,19 @@ class SRT16( case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) }) - val csaIn1 = leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2) - val csaIn2 = leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) - val csa1 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## false.B, dividerMap(0))) // -2 - val csa2 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## false.B, dividerMap(1))) // -1 - val csa3 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## false.B, dividerMap(2))) // 0 - val csa4 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## true.B, dividerMap(3))) // 1 - val csa5 = addition.csa.c32(VecInit(csaIn1, csaIn2 ## true.B, dividerMap(4))) // 2 + val csa0InWidth = rWidth + radixLog2 + 1 + val csaIn1 = leftShift(partialReminderSum, radixLog2).head(csa0InWidth) + val csaIn2 = leftShift(partialReminderCarry, radixLog2).head(csa0InWidth) + + val csa1 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(0).head(csa0InWidth))) // -2 csain 10bit + val csa2 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(1).head(csa0InWidth))) // -1 + val csa3 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(2).head(csa0InWidth))) // 0 + val csa4 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(3).head(csa0InWidth))) // 1 + val csa5 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(4).head(csa0InWidth))) // 2 // qds - val rWidth: Int = 1 + radixLog2 + rTruncateWidth - val tables: Seq[Seq[Int]] = division.srt.SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS + + val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS val partialDivider: UInt = dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) val qdsOH0: UInt = QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)( @@ -97,13 +100,6 @@ class SRT16( val qds4SelectedQuotientOH: UInt = qds(csa4) // 1 val qds5SelectedQuotientOH: UInt = qds(csa5) // 2 - val csa0OutMap = VecInit((-2 to 2).map { - case -2 => csa1 - case -1 => csa2 - case 0 => csa3 - case 1 => csa4 - case 2 => csa5 - }) val qds1SelectedQuotientOHMap = VecInit((-2 to 2).map { case -2 => qds1SelectedQuotientOH case -1 => qds2SelectedQuotientOH @@ -113,8 +109,16 @@ class SRT16( }) val qdsOH1 = Mux1H(qdsOH0, qds1SelectedQuotientOHMap) // q_j+2 oneHot + val qds0sign = qdsOH0(ohWidth - 1, ohWidth / 2 + 1).orR val qds1sign = qdsOH1(ohWidth - 1, ohWidth / 2 + 1).orR - val csa0Out = Mux1H(qdsOH0, csa0OutMap) + + val csa0Out = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qds0sign, + Mux1H(qdsOH0, dividerMap) + ) + ) val csa1Out = addition.csa.c32( VecInit( leftShift(csa0Out(1), radixLog2).head(wLen - radixLog2), diff --git a/arithmetic/src/division/srt/srt4/QDS.scala b/arithmetic/src/division/srt/srt4/QDS.scala index eff9ab4..b3c7799 100644 --- a/arithmetic/src/division/srt/srt4/QDS.scala +++ b/arithmetic/src/division/srt/srt4/QDS.scala @@ -2,8 +2,7 @@ package division.srt.srt4 import chisel3._ import chisel3.util.BitPat -import chisel3.util.experimental.decode._ -import division.srt.SRTTable +import chisel3.util.experimental.decode.{TruthTable} import utils.extend class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle { @@ -16,7 +15,7 @@ class QDSOutput(ohWidth: Int) extends Bundle { val selectedQuotientOH: UInt = UInt(ohWidth.W) } -class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { +class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]]) extends Module { // IO val input = IO(Input(new QDSInput(rWidth, partialDividerWidth))) val output = IO(Output(new QDSOutput(ohWidth))) @@ -44,7 +43,6 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { // ) // get from SRTTable. - val tables: Seq[Seq[Int]] = SRTTable(4, 2, 4, 4).tablesToQDS lazy val selectRom = VecInit(tables.map { case x => VecInit(x.map { @@ -55,7 +53,7 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { else (-x).toBinaryString ) .toString - .U + .U(rWidth.W) }) }) @@ -87,12 +85,13 @@ object QDS { def apply( rWidth: Int, ohWidth: Int, - partialDividerWidth: Int + partialDividerWidth: Int, + tables: Seq[Seq[Int]] )(partialReminderSum: UInt, partialReminderCarry: UInt, partialDivider: UInt ): UInt = { - val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth)) + val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables)) m.input.partialReminderSum := partialReminderSum m.input.partialReminderCarry := partialReminderCarry m.input.partialDivider := partialDivider diff --git a/arithmetic/src/division/srt/srt4/SRT4.scala b/arithmetic/src/division/srt/srt4/SRT4.scala index ca10fe9..4111a27 100644 --- a/arithmetic/src/division/srt/srt4/SRT4.scala +++ b/arithmetic/src/division/srt/srt4/SRT4.scala @@ -14,6 +14,7 @@ import utils.leftShift * dTruncateWidth = 4, rTruncateWidth = 8 * y^(xxx.xxxx), d^(0.1xxx) * -44/16 < y^ < 42/16 + * floor((-r*rho - 2^-t)_t) <= y^ <= floor((r*rho - ulp)_t) */ class SRT4( @@ -69,9 +70,9 @@ class SRT4( // qds val rWidth: Int = 1 + radixLog2 + rTruncateWidth - + val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS val selectedQuotientOH: UInt = - QDS(rWidth, ohWidth, dTruncateWidth - 1)( + QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)( leftShift(partialReminderSum, radixLog2).head(rWidth), leftShift(partialReminderCarry, radixLog2).head(rWidth), dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> *** diff --git a/arithmetic/src/division/srt/srt8/QDS.scala b/arithmetic/src/division/srt/srt8/QDS.scala index eff243a..513c0e5 100644 --- a/arithmetic/src/division/srt/srt8/QDS.scala +++ b/arithmetic/src/division/srt/srt8/QDS.scala @@ -1,9 +1,8 @@ package division.srt.srt8 import chisel3._ -import chisel3.util.{BitPat, ValidIO} -import chisel3.util.experimental.decode.{TruthTable, _} -import division.srt.SRTTable +import chisel3.util.{BitPat} +import chisel3.util.experimental.decode.{TruthTable} import utils.extend class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle { @@ -16,14 +15,13 @@ class QDSOutput(ohWidth: Int) extends Bundle { val selectedQuotientOH: UInt = UInt(ohWidth.W) } -class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int) extends Module { +class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]]) extends Module { // IO val input = IO(Input(new QDSInput(rWidth, partialDividerWidth))) val output = IO(Output(new QDSOutput(ohWidth))) val columnSelect = input.partialDivider // Seq[Seq[Int]] => Vec[Vec[UInt]] - val tables: Seq[Seq[Int]] = SRTTable(8, 7, 4, 4).tablesToQDS lazy val selectRom = VecInit(tables.map { case x => VecInit(x.map { @@ -75,12 +73,13 @@ object QDS { def apply( rWidth: Int, ohWidth: Int, - partialDividerWidth: Int + partialDividerWidth: Int, + tables: Seq[Seq[Int]] )(partialReminderSum: UInt, partialReminderCarry: UInt, partialDivider: UInt ): UInt = { - val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth)) + val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables)) m.input.partialReminderSum := partialReminderSum m.input.partialReminderCarry := partialReminderCarry m.input.partialDivider := partialDivider diff --git a/arithmetic/src/division/srt/srt8/SRT8.scala b/arithmetic/src/division/srt/srt8/SRT8.scala index b0a889e..5eeb982 100644 --- a/arithmetic/src/division/srt/srt8/SRT8.scala +++ b/arithmetic/src/division/srt/srt8/SRT8.scala @@ -4,7 +4,7 @@ import division.srt._ import division.srt.SRTTable import chisel3._ import chisel3.util._ -import utils.leftShift +import utils.{leftShift} /** SRT8 * 1/2 <= d < 1, 1/2 < rho <=1, 0 < q < 2 @@ -13,6 +13,8 @@ import utils.leftShift * dTruncateWidth = 4, rTruncateWidth = 4 * y^(xxxx.xxxx), d^(0.1xxx) * table from SRTTable + * -129/16 < y^ < 127/16 + * floor((-r*rho - 2^-t)_t) <= y^ <= floor((r*rho - ulp)_t) */ class SRT8( @@ -68,8 +70,9 @@ class SRT8( // qds val rWidth: Int = 1 + radixLog2 + rTruncateWidth + val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS val selectedQuotientOH: UInt = - QDS(rWidth, ohWidth, dTruncateWidth - 1)( + QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)( leftShift(partialReminderSum, radixLog2).head(rWidth), leftShift(partialReminderCarry, radixLog2).head(rWidth), dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> *** @@ -81,14 +84,14 @@ class SRT8( val qHigh: UInt = selectedQuotientOH(9, 5) val qLow: UInt = selectedQuotientOH(4, 0) // csa for SRT8 -> CSA32+CSA32 - val divideMap0 = VecInit((-2 to 2).map { + val dividerMap0 = VecInit((-2 to 2).map { case -2 => divider << 3 // -8 case -1 => divider << 2 // -4 case 0 => 0.U // 0 case 1 => Fill(2, 1.U(1.W)) ## ~(divider << 2) // 4 case 2 => Fill(1, 1.U(1.W)) ## ~(divider << 3) // 8 }) - val divideMap1 = VecInit((-2 to 2).map { + val dividerMap1 = VecInit((-2 to 2).map { case -2 => divider << 1 // -2 case -1 => divider // -1 case 0 => 0.U // 0 @@ -99,14 +102,14 @@ class SRT8( VecInit( leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0, - Mux1H(qHigh, divideMap0) + Mux1H(qHigh, dividerMap0) ) ) val csa1 = addition.csa.c32( VecInit( csa0(1).head(wLen - radixLog2), leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1, - Mux1H(qLow, divideMap1) + Mux1H(qLow, dividerMap1) ) ) diff --git a/arithmetic/tests/src/division/srt/SRT16Test.scala b/arithmetic/tests/src/division/srt/SRT16Test.scala index 3bbf36a..47ffa9b 100644 --- a/arithmetic/tests/src/division/srt/SRT16Test.scala +++ b/arithmetic/tests/src/division/srt/SRT16Test.scala @@ -69,9 +69,9 @@ object SRT16Test extends TestSuite with ChiselUtestTester { } } - testcase(16) + testcase(64) // for( i <- 1 to 50){ -// testcase(128) +// testcase(64) // } } } diff --git a/arithmetic/tests/src/division/srt/SRT4Test.scala b/arithmetic/tests/src/division/srt/SRT4Test.scala index 5a4d938..efb42dd 100644 --- a/arithmetic/tests/src/division/srt/SRT4Test.scala +++ b/arithmetic/tests/src/division/srt/SRT4Test.scala @@ -30,7 +30,7 @@ object SRT4Test extends TestSuite with ChiselUtestTester { } val zeroHeadDividend: Int = m - zeroCheck(dividend) val zeroHeadDivider: Int = m - zeroCheck(divider) - val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + 1 + val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + radixLog2 - 1 val noguard: Boolean = needComputerWidth % radixLog2 == 0 val counter: Int = (needComputerWidth + 1) / 2 if ((divider == 0) || (divider > dividend) || (needComputerWidth <= 0)) @@ -74,7 +74,7 @@ object SRT4Test extends TestSuite with ChiselUtestTester { } testcase(64) -// for( i <- 1 to 100){ +// for( i <- 1 to 50){ // testcase(64) // } } diff --git a/arithmetic/tests/src/division/srt/SRT8Test.scala b/arithmetic/tests/src/division/srt/SRT8Test.scala index b436aa8..317b8e0 100644 --- a/arithmetic/tests/src/division/srt/SRT8Test.scala +++ b/arithmetic/tests/src/division/srt/SRT8Test.scala @@ -70,7 +70,7 @@ object SRT8Test extends TestSuite with ChiselUtestTester { testcase(64) // for( i <- 1 to 50){ -// testcase(128) +// testcase(64) // } } } diff --git a/arithmetic/tests/src/division/srt/SRTSpec.scala b/arithmetic/tests/src/division/srt/SRTSpec.scala index cedf006..4309e09 100644 --- a/arithmetic/tests/src/division/srt/SRTSpec.scala +++ b/arithmetic/tests/src/division/srt/SRTSpec.scala @@ -9,9 +9,10 @@ import utils.extend object SRTSpec extends TestSuite{ override def tests: Tests = Tests { test("SRT should draw PD") { - val srt = SRTTable(8, 7, 4, 4) + val srt = SRTTable(4,2,4,4) +// println(srt.tables) // println(srt.tablesToQDS) - srt.dumpGraph(srt.pd, os.root / "tmp" / "srt8-7-4-4.png") + srt.dumpGraph(srt.pd, os.root / "tmp" / "srt4-2-4-4.png") } } } From ffdb455dff4ced61018453d846643160e8e4e7c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Sun, 12 Jun 2022 11:46:40 +0800 Subject: [PATCH 30/31] fix selectRom & add selection of Radix --- arithmetic/src/division/srt/SRT.scala | 34 +++++++ arithmetic/src/division/srt/SRTIO.scala | 25 +++++- arithmetic/src/division/srt/srt16/OTF.scala | 12 +-- arithmetic/src/division/srt/srt16/QDS.scala | 23 +---- arithmetic/src/division/srt/srt4/OTF.scala | 12 +-- arithmetic/src/division/srt/srt4/QDS.scala | 25 ++---- arithmetic/src/division/srt/srt8/OTF.scala | 12 +-- arithmetic/src/division/srt/srt8/QDS.scala | 27 ++---- arithmetic/src/division/srt/srt8/SRT8.scala | 2 +- arithmetic/src/utils/package.scala | 2 +- .../tests/src/division/srt/SRTTest.scala | 88 +++++++++++++++++++ 11 files changed, 166 insertions(+), 96 deletions(-) create mode 100644 arithmetic/src/division/srt/SRT.scala create mode 100644 arithmetic/tests/src/division/srt/SRTTest.scala diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala new file mode 100644 index 0000000..dd1d961 --- /dev/null +++ b/arithmetic/src/division/srt/SRT.scala @@ -0,0 +1,34 @@ +package division.srt + +import division.srt.srt4._ +import division.srt.srt8._ +import division.srt.srt16._ +import chisel3._ +import chisel3.util.{DecoupledIO, ValidIO} + +class SRT( + dividendWidth: Int, + dividerWidth: Int, + n: Int, // the longest width + radixLog2: Int = 2, + a: Int = 2, + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4) + extends Module { + val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) + val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) + // select radix + if (radixLog2 == 2) { // SRT4 + val srt = Module(new SRT4(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth)) + srt.input <> input + output <> srt.output + } else if (radixLog2 == 3) { // SRT8 + val srt = Module(new SRT8(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth)) + srt.input <> input + output <> srt.output + } else if (radixLog2 == 4) { //SRT16 + val srt = Module(new SRT16(dividendWidth, dividerWidth, n, radixLog2 >> 1, a, dTruncateWidth, rTruncateWidth)) + srt.input <> input + output <> srt.output + } +} diff --git a/arithmetic/src/division/srt/SRTIO.scala b/arithmetic/src/division/srt/SRTIO.scala index ae8e7b0..417aaa5 100644 --- a/arithmetic/src/division/srt/SRTIO.scala +++ b/arithmetic/src/division/srt/SRTIO.scala @@ -2,7 +2,7 @@ package division.srt import chisel3._ import chisel3.util.log2Ceil - +// SRTIO class SRTInput(dividendWidth: Int, dividerWidth: Int, n: Int) extends Bundle { val dividend = UInt(dividendWidth.W) //.*********** val divider = UInt(dividerWidth.W) //.1********** @@ -13,3 +13,26 @@ class SRTOutput(reminderWidth: Int, quotientWidth: Int) extends Bundle { val reminder = UInt(reminderWidth.W) val quotient = UInt(quotientWidth.W) } + +//OTFIO +class OTFInput(qWidth: Int, ohWidth: Int) extends Bundle { + val quotient = UInt(qWidth.W) + val quotientMinusOne = UInt(qWidth.W) + val selectedQuotientOH = UInt(ohWidth.W) +} + +class OTFOutput(qWidth: Int) extends Bundle { + val quotient = UInt(qWidth.W) + val quotientMinusOne = UInt(qWidth.W) +} + +// QDSIO +class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle { + val partialReminderCarry: UInt = UInt(rWidth.W) + val partialReminderSum: UInt = UInt(rWidth.W) + val partialDivider: UInt = UInt(partialDividerWidth.W) +} + +class QDSOutput(ohWidth: Int) extends Bundle { + val selectedQuotientOH: UInt = UInt(ohWidth.W) +} diff --git a/arithmetic/src/division/srt/srt16/OTF.scala b/arithmetic/src/division/srt/srt16/OTF.scala index 88fd12f..1c17a8e 100644 --- a/arithmetic/src/division/srt/srt16/OTF.scala +++ b/arithmetic/src/division/srt/srt16/OTF.scala @@ -1,19 +1,9 @@ package division.srt.srt16 +import division.srt._ import chisel3._ import chisel3.util.Mux1H -class OTFInput(qWidth: Int, ohWidth: Int) extends Bundle { - val quotient = UInt(qWidth.W) - val quotientMinusOne = UInt(qWidth.W) - val selectedQuotientOH = UInt(ohWidth.W) -} - -class OTFOutput(qWidth: Int) extends Bundle { - val quotient = UInt(qWidth.W) - val quotientMinusOne = UInt(qWidth.W) -} - class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int) extends Module { val input = IO(Input(new OTFInput(qWidth, ohWidth))) val output = IO(Output(new OTFOutput(qWidth))) diff --git a/arithmetic/src/division/srt/srt16/QDS.scala b/arithmetic/src/division/srt/srt16/QDS.scala index f615d93..bca9e6e 100644 --- a/arithmetic/src/division/srt/srt16/QDS.scala +++ b/arithmetic/src/division/srt/srt16/QDS.scala @@ -1,19 +1,11 @@ package division.srt.srt16 +import division.srt._ import chisel3._ import chisel3.util.BitPat +import chisel3.util.BitPat.bitPatToUInt import chisel3.util.experimental.decode._ -import utils.extend - -class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle { - val partialReminderCarry: UInt = UInt(rWidth.W) - val partialReminderSum: UInt = UInt(rWidth.W) - val partialDivider: UInt = UInt(partialDividerWidth.W) -} - -class QDSOutput(ohWidth: Int) extends Bundle { - val selectedQuotientOH: UInt = UInt(ohWidth.W) -} +import utils.{extend, sIntToBitPat} class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]]) extends Module { // IO @@ -24,14 +16,7 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[I lazy val selectRom = VecInit(tables.map { case x => VecInit(x.map { - case x => - new StringBuffer("b") - .append( - if ((-x).toBinaryString.length >= rWidth) (-x).toBinaryString.reverse.substring(0, rWidth).reverse - else (-x).toBinaryString - ) - .toString - .U(rWidth.W) + case x => bitPatToUInt(sIntToBitPat(-x, rWidth)) }) }) diff --git a/arithmetic/src/division/srt/srt4/OTF.scala b/arithmetic/src/division/srt/srt4/OTF.scala index 3e94a10..ffe3830 100644 --- a/arithmetic/src/division/srt/srt4/OTF.scala +++ b/arithmetic/src/division/srt/srt4/OTF.scala @@ -1,19 +1,9 @@ package division.srt.srt4 +import division.srt._ import chisel3._ import chisel3.util.Mux1H -class OTFInput(qWidth: Int, ohWidth: Int) extends Bundle { - val quotient = UInt(qWidth.W) - val quotientMinusOne = UInt(qWidth.W) - val selectedQuotientOH = UInt(ohWidth.W) -} - -class OTFOutput(qWidth: Int) extends Bundle { - val quotient = UInt(qWidth.W) - val quotientMinusOne = UInt(qWidth.W) -} - class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int) extends Module { val input = IO(Input(new OTFInput(qWidth, ohWidth))) val output = IO(Output(new OTFOutput(qWidth))) diff --git a/arithmetic/src/division/srt/srt4/QDS.scala b/arithmetic/src/division/srt/srt4/QDS.scala index b3c7799..8897761 100644 --- a/arithmetic/src/division/srt/srt4/QDS.scala +++ b/arithmetic/src/division/srt/srt4/QDS.scala @@ -1,19 +1,11 @@ package division.srt.srt4 +import division.srt._ import chisel3._ import chisel3.util.BitPat -import chisel3.util.experimental.decode.{TruthTable} -import utils.extend - -class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle { - val partialReminderCarry: UInt = UInt(rWidth.W) - val partialReminderSum: UInt = UInt(rWidth.W) - val partialDivider: UInt = UInt(partialDividerWidth.W) -} - -class QDSOutput(ohWidth: Int) extends Bundle { - val selectedQuotientOH: UInt = UInt(ohWidth.W) -} +import chisel3.util.BitPat.bitPatToUInt +import chisel3.util.experimental.decode.TruthTable +import utils.{extend, sIntToBitPat} class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]]) extends Module { // IO @@ -46,14 +38,7 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[I lazy val selectRom = VecInit(tables.map { case x => VecInit(x.map { - case x => - new StringBuffer("b") - .append( - if ((-x).toBinaryString.length >= rWidth) (-x).toBinaryString.reverse.substring(0, rWidth).reverse - else (-x).toBinaryString - ) - .toString - .U(rWidth.W) + case x => bitPatToUInt(sIntToBitPat(-x, rWidth)) }) }) diff --git a/arithmetic/src/division/srt/srt8/OTF.scala b/arithmetic/src/division/srt/srt8/OTF.scala index 88196ef..cefff67 100644 --- a/arithmetic/src/division/srt/srt8/OTF.scala +++ b/arithmetic/src/division/srt/srt8/OTF.scala @@ -1,19 +1,9 @@ package division.srt.srt8 +import division.srt._ import chisel3._ import chisel3.util.Mux1H -class OTFInput(qWidth: Int, ohWidth: Int) extends Bundle { - val quotient = UInt(qWidth.W) - val quotientMinusOne = UInt(qWidth.W) - val selectedQuotientOH = UInt(ohWidth.W) -} - -class OTFOutput(qWidth: Int) extends Bundle { - val quotient = UInt(qWidth.W) - val quotientMinusOne = UInt(qWidth.W) -} - class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int) extends Module { val input = IO(Input(new OTFInput(qWidth, ohWidth))) val output = IO(Output(new OTFOutput(qWidth))) diff --git a/arithmetic/src/division/srt/srt8/QDS.scala b/arithmetic/src/division/srt/srt8/QDS.scala index 513c0e5..c306db6 100644 --- a/arithmetic/src/division/srt/srt8/QDS.scala +++ b/arithmetic/src/division/srt/srt8/QDS.scala @@ -1,19 +1,11 @@ package division.srt.srt8 +import division.srt._ import chisel3._ -import chisel3.util.{BitPat} -import chisel3.util.experimental.decode.{TruthTable} -import utils.extend - -class QDSInput(rWidth: Int, partialDividerWidth: Int) extends Bundle { - val partialReminderCarry: UInt = UInt(rWidth.W) - val partialReminderSum: UInt = UInt(rWidth.W) - val partialDivider: UInt = UInt(partialDividerWidth.W) -} - -class QDSOutput(ohWidth: Int) extends Bundle { - val selectedQuotientOH: UInt = UInt(ohWidth.W) -} +import chisel3.util.BitPat +import chisel3.util.BitPat.bitPatToUInt +import chisel3.util.experimental.decode.TruthTable +import utils.{extend, sIntToBitPat} class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]]) extends Module { // IO @@ -25,14 +17,7 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[I lazy val selectRom = VecInit(tables.map { case x => VecInit(x.map { - case x => - new StringBuffer("b") - .append( - if ((-x).toBinaryString.length >= rWidth) (-x).toBinaryString.reverse.substring(0, rWidth).reverse - else (-x).toBinaryString - ) - .toString - .U(rWidth.W) + case x => bitPatToUInt(sIntToBitPat(-x, rWidth)) }) }) diff --git a/arithmetic/src/division/srt/srt8/SRT8.scala b/arithmetic/src/division/srt/srt8/SRT8.scala index 5eeb982..a120a9b 100644 --- a/arithmetic/src/division/srt/srt8/SRT8.scala +++ b/arithmetic/src/division/srt/srt8/SRT8.scala @@ -4,7 +4,7 @@ import division.srt._ import division.srt.SRTTable import chisel3._ import chisel3.util._ -import utils.{leftShift} +import utils.leftShift /** SRT8 * 1/2 <= d < 1, 1/2 < rho <=1, 0 < q < 2 diff --git a/arithmetic/src/utils/package.scala b/arithmetic/src/utils/package.scala index 2edf1ab..b3e3d73 100644 --- a/arithmetic/src/utils/package.scala +++ b/arithmetic/src/utils/package.scala @@ -55,7 +55,7 @@ package object utils { else BitPat((x + (1 << w)).U(w.W)) } - + // left shift and keep the width of Bits def leftShift(x: Bits, n: Int): UInt = { val length: Int = x.getWidth diff --git a/arithmetic/tests/src/division/srt/SRTTest.scala b/arithmetic/tests/src/division/srt/SRTTest.scala new file mode 100644 index 0000000..588f44e --- /dev/null +++ b/arithmetic/tests/src/division/srt/SRTTest.scala @@ -0,0 +1,88 @@ +package division.srt + +import chisel3._ +import chisel3.tester.{ChiselUtestTester, testableClock, testableData} +import utest._ + +import scala.util.Random + +object SRTTest extends TestSuite with ChiselUtestTester { + def tests: Tests = Tests { + test("SRT should pass") { + def testcase(n: Int = 64, + radixLog2: Int = 4, + a: Int = 2, + dTruncateWidth: Int = 4, + rTruncateWidth: Int = 4): Unit ={ + //tips + println("SRT%d(width = %d, a = %d, dTruncateWidth = %d, rTruncateWidth = %d) should pass ".format( + 1 << radixLog2 , n , a, dTruncateWidth, rTruncateWidth)) + // parameters + val m: Int = n - 1 + val p: Int = Random.nextInt(m - radixLog2 +1) //order to offer guardwidth + val q: Int = Random.nextInt(m - radixLog2 +1) + val dividend: BigInt = BigInt(p, Random) + val divider: BigInt = BigInt(q, Random) + // val dividend: BigInt = BigInt("65") + // val divider: BigInt = BigInt("1") + def zeroCheck(x: BigInt): Int = { + var flag = false + var k: Int = m + while (!flag && (k >= -1)) { + flag = ((BigInt(1) << k) & x) != 0 + k = k - 1 + } + k + 1 + } + val zeroHeadDividend: Int = m - zeroCheck(dividend) + val zeroHeadDivider: Int = m - zeroCheck(divider) + val needComputerWidth: Int = zeroHeadDivider - zeroHeadDividend + 1 + (if(radixLog2 == 4) 2 else radixLog2) -1 + val noguard: Boolean = needComputerWidth % radixLog2 == 0 + val guardWidth: Int = if (noguard) 0 else radixLog2 - needComputerWidth % radixLog2 + val counter: Int = (needComputerWidth + guardWidth) / radixLog2 + if ((divider == 0) || (divider > dividend) || (needComputerWidth <= 0)) + return + val quotient: BigInt = dividend / divider + val remainder: BigInt = dividend % divider + val leftShiftWidthDividend: Int = zeroHeadDividend - guardWidth + val leftShiftWidthDivider: Int = zeroHeadDivider + // println("dividend = %8x, dividend = %d ".format(dividend, dividend)) + // println("divider = %8x, divider = %d".format(divider, divider)) + // println("zeroHeadDividend = %d, dividend << zeroHeadDividend = %d".format(zeroHeadDividend, dividend << leftShiftWidthDividend)) + // println("zeroHeadDivider = %d, divider << zeroHeadDivider = %d".format(zeroHeadDivider, divider << leftShiftWidthDivider)) + // println("quotient = %d, remainder = %d".format(quotient, remainder)) + // println("counter = %d, needComputerWidth = %d".format(counter, needComputerWidth)) + // test + testCircuit(new SRT(n, n, n, radixLog2, a, dTruncateWidth, rTruncateWidth), + Seq(chiseltest.internal.NoThreadingAnnotation, + chiseltest.simulator.WriteVcdAnnotation)) { + dut: SRT => + dut.clock.setTimeout(0) + dut.input.valid.poke(true.B) + dut.input.bits.dividend.poke((dividend << leftShiftWidthDividend).U) + dut.input.bits.divider.poke((divider << leftShiftWidthDivider).U) + dut.input.bits.counter.poke(counter.U) + dut.clock.step() + dut.input.valid.poke(false.B) + var flag = false + for (a <- 1 to 1000 if !flag) { + if (dut.output.valid.peek().litValue == 1) { + flag = true + println(dut.output.bits.quotient.peek().litValue) + println(dut.output.bits.reminder.peek().litValue) + utest.assert(dut.output.bits.quotient.peek().litValue == quotient) + utest.assert(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider == remainder) + } + dut.clock.step() + } + utest.assert(flag) + dut.clock.step(scala.util.Random.nextInt(5)) + } + } + testcase(64) + for( i <- 1 to 50){ + testcase(64,3,7,4) + } + } + } +} \ No newline at end of file From 2220115440b8f6ccd206e658f04b4bd9557eafd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9C=C3=A2wissy?= <1536771081@qq.com> Date: Tue, 14 Jun 2022 18:13:48 +0800 Subject: [PATCH 31/31] add selection of a & fix srt8 --- arithmetic/src/division/srt/SRT.scala | 36 +++- arithmetic/src/division/srt/SRTTable.scala | 2 +- arithmetic/src/division/srt/srt16/OTF.scala | 4 +- arithmetic/src/division/srt/srt16/SRT16.scala | 1 - arithmetic/src/division/srt/srt4/OTF.scala | 62 +++++-- arithmetic/src/division/srt/srt4/QDS.scala | 40 +++-- arithmetic/src/division/srt/srt4/SRT4.scala | 81 ++++++--- arithmetic/src/division/srt/srt8/OTF.scala | 117 ++++++++++--- arithmetic/src/division/srt/srt8/QDS.scala | 110 +++++++++--- arithmetic/src/division/srt/srt8/SRT8.scala | 163 ++++++++++++++---- arithmetic/src/utils/package.scala | 2 +- .../tests/src/division/srt/SRTSpec.scala | 4 +- .../tests/src/division/srt/SRTTest.scala | 18 +- 13 files changed, 483 insertions(+), 157 deletions(-) diff --git a/arithmetic/src/division/srt/SRT.scala b/arithmetic/src/division/srt/SRT.scala index dd1d961..539afb4 100644 --- a/arithmetic/src/division/srt/SRT.scala +++ b/arithmetic/src/division/srt/SRT.scala @@ -15,9 +15,35 @@ class SRT( dTruncateWidth: Int = 4, rTruncateWidth: Int = 4) extends Module { +// val x = (radixLog2, a, dTruncateWidth) +// val tips = x match { +// case (2,2,4) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// case (2,2,5) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// case (2,2,6) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// +// case (3,4,6) => require(rTruncateWidth >= 7, "rTruncateWidth need >= 7") +// case (3,4,7) => require(rTruncateWidth >= 6, "rTruncateWidth need >= 6") +// +// case (3,5,5) => require(rTruncateWidth >= 5, "rTruncateWidth need >= 5") +// case (3,5,6) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// +// case (3,6,4) => require(rTruncateWidth >= 6, "rTruncateWidth need >= 6") +// case (3,6,5) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// +// case (3,7,4) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// case (3,7,5) => require(rTruncateWidth >= 3, "rTruncateWidth need >= 3") +// +// case (4,2,4) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// case (4,2,5) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// case (4,2,6) => require(rTruncateWidth >= 4, "rTruncateWidth need >= 4") +// +// case _ => println("this srt is not supported") +// } + val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) - // select radix + +// select radix if (radixLog2 == 2) { // SRT4 val srt = Module(new SRT4(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth)) srt.input <> input @@ -31,4 +57,12 @@ class SRT( srt.input <> input output <> srt.output } + +// val srt = radixLog2 match { +// case 2 => Module(new SRT4(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth)) +// case 3 => Module(new SRT8(dividendWidth, dividerWidth, n, radixLog2, a, dTruncateWidth, rTruncateWidth)) +// case 4 => Module(new SRT16(dividendWidth, dividerWidth, n, radixLog2 >> 1, a, dTruncateWidth, rTruncateWidth)) +// } +// srt.input <> input +// output <> srt.output } diff --git a/arithmetic/src/division/srt/SRTTable.scala b/arithmetic/src/division/srt/SRTTable.scala index b76c3a0..da843da 100644 --- a/arithmetic/src/division/srt/SRTTable.scala +++ b/arithmetic/src/division/srt/SRTTable.scala @@ -87,7 +87,7 @@ case class SRTTable( } // TODO: select a Constant from each m, then offer the table to QDS. - // select rule: symmetry and draw a line parallel to the X-axis, how define the rule + // todo: ? select rule: symmetry and draw a line parallel to the X-axis, how define the rule lazy val tablesToQDS: Seq[Seq[Int]] = { (aMin.toInt to aMax.toInt).drop(1).map { k => k -> dSet.dropRight(1).map { d => diff --git a/arithmetic/src/division/srt/srt16/OTF.scala b/arithmetic/src/division/srt/srt16/OTF.scala index 1c17a8e..d9d6d8b 100644 --- a/arithmetic/src/division/srt/srt16/OTF.scala +++ b/arithmetic/src/division/srt/srt16/OTF.scala @@ -40,11 +40,11 @@ object OTF { )(quotient: UInt, quotientMinusOne: UInt, selectedQuotientOH: UInt - ): Seq[UInt] = { + ): Vec[UInt] = { val m = Module(new OTF(radixLog2, qWidth, ohWidth)) m.input.quotient := quotient m.input.quotientMinusOne := quotientMinusOne m.input.selectedQuotientOH := selectedQuotientOH - Seq(m.output.quotient, m.output.quotientMinusOne) + VecInit(m.output.quotient, m.output.quotientMinusOne) } } diff --git a/arithmetic/src/division/srt/srt16/SRT16.scala b/arithmetic/src/division/srt/srt16/SRT16.scala index 210ba49..7ea12f1 100644 --- a/arithmetic/src/division/srt/srt16/SRT16.scala +++ b/arithmetic/src/division/srt/srt16/SRT16.scala @@ -76,7 +76,6 @@ class SRT16( val csa5 = addition.csa.c32(VecInit(csaIn1, csaIn2, dividerMap(4).head(csa0InWidth))) // 2 // qds - val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS val partialDivider: UInt = dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) val qdsOH0: UInt = diff --git a/arithmetic/src/division/srt/srt4/OTF.scala b/arithmetic/src/division/srt/srt4/OTF.scala index ffe3830..f89106a 100644 --- a/arithmetic/src/division/srt/srt4/OTF.scala +++ b/arithmetic/src/division/srt/srt4/OTF.scala @@ -4,29 +4,52 @@ import division.srt._ import chisel3._ import chisel3.util.Mux1H -class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int) extends Module { +class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int, a: Int) extends Module { val input = IO(Input(new OTFInput(qWidth, ohWidth))) val output = IO(Output(new OTFOutput(qWidth))) val radix: Int = 1 << radixLog2 // datapath // q_j+1 in this circle, only for srt4 - val qNext: UInt = Mux1H( - Seq( - input.selectedQuotientOH(0) -> "b110".U, - input.selectedQuotientOH(1) -> "b111".U, - input.selectedQuotientOH(2) -> "b000".U, - input.selectedQuotientOH(3) -> "b001".U, - input.selectedQuotientOH(4) -> "b010".U - ) - ) - // val cShiftQ: Bool = qNext >= 0.U // val cShiftQM: Bool = qNext <= 0.U - val cShiftQ: Bool = input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR - val cShiftQM: Bool = input.selectedQuotientOH(ohWidth / 2, 0).orR - val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0) - val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0) + val qNext: UInt = Wire(UInt(3.W)) + val cShiftQ, cShiftQM = Wire(Bool()) + + if (a == 2) { + qNext := Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b110".U, //-2 + input.selectedQuotientOH(1) -> "b111".U, //-1 + input.selectedQuotientOH(2) -> "b000".U, // 0 + input.selectedQuotientOH(3) -> "b001".U, // 1 + input.selectedQuotientOH(4) -> "b010".U // 2 + ) + ) + cShiftQ := input.selectedQuotientOH(ohWidth - 1, ohWidth / 2).orR + cShiftQM := input.selectedQuotientOH(ohWidth / 2, 0).orR + } else if (a == 3) { + qNext := Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b111".U, //-1 + input.selectedQuotientOH(1) -> "b000".U, // 0 + input.selectedQuotientOH(2) -> "b001".U // 1 + ) + ) + Mux1H( + Seq( + input.selectedQuotientOH(3) -> "b110".U, // -2 + input.selectedQuotientOH(4) -> "b000".U, // 0 + input.selectedQuotientOH(5) -> "b010".U // 2 + ) + ) + cShiftQ := input.selectedQuotientOH(5) || + (input.selectedQuotientOH(4) && input.selectedQuotientOH(2, 1).orR) + cShiftQM := input.selectedQuotientOH(3) || + (input.selectedQuotientOH(4) && input.selectedQuotientOH(1, 0).orR) + } + + val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0) + val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0) output.quotient := Mux(cShiftQ, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qIn output.quotientMinusOne := Mux(!cShiftQM, input.quotient, input.quotientMinusOne)(qWidth - radixLog2, 0) ## qmIn @@ -36,15 +59,16 @@ object OTF { def apply( radixLog2: Int, qWidth: Int, - ohWidth: Int + ohWidth: Int, + a: Int )(quotient: UInt, quotientMinusOne: UInt, selectedQuotientOH: UInt - ): Seq[UInt] = { - val m = Module(new OTF(radixLog2, qWidth, ohWidth)) + ): Vec[UInt] = { + val m = Module(new OTF(radixLog2, qWidth, ohWidth, a)) m.input.quotient := quotient m.input.quotientMinusOne := quotientMinusOne m.input.selectedQuotientOH := selectedQuotientOH - Seq(m.output.quotient, m.output.quotientMinusOne) + VecInit(m.output.quotient, m.output.quotientMinusOne) } } diff --git a/arithmetic/src/division/srt/srt4/QDS.scala b/arithmetic/src/division/srt/srt4/QDS.scala index 8897761..fa40757 100644 --- a/arithmetic/src/division/srt/srt4/QDS.scala +++ b/arithmetic/src/division/srt/srt4/QDS.scala @@ -7,7 +7,7 @@ import chisel3.util.BitPat.bitPatToUInt import chisel3.util.experimental.decode.TruthTable import utils.{extend, sIntToBitPat} -class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]]) extends Module { +class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]], a: Int) extends Module { // IO val input = IO(Input(new QDSInput(rWidth, partialDividerWidth))) val output = IO(Output(new QDSOutput(ohWidth))) @@ -54,15 +54,30 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[I // decoder or findFirstOne here, prefer decoder, the decoder only for srt4 output.selectedQuotientOH := chisel3.util.experimental.decode.decoder( selectPoints, - TruthTable( - Seq( - BitPat("b???0") -> BitPat("b10000"), //2 - BitPat("b??01") -> BitPat("b01000"), //1 - BitPat("b?011") -> BitPat("b00100"), //0 - BitPat("b0111") -> BitPat("b00010") //-1 - ), - BitPat("b00001") //-2 - ) + a match { + case 2 => + TruthTable( + Seq( + BitPat("b???0") -> BitPat("b10000"), //2 + BitPat("b??01") -> BitPat("b01000"), //1 + BitPat("b?011") -> BitPat("b00100"), //0 + BitPat("b0111") -> BitPat("b00010") //-1 + ), + BitPat("b00001") //-2 + ) + case 3 => + TruthTable( + Seq( // 2 0 -2 1 0 -1 + BitPat("b??_???0") -> BitPat("b100_100"), //3 = 2 + 1 + BitPat("b??_??01") -> BitPat("b100_010"), //2 = 2 + 0 + BitPat("b??_?011") -> BitPat("b010_100"), //1 = 0 + 1 + BitPat("b??_0111") -> BitPat("b010_010"), //0 = 0 + 0 + BitPat("b?0_1111") -> BitPat("b010_001"), //-1 = 0 + -1 + BitPat("b01_1111") -> BitPat("b001_010") //-2 = -2 + 0 + ), + BitPat("b001_001") //-3 = -2 + -1 + ) + } ) } @@ -71,12 +86,13 @@ object QDS { rWidth: Int, ohWidth: Int, partialDividerWidth: Int, - tables: Seq[Seq[Int]] + tables: Seq[Seq[Int]], + a: Int )(partialReminderSum: UInt, partialReminderCarry: UInt, partialDivider: UInt ): UInt = { - val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables)) + val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables, a)) m.input.partialReminderSum := partialReminderSum m.input.partialReminderCarry := partialReminderCarry m.input.partialDivider := partialDivider diff --git a/arithmetic/src/division/srt/srt4/SRT4.scala b/arithmetic/src/division/srt/srt4/SRT4.scala index 4111a27..f10dfaa 100644 --- a/arithmetic/src/division/srt/srt4/SRT4.scala +++ b/arithmetic/src/division/srt/srt4/SRT4.scala @@ -5,6 +5,7 @@ import addition.csa.CarrySaveAdder import addition.csa.common.CSACompressor3_2 import chisel3._ import chisel3.util._ +import spire.math import utils.leftShift /** SRT4 @@ -26,11 +27,8 @@ class SRT4( dTruncateWidth: Int = 4, rTruncateWidth: Int = 4) extends Module { - - val xLen: Int = dividendWidth + radixLog2 + 1 - val wLen: Int = xLen + radixLog2 - val ohWidth: Int = 2 * a + 1 - + val xLen: Int = dividendWidth + radixLog2 + 1 + val wLen: Int = xLen + radixLog2 // IO val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) val output = IO(ValidIO(new SRTOutput(dividerWidth, dividendWidth))) @@ -41,9 +39,8 @@ class SRT4( val counterNext = Wire(UInt(log2Ceil(n).W)) // Control - // sign of select quotient, true -> negative, false -> positive // sign of Cycle, true -> (counter === 0.U) - val qdsSign, isLastCycle, enable: Bool = Wire(Bool()) + val isLastCycle, enable: Bool = Wire(Bool()) // State // because we need a CSA to minimize the critical path @@ -68,41 +65,77 @@ class SRT4( output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 4, radixLog2) output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) - // qds val rWidth: Int = 1 + radixLog2 + rTruncateWidth val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS + val ohWidth: Int = a match { + case 2 => 2 * a + 1 + case 3 => 6 + } + //qds val selectedQuotientOH: UInt = - QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)( + QDS(rWidth, ohWidth, dTruncateWidth - 1, tables, a)( leftShift(partialReminderSum, radixLog2).head(rWidth), leftShift(partialReminderCarry, radixLog2).head(rWidth), dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> *** ) - qdsSign := selectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR + // On-The-Fly conversion + val otf = OTF(radixLog2, n, ohWidth, a)(quotient, quotientMinusOne, selectedQuotientOH) - // csa for SRT4 -> CSA32 - val csa = Module(new CarrySaveAdder(CSACompressor3_2, xLen)) - csa.in(0) := leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2) - csa.in(1) := leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign - csa.in(2) := - Mux1H( - selectedQuotientOH, - //this is for SRT4, for SRT8 or SRT16, this should be changed - VecInit((-2 to 2).map { + val csa: Vec[UInt] = + if (a == 2) { // a == 2 + //csa + val dividerMap = VecInit((-2 to 2).map { case -2 => divider << 1 case -1 => divider case 0 => 0.U case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) }) - ) + val qdsSign = selectedQuotientOH(ohWidth - 1, ohWidth / 2 + 1).orR + addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign, + Mux1H(selectedQuotientOH, dividerMap) + ) + ) + } else { // a==3 + val qHigh = selectedQuotientOH(5, 3) + val qLow = selectedQuotientOH(2, 0) + val qds0Sign = qHigh.head(1) + val qds1Sign = qLow.head(1) - // On-The-Fly conversion - val otf = OTF(radixLog2, n, ohWidth)(quotient, quotientMinusOne, selectedQuotientOH) + // csa + val dividerHMap = VecInit((-1 to 1).map { + case -1 => divider << 1 // -2 + case 0 => 0.U // 0 + case 1 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) // 2 + }) + val dividerLMap = VecInit((-1 to 1).map { + case -1 => divider // -1 + case 0 => 0.U // 0 + case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider // 1 + }) + val csa0 = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qds0Sign, + Mux1H(qHigh, dividerHMap) + ) + ) + addition.csa.c32( + VecInit( + csa0(1).head(wLen - radixLog2), + leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qds1Sign, + Mux1H(qLow, dividerLMap) + ) + ) + } dividerNext := Mux(input.fire, input.bits.divider, divider) counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) quotientNext := Mux(input.fire, 0.U, otf(0)) quotientMinusOneNext := Mux(input.fire, 0.U, otf(1)) - partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa.out(1) << radixLog2) - partialReminderCarryNext := Mux(input.fire, 0.U, csa.out(0) << 1 + radixLog2) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa(0) << 1 + radixLog2) } diff --git a/arithmetic/src/division/srt/srt8/OTF.scala b/arithmetic/src/division/srt/srt8/OTF.scala index cefff67..640a67f 100644 --- a/arithmetic/src/division/srt/srt8/OTF.scala +++ b/arithmetic/src/division/srt/srt8/OTF.scala @@ -4,37 +4,101 @@ import division.srt._ import chisel3._ import chisel3.util.Mux1H -class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int) extends Module { +class OTF(radixLog2: Int, qWidth: Int, ohWidth: Int, a: Int) extends Module { val input = IO(Input(new OTFInput(qWidth, ohWidth))) val output = IO(Output(new OTFOutput(qWidth))) val radix: Int = 1 << radixLog2 // datapath - // q_j+1 in this circle, only for srt8(a = 7) - val qNext: UInt = Mux1H( - Seq( - input.selectedQuotientOH(0) -> "b11110".U, // -2 - input.selectedQuotientOH(1) -> "b11111".U, // -1 - input.selectedQuotientOH(2) -> "b00000".U, // 0 - input.selectedQuotientOH(3) -> "b00001".U, // 1 - input.selectedQuotientOH(4) -> "b00010".U // 2 - ) - ) + Mux1H( - Seq( - input.selectedQuotientOH(5) -> "b11000".U, // -8 - input.selectedQuotientOH(6) -> "b11100".U, // -4 - input.selectedQuotientOH(7) -> "b00000".U, // 0 - input.selectedQuotientOH(8) -> "b00100".U, // 4 - input.selectedQuotientOH(9) -> "b01000".U // 8 - ) - ) - + // q_j+1 in this circle // val cShiftQ: Bool = qNext >= 0.U // val cShiftQM: Bool = qNext <= 0.U - val cShiftQ: Bool = input.selectedQuotientOH(9, 8).orR || + val qNext: UInt = Wire(UInt(5.W)) + val cShiftQ, cShiftQM: Bool = Wire(Bool()) + + if (a == 7) { + qNext := Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b11110".U, // -2 + input.selectedQuotientOH(1) -> "b11111".U, // -1 + input.selectedQuotientOH(2) -> "b00000".U, // 0 + input.selectedQuotientOH(3) -> "b00001".U, // 1 + input.selectedQuotientOH(4) -> "b00010".U // 2 + ) + ) + Mux1H( + Seq( + input.selectedQuotientOH(5) -> "b11000".U, // -8 + input.selectedQuotientOH(6) -> "b11100".U, // -4 + input.selectedQuotientOH(7) -> "b00000".U, // 0 + input.selectedQuotientOH(8) -> "b00100".U, // 4 + input.selectedQuotientOH(9) -> "b01000".U // 8 + ) + ) + cShiftQ := input.selectedQuotientOH(9, 8).orR || (input.selectedQuotientOH(7) && input.selectedQuotientOH(4, 2).orR) - val cShiftQM: Bool = input.selectedQuotientOH(6, 5).orR || + cShiftQM := input.selectedQuotientOH(6, 5).orR || (input.selectedQuotientOH(7) && input.selectedQuotientOH(2, 0).orR) + } else if (a == 6) { + qNext := Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b11110".U, // -2 + input.selectedQuotientOH(1) -> "b11111".U, // -1 + input.selectedQuotientOH(2) -> "b00000".U, // 0 + input.selectedQuotientOH(3) -> "b00001".U, // 1 + input.selectedQuotientOH(4) -> "b00010".U // 2 + ) + ) + Mux1H( + Seq( + input.selectedQuotientOH(5) -> "b11100".U, // -4 + input.selectedQuotientOH(6) -> "b00000".U, // 0 + input.selectedQuotientOH(7) -> "b00100".U // 4 + ) + ) + cShiftQ := input.selectedQuotientOH(7) || + (input.selectedQuotientOH(6) && input.selectedQuotientOH(4, 2).orR) + cShiftQM := input.selectedQuotientOH(5) || + (input.selectedQuotientOH(6) && input.selectedQuotientOH(2, 0).orR) + } else if (a == 5) { + qNext := Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b11110".U, // -2 + input.selectedQuotientOH(1) -> "b11111".U, // -1 + input.selectedQuotientOH(2) -> "b00000".U, // 0 + input.selectedQuotientOH(3) -> "b00001".U, // 1 + input.selectedQuotientOH(4) -> "b00010".U // 2 + ) + ) + Mux1H( + Seq( + input.selectedQuotientOH(5) -> "b11100".U, // -4 + input.selectedQuotientOH(6) -> "b00000".U, // 0 + input.selectedQuotientOH(7) -> "b00100".U // 4 + ) + ) + cShiftQ := input.selectedQuotientOH(7) || + (input.selectedQuotientOH(6) && input.selectedQuotientOH(4, 2).orR) + cShiftQM := input.selectedQuotientOH(5) || + (input.selectedQuotientOH(6) && input.selectedQuotientOH(2, 0).orR) + } else if (a == 4) { + qNext := Mux1H( + Seq( + input.selectedQuotientOH(0) -> "b11110".U, // -2 + input.selectedQuotientOH(1) -> "b11111".U, // -1 + input.selectedQuotientOH(2) -> "b00000".U, // 0 + input.selectedQuotientOH(3) -> "b00001".U, // 1 + input.selectedQuotientOH(4) -> "b00010".U // 2 + ) + ) + Mux1H( + Seq( + input.selectedQuotientOH(5) -> "b11110".U, // -2 + input.selectedQuotientOH(6) -> "b00000".U, // 0 + input.selectedQuotientOH(7) -> "b00010".U // 2 + ) + ) + cShiftQ := input.selectedQuotientOH(7) || + (input.selectedQuotientOH(6) && input.selectedQuotientOH(3, 2).orR) + cShiftQM := input.selectedQuotientOH(5) || + (input.selectedQuotientOH(6) && input.selectedQuotientOH(2, 1).orR) + } val qIn: UInt = Mux(cShiftQ, qNext, radix.U + qNext)(radixLog2 - 1, 0) val qmIn: UInt = Mux(!cShiftQM, qNext - 1.U, (radix - 1).U + qNext)(radixLog2 - 1, 0) @@ -47,15 +111,16 @@ object OTF { def apply( radixLog2: Int, qWidth: Int, - ohWidth: Int + ohWidth: Int, + a: Int )(quotient: UInt, quotientMinusOne: UInt, selectedQuotientOH: UInt - ): Seq[UInt] = { - val m = Module(new OTF(radixLog2, qWidth, ohWidth)) + ): Vec[UInt] = { + val m = Module(new OTF(radixLog2, qWidth, ohWidth, a)) m.input.quotient := quotient m.input.quotientMinusOne := quotientMinusOne m.input.selectedQuotientOH := selectedQuotientOH - Seq(m.output.quotient, m.output.quotientMinusOne) + VecInit(m.output.quotient, m.output.quotientMinusOne) } } diff --git a/arithmetic/src/division/srt/srt8/QDS.scala b/arithmetic/src/division/srt/srt8/QDS.scala index c306db6..e9adc83 100644 --- a/arithmetic/src/division/srt/srt8/QDS.scala +++ b/arithmetic/src/division/srt/srt8/QDS.scala @@ -7,7 +7,7 @@ import chisel3.util.BitPat.bitPatToUInt import chisel3.util.experimental.decode.TruthTable import utils.{extend, sIntToBitPat} -class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]]) extends Module { +class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[Int]], a: Int) extends Module { // IO val input = IO(Input(new QDSInput(rWidth, partialDividerWidth))) val output = IO(Output(new QDSOutput(ohWidth))) @@ -29,28 +29,91 @@ class QDS(rWidth: Int, ohWidth: Int, partialDividerWidth: Int, tables: Seq[Seq[I + extend(mk, adderWidth).asUInt).head(1) }).asUInt - // decoder or findFirstOne here, prefer decoder, the decoder only for srt8(a = 7) output.selectedQuotientOH := chisel3.util.experimental.decode.decoder( selectPoints, - TruthTable( - Seq( // 8 4 0 -4 -8__2 1 0 -1 -2 - BitPat("b??_????_????_???0") -> BitPat("b10000_00010"), // 7 = +8 + (-1) - BitPat("b??_????_????_??01") -> BitPat("b01000_10000"), // 6 = +4 + (+2) - BitPat("b??_????_????_?011") -> BitPat("b01000_01000"), // 5 = +4 + (+1) - BitPat("b??_????_????_0111") -> BitPat("b01000_00100"), // 4 = +4 + ( 0) - BitPat("b??_????_???0_1111") -> BitPat("b01000_00010"), // 3 = +4 + (-1) - BitPat("b??_????_??01_1111") -> BitPat("b00100_10000"), // 2 = 0 + (+2) - BitPat("b??_????_?011_1111") -> BitPat("b00100_01000"), // 1 = 0 + (+1) - BitPat("b??_????_0111_1111") -> BitPat("b00100_00100"), // 0 = 0 + ( 0) - BitPat("b??_???0_1111_1111") -> BitPat("b00100_00010"), //-1 = 0 + (-1) - BitPat("b??_??01_1111_1111") -> BitPat("b00100_00001"), //-2 = 0 + (-2) - BitPat("b??_?011_1111_1111") -> BitPat("b00010_01000"), //-3 = -4 + ( 1) - BitPat("b??_0111_1111_1111") -> BitPat("b00010_00100"), //-4 = -4 + ( 0) - BitPat("b?0_1111_1111_1111") -> BitPat("b00010_00010"), //-5 = -4 + (-1) - BitPat("b01_1111_1111_1111") -> BitPat("b00010_00001") //-6 = -4 + (-2) - ), - BitPat("b00001_01000") //-7 = -8 + (+1) - ) + a match { + case 7 => + TruthTable( + Seq( // 8 4 0 -4 -8__2 1 0 -1 -2 + BitPat("b??_????_????_???0") -> BitPat("b10000_00010"), // 7 = +8 + (-1) + BitPat("b??_????_????_??01") -> BitPat("b01000_10000"), // 6 = +4 + (+2) + BitPat("b??_????_????_?011") -> BitPat("b01000_01000"), // 5 = +4 + (+1) + BitPat("b??_????_????_0111") -> BitPat("b01000_00100"), // 4 = +4 + ( 0) + BitPat("b??_????_???0_1111") -> BitPat("b01000_00010"), // 3 = +4 + (-1) + BitPat("b??_????_??01_1111") -> BitPat("b00100_10000"), // 2 = 0 + (+2) + BitPat("b??_????_?011_1111") -> BitPat("b00100_01000"), // 1 = 0 + (+1) + BitPat("b??_????_0111_1111") -> BitPat("b00100_00100"), // 0 = 0 + ( 0) + BitPat("b??_???0_1111_1111") -> BitPat("b00100_00010"), //-1 = 0 + (-1) + BitPat("b??_??01_1111_1111") -> BitPat("b00100_00001"), //-2 = 0 + (-2) + BitPat("b??_?011_1111_1111") -> BitPat("b00010_01000"), //-3 = -4 + ( 1) + BitPat("b??_0111_1111_1111") -> BitPat("b00010_00100"), //-4 = -4 + ( 0) + BitPat("b?0_1111_1111_1111") -> BitPat("b00010_00010"), //-5 = -4 + (-1) + BitPat("b01_1111_1111_1111") -> BitPat("b00010_00001") // -6 = -4 + (-2) + ), + BitPat("b00001_01000") //-7 = -8 + (+1) + ) + case 6 => + TruthTable( + Seq( // 4 0 -4__2 1 0 -1 -2 + BitPat("b????_????_???0") -> BitPat("b100_10000"), // 6 = +4 + (+2) + BitPat("b????_????_??01") -> BitPat("b100_01000"), // 5 = +4 + (+1) + BitPat("b????_????_?011") -> BitPat("b100_00100"), // 4 = +4 + ( 0) + BitPat("b????_????_0111") -> BitPat("b100_00010"), // 3 = +4 + (-1) + BitPat("b????_???0_1111") -> BitPat("b010_10000"), // 2 = 0 + (+2) + BitPat("b????_??01_1111") -> BitPat("b010_01000"), // 1 = 0 + (+1) + BitPat("b????_?011_1111") -> BitPat("b010_00100"), // 0 = 0 + ( 0) + BitPat("b????_0111_1111") -> BitPat("b010_00010"), //-1 = 0 + (-1) + BitPat("b???0_1111_1111") -> BitPat("b010_00001"), //-2 = 0 + (-2) + BitPat("b??01_1111_1111") -> BitPat("b001_01000"), //-3 = -4 + ( 1) + BitPat("b?011_1111_1111") -> BitPat("b001_00100"), //-4 = -4 + ( 0) + BitPat("b0111_1111_1111") -> BitPat("b001_00010") // -5 = -4 + (-1) + ), + BitPat("b001_00001") //-6 = -4 + (-2) + ) + case 5 => + TruthTable( + Seq( // 4 0 -4__2 1 0 -1 -2 + BitPat("b??_????_???0") -> BitPat("b100_01000"), // 5 = +4 + (+1) + BitPat("b??_????_??01") -> BitPat("b100_00100"), // 4 = +4 + ( 0) + BitPat("b??_????_?011") -> BitPat("b100_00010"), // 3 = +4 + (-1) + BitPat("b??_????_0111") -> BitPat("b010_10000"), // 2 = 0 + (+2) + BitPat("b??_???0_1111") -> BitPat("b010_01000"), // 1 = 0 + (+1) + BitPat("b??_??01_1111") -> BitPat("b010_00100"), // 0 = 0 + ( 0) + BitPat("b??_?011_1111") -> BitPat("b010_00010"), //-1 = 0 + (-1) + BitPat("b??_0111_1111") -> BitPat("b010_00001"), //-2 = 0 + (-2) + BitPat("b?0_1111_1111") -> BitPat("b001_01000"), //-3 = -4 + ( 1) + BitPat("b01_1111_1111") -> BitPat("b001_00100") // -4 = -4 + ( 0) + ), + BitPat("b001_00010") //-5 = -4 + (-1) + ) + case 4 => + TruthTable( + Seq( // 2 0 -2__2 1 0 -1 -2 + BitPat("b????_???0") -> BitPat("b100_10000"), // 4 = +2 + ( 2) + BitPat("b????_??01") -> BitPat("b100_01000"), // 3 = +2 + ( 1) + BitPat("b????_?011") -> BitPat("b100_00100"), // 2 = 2 + ( 0) + BitPat("b????_0111") -> BitPat("b010_01000"), // 1 = 0 + (+1) + BitPat("b???0_1111") -> BitPat("b010_00100"), // 0 = 0 + ( 0) + BitPat("b??01_1111") -> BitPat("b010_00010"), //-1 = 0 + (-1) + BitPat("b?011_1111") -> BitPat("b001_00100"), //-2 = -2 + ( 0) + BitPat("b0111_1111") -> BitPat("b001_00010") // -3 = -2 + (-1) + ), + BitPat("b001_00001") //-4 = -2 + (-2) + ) + // TruthTable( + // Seq( // 4 0 -4__2 1 0 -1 -2 + // BitPat("b????_???0") -> BitPat("b100_00100"), // 4 = +4 + ( 0) + // BitPat("b????_??01") -> BitPat("b100_00010"), // 3 = +4 + (-1) + // BitPat("b????_?011") -> BitPat("b010_10000"), // 2 = 0 + (+2) + // BitPat("b????_0111") -> BitPat("b010_01000"), // 1 = 0 + (+1) + // BitPat("b???0_1111") -> BitPat("b010_00100"), // 0 = 0 + ( 0) + // BitPat("b??01_1111") -> BitPat("b010_00010"), //-1 = 0 + (-1) + // BitPat("b?011_1111") -> BitPat("b010_00001"), //-2 = 0 + (-2) + // BitPat("b0111_1111") -> BitPat("b001_01000") //-3 = -4 + ( 1) + // ), + // BitPat("b001_00100") //-4 = -4 + ( 0) + // ) + } ) } @@ -59,12 +122,13 @@ object QDS { rWidth: Int, ohWidth: Int, partialDividerWidth: Int, - tables: Seq[Seq[Int]] + tables: Seq[Seq[Int]], + a: Int )(partialReminderSum: UInt, partialReminderCarry: UInt, partialDivider: UInt ): UInt = { - val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables)) + val m = Module(new QDS(rWidth, ohWidth, partialDividerWidth, tables, a)) m.input.partialReminderSum := partialReminderSum m.input.partialReminderCarry := partialReminderCarry m.input.partialDivider := partialDivider diff --git a/arithmetic/src/division/srt/srt8/SRT8.scala b/arithmetic/src/division/srt/srt8/SRT8.scala index a120a9b..f45196d 100644 --- a/arithmetic/src/division/srt/srt8/SRT8.scala +++ b/arithmetic/src/division/srt/srt8/SRT8.scala @@ -27,9 +27,8 @@ class SRT8( rTruncateWidth: Int = 4) extends Module { - val xLen: Int = dividendWidth + radixLog2 + 1 - val wLen: Int = xLen + radixLog2 - val ohWidth: Int = 10 + val xLen: Int = dividendWidth + radixLog2 + 1 + val wLen: Int = xLen + radixLog2 // IO val input = IO(Flipped(DecoupledIO(new SRTInput(dividendWidth, dividerWidth, n)))) @@ -43,7 +42,7 @@ class SRT8( // Control // sign of select quotient, true -> negative, false -> positive // sign of Cycle, true -> (counter === 0.U) - val qdsSign0, qdsSign1, isLastCycle, enable: Bool = Wire(Bool()) + val isLastCycle, enable: Bool = Wire(Bool()) // State // because we need a CSA to minimize the critical path @@ -68,58 +67,150 @@ class SRT8( output.bits.reminder := Mux(needCorrect, remainderCorrect, remainderNoCorrect)(wLen - 5, radixLog2) output.bits.quotient := Mux(needCorrect, quotientMinusOne, quotient) - // qds val rWidth: Int = 1 + radixLog2 + rTruncateWidth val tables: Seq[Seq[Int]] = SRTTable(1 << radixLog2, a, dTruncateWidth, rTruncateWidth).tablesToQDS + + val ohWidth: Int = a match { + case 7 => 10 + case 6 => 8 + case 5 => 8 + case 4 => 8 + } + // qds val selectedQuotientOH: UInt = - QDS(rWidth, ohWidth, dTruncateWidth - 1, tables)( + QDS(rWidth, ohWidth, dTruncateWidth - 1, tables, a)( leftShift(partialReminderSum, radixLog2).head(rWidth), leftShift(partialReminderCarry, radixLog2).head(rWidth), dividerNext.head(dTruncateWidth)(dTruncateWidth - 2, 0) //.1********* -> 1*** -> *** ) + // On-The-Fly conversion + val otf = OTF(radixLog2, n, ohWidth, a)(quotient, quotientMinusOne, selectedQuotientOH) - qdsSign0 := selectedQuotientOH(9, 8).orR - qdsSign1 := selectedQuotientOH(4, 3).orR - - val qHigh: UInt = selectedQuotientOH(9, 5) - val qLow: UInt = selectedQuotientOH(4, 0) - // csa for SRT8 -> CSA32+CSA32 - val dividerMap0 = VecInit((-2 to 2).map { - case -2 => divider << 3 // -8 - case -1 => divider << 2 // -4 - case 0 => 0.U // 0 - case 1 => Fill(2, 1.U(1.W)) ## ~(divider << 2) // 4 - case 2 => Fill(1, 1.U(1.W)) ## ~(divider << 3) // 8 - }) - val dividerMap1 = VecInit((-2 to 2).map { + val dividerLMap = VecInit((-2 to 2).map { case -2 => divider << 1 // -2 case -1 => divider // -1 case 0 => 0.U // 0 case 1 => Fill(1 + radixLog2, 1.U(1.W)) ## ~divider // 1 case 2 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) // 2 }) - val csa0 = addition.csa.c32( - VecInit( - leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), - leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0, - Mux1H(qHigh, dividerMap0) + + if (a == 7) { + val qHigh: UInt = selectedQuotientOH(9, 5) + val qLow: UInt = selectedQuotientOH(4, 0) + val qdsSign0: Bool = qHigh.head(2).orR + val qdsSign1: Bool = qLow.head(2).orR + // csa for SRT8 -> CSA32+CSA32 + val dividerHMap = VecInit((-2 to 2).map { + case -2 => divider << 3 // -8 + case -1 => divider << 2 // -4 + case 0 => 0.U // 0 + case 1 => Fill(2, 1.U(1.W)) ## ~(divider << 2) // 4 + case 2 => Fill(1, 1.U(1.W)) ## ~(divider << 3) // 8 + }) + val csa0 = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0, + Mux1H(qHigh, dividerHMap) + ) ) - ) - val csa1 = addition.csa.c32( - VecInit( - csa0(1).head(wLen - radixLog2), - leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1, - Mux1H(qLow, dividerMap1) + val csa1 = addition.csa.c32( + VecInit( + csa0(1).head(wLen - radixLog2), + leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1, + Mux1H(qLow, dividerLMap) + ) ) - ) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa1(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa1(0) << 1 + radixLog2) + } else if (a == 6) { + val qHigh: UInt = selectedQuotientOH(7, 5) + val qLow: UInt = selectedQuotientOH(4, 0) + val qdsSign0: Bool = qHigh.head(1).asBool + val qdsSign1: Bool = qLow.head(2).orR - // On-The-Fly conversion - val otf = OTF(radixLog2, n, ohWidth)(quotient, quotientMinusOne, selectedQuotientOH) + // csa for SRT8 -> CSA32+CSA32 + val dividerHMap = VecInit((-1 to 1).map { + case -1 => divider << 2 // -4 + case 0 => 0.U // 0 + case 1 => Fill(2, 1.U(1.W)) ## ~(divider << 2) // 4 + }) + val csa0 = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0, + Mux1H(qHigh, dividerHMap) + ) + ) + val csa1 = addition.csa.c32( + VecInit( + csa0(1).head(wLen - radixLog2), + leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1, + Mux1H(qLow, dividerLMap) + ) + ) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa1(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa1(0) << 1 + radixLog2) + } else if (a == 5) { + val qHigh: UInt = selectedQuotientOH(7, 5) + val qLow: UInt = selectedQuotientOH(4, 0) + val qdsSign0: Bool = qHigh.head(1).asBool + val qdsSign1: Bool = qLow.head(2).orR + + // csa for SRT8 -> CSA32+CSA32 + val dividerHMap = VecInit((-1 to 1).map { + case -1 => divider << 2 // -4 + case 0 => 0.U // 0 + case 1 => Fill(2, 1.U(1.W)) ## ~(divider << 2) // 4 + }) + val csa0 = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0, + Mux1H(qHigh, dividerHMap) + ) + ) + val csa1 = addition.csa.c32( + VecInit( + csa0(1).head(wLen - radixLog2), + leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1, + Mux1H(qLow, dividerLMap) + ) + ) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa1(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa1(0) << 1 + radixLog2) + } else if (a == 4) { + val qHigh: UInt = selectedQuotientOH(7, 5) + val qLow: UInt = selectedQuotientOH(4, 0) + val qdsSign0: Bool = qHigh.head(1).asBool + val qdsSign1: Bool = qLow.head(2).orR + + // csa for SRT8 -> CSA32+CSA32 + val dividerHMap = VecInit((-1 to 1).map { + case -1 => divider << 1 // -2 + case 0 => 0.U // 0 + case 1 => Fill(radixLog2, 1.U(1.W)) ## ~(divider << 1) // 2 + }) + val csa0 = addition.csa.c32( + VecInit( + leftShift(partialReminderSum, radixLog2).head(wLen - radixLog2), + leftShift(partialReminderCarry, radixLog2).head(wLen - radixLog2 - 1) ## qdsSign0, + Mux1H(qHigh, dividerHMap) + ) + ) + val csa1 = addition.csa.c32( + VecInit( + csa0(1).head(wLen - radixLog2), + leftShift(csa0(0), 1).head(wLen - radixLog2 - 1) ## qdsSign1, + Mux1H(qLow, dividerLMap) + ) + ) + partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa1(1) << radixLog2) + partialReminderCarryNext := Mux(input.fire, 0.U, csa1(0) << 1 + radixLog2) + } dividerNext := Mux(input.fire, input.bits.divider, divider) counterNext := Mux(input.fire, input.bits.counter, counter - 1.U) quotientNext := Mux(input.fire, 0.U, otf(0)) quotientMinusOneNext := Mux(input.fire, 0.U, otf(1)) - partialReminderSumNext := Mux(input.fire, input.bits.dividend, csa1(1) << radixLog2) - partialReminderCarryNext := Mux(input.fire, 0.U, csa1(0) << 1 + radixLog2) } diff --git a/arithmetic/src/utils/package.scala b/arithmetic/src/utils/package.scala index b3e3d73..2edf1ab 100644 --- a/arithmetic/src/utils/package.scala +++ b/arithmetic/src/utils/package.scala @@ -55,7 +55,7 @@ package object utils { else BitPat((x + (1 << w)).U(w.W)) } - + // left shift and keep the width of Bits def leftShift(x: Bits, n: Int): UInt = { val length: Int = x.getWidth diff --git a/arithmetic/tests/src/division/srt/SRTSpec.scala b/arithmetic/tests/src/division/srt/SRTSpec.scala index 4309e09..2d9c960 100644 --- a/arithmetic/tests/src/division/srt/SRTSpec.scala +++ b/arithmetic/tests/src/division/srt/SRTSpec.scala @@ -9,10 +9,10 @@ import utils.extend object SRTSpec extends TestSuite{ override def tests: Tests = Tests { test("SRT should draw PD") { - val srt = SRTTable(4,2,4,4) + val srt = SRTTable(8,5,5,5) // println(srt.tables) // println(srt.tablesToQDS) - srt.dumpGraph(srt.pd, os.root / "tmp" / "srt4-2-4-4.png") + srt.dumpGraph(srt.pd, os.root / "tmp" / "srt8-5-5-5.png") } } } diff --git a/arithmetic/tests/src/division/srt/SRTTest.scala b/arithmetic/tests/src/division/srt/SRTTest.scala index 588f44e..5c73846 100644 --- a/arithmetic/tests/src/division/srt/SRTTest.scala +++ b/arithmetic/tests/src/division/srt/SRTTest.scala @@ -46,12 +46,12 @@ object SRTTest extends TestSuite with ChiselUtestTester { val remainder: BigInt = dividend % divider val leftShiftWidthDividend: Int = zeroHeadDividend - guardWidth val leftShiftWidthDivider: Int = zeroHeadDivider - // println("dividend = %8x, dividend = %d ".format(dividend, dividend)) - // println("divider = %8x, divider = %d".format(divider, divider)) - // println("zeroHeadDividend = %d, dividend << zeroHeadDividend = %d".format(zeroHeadDividend, dividend << leftShiftWidthDividend)) - // println("zeroHeadDivider = %d, divider << zeroHeadDivider = %d".format(zeroHeadDivider, divider << leftShiftWidthDivider)) - // println("quotient = %d, remainder = %d".format(quotient, remainder)) - // println("counter = %d, needComputerWidth = %d".format(counter, needComputerWidth)) +// println("dividend = %8x, dividend = %d ".format(dividend, dividend)) +// println("divider = %8x, divider = %d".format(divider, divider)) +// println("zeroHeadDividend = %d, dividend << zeroHeadDividend = %d".format(zeroHeadDividend, dividend << leftShiftWidthDividend)) +// println("zeroHeadDivider = %d, divider << zeroHeadDivider = %d".format(zeroHeadDivider, divider << leftShiftWidthDivider)) +// println("quotient = %d, remainder = %d".format(quotient, remainder)) +// println("counter = %d, needComputerWidth = %d".format(counter, needComputerWidth)) // test testCircuit(new SRT(n, n, n, radixLog2, a, dTruncateWidth, rTruncateWidth), Seq(chiseltest.internal.NoThreadingAnnotation, @@ -69,7 +69,7 @@ object SRTTest extends TestSuite with ChiselUtestTester { if (dut.output.valid.peek().litValue == 1) { flag = true println(dut.output.bits.quotient.peek().litValue) - println(dut.output.bits.reminder.peek().litValue) + println(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider) utest.assert(dut.output.bits.quotient.peek().litValue == quotient) utest.assert(dut.output.bits.reminder.peek().litValue >> zeroHeadDivider == remainder) } @@ -79,9 +79,9 @@ object SRTTest extends TestSuite with ChiselUtestTester { dut.clock.step(scala.util.Random.nextInt(5)) } } - testcase(64) +// testcase(64) for( i <- 1 to 50){ - testcase(64,3,7,4) + testcase(n = 64, radixLog2 = 3, a = 7, dTruncateWidth = 4, rTruncateWidth = 4) } } }