
Make an arbitrary class in Scala as a monad instan

In order to make anything operable in monad context, if using Haskell - I just add implementation of class Monad for given type anywhere. So I don't touch a source of the data type definition at all. Like (something artificial)

data Z a = MyZLeft a | MyZRight a

swap (MyZLeft x) = MyZRight x
swap (MyZRight x) = MyZLeft x

instance Monad Z where
  return a = MyZRight a
  (>>=) x f = case x of
                MyZLeft s -> swap (f s)
                MyZRight s -> swap (f s)

so I'm not touching definition of Z, but make it as a monad

How do I do this in Scala? It seems that there's no way besides of mixing some traits in and defining methods map/flatMap/filter/withFilter ?


Take a look at scalaz:

// You could use implementation in the end of this answer instead of this import
import scalaz._, Scalaz._

sealed trait Z[T]
case class MyZLeft[T](t: T) extends Z[T]
case class MyZRight[T](t: T) extends Z[T]

def swap[T](z: Z[T]) = z match {
  case MyZLeft(t) => MyZRight(t)
  case MyZRight(t) => MyZLeft(t)

implicit object ZIsMonad extends Monad[Z] {
  def point[A](a: => A): Z[A] = MyZRight(a)
  def bind[A, B](fa: Z[A])(f: A => Z[B]): Z[B] = fa match {
    case MyZLeft(t) => swap(f(t))
    case MyZRight(t) => swap(f(t))


val z = 1.point[Z]
// Z[Int] = MyZRight(1)

z map { _ + 2 }
// Z[Int] = MyZLeft(3)

z >>= { i => MyZLeft(i + "abc") }
// Z[String] = MyZRight(1abc)

z >>= { i => (i + "abc").point[Z] }
// Z[String] = MyZLeft(1abc)

for-comprehensions (similar to do-notation):

for {
  i <- z
  j <- (i + 1).point[Z]
  k = i + j
} yield i * j * k
// Z[Int] = MyZRight(6)

See also Scalaz cheatsheet and Learning scalaz.

There is no magic in scalaz - you could implement this without scalaz.

Simplest implementation of Monad with syntax in case you don't want to use scalaz:

import scala.language.higherKinds

trait Monad[M[_]] {
  def point[A](a: => A): M[A]
  def bind[A, B](fa: M[A])(f: A => M[B]): M[B]

implicit class MonadPointer[A](a: A) {
  def point[M[_]: Monad] = implicitly[Monad[M]].point(a)

implicit class MonadWrapper[M[_]: Monad, A](t: M[A]) {
  private def m = implicitly[Monad[M]]
  def flatMap[B](f: A => M[B]): M[B] = m.bind(t)(f)
  def >>=[B](f: A => M[B]): M[B] = flatMap(f)
  def map[B](f: A => B): M[B] = m.bind(t)(a => m.point(f(a)))
  def flatten[B](implicit f: A => M[B]) = m.bind(t)(f)


To be a monad, a scala class isn't required to extend a particular class or mixin a particular trait. It merely needs to

  • be a type-parameterized class (SomeClass[T])
  • implement the "unit" method (may actually use any method name, but often named to match the className of the monad - c.f. List(x) and Try(doSomething()) )
  • implement the flatMap method (a.k.a. "bind"):

    Object SomeClass[T] {
        def SomeClass(t: T): SomeClass[T] = ...
    class SomeClass[T] {
        def flatMap[U](T => SomeClass[U]): SomeClass[U] = ...

This is definition via structural typing / duck typing as opposed to definition via type extension.

Additionally, to technically qualify as a Monad, implementation must satisfy the three monad laws (where m is of type SomeClass[T] and unit = SomeClass[T](t) for some t: T).

  1. Monad Identity Law: binding monad with unit leaves it unchanged

      m flatMap unit = m flatMap SomeClass(_) = m
  2. Monad Unit Law: binding unit with arbitrary function, is the same as applying that function to the unit's value

      unit flatMap f = SomeClass(t) flatMap f = f(t)           (where f: T => Any)
  3. Monad Composition Law: bind is associative

      (m flatMap f(_)) flatMap g(_) = m flatMap (t => f(t) flatMap(u => g(u))  
      (where f: T => SomeClass[U] and g: U => SomeClass[V] for some U and V)

Reference: http://james-iry.blogspot.com.au/2007/10/monads-are-elephants-part-3.html


If you're looking for a shortcut to implementation, you can define a common ancestor which provides a standard definition of flatMap:

trait Monad[T] {
  def map[U](f: T => U): Monad[U]
  def flatten: Monad[T]
  def flatMap[V](g: T => Monad[V]): Monad[V] = map(g) flatten

But you then must define concrete implementations for map & flatten. These are the result of design - there are literally infinite possibilities that meet these signatures (i.e. can't be automatically found within the ether & aren't defined via the laws of physics ;) )


Note even getting into the specifics of code implementation in Scala or Haskell, I want to note that one thing is having a class for which you know a way to add unit and multiplication, and another is when there's a general case.

In general case the only solution I know is to throw in free monad F |-> 1+F(1+F(1+F(...))). Which may as well not exist at all.

Otherwise, you have to prove that whatever you introduce as unit an multiplication satisfy monad laws (see the the response by GlenBest.

