Skip to content

Add common utils into package.scala #11

@pathikrit

Description

@pathikrit
import scala.collection.mutable

/**
  * Container to hold helpful "[pimps](http://www.artima.com/weblogs/viewpost.jsp?thread=179766)"
  * @see https://github.com/stacycurl/pimpathon/issues/120
  * @see https://github.com/cvogt/scala-extensions
  */
trait Pimps {
  implicit class StringInterpolations(sc: StringContext) {
    /**
      * Helper to do case insensitive string matches e.g.
      * "Hello" match {
      *   case ci"heLLO" => println("awesome")
      *   case _ => throw new IllegalStateException("Unreachable")
      * }
      */
    def ci = new {
      def unapply(other: String) = sc.parts.mkString.equalsIgnoreCase(other)
    }
  }

  implicit class AnyExtensions(a: Any) {
    /**
      * Safe type-cast
      * @return Some[A] if this is instance of A else None
      */
    def cast[A]: Option[A] = a partialMatch {
      case x: A @unchecked => x
    }
  }

  implicit class GenericExtensions[A](a: A) {
    def partialMatch[B](f: PartialFunction[A, B]): Option[B] = f lift a

    def in(set: Set[A]) = set contains a
    def notIn(set: Set[A]) = !in(set)

    @inline def isNull = a == null
    @inline def isNotNull = a != null
  }

  implicit class FuzzyDouble(x: Double)(implicit eps: Double = 1e-9) {
    def (y: Double) = x > (y + eps)
    def (y: Double) = x < (y - eps)
    def (y: Double) = (x - y).abs <= eps
    def (y: Double) = (x - y).abs > eps
  }

  implicit class TraversableExtensions[A](t: Traversable[A]) {
    def groupByIdentifier[K](f: A => K): Map[K, A] = t groupBy f mapValues { values =>
      values.head.ensuring(values.size == 1, s"Key maps to multiple values: $values")
    }

    def duplicates: Traversable[A] = t groupBy identity collect {case (k, vs) if vs.size > 1 => k}

    def zipWith[B](f: A => B): Traversable[(A, B)] = t map {i => i -> f(i)}

    def mapTo[B](f: A => B): Map[A, B] = zipWith(f).toUniqueMap

    /**
      * Let's you use X instead of double for-loops
      */
    def X[B](u: Traversable[B]): Traversable[(A, B)] = for {a <- t; b <- u} yield a -> b

    /**
      * Fix the fact that certain operations (e.g. maxBy) on Traversables throws exceptions on empty collections
      * Example usage: `grades.whenNonEmpty(_.max)`
      */
    def whenNonEmpty[B](f: t.type => B): Option[B] = when(t.nonEmpty)(f(t))

    def collectFirstDefined[B](f: PartialFunction[A, Option[B]]): Option[B] = t.view collect f collectFirst {case Some(x) => x}

    def counts = {
      val c = map[A] to 0
      t.foreach(i => c(i) += 1)
      c
    }
  }

  implicit class ArrayExtensions[A <: AnyRef](a: Array[A]) {
    /**
      * Sort a slice [from, until) of this array
      */
    def sort(from: Int, until: Int)(implicit cmp: Ordering[A]): Unit = java.util.Arrays.sort(a, from, until, cmp)
  }

  implicit class ArrayExtensions2[A](a: Array[A]) {
    def sort(from: Int, until: Int)(implicit cmp: Ordering[A]): Unit = a.view(from, until).sorted.copyToArray(a, from)
  }

  implicit class IndexedSeqExtensions[A](s: IndexedSeq[A]) {
    def withOffset(offset: Int): IndexedSeq[A] = Iterator.fill(offset)(null.asInstanceOf[A]) ++: s
  }

  implicit class PairsExtensions[A, B](t: Traversable[(A, B)]) {
    /**
      * Unlike `.toMap`, this ensures that each key is mapped uniquely to a value
      */
    def toUniqueMap: Map[A, B] = t.groupByIdentifier(_._1).mapValues(_._2)

    def toMultiMap: Map[A, Traversable[B]] = t.groupBy(_._1).mapValues(_.map(_._2))

    def swap: Traversable[(B, A)] = t.map(_.swap)

    /**
      * Helper for triple loop
      */
    def X[C](u: Traversable[C]): Traversable[(A, B, C)] = for {(a, b) <- t; c <- u} yield (a, b, c)
  }

  implicit class MapExtensions[K, V](map: Map[K, V]) {
    def invert: Map[V, Set[K]] = map.swap.toMultiMap.mapValues(_.toSet)

    def mapKeys[K2](f: K => K2): Map[K2, V] = (map map {case (k, v) => f(k) -> v}).toUniqueMap
  }

  implicit class IntExtensions(x: Int) {
    /**
      * @return m = x mod y such that we preserve the relation (x/y)*y + m == x
      */
    def mod(y: Int) = x - (x/y)*y

    /**
      * @return range that goes forward or backward depending on x and y
      */
    def -->(y: Int) = x to y by (if (x < y) 1 else -1)

    /**
      * Fix the fact that indexOf etc returns -1 instead of None when not found
      * @see http://stackoverflow.com/questions/25455831
      * Example usage: `students.indexOf("Greg").nonNegative`
      */
    def nonNegative: Option[Int] = when(x >= 0)(x)
  }

  implicit class LongExtensions(x: Long) {
    /**
      * count set bits
      */
    def bitCount = java.lang.Long.bitCount(x)

    def pow(i: Int): Long = if (i == 0) 1 else {
      val h = x pow (i/2)
      if (i%2 == 0) h*h else h*h*x
    }
  }

  val mod: Int = (1e9 + 7).toInt

  /**
    * Use this to sort descending e.g. `students.sortBy(_.score)(desc)`
    */
  def desc[A: Ordering]: Ordering[A] = asc[A].reverse
  def asc[A: Ordering]: Ordering[A] = implicitly[Ordering[A]]

  /**
    * Quick way to create a mutable Map with default value e.g. `val counts = map[String] to 0`
    */
  def map[K] = new {
    def to[V](default: V): mutable.Map[K, V] = mutable.Map.empty[K, V] withDefaultValue default
  }

  def memoize[A] = new {
    def using[B, C](encode: B => A)(f: B => C): (B => C) = new (B => C) {
      val cache = mutable.Map.empty[A, C]
      override def apply(key: B) = cache.getOrElseUpdate(encode(key), f(key))
    }

    def apply[B](f: A => B): (A => B) = new mutable.HashMap[A, B]() {
      override def apply(key: A) = getOrElseUpdate(key, f(key))
    }
  }

  @inline def when[A](check: Boolean)(f: => A): Option[A] = if (check) Some(f) else None

  @inline def repeat(n: Int)(f: => Unit): Unit = (1 to n).foreach(_ => f)
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions