Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/main/scala/esp/Generator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

package esp

import chisel3.Driver
import chisel3._

import esp.examples.CounterAccelerator
import esp.examples.{CounterAccelerator, MedianFilter}

object Generator {

def main(args: Array[String]): Unit = {
val examples: Seq[(String, String, () => AcceleratorWrapper)] =
Seq( ("CounterAccelerator", "Default", (a: Int) => new CounterAccelerator(a)) )
Seq( ("CounterAccelerator", "Default", (a: Int) => new CounterAccelerator(a)),
("MedianFilterAccelerator", "Default", (a: Int) => new MedianFilter(a, 1024, UInt(a.W))))
.flatMap( a => Seq(32, 64, 128).map(b => (a._1, s"${a._2}_dma$b", () => new AcceleratorWrapper(b, a._3))) )

examples.map { case (name, impl, gen) =>
Expand Down
138 changes: 138 additions & 0 deletions src/main/scala/esp/examples/MedianFilter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright 2019 IBM
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package esp.examples

import chisel3._
import chisel3.util.{log2Up, Enum, Valid}

import esp.{Config, ConfigIO, Implementation, Parameter, Specification}

import sys.process._

trait MedianFilterSpecification extends Specification {

override lazy val config: Config = Config(
name = "MedianFilter",
description = "Bitonic sort median filter",
memoryFootprintMiB = 1,
deviceId = 0xD,
param = Array(
Parameter(
name = "git Hash",
description = Some("Git short SHA hash of the repo used to generate this accelerator"),
value = Some(Integer.parseInt(("git log -n1 --format=%h" !!).filter(_ >= ' '), 16))),
Parameter(
name = "nRows",
description = Some("Number of rows in the input image")),
Parameter(
name = "nCols",
description = Some("Number of columns in the input image"))))

}

class MedianFilterState extends Bundle {
val dmaReadReq = Bool()
val dmaReadResp = Bool()
val dmaWriteReq = Bool()
val compute = Bool()
}

class MedianFilterRequest(val configIO: ConfigIO) extends Bundle {
val config = configIO
/* @todo brittle */
val readLength = UInt(64.W)
val respLength = UInt(64.W)
val state = new MedianFilterState
}

object MedianFilterRequest {

def init(configIO: ConfigIO) = {
val a = Wire(new Valid(new MedianFilterRequest(configIO)))
a.valid := false.B
a.bits.config.getElements.foreach(_ := DontCare)
a.bits.readLength := DontCare
a.bits.respLength := DontCare
a.bits.state := (new MedianFilterState).fromBits(0.U)
a
}

}

class MedianFilter[A <: Data](dmaWidth: Int, scratchpadSize: Int, dataType: A)
extends Implementation(dmaWidth) with MedianFilterSpecification {

require(dmaWidth == dataType.getWidth, "MedianFilter requires data type width to match dmaWidth")

override val implementationName: String = "Default_medianFilter" + dmaWidth

val scratchpad = SyncReadMem[A](scratchpadSize, dataType.cloneType)

val req = RegInit(MedianFilterRequest.init(io.config.get.cloneType))

/* Compute the maximum address that needs to be read via DMA. This is rounded up if needed. */
def maxAddr: UInt = {
val (remainder, dividend) = (io.config.get("nRows").asUInt * io.config.get("nCols").asUInt)
.toBools
.splitAt(dmaWidth / dataType.getWidth - 1)
if (remainder.isEmpty) {
Vec(dividend).asUInt
} else {
Vec(dividend).asUInt + Mux(remainder.reduce(_ || _), 1.U, 0.U)
}
}

when (!req.valid && io.enable) {
printf {
val a: Printable = io.config.get.elements
.map{ case (name, data) => p"[info] - $name: $data\n" }
.reduce(_ + _)
p"[info] enabled:\n" + a
}
req.valid := true.B
req.bits.config := io.config.get
req.bits.readLength := maxAddr
req.bits.respLength := 1.U
req.bits.state.getElements.map(_ := false.B)
req.bits.state.dmaReadReq := true.B
}

io.dma.readControl.valid := req.valid && req.bits.state.dmaReadReq
io.dma.readControl.bits.index := 0.U
io.dma.readControl.bits.length := req.bits.readLength
when (io.dma.readControl.fire) {
req.bits.state.dmaReadReq := false.B
req.bits.state.dmaReadResp := true.B
}

io.dma.readChannel.ready := req.valid && req.bits.state.dmaReadResp
when (req.valid && req.bits.state.dmaReadResp && io.dma.readChannel.fire) {
printf(p"[info] Read: ${io.dma.readChannel.bits}\n")
req.bits.respLength := req.bits.respLength + 1.U
scratchpad(req.bits.respLength) := io.dma.readChannel.bits
when (req.bits.respLength === req.bits.readLength) {
printf(p"[info] done\n")
req.bits.state.dmaReadResp := false.B
req.bits.state.compute := true.B
}
}

io.done := req.bits.state.compute
when (req.valid && req.bits.state.compute) {
req.valid := false.B
req.bits.state.getElements.map(_ := 0.U)
}

}
50 changes: 50 additions & 0 deletions src/main/scala/esp/examples/ShiftArray.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright 2019 IBM
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package esp.examples

import chisel3._
import chisel3.util.Valid
import chisel3.experimental.withReset

class ShiftArrayIO[A <: Data](rows: Int, cols: Int, gen: A) extends Bundle {

val in = Flipped(Valid(Vec(rows, gen.cloneType)))
val out = Valid(Vec(rows, Vec(cols, gen.cloneType)))

}

class ShiftArray[A <: Data](val rows: Int, val cols: Int, gen: A) extends Module {

val io = IO(new ShiftArrayIO(rows, cols, gen))

val regArray = Seq.fill(rows)(Seq.fill(cols)(Reg(gen)))
val count = RegInit(0.U(cols.W))

/* Shift regArray left when the input fires */
when (io.in.fire()) {
count := count ## 1.U
regArray
.zip(io.in.bits)
.foreach{ case (a, in) => a.foldLeft(in){ case (r, l) => l := r; l } }
}

/* Route the regArray to the output */
io.out.bits.flatten
.zip(regArray.flatten)
.foreach{ case (l, r) => l := r }

io.out.valid := count.toBools.last

}
46 changes: 46 additions & 0 deletions src/main/scala/esp/examples/ShiftArraySimple.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2019 IBM
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package esp.examples

import chisel3._
import chisel3.util.{log2Up, Valid}

class ShiftArraySimpleIO[A <: Data](rows: Int, cols: Int, gen: A) extends Bundle {

val in = Flipped(Valid(gen.cloneType))
val out = Valid(Vec(rows * cols, gen.cloneType))

}

class ShiftArraySimple[A <: Data](rows: Int, cols: Int, gen: A) extends Module {

val io = IO(new ShiftArraySimpleIO(rows, cols, gen))

val reg: Seq[A] = Seq.fill(rows * cols)(Reg(gen.cloneType))

val fullMask: UInt = RegInit(0.U((rows * cols).W))

val counterValid: UInt = RegInit(0.U(rows.W))

when (io.in.fire()) {
reg.foldLeft(io.in.bits){ case (r, l) => l := r; l }
fullMask := fullMask ## 1.U
counterValid := Mux(counterValid === (rows - 1).U, 0.U, counterValid + 1.U)
}

io.out.valid := (fullMask.toBools.last) && (counterValid === 0.U)
io.out.bits := reg

}
4 changes: 1 addition & 3 deletions src/main/scala/esp/simulation/Dma.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ object DmaRequest {
def init(memorySize: Int) = {
val a = Wire(new Valid(new DmaRequest(memorySize)))
a.valid := false.B
a.bits.index := DontCare
a.bits.length := DontCare
a.bits.tpe := DontCare
a.bits.getElements.foreach(_ := DontCare)
a
}
}
Expand Down
41 changes: 41 additions & 0 deletions src/main/scala/esp/simulation/Tile.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2019 IBM
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package esp.simulation

import chisel3._

import esp.{Config, ConfigIO, Implementation, Specification}

class TileIO(val espConfig: Config) extends Bundle {
val enable = Input(Bool())
val config = ConfigIO(espConfig).map(Input(_))
val done = Output(Bool())
val debug = Output(UInt(32.W))
}

class Tile(memorySize: Int, gen: => Specification with Implementation, initFile: Option[String] = None) extends Module {

lazy val io = IO(new TileIO(accelerator.config))

val accelerator: Implementation = Module(gen)
val dma = Module(new Dma(memorySize, UInt(accelerator.dmaWidth.W), initFile))

accelerator.io.dma <> dma.io
accelerator.io.enable := io.enable
accelerator.io.config.zip(io.config).map{ case (a, b) => a := b }
io.done := accelerator.io.done
io.debug := accelerator.io.debug

}
62 changes: 62 additions & 0 deletions src/test/scala/esptests/examples/MedianFilterSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2019 IBM
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package esptests.examples

import chisel3._
import chisel3.iotesters.{ChiselFlatSpec, Driver, AdvTester}

import esp.examples.MedianFilter
import esp.simulation.Tile

import java.io.File

/** A test that the [[CounterAccelerator]] asserts it's done when it should
* @param dut a [[CounterAccelerator]]
*/
class MedianFilterTester(dut: Tile) extends AdvTester(dut) {
def reset(): Unit = Seq(dut.io.enable).map(p => wire_poke(p, false))

def config(nRows: Int, nCols: Int): Unit = {
wire_poke(dut.io.config.get("nRows").asUInt, nRows)
wire_poke(dut.io.config.get("nCols").asUInt, nCols)
}

reset()
step(1)

config(3, 3)
step(1)

wire_poke(dut.io.enable, 1)

eventually(peek(dut.io.done) == 1)
}

class MedianFilterSpec extends ChiselFlatSpec {

val memFile: Option[String] = {
val resourceDir: File = new File(System.getProperty("user.dir"), "src/test/resources")
Some(new File(resourceDir, "linear-mem.txt").toString)
}

behavior of "MedianFilter"

it should "filter a 3x3 image to 1 pixel" in {
Driver(() => new Tile(1024, new MedianFilter(32, 1024, UInt(32.W)), memFile), "treadle") {
dut => new MedianFilterTester(dut)
} should be (true)
}

}
Loading