diff --git a/service/src/main/scala/com/intuit/superglue/lineage/LineageService.scala b/service/src/main/scala/com/intuit/superglue/lineage/LineageService.scala index 882d2f5..bb50900 100644 --- a/service/src/main/scala/com/intuit/superglue/lineage/LineageService.scala +++ b/service/src/main/scala/com/intuit/superglue/lineage/LineageService.scala @@ -58,11 +58,15 @@ class LineageService(val repository: SuperglueRepository) { if (backwardFringe.isEmpty && forwardFringe.isEmpty) return Future.successful(traversed) if (backwardDepth.contains(0) && forwardDepth.contains(0)) return Future.successful(traversed) - // Don't traverse backward if the backward depth is zero - val backwardLineageViews = if (backwardDepth.contains(0)) Future.successful(Set.empty) else { - // Gets the backward lineage for a given set of tables by querying the cache - getLineageViewResultSet(backwardFringe, Output) - .map(_.filter(view => !backwardVisited.contains(view.outputTableId))) + + val backwardLineageViews = backwardDepth match { + // Don't traverse backward if the backward depth is zero + case Some(backwordDepthValue) if backwordDepthValue == 0 => + Future.successful(Set.empty) + case _ => + // Gets the backward lineage for a given set of tables by querying the cache + getLineageViewResultSet(backwardFringe, Output) + .map(_.filter(view => !backwardVisited.contains(view.outputTableId))) } // Don't traverse forward if the forward depth is zero diff --git a/service/src/test/scala/com/intuit/superglue/lineage/LineageServiceTest.scala b/service/src/test/scala/com/intuit/superglue/lineage/LineageServiceTest.scala index bbceaa7..0f2339b 100644 --- a/service/src/test/scala/com/intuit/superglue/lineage/LineageServiceTest.scala +++ b/service/src/test/scala/com/intuit/superglue/lineage/LineageServiceTest.scala @@ -157,4 +157,29 @@ class LineageServiceTest extends FlatSpec { val futureLineageGraph = lineageService.tableLineage("START", None, None) Await.result(futureLineageGraph, 1 second) // This will fail if we loop infinitely } + + it should "collect full backward and some forward lineage" in { + val views = (-100 to 100) + .map(num => LineageView(TablePK(num), "", TablePK(num + 1), "", ScriptPK(0), StatementPK(0))) + .map { + case LineageView(pk @ TablePK(0), _, a, b, c, d) => LineageView(pk, "START", a, b, c, d) + case LineageView(a, b, pk @ TablePK(0), _, c, d) => LineageView(a, b, pk, "START", c, d) + case other => other + }.toSet + val repository = makeRepository(views) + val lineageService = new LineageService(repository) + val fut = lineageService.tableLineage("START", None, Some(2)) + val lineageGraph = Await.result(fut, 1 second) + + val expectedLinks = (-100 until 2) + .map(num => Link(TableNode(TablePK(num), ""), TableNode(TablePK(num + 1), ""))) + .map { + // The table name for id 0 should be "START" + case Link(TableNode(pk @ TablePK(0), _, _), out) => Link(TableNode(pk, "START"), out) + case Link(in, TableNode(pk @ TablePK(0), _, _)) => Link(in, TableNode(pk, "START")) + case other => other + }.toSet + + assert(expectedLinks == lineageGraph.links) + } }