some guys constructing a geodesic dome

Macro-based smart constructors

Mark Hopkins
@antiselfdual

April 13, 2016

Raw

val customerID  = 13753634L
val format      = "YYYY-MM-dd HH:mm:ss"  
val port        = 8080
val custDetails = ("Mary Wu", 13753634L, 1234, 1298, 431)

def f(customer: String, format: String, port: Int) = ...

Raw values

Raw

val customerID  = 13753634L
val format      = "YYYY-MM-dd HH:mm:ss"  
val port        = 8080
val custDetails = ("Mary Wu", 13753634L, 1234, 1298, 431) 

def f(customer: String, format: String, port: Int) = ...

Raw values

We'll inevitably mix things up, causing bugs.

How can we add safety?

... without incurring too great a cost (run time or ease-of-use)

Better

val customerID  = CustomerID(13753634L)
val format      = Format("YYYY-MM-dd HH:mm:ss")
val port        = Port(8080)
val custDetails = 
  Customer(
    name = "Mary Wu",
    id = customerId,
    balance = Aud(1234),
    ...
  ) 
  
def f(c: CustomerID, f: Format, p: Port) = ...

Types avoid mixups.

Cheaper

case class Port(value: Int)
case class Port(value: Int) extends AnyVal

Value classes remove the overhead of allocating a new object.

Safer

import scalaz.\/
import scalaz.syntax.either._

final class Port private(value: Int)

object Port {
  def apply(p: Int): String \/ Port =
    if (1024 <= p && p <= 65535) new Port(p).right
    else s"$p is out of range".left
}
> val p: Port = 
  Port(-1) valueOr { s => throw new IllegalArgumentException(s) }

java.lang.IllegalArgumentException: -1 is out of range

Variation

import scalaz.\/

final case class TimeFormat private(pattern: String) extends AnyVal

object TimeFormat {
  def parse(pattern: String): String \/ TimeFormat =
    \/.fromTryCatchNonFatal(DateTimeFormat.forPattern(pattern))
      .bimap(
        e => e.getMessage,
        _ => TimeFormat(pattern)
      )
}      
> val attempt = TimeFormat.parse("YYYY-MM-tt") 
> attempt valueOr { s => throw new IllegalArgumentException(s) }

java.lang.IllegalArgumentException: Illegal pattern component: tt

Variation

Maybe just returning an Option[_] is enough.

class CustomerID private (val id: Long) extends AnyVal

object CustomerID {
  def parse(id: Long): Option[CustomerID] =
    if (0 < id && id.toString.length <= 10)
      Some(Customer(id))
    else 
      None
}

Awkwardness

But sometimes we know our value is okay!

We want to write

def run(answer: Int, port: Port = Port(8080)) = ...
def testServerRunsOn8081 = {
    val port = Port(8081)
    
    Server.run(42, port)
    ...
}
val format       = Format("YYYY-MM-dd")
val testCustomer = CustomerID(123)

Awkwardness

Instead we now have to write

def run(quiddity: Int, port: Port = Port(8080).valueOr(error)) = ...
def testServerRunsOn8081 = {
    val port = Port(8081) valueOr error
    
    Server.run(42, port)
    ...
}
val format       = Format("YYY-MM-dd") valueOr error
val testCustomer = CustomerID(123).get

Tradeoff?

It seems we either have to

Neither option is particularly nice!

Badness

Worse, our tests for correctness are only at runtime.

If we forget a test...

We should have static checks for statically known values.

What we want

Can we use macros to improve the situation?

We want to allow literal parameters, with a compile-time check for correctness:

val f = Format("YYYY-MM-dd")

But otherwise require handling the result properly


val port = Port.parse(args.required("port")) match {
  case Some(p) => p
  case None    => error("Provide a valid port! That one was junk.")
}

Macros!

Can macros do that? Yes!

First, add the right dependency:

libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-reflect" % _)

Idea

We'll hide the constructor and provide a public parse method.

But provide another secret method that only the macro can access.

The macro will use this secret method to create a new instance, once it's verified the value is correct.

No!

We'll hide the constructor and provide a public parse method.

But provide another secret method that only the macro can access.

The macro will use this secret method to create a new instance, once it's verified the value is correct.

Unfortunately, that's not how macros work!

They just rewrite the AST.

So any code we generate has to involve only code that makes sense at the call site.

So no invocation of private methods.

Solution

Abort if invalid.

It's okay that this generated code "could throw an exception", because we've already checked to make sure it won't.

Implementation

import scala.reflect.macros.blackbox.Context

object TimeFormat {
  def parse(pattern: String): String \/ TimeFormat = ...
    
  def apply(pattern: String): TimeFormat = macro impl

  def impl(c: Context)(pattern: c.Tree) = {
    import c.universe._

    def fail(msg: String) = c.abort(c.enclosingPosition, msg)

    pattern match {
      case Literal(Constant(s: String)) =>
        parse(s) valueOr fail
        val tp = weakTypeOf[TimeFormat].typeSymbol.compantion
        q"""
          $tp.parse($pattern)
          .valueOr(sys.error)
        """
      case other => 
        fail(s"""
          $other is not a String literal.
          Use TimeFormat.parse instead.
        """)
    }
  }
}

Extensions

What else?

As well as validating, a macro smart constructor can also calculate values:


// the programmer hashed the file and copied in the hash
val hash1 = Hash("476713e6b5cbcf5c84ef0953d1329e1b") 

// hashed at compile time
val hash2 = Hash(new File("my-file.txt")) 

hash2.value == "476713e6b5cbcf5c84ef0953d1329e1b" // true

Unwrapping

Implicit unwrapping can sometimes be useful.

Then we gain safety, but don't lose ease-of-use.

case class Format private (formatter: DateTimeFormatter)

object Format {
  implicit def unwrap(f: Format) = f.formatter
  def apply(pattern: String): Format = macro impl
  def impl(c: Context)(pattern: c.Tree) = ...
}
> Format("YYYY-MM-dd").parseDateTime("2016-04-13")

2016-04-13T00:00:00.000+11:00

String interpolation

Macro string interpolators can accomplish some of the same goals.

Normal string interpolator


object TimeFormatLiterals {
  implicit class TimeFormatContext(sc: StringContext) {
    def format() =
      TimeFormat.parse(sc.parts.head)
}      
import TimeFormatLiterals._

> val attempt = format"YYYY-MM-tt"
> attempt valueOr { s => throw new IllegalArgumentException(s) }

java.lang.IllegalArgumentException: Illegal pattern component: tt

String interpolation

Macro string interpolators can accomplish some of the same goals.

Macro string interpolator

object DateTimeFormatterLiterals {
  implicit class FormatContext(val sc: StringContext) {
    def timeFormat(): TimeFormat = macro timeFormatImpl
  }

  def timeFormatImpl(c: Context)() = {
    import c.universe._
    import scala.util.{Failure, Success, Try}

    c.prefix.tree match {
      case q""" $object($implicitClass(${p: String})) """ =>
        Try(DateTimeFormat.forPattern(f)) match {
          case Failure(e) =>
            fail("Invalid time format. " + e.getMessage)
          case Success(_) =>
            q"_root_.mypackage.TimeFormat.parse($p) valueOr sys.error"
        }
    }
  }
}

Macro string interpolator usage


> format"EEE d MMMM" print DateTime.now()

Wednesday 13 April

> format"EEE d tttt" print DateTime.now()

Invalid time format. Illegal pattern component: tttt
 
  format"EEE d tttt" print DateTime.now()
         ^  
Compilation failed         

Demo code

github.com/mjhopkins/macro-smart-constructors

Twitter

@antiselfdual

End