trait Task 
{
	def run() : Unit
	def isDone : Boolean	
	
	def join(other : Task) = new Task
	{
		def run() = { Task.this.run(); other.run }
		def isDone = { Task.this.isDone && other.isDone }
	}
}
type Executor = (=>Unit)=>Task // nicer name

// convenience
implicit def run(code: =>Unit) = new Runnable
{
	def run() = code
}

class Parallel[A] (inner : Array[A])(implicit executors : Array[Executor])
{			
	def map[B](f : (A) => B) : Array[B] = //Parallel[B]
	{
		val result = new Array[B](inner.length)
		if( inner.length == 0 ) return result
		
		val execution = mapDivide ( inner.projection, result.projection, f, executors.projection )
		execution.run
		while(!execution.isDone) {}
		result//new Parallel(result, executors)
	}

	private[this] type SubArray[T] = Array.Projection[T] // shorter name
	
	private [this] def mapDivide[B](from : SubArray[A], to : SubArray[B], f : (A) => B, executors : SubArray[Executor]) : Task = 
	{
		assume (from.length == to.length && from.length > 0 && executors.length > 0)
				
		if( from.length == 1 || executors.length == 1) 
		{
			executors(0)
			{
				for (i <- 0 until from.length)
				{
					to(i) = f ( from(i) )
				}
			}
		}
		else
		{
			val left = mapDivide ( from.take(from.length/2), to.take(to.length/2), f, executors.take(executors.length/2) )
			val right = mapDivide ( from.drop(from.length/2), to.drop(to.length/2), f, executors.drop(executors.length/2) )
			
			left.join(right)
		}
	}
	
	def foreach(f : (A) => Unit) : Unit =
	{
		if( inner.length == 0 ) return
		
		val execution = foreachDivide ( inner.projection, f, executors.projection )
		execution.run
		while(!execution.isDone) {}
	}
	
	private [this] def foreachDivide[B](sub : SubArray[A], f : (A) => Unit, executors :SubArray[Executor]) : Task = 
	{
		assume (executors.length > 0)
				
		if( sub.length == 1 || executors.length == 1) 
		{
			executors(0)
			{
				for (elem <- sub)
				{
					f ( elem )
				}
			}
		}
		else
		{
			val left = foreachDivide ( sub.take(sub.length/2), f, executors.take(executors.length/2) )
			val right = foreachDivide ( sub.drop(sub.length/2), f, executors.drop(executors.length/2) )
			
			left.join(right)
		}
	}		

	// TODO: filter, flatMap
	
}


// use a heavy thread for each executor
implicit val executors = 
{	
	def thread(name : Int)( code : => Unit) = new Task
	{
		var done = true
		
		val thread = new Thread
		{
			override def run() 
			{
				//println("Thread start" + name )
				code
				this.synchronized { done = true; this.notifyAll }				
				//println("Thread end" + name )				
			}
		}
		
		def run() = 
		{			
			thread.synchronized 
			{ 
				if(!done) wait
				done = false 
				thread.start
			}			
		} 
		
		def isDone = thread.synchronized { done }
	}
	Array[Executor] ( thread(1)_ , thread(2)_, thread(3)_, thread(4)_  )
}




// time to 
def benchmark[A] (label : String)(code: =>A) : A =
{
	println( "Benchmark: " + label)
	import java.lang.System.{currentTimeMillis=>time}
	
	val start = time()
	val value = code
	val dur = time() - start;
	println( "Duration: " + dur)
	return value;
}


def fun( x : Double)(implicit loops : Int) : Double = 
{
	import Math._

	var y = 0.0	
	for (i <- 1 to loops)
	{
		y = y + cos(sin(exp(sqrt(x+i))))
	}
	y	
}

implicit val inner_loops = 100


val array = Array.range(0, 12*1000).zipWithIndex
var result = 0.0


println("\n")
result = benchmark("Sequential")
{
	var sum = 0.0; 
	for ( x <- array )
	{
		sum += ( fun (x._1) + x._2 )
	}
	sum
}
println ( "Result: " + result.toString )


println("\n")
result = benchmark("Parallel (synchronized)") 
{
	var sum = 0.0;
	for ( x <- new Parallel ( array ) )
	{
		val tmp = fun (x._1) + x._2
		// need to sync because of parallel access!
		synchronized
		{
			sum += tmp
		}
	}	
	sum
}
println ( "Result: " + result.toString )


println("\n")
result = benchmark("Parallel (2-stage)") 
{
	val tmp = for ( x <- new Parallel ( array ) ) yield
	{
		fun (x._1) + x._2
	}
	var sum = 0.0; 
	for ( x <- tmp )
	{
		sum += x
	}	
	sum
}
println ( "Result: " + result.toString )
