FunctionCaching.scala [raw]


import stainless.lang._
import stainless.collection._

object FunctionCaching {

  case class FunCache(var cached: Map[BigInt, BigInt])

  def fun(x: BigInt)(implicit funCache: FunCache): BigInt = {
    funCache.cached.get(x) match {
      case None() =>
        val res = 2*x + 42
        funCache.cached = funCache.cached.updated(x, res)
        res
      case Some(cached) =>
        cached
    }
  } ensuring(res => old(funCache).cached.get(x) match {
    case None() => true
    case Some(v) => v == res
  })

  def funProperlyCached(x: BigInt, trash: List[BigInt]): Boolean = {
    implicit val cache = FunCache(Map())
    val res1 = fun(x)
    multipleCalls(trash, x)
    val res2 = fun(x)
    res1 == res2
  } holds

  def multipleCalls(args: List[BigInt], x: BigInt)(implicit funCache: FunCache): Unit = {
    require(funCache.cached.get(x).forall(_ == 2*x + 42))
    args match {
      case Nil() => ()
      case y::ys =>
        fun(y)
        multipleCalls(ys, x)
    }
  } ensuring(_ => funCache.cached.get(x).forall(_ == 2*x + 42))

}

back