{-# LANGUAGE OverloadedStrings #-}
module Clash.Normalize.Transformations.SeparateArgs
( separateArguments
) where
import qualified Control.Lens as Lens
import Control.Monad.Writer (listen)
import qualified Data.List as List
import qualified Data.Monoid as Monoid
import GHC.Stack (HasCallStack)
import Clash.Core.HasType
import Clash.Core.Name (Name(..))
import Clash.Core.Subst (extendIdSubst, mkSubst, substTm)
import Clash.Core.Term (Term(..), collectArgsTicks, mkApps, mkLams, mkTicks)
import Clash.Core.Type (Type, mkPolyFunTy, splitFunForallTy)
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Util (Projections (..), shouldSplit)
import Clash.Core.Var (Id, TyVar, Var (..), isGlobalId, mkLocalId)
import Clash.Core.VarEnv (extendInScopeSet, uniqAway)
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
import Clash.Rewrite.Types (TransformContext(..), tcCache)
import Clash.Rewrite.Util (changed, mkDerivedName)
separateArguments :: HasCallStack => NormRewrite
separateArguments :: HasCallStack => NormRewrite
separateArguments TransformContext
ctx e0 :: Term
e0@(Lam Id
b Term
eb) = do
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
case separateLambda tcm ctx b eb of
Just Term
e1 -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e1
Maybe Term
Nothing -> Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e0
separateArguments (TransformContext InScopeSet
is0 Context
_) e :: Term
e@(Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks -> (Var Id
g, [Either Term Type]
args, [TickInfo]
ticks))
| Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
g = do
let ([Either TyVar Type]
argTys0,Type
resTy) = Type -> ([Either TyVar Type], Type)
splitFunForallTy (Id -> Type
forall a. HasType a => a -> Type
coreTypeOf Id
g)
(concat -> args1, Monoid.getAny -> hasChanged)
<- RewriteMonad
NormalizeState [[(Either TyVar Type, Either Term Type)]]
-> RewriteMonad
NormalizeState ([[(Either TyVar Type, Either Term Type)]], Any)
forall a.
RewriteMonad NormalizeState a
-> RewriteMonad NormalizeState (a, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
listen (((Either TyVar Type, Either Term Type)
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)])
-> [(Either TyVar Type, Either Term Type)]
-> RewriteMonad
NormalizeState [[(Either TyVar Type, Either Term Type)]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM ((Either TyVar Type
-> Either Term Type
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)])
-> (Either TyVar Type, Either Term Type)
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Either TyVar Type
-> Either Term Type
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)]
splitArg) ([Either TyVar Type]
-> [Either Term Type] -> [(Either TyVar Type, Either Term Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Either TyVar Type]
argTys0 [Either Term Type]
args))
if hasChanged then
let (argTys1,args2) = unzip args1
gTy = Type -> [Either TyVar Type] -> Type
mkPolyFunTy Type
resTy [Either TyVar Type]
argTys1
in return (mkApps (mkTicks (Var g {varType = gTy}) ticks) args2)
else
return e
where
splitArg
:: Either TyVar Type
-> Either Term Type
-> NormalizeSession [(Either TyVar Type,Either Term Type)]
splitArg :: Either TyVar Type
-> Either Term Type
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)]
splitArg Either TyVar Type
tv arg :: Either Term Type
arg@(Right Type
_) = [(Either TyVar Type, Either Term Type)]
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)]
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return [(Either TyVar Type
tv,Either Term Type
arg)]
splitArg Either TyVar Type
ty arg :: Either Term Type
arg@(Left Term
tmArg) = do
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
let argTy = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
tmArg
case shouldSplit tcm argTy of
Just ([Term] -> Term
_,Projections forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> Term -> m [Term]
projections,[Type]
_) -> do
tmArgs <- InScopeSet -> Term -> RewriteMonad NormalizeState [Term]
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> Term -> m [Term]
projections InScopeSet
is0 Term
tmArg
changed (map ((ty,) . Left) tmArgs)
Maybe ([Term] -> Term, Projections, [Type])
_ ->
[(Either TyVar Type, Either Term Type)]
-> RewriteMonad
NormalizeState [(Either TyVar Type, Either Term Type)]
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return [(Either TyVar Type
ty,Either Term Type
arg)]
separateArguments TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall a. a -> RewriteMonad NormalizeState a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC separateArguments #-}
separateLambda
:: TyConMap
-> TransformContext
-> Id
-> Term
-> Maybe Term
separateLambda :: TyConMap -> TransformContext -> Id -> Term -> Maybe Term
separateLambda TyConMap
tcm ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) Id
b Term
eb0 =
case TyConMap -> Type -> Maybe ([Term] -> Term, Projections, [Type])
shouldSplit TyConMap
tcm (Id -> Type
forall a. HasType a => a -> Type
coreTypeOf Id
b) of
Just ([Term] -> Term
dc, Projections
_, [Type]
argTys) ->
let
nm :: TmName
nm = TransformContext -> OccName -> TmName
mkDerivedName TransformContext
ctx (TmName -> OccName
forall a. Name a -> OccName
nameOcc (Id -> TmName
forall a. Var a -> Name a
varName Id
b))
bs0 :: [Id]
bs0 = (Type -> Id) -> [Type] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> TmName -> Id
`mkLocalId` TmName
nm) [Type]
argTys
(InScopeSet
is1, [Id]
bs1) = (InScopeSet -> Id -> (InScopeSet, Id))
-> InScopeSet -> [Id] -> (InScopeSet, [Id])
forall (t :: Type -> Type) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
List.mapAccumL InScopeSet -> Id -> (InScopeSet, Id)
forall {a}. InScopeSet -> Var a -> (InScopeSet, Var a)
newBinder InScopeSet
is0 [Id]
bs0
subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
is1) Id
b ([Term] -> Term
dc ((Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
bs1))
eb1 :: Term
eb1 = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"separateArguments" Subst
subst Term
eb0
in
Term -> Maybe Term
forall a. a -> Maybe a
Just (Term -> [Id] -> Term
mkLams Term
eb1 [Id]
bs1)
Maybe ([Term] -> Term, Projections, [Type])
_ ->
Maybe Term
forall a. Maybe a
Nothing
where
newBinder :: InScopeSet -> Var a -> (InScopeSet, Var a)
newBinder InScopeSet
isN0 Var a
x =
let
x' :: Var a
x' = InScopeSet -> Var a -> Var a
forall a. (Uniquable a, ClashPretty a) => InScopeSet -> a -> a
uniqAway InScopeSet
isN0 Var a
x
isN1 :: InScopeSet
isN1 = InScopeSet -> Var a -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
isN0 Var a
x'
in
(InScopeSet
isN1, Var a
x')
{-# SCC separateLambda #-}