implementation module StrictnessPropagation

import StdEnum, StdArray, StdTuple, StdBool
from StdList import map, foldl, !!, zip2, unzip, filter, reverse, instance length []
import SaplParser, Flavour
from Set import :: Set, newSet, fromList, member, insert, delete, union, unions, intersection, intersections, difference
from Map import get, put
import Maybe

isStrictArg {ps_constructors, ps_functions} {builtInFunctions, inlineFunctions} n nr_args i
	= checkCons
where
	checkCons = case get n ps_constructors of
					(Just cons) = if (nr_args < cons.nr_args || i >= cons.nr_args) False (isStrictVar (cons.args !! i))
								= checkFun 
								
	checkFun = case get n ps_functions of
					(Just args) = let largs = length args in if (nr_args < largs || i >= largs) False (isStrictVar (args !! i))
								= checkInline
								
	checkInline =  case get n inlineFunctions of
					(Just def) 	= if (nr_args < def.arity || i >= def.arity) False (def.strictness.[i] == '1')
							   	= False

doStrictnessPropagation :: !ParserState !Flavour ![FuncType] -> (![FuncType], !ParserState)
doStrictnessPropagation ps flavour funs 
	# (nfs, nps) = foldl (\(nfs,ps) f -> let (nf, nps) = propFunc ps flavour f in ([nf:nfs], nps)) ([], ps) funs
	= (reverse nfs, nps)

// TODO: if strictness is given to the arguments the whole propogation stuff
//		 should be recomputed again and again until a fixpoint...
// 		 Expect: if the functions are in the good order which is the case if the code is linked
propFunc :: !ParserState !Flavour !FuncType -> (!FuncType, !ParserState)
propFunc ps=:{ps_functions} flavour (FTFunc name body args) 
	= (FTFunc name nbody nargs, {ps & ps_functions = put (unpackVar name) nargs ps_functions})
where
	(ds, nbody) = (propBody ps flavour newSet body)
	nargs = map addStrictness args
	
	addStrictness var=:(StrictVar _ _) = var
	addStrictness var=:(NormalVar vn _) = if (member vn ds) (toStrictVar var) var	

propFunc ps _ f = (f, ps)

propBody :: !ParserState !Flavour !(Set String) !SaplTerm -> (!Set String, !SaplTerm)
propBody ps flavour sd body = walk sd body
where
	walk sd t=:(SVar var) = (insert (unpackVar var) sd, t)

	walk sd t=:(SApplication var args)
		// We can skip the new args, cannot contain let definitions...
		# nsds = map fst (map (walk newSet) strictArgs)
		= (unions [sd:nsds], t)
	where
		varName = unpackVar var
		nr_args = length args
		checkArg (arg, i) = isStrictArg ps flavour varName nr_args i
		strictArgs = map fst (filter checkArg (zip2 args [0..]))

	walk sd (SIf c l r) 
		# (sdl, nl) = walk newSet l
		# (sdr, nr) = walk newSet r		
		# (sdc, nc) = walk sd c
		= (union sdc (intersection sdl sdr), SIf nc nl nr)

	walk sd (SSelect p cases) 
		# (sdp, np) = walk sd p
		# (sdcs, ncases) = unzip (map walkcase cases)
		= (union sdp (intersections sdcs), SSelect np ncases)
	where
		walkcase (p, c) 
			# (sd, nc) = walk newSet c 
			= (difference sd (patternvars p), (p, nc))
		
		patternvars (PCons _ vars) = fromList (map unpackVar vars)
		patternvars _ = newSet 
	
	// It is supposed that bindings are topologically sorted
	walk sd (SLet body bnds)
		# (sdb, nbody) = walk newSet body
		# (sdl, nbnds) =  wbnds sdb (reverse bnds) [] // reverse is important
		= (union sd sdl, SLet nbody nbnds) 
	where
		wbnds sd [] nbnds = (sd, nbnds)
		wbnds sd [bnd:bnds] nbnds
			# nbnd = if (member vn sd) (toStrictBind bnd) bnd
			# nsd = walkbnd sd nbnd
			= wbnds nsd bnds [nbnd:nbnds]
		where
			vn = unpackVar (unpackBindVar bnd)

			// Delete itself, it dosn't need any more
			walkbnd sd (SaplLetDef (StrictVar vn _) body) = delete vn (fst (walk sd body)) // skip new body, it cannot be a let definition
			walkbnd sd (SaplLetDef (NormalVar vn _) body) = delete vn sd
		
	walk sd t = (sd, t)