diff --git a/src/main/scala/esp/Generator.scala b/src/main/scala/esp/Generator.scala index 2254a5c..e813431 100644 --- a/src/main/scala/esp/Generator.scala +++ b/src/main/scala/esp/Generator.scala @@ -14,15 +14,16 @@ package esp -import chisel3.Driver +import chisel3._ -import esp.examples.CounterAccelerator +import esp.examples.{CounterAccelerator, MedianFilter} object Generator { def main(args: Array[String]): Unit = { val examples: Seq[(String, String, () => AcceleratorWrapper)] = - Seq( ("CounterAccelerator", "Default", (a: Int) => new CounterAccelerator(a)) ) + Seq( ("CounterAccelerator", "Default", (a: Int) => new CounterAccelerator(a)), + ("MedianFilterAccelerator", "Default", (a: Int) => new MedianFilter(a, 1024, UInt(a.W)))) .flatMap( a => Seq(32, 64, 128).map(b => (a._1, s"${a._2}_dma$b", () => new AcceleratorWrapper(b, a._3))) ) examples.map { case (name, impl, gen) => diff --git a/src/main/scala/esp/examples/MedianFilter.scala b/src/main/scala/esp/examples/MedianFilter.scala new file mode 100644 index 0000000..a80c792 --- /dev/null +++ b/src/main/scala/esp/examples/MedianFilter.scala @@ -0,0 +1,138 @@ +// Copyright 2019 IBM +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package esp.examples + +import chisel3._ +import chisel3.util.{log2Up, Enum, Valid} + +import esp.{Config, ConfigIO, Implementation, Parameter, Specification} + +import sys.process._ + +trait MedianFilterSpecification extends Specification { + + override lazy val config: Config = Config( + name = "MedianFilter", + description = "Bitonic sort median filter", + memoryFootprintMiB = 1, + deviceId = 0xD, + param = Array( + Parameter( + name = "git Hash", + description = Some("Git short SHA hash of the repo used to generate this accelerator"), + value = Some(Integer.parseInt(("git log -n1 --format=%h" !!).filter(_ >= ' '), 16))), + Parameter( + name = "nRows", + description = Some("Number of rows in the input image")), + Parameter( + name = "nCols", + description = Some("Number of columns in the input image")))) + +} + +class MedianFilterState extends Bundle { + val dmaReadReq = Bool() + val dmaReadResp = Bool() + val dmaWriteReq = Bool() + val compute = Bool() +} + +class MedianFilterRequest(val configIO: ConfigIO) extends Bundle { + val config = configIO + /* @todo brittle */ + val readLength = UInt(64.W) + val respLength = UInt(64.W) + val state = new MedianFilterState +} + +object MedianFilterRequest { + + def init(configIO: ConfigIO) = { + val a = Wire(new Valid(new MedianFilterRequest(configIO))) + a.valid := false.B + a.bits.config.getElements.foreach(_ := DontCare) + a.bits.readLength := DontCare + a.bits.respLength := DontCare + a.bits.state := (new MedianFilterState).fromBits(0.U) + a + } + +} + +class MedianFilter[A <: Data](dmaWidth: Int, scratchpadSize: Int, dataType: A) + extends Implementation(dmaWidth) with MedianFilterSpecification { + + require(dmaWidth == dataType.getWidth, "MedianFilter requires data type width to match dmaWidth") + + override val implementationName: String = "Default_medianFilter" + dmaWidth + + val scratchpad = SyncReadMem[A](scratchpadSize, dataType.cloneType) + + val req = RegInit(MedianFilterRequest.init(io.config.get.cloneType)) + + /* Compute the maximum address that needs to be read via DMA. This is rounded up if needed. */ + def maxAddr: UInt = { + val (remainder, dividend) = (io.config.get("nRows").asUInt * io.config.get("nCols").asUInt) + .toBools + .splitAt(dmaWidth / dataType.getWidth - 1) + if (remainder.isEmpty) { + Vec(dividend).asUInt + } else { + Vec(dividend).asUInt + Mux(remainder.reduce(_ || _), 1.U, 0.U) + } + } + + when (!req.valid && io.enable) { + printf { + val a: Printable = io.config.get.elements + .map{ case (name, data) => p"[info] - $name: $data\n" } + .reduce(_ + _) + p"[info] enabled:\n" + a + } + req.valid := true.B + req.bits.config := io.config.get + req.bits.readLength := maxAddr + req.bits.respLength := 1.U + req.bits.state.getElements.map(_ := false.B) + req.bits.state.dmaReadReq := true.B + } + + io.dma.readControl.valid := req.valid && req.bits.state.dmaReadReq + io.dma.readControl.bits.index := 0.U + io.dma.readControl.bits.length := req.bits.readLength + when (io.dma.readControl.fire) { + req.bits.state.dmaReadReq := false.B + req.bits.state.dmaReadResp := true.B + } + + io.dma.readChannel.ready := req.valid && req.bits.state.dmaReadResp + when (req.valid && req.bits.state.dmaReadResp && io.dma.readChannel.fire) { + printf(p"[info] Read: ${io.dma.readChannel.bits}\n") + req.bits.respLength := req.bits.respLength + 1.U + scratchpad(req.bits.respLength) := io.dma.readChannel.bits + when (req.bits.respLength === req.bits.readLength) { + printf(p"[info] done\n") + req.bits.state.dmaReadResp := false.B + req.bits.state.compute := true.B + } + } + + io.done := req.bits.state.compute + when (req.valid && req.bits.state.compute) { + req.valid := false.B + req.bits.state.getElements.map(_ := 0.U) + } + +} diff --git a/src/main/scala/esp/examples/ShiftArray.scala b/src/main/scala/esp/examples/ShiftArray.scala new file mode 100644 index 0000000..39c10d5 --- /dev/null +++ b/src/main/scala/esp/examples/ShiftArray.scala @@ -0,0 +1,50 @@ +// Copyright 2019 IBM +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package esp.examples + +import chisel3._ +import chisel3.util.Valid +import chisel3.experimental.withReset + +class ShiftArrayIO[A <: Data](rows: Int, cols: Int, gen: A) extends Bundle { + + val in = Flipped(Valid(Vec(rows, gen.cloneType))) + val out = Valid(Vec(rows, Vec(cols, gen.cloneType))) + +} + +class ShiftArray[A <: Data](val rows: Int, val cols: Int, gen: A) extends Module { + + val io = IO(new ShiftArrayIO(rows, cols, gen)) + + val regArray = Seq.fill(rows)(Seq.fill(cols)(Reg(gen))) + val count = RegInit(0.U(cols.W)) + + /* Shift regArray left when the input fires */ + when (io.in.fire()) { + count := count ## 1.U + regArray + .zip(io.in.bits) + .foreach{ case (a, in) => a.foldLeft(in){ case (r, l) => l := r; l } } + } + + /* Route the regArray to the output */ + io.out.bits.flatten + .zip(regArray.flatten) + .foreach{ case (l, r) => l := r } + + io.out.valid := count.toBools.last + +} diff --git a/src/main/scala/esp/examples/ShiftArraySimple.scala b/src/main/scala/esp/examples/ShiftArraySimple.scala new file mode 100644 index 0000000..2165208 --- /dev/null +++ b/src/main/scala/esp/examples/ShiftArraySimple.scala @@ -0,0 +1,46 @@ +// Copyright 2019 IBM +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package esp.examples + +import chisel3._ +import chisel3.util.{log2Up, Valid} + +class ShiftArraySimpleIO[A <: Data](rows: Int, cols: Int, gen: A) extends Bundle { + + val in = Flipped(Valid(gen.cloneType)) + val out = Valid(Vec(rows * cols, gen.cloneType)) + +} + +class ShiftArraySimple[A <: Data](rows: Int, cols: Int, gen: A) extends Module { + + val io = IO(new ShiftArraySimpleIO(rows, cols, gen)) + + val reg: Seq[A] = Seq.fill(rows * cols)(Reg(gen.cloneType)) + + val fullMask: UInt = RegInit(0.U((rows * cols).W)) + + val counterValid: UInt = RegInit(0.U(rows.W)) + + when (io.in.fire()) { + reg.foldLeft(io.in.bits){ case (r, l) => l := r; l } + fullMask := fullMask ## 1.U + counterValid := Mux(counterValid === (rows - 1).U, 0.U, counterValid + 1.U) + } + + io.out.valid := (fullMask.toBools.last) && (counterValid === 0.U) + io.out.bits := reg + +} diff --git a/src/main/scala/esp/simulation/Dma.scala b/src/main/scala/esp/simulation/Dma.scala index 94624db..26874bb 100644 --- a/src/main/scala/esp/simulation/Dma.scala +++ b/src/main/scala/esp/simulation/Dma.scala @@ -33,9 +33,7 @@ object DmaRequest { def init(memorySize: Int) = { val a = Wire(new Valid(new DmaRequest(memorySize))) a.valid := false.B - a.bits.index := DontCare - a.bits.length := DontCare - a.bits.tpe := DontCare + a.bits.getElements.foreach(_ := DontCare) a } } diff --git a/src/main/scala/esp/simulation/Tile.scala b/src/main/scala/esp/simulation/Tile.scala new file mode 100644 index 0000000..1dd607b --- /dev/null +++ b/src/main/scala/esp/simulation/Tile.scala @@ -0,0 +1,41 @@ +// Copyright 2019 IBM +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package esp.simulation + +import chisel3._ + +import esp.{Config, ConfigIO, Implementation, Specification} + +class TileIO(val espConfig: Config) extends Bundle { + val enable = Input(Bool()) + val config = ConfigIO(espConfig).map(Input(_)) + val done = Output(Bool()) + val debug = Output(UInt(32.W)) +} + +class Tile(memorySize: Int, gen: => Specification with Implementation, initFile: Option[String] = None) extends Module { + + lazy val io = IO(new TileIO(accelerator.config)) + + val accelerator: Implementation = Module(gen) + val dma = Module(new Dma(memorySize, UInt(accelerator.dmaWidth.W), initFile)) + + accelerator.io.dma <> dma.io + accelerator.io.enable := io.enable + accelerator.io.config.zip(io.config).map{ case (a, b) => a := b } + io.done := accelerator.io.done + io.debug := accelerator.io.debug + +} diff --git a/src/test/scala/esptests/examples/MedianFilterSpec.scala b/src/test/scala/esptests/examples/MedianFilterSpec.scala new file mode 100644 index 0000000..e9d4d21 --- /dev/null +++ b/src/test/scala/esptests/examples/MedianFilterSpec.scala @@ -0,0 +1,62 @@ +// Copyright 2019 IBM +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package esptests.examples + +import chisel3._ +import chisel3.iotesters.{ChiselFlatSpec, Driver, AdvTester} + +import esp.examples.MedianFilter +import esp.simulation.Tile + +import java.io.File + +/** A test that the [[CounterAccelerator]] asserts it's done when it should + * @param dut a [[CounterAccelerator]] + */ +class MedianFilterTester(dut: Tile) extends AdvTester(dut) { + def reset(): Unit = Seq(dut.io.enable).map(p => wire_poke(p, false)) + + def config(nRows: Int, nCols: Int): Unit = { + wire_poke(dut.io.config.get("nRows").asUInt, nRows) + wire_poke(dut.io.config.get("nCols").asUInt, nCols) + } + + reset() + step(1) + + config(3, 3) + step(1) + + wire_poke(dut.io.enable, 1) + + eventually(peek(dut.io.done) == 1) +} + +class MedianFilterSpec extends ChiselFlatSpec { + + val memFile: Option[String] = { + val resourceDir: File = new File(System.getProperty("user.dir"), "src/test/resources") + Some(new File(resourceDir, "linear-mem.txt").toString) + } + + behavior of "MedianFilter" + + it should "filter a 3x3 image to 1 pixel" in { + Driver(() => new Tile(1024, new MedianFilter(32, 1024, UInt(32.W)), memFile), "treadle") { + dut => new MedianFilterTester(dut) + } should be (true) + } + +} diff --git a/src/test/scala/esptests/examples/ShiftArraySimpleSpec.scala b/src/test/scala/esptests/examples/ShiftArraySimpleSpec.scala new file mode 100644 index 0000000..8939bbf --- /dev/null +++ b/src/test/scala/esptests/examples/ShiftArraySimpleSpec.scala @@ -0,0 +1,76 @@ +// Copyright 2019 IBM +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package esptests.examples + +import chisel3._ +import chisel3.iotesters.{ChiselFlatSpec, Driver, AdvTester} + +import esp.examples.ShiftArraySimple + +class ShiftArraySimpleTester[A <: Bits](dut: ShiftArraySimple[A]) extends AdvTester(dut) { + def init(): Unit = { + Seq(dut.io.in.valid).map(p => wire_poke(p, false)) + } + + def load(a: Seq[Int], valid: Seq[(Int, Int)]): Unit = { + require(a.size == valid.size) + wire_poke(dut.io.in.valid, true) + a.zip(valid).map{ case (aa, v) => + expect(peek(dut.io.out.valid) == v._1, s"dut.io.out.valid was NOT ${v._1} (was ${peek(dut.io.out.valid)})") + wire_poke(dut.io.in.bits, aa) + step(1) + expect(peek(dut.io.out.valid) == v._2, s"dut.io.out.valid was NOT ${v._2} (was ${peek(dut.io.out.valid)})") + } + wire_poke(dut.io.in.valid, false) + } + + reset(4) + + init() + step(1) + + load(0 until 9, Seq.fill(8)((0, 0)) :+ (0, 1)) + peek(dut.io.out.bits).zip(8 to 0 by -1).map{ case (out, expected) => expect(out == expected, s"$out != $expected") } + println(peek(dut.io.out.bits).mkString(", ")) + + load(9 until 12, Seq((1, 0), (0, 0), (0, 1))) + peek(dut.io.out.bits).zip(11 to 3 by -1).map{ case (out, expected) => expect(out == expected, s"$out != $expected") } + println(peek(dut.io.out.bits).mkString(", ")) + + load(12 until 15, Seq((1, 0), (0, 0), (0, 1))) + peek(dut.io.out.bits).zip(14 to 6 by -1).map{ case (out, expected) => expect(out == expected, s"$out != $expected") } + println(peek(dut.io.out.bits).mkString(", ")) + + reset(1) + load(15 until (15 + 9), Seq.fill(8)((0, 0)) :+ (0, 1)) + peek(dut.io.out.bits).zip(23 to 14 by -1).map{ case (out, expected) => expect(out == expected, s"$out != $expected") } + println(peek(dut.io.out.bits).mkString(", ")) + + load((15 + 9) until (15 + 9 + 3), Seq((1, 0), (0, 0), (0, 1))) + peek(dut.io.out.bits).zip(26 to 17 by -1).map{ case (out, expected) => expect(out == expected, s"$out != $expected") } + println(peek(dut.io.out.bits).mkString(", ")) +} + +class ShiftArraySimpleSpec extends ChiselFlatSpec { + + behavior of "ShiftArraySimple" + + it should "work" in { + Driver(() => new ShiftArraySimple(3, 3, UInt(16.W)), "verilator") { + dut => new ShiftArraySimpleTester(dut) + } should be (true) + } + +} diff --git a/src/test/scala/esptests/examples/ShiftArraySpec.scala b/src/test/scala/esptests/examples/ShiftArraySpec.scala new file mode 100644 index 0000000..4b4ae35 --- /dev/null +++ b/src/test/scala/esptests/examples/ShiftArraySpec.scala @@ -0,0 +1,82 @@ +// Copyright 2019 IBM +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package esptests.examples + +import chisel3._ +import chisel3.iotesters.{ChiselFlatSpec, Driver, AdvTester} + +import esp.examples.ShiftArray + +class ShiftArrayTester[A <: Bits](dut: ShiftArray[A]) extends AdvTester(dut) { + def init(): Unit = { + Seq(dut.io.in.valid).map(p => wire_poke(p, false)) + } + + def load(in: Seq[Int], outValid: Boolean): Unit = { + require(in.size % dut.rows == 0) + in + .grouped(dut.rows) + .foreach{ case a => + expect(dut.io.out.valid, outValid) + wire_poke(dut.io.in.valid, 1) + a.zipWithIndex.map{ case (b, i) => wire_poke(dut.io.in.bits(i), b) } + step(1) + wire_poke(dut.io.in.valid, 0) } + } + + def compare(in: Seq[Int]): Unit = { + expect(dut.io.out.valid, true) + val out: Seq[Int] = peek(dut.io.out.bits).map(_.toInt) + val expected: Seq[Int] = in + .grouped(dut.rows).toSeq + .transpose + .map(_.reverse) + .flatten + println(s"""read: ${out.mkString(", ")}""") + out.zip(expected).map{ case (o, e) => expect(o == e, s"($o should be $e)") } + } + + val input: Seq[Int] = 0 until dut.rows * dut.cols + + reset(4) + expect(dut.io.out.valid, false) + + init() + step(1) + + load(0 until 9, false) + compare(0 until 9) + + load(9 until 12, true) + compare(3 until 12) + + reset(1) + expect(dut.io.out.valid, false) + + load(12 until 21, false) + compare(12 until 21) +} + +class ShiftArraySpec extends ChiselFlatSpec { + + behavior of "ShiftArray" + + it should "present a 3x3 array" in { + Driver(() => new ShiftArray(3, 3, UInt(16.W)), "treadle") { + dut => new ShiftArrayTester(dut) + } should be (true) + } + +}