TmOracle.hs 10.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
{-
Author: George Karachalias <george.karachalias@cs.kuleuven.be>

The term equality oracle. The main export of the module is function `tmOracle'.
-}

{-# LANGUAGE CPP, MultiWayIf #-}

module TmOracle (

        -- re-exported from PmExpr
        PmExpr(..), PmLit(..), SimpleEq, ComplexEq, PmVarEnv, falsePmExpr,
13 14
        eqPmLit, filterComplex, isNotPmExprOther, runPmPprM, lhsExprToPmExpr,
        hsExprToPmExpr, pprPmExprWithParens,
15 16

        -- the term oracle
17
        tmOracle, TmState, initialTmState, solveOneEq, extendSubst, canDiverge,
18 19

        -- misc.
20
        toComplex, exprDeepLookup, pmLitType, flattenPmVarEnv
21 22 23 24 25 26 27
    ) where

#include "HsVersions.h"

import PmExpr

import Id
28
import Name
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
import Type
import HsLit
import TcHsSyn
import MonadUtils
import Util

import qualified Data.Map as Map

{-
%************************************************************************
%*                                                                      *
                      The term equality oracle
%*                                                                      *
%************************************************************************
-}

-- | The type of substitutions.
46
type PmVarEnv = Map.Map Name PmExpr
47 48 49 50 51 52 53 54

-- | The environment of the oracle contains
--     1. A Bool (are there any constraints we cannot handle? (PmExprOther)).
--     2. A substitution we extend with every step and return as a result.
type TmOracleEnv = (Bool, PmVarEnv)

-- | Check whether a constraint (x ~ BOT) can succeed,
-- given the resulting state of the term oracle.
55
canDiverge :: Name -> TmState -> Bool
56 57 58 59 60 61 62 63 64 65 66 67 68
canDiverge x (standby, (_unhandled, env))
  -- If the variable seems not evaluated, there is a possibility for
  -- constraint x ~ BOT to be satisfiable.
  | PmExprVar y <- varDeepLookup env x -- seems not forced
  -- If it is involved (directly or indirectly) in any equality in the
  -- worklist, we can assume that it is already indirectly evaluated,
  -- as a side-effect of equality checking. If not, then we can assume
  -- that the constraint is satisfiable.
  = not $ any (isForcedByEq x) standby || any (isForcedByEq y) standby
  -- Variable x is already in WHNF so the constraint is non-satisfiable
  | otherwise = False

  where
69
    isForcedByEq :: Name -> ComplexEq -> Bool
70 71 72
    isForcedByEq y (e1, e2) = varIn y e1 || varIn y e2

-- | Check whether a variable is in the free variables of an expression
73
varIn :: Name -> PmExpr -> Bool
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
varIn x e = case e of
  PmExprVar y    -> x == y
  PmExprCon _ es -> any (x `varIn`) es
  PmExprLit _    -> False
  PmExprEq e1 e2 -> (x `varIn` e1) || (x `varIn` e2)
  PmExprOther _  -> False

-- | Flatten the DAG (Could be improved in terms of performance.).
flattenPmVarEnv :: PmVarEnv -> PmVarEnv
flattenPmVarEnv env = Map.map (exprDeepLookup env) env

-- | The state of the term oracle (includes complex constraints that cannot
-- progress unless we get more information).
type TmState = ([ComplexEq], TmOracleEnv)

-- | Initial state of the oracle.
initialTmState :: TmState
initialTmState = ([], (False, Map.empty))

93 94 95
-- | Solve a complex equality (top-level).
solveOneEq :: TmState -> ComplexEq -> Maybe TmState
solveOneEq solver_env@(_,(_,env)) complex
96
  = solveComplexEq solver_env -- do the actual *merging* with existing state
97 98
  $ simplifyComplexEq               -- simplify as much as you can
  $ applySubstComplexEq env complex -- replace everything we already know
99 100 101 102 103 104 105 106 107

-- | Solve a complex equality.
solveComplexEq :: TmState -> ComplexEq -> Maybe TmState
solveComplexEq solver_state@(standby, (unhandled, env)) eq@(e1, e2) = case eq of
  -- We cannot do a thing about these cases
  (PmExprOther _,_)            -> Just (standby, (True, env))
  (_,PmExprOther _)            -> Just (standby, (True, env))

  (PmExprLit l1, PmExprLit l2) -> case eqPmLit l1 l2 of
108
    -- See Note [Undecidable Equality for Overloaded Literals]
109 110
    True  -> Just solver_state
    False -> Nothing
111 112 113 114

  (PmExprCon c1 ts1, PmExprCon c2 ts2)
    | c1 == c2  -> foldlM solveComplexEq solver_state (zip ts1 ts2)
    | otherwise -> Nothing
115 116 117 118 119 120
  (PmExprCon _ [], PmExprEq t1 t2)
    | isTruePmExpr e1  -> solveComplexEq solver_state (t1, t2)
    | isFalsePmExpr e1 -> Just (eq:standby, (unhandled, env))
  (PmExprEq t1 t2, PmExprCon _ [])
    | isTruePmExpr e2   -> solveComplexEq solver_state (t1, t2)
    | isFalsePmExpr e2  -> Just (eq:standby, (unhandled, env))
121 122 123 124 125 126 127 128 129 130 131 132 133

  (PmExprVar x, PmExprVar y)
    | x == y    -> Just solver_state
    | otherwise -> extendSubstAndSolve x e2 solver_state

  (PmExprVar x, _) -> extendSubstAndSolve x e2 solver_state
  (_, PmExprVar x) -> extendSubstAndSolve x e1 solver_state

  (PmExprEq _ _, PmExprEq _ _) -> Just (eq:standby, (unhandled, env))

  _ -> Just (standby, (True, env)) -- I HATE CATCH-ALLS

-- | Extend the substitution and solve the (possibly updated) constraints.
134
extendSubstAndSolve :: Name -> PmExpr -> TmState -> Maybe TmState
135 136 137 138 139
extendSubstAndSolve x e (standby, (unhandled, env))
  = foldlM solveComplexEq new_incr_state (map simplifyComplexEq changed)
  where
    -- Apply the substitution to the worklist and partition them to the ones
    -- that had some progress and the rest. Then, recurse over the ones that
140 141
    -- had some progress. Careful about performance:
    -- See Note [Representation of Term Equalities] in deSugar/Check.hs
142 143 144
    (changed, unchanged) = partitionWith (substComplexEq x e) standby
    new_incr_state       = (unchanged, (unhandled, Map.insert x e env))

145 146 147 148 149 150 151 152 153 154 155 156 157
-- | When we know that a variable is fresh, we do not actually have to
-- check whether anything changes, we know that nothing does. Hence,
-- `extendSubst` simply extends the substitution, unlike what
-- `extendSubstAndSolve` does.
extendSubst :: Id -> PmExpr -> TmState -> TmState
extendSubst y e (standby, (unhandled, env))
  | isNotPmExprOther simpl_e
  = (standby, (unhandled, Map.insert x simpl_e env))
  | otherwise = (standby, (True, env))
  where
    x = idName y
    simpl_e = fst $ simplifyPmExpr $ exprDeepLookup env e

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
-- | Simplify a complex equality.
simplifyComplexEq :: ComplexEq -> ComplexEq
simplifyComplexEq (e1, e2) = (fst $ simplifyPmExpr e1, fst $ simplifyPmExpr e2)

-- | Simplify an expression. The boolean indicates if there has been any
-- simplification or if the operation was a no-op.
simplifyPmExpr :: PmExpr -> (PmExpr, Bool)
-- See Note [Deep equalities]
simplifyPmExpr e = case e of
  PmExprCon c ts -> case mapAndUnzip simplifyPmExpr ts of
                      (ts', bs) -> (PmExprCon c ts', or bs)
  PmExprEq t1 t2 -> simplifyEqExpr t1 t2
  _other_expr    -> (e, False) -- the others are terminals

-- | Simplify an equality expression. The equality is given in parts.
simplifyEqExpr :: PmExpr -> PmExpr -> (PmExpr, Bool)
-- See Note [Deep equalities]
simplifyEqExpr e1 e2 = case (e1, e2) of
  -- Varables
  (PmExprVar x, PmExprVar y)
    | x == y -> (truePmExpr, True)

  -- Literals
  (PmExprLit l1, PmExprLit l2) -> case eqPmLit l1 l2 of
182
    -- See Note [Undecidable Equality for Overloaded Literals]
183 184
    True  -> (truePmExpr,  True)
    False -> (falsePmExpr, True)
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214

  -- Can potentially be simplified
  (PmExprEq {}, _) -> case (simplifyPmExpr e1, simplifyPmExpr e2) of
    ((e1', True ), (e2', _    )) -> simplifyEqExpr e1' e2'
    ((e1', _    ), (e2', True )) -> simplifyEqExpr e1' e2'
    ((e1', False), (e2', False)) -> (PmExprEq e1' e2', False) -- cannot progress
  (_, PmExprEq {}) -> case (simplifyPmExpr e1, simplifyPmExpr e2) of
    ((e1', True ), (e2', _    )) -> simplifyEqExpr e1' e2'
    ((e1', _    ), (e2', True )) -> simplifyEqExpr e1' e2'
    ((e1', False), (e2', False)) -> (PmExprEq e1' e2', False) -- cannot progress

  -- Constructors
  (PmExprCon c1 ts1, PmExprCon c2 ts2)
    | c1 == c2 ->
        let (ts1', bs1) = mapAndUnzip simplifyPmExpr ts1
            (ts2', bs2) = mapAndUnzip simplifyPmExpr ts2
            (tss, _bss) = zipWithAndUnzip simplifyEqExpr ts1' ts2'
            worst_case  = PmExprEq (PmExprCon c1 ts1') (PmExprCon c2 ts2')
        in  if | not (or bs1 || or bs2) -> (worst_case, False) -- no progress
               | all isTruePmExpr  tss  -> (truePmExpr, True)
               | any isFalsePmExpr tss  -> (falsePmExpr, True)
               | otherwise              -> (worst_case, False)
    | otherwise -> (falsePmExpr, True)

  -- We cannot do anything about the rest..
  _other_equality -> (original, False)

  where
    original = PmExprEq e1 e2 -- reconstruct equality

215 216 217
-- | Apply an (un-flattened) substitution to a simple equality.
applySubstComplexEq :: PmVarEnv -> ComplexEq -> ComplexEq
applySubstComplexEq env (e1,e2) = (exprDeepLookup env e1, exprDeepLookup env e2)
218

219
-- | Apply an (un-flattened) substitution to a variable.
220
varDeepLookup :: PmVarEnv -> Name -> PmExpr
221 222 223 224 225
varDeepLookup env x
  | Just e <- Map.lookup x env = exprDeepLookup env e -- go deeper
  | otherwise                  = PmExprVar x          -- terminal
{-# INLINE varDeepLookup #-}

226
-- | Apply an (un-flattened) substitution to an expression.
227 228 229 230 231 232 233 234
exprDeepLookup :: PmVarEnv -> PmExpr -> PmExpr
exprDeepLookup env (PmExprVar x)    = varDeepLookup env x
exprDeepLookup env (PmExprCon c es) = PmExprCon c (map (exprDeepLookup env) es)
exprDeepLookup env (PmExprEq e1 e2) = PmExprEq (exprDeepLookup env e1)
                                               (exprDeepLookup env e2)
exprDeepLookup _   other_expr       = other_expr -- PmExprLit, PmExprOther

-- | External interface to the term oracle.
235 236
tmOracle :: TmState -> [ComplexEq] -> Maybe TmState
tmOracle tm_state eqs = foldlM solveOneEq tm_state eqs
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257

-- | Type of a PmLit
pmLitType :: PmLit -> Type -- should be in PmExpr but gives cyclic imports :(
pmLitType (PmSLit   lit) = hsLitType   lit
pmLitType (PmOLit _ lit) = overLitType lit

{- Note [Deep equalities]
~~~~~~~~~~~~~~~~~~~~~~~~~
Solving nested equalities is the most difficult part. The general strategy
is the following:

  * Equalities of the form (True ~ (e1 ~ e2)) are transformed to just
    (e1 ~ e2) and then treated recursively.

  * Equalities of the form (False ~ (e1 ~ e2)) cannot be analyzed unless
    we know more about the inner equality (e1 ~ e2). That's exactly what
    `simplifyEqExpr' tries to do: It takes e1 and e2 and either returns
    truePmExpr, falsePmExpr or (e1' ~ e2') in case it is uncertain. Note
    that it is not e but rather e', since it may perform some
    simplifications deeper.
-}