diff --git a/core/src/Ohua/ALang/Passes/Smap.hs b/core/src/Ohua/ALang/Passes/Smap.hs index 94282a9..01bcd15 100644 --- a/core/src/Ohua/ALang/Passes/Smap.hs +++ b/core/src/Ohua/ALang/Passes/Smap.hs @@ -38,16 +38,17 @@ collectSf :: Expression collectSf = Lit $ FunRefLit $ FunRef Refs.collect Nothing smapRewrite :: (Monad m, MonadGenBnd m) => Expression -> m Expression -smapRewrite (Let v a b) = Let v <$> smapRewrite a <*> smapRewrite b -smapRewrite (Lambda v e) = Lambda v <$> smapRewrite e -smapRewrite e@(Apply (Apply (Lit (FunRefLit (FunRef "ohua.lang/smap" Nothing))) lamExpr) dataGen) = do - lamExpr' <- smapRewrite lamExpr +smapRewrite = + rewriteM $ \case + PureFunction op _ `Apply` lamExpr `Apply` dataGen + | op == Refs.smap -> Just <$> do + lamExpr' <- smapRewrite lamExpr -- post traversal optimization - ctrlVar <- generateBindingWith "ctrl" - lamExpr'' <- liftIntoCtrlCtxt ctrlVar lamExpr' - let ((inBnd:[]), expr) = lambdaArgsAndBody lamExpr'' - d <- generateBindingWith "d" - let expr' = renameVar expr (Var inBnd, d) + ctrlVar <- generateBindingWith "ctrl" + lamExpr'' <- liftIntoCtrlCtxt ctrlVar lamExpr' + let ((inBnd:[]), expr) = lambdaArgsAndBody lamExpr'' + d <- generateBindingWith "d" + let expr' = renameVar expr (Var inBnd, d) -- [ohualang| -- let (d, $var:ctrlVar, size) = ohua.lang/smapFun $var:dataGen in -- let (a,b,c) = ctrl $var:ctrlVar a b c in @@ -55,14 +56,16 @@ smapRewrite e@(Apply (Apply (Lit (FunRefLit (FunRef "ohua.lang/smap" Nothing))) -- let resultList = collect size result in -- resultList -- (this breaks haddock) |] - size <- generateBindingWith "size" - ctrls <- generateBindingWith "ctrls" - result <- generateBindingWith "result" - resultList <- generateBindingWith "resultList" - return $ - Let ctrls (Apply smapSfFun dataGen) $ - mkDestructured [d, ctrlVar, size] ctrls $ - Let result expr' $ - Let resultList (Apply (Apply collectSf $ Var size) $ Var result) $ - Var resultList -smapRewrite e = return e + size <- generateBindingWith "size" + ctrls <- generateBindingWith "ctrls" + result <- generateBindingWith "result" + resultList <- generateBindingWith "resultList" + return $ + Let ctrls (Apply smapSfFun dataGen) $ + mkDestructured [d, ctrlVar, size] ctrls $ + Let result expr' $ + Let + resultList + (Apply (Apply collectSf $ Var size) $ Var result) $ + Var resultList + _ -> pure Nothing diff --git a/core/src/Ohua/DFLang/Passes.hs b/core/src/Ohua/DFLang/Passes.hs index 7e40a7e..34ef4bc 100644 --- a/core/src/Ohua/DFLang/Passes.hs +++ b/core/src/Ohua/DFLang/Passes.hs @@ -162,7 +162,7 @@ handleApplyExpr g = failWith $ "Expected apply but got: " <> show g -- | Inspect an expression expecting something which can be captured -- in a DFVar otherwise throws appropriate errors. -expectVar :: MonadError Error m => Expression -> m DFVar +expectVar :: (HasCallStack, MonadError Error m) => Expression -> m DFVar expectVar (Var bnd) = pure $ DFVar bnd -- TODO currently only allowed for the unitFn function -- expectVar r@PureFunction {} = @@ -171,7 +171,7 @@ expectVar (Var bnd) = pure $ DFVar bnd -- show (pretty r) expectVar (Lit l) = pure $ DFEnvVar l expectVar a = - failWith $ "Argument must be local binding or literal, was " <> show a + throwErrorS $ "Argument must be local binding or literal, was " <> show a -- In this function I use the so called 'Tardis' monad, which is a special state -- monad. It has one state that travels "forward" in time, which is the same as