Skip to content

Modular constant folding

The following code (taken from #16329):

func :: Int -> IO Int
func n = foldM step 0 xs
    where
    xs = map (n+) [1,2,3]
    step acc x =
        case x `mod` 3  of
            0 -> pure acc
            1 -> pure $ acc + 1
            2 -> pure $ acc + 2

func' n = foldl step 0 xs
    where
    xs = map (n+) [1,2,3]
    step acc x =
        case x `mod` 3 of
            0 -> acc
            1 -> acc + 1
            2 -> acc + 2

is simplified to the following core code:

func n =
    case n + 1 `mod` 3 of
        0 -> case n + 2 `mod` 3 of
            0 -> case n + 3 `mod` 3 of
                0 -> pure 0
                1 -> pure 1
                2 -> pure 2
            1 -> case n + 3 `mod` 3 of
                0 -> pure 1
                1 -> pure 2
                2 -> pure 3
            2 -> case n + 3 `mod` 3 of
                0 -> pure 2
                1 -> pure 3
                2 -> pure 4
        1 -> case n + 2 `mod` 3 of
            0 -> case n + 3 `mod` 3 of
                0 -> pure 1
                1 -> pure 2
                2 -> pure 3
            1 -> case n + 3 `mod` 3 of
                0 -> pure 2
                1 -> pure 3
                2 -> pure 4
            2 -> case n + 3 `mod` 3 of
                0 -> pure 3
                1 -> pure 4
                2 -> pure 5
        2 -> case n + 2 `mod` 3 of
            0 -> case n + 3 `mod` 3 of
                0 -> pure 2
                1 -> pure 3
                2 -> pure 4
            1 -> case n + 3 `mod` 3 of
                0 -> pure 3
                1 -> pure 4
                2 -> pure 5
            2 -> case n + 3 `mod` 3 of
                0 -> pure 4
                1 -> pure 5
                2 -> pure 6


func' n =
    join j2 w2 =
        join j1 w1 =
            case n + 3  `mod` 3 of
                0 -> w1
                1 -> w1 + 1
                2 -> w1 + 2
        in case n + 2 `mod` 3 of
            0 -> jump j1 w2
            1 -> jump j1 (w2 + 1)
            2 -> jump j1 (w2 + 2)
    in case n + 1 `mod` 3 of
        0 -> jump j2 0
        1 -> jump j2 1
        2 -> jump j2 2

Case-folding with modular arithmetic should remove the nesting.

Edited by Sylvain Henry
To upload designs, you'll need to enable LFS and have an admin enable hashed storage. More information