From 65aa60b5409e0932ff771c54383ba8f8c133991b Mon Sep 17 00:00:00 2001
From: "Iavor S. Diatchki" <diatchki@galois.com>
Date: Tue, 18 Mar 2014 18:54:23 -0700
Subject: [PATCH] Implement ordering comparisons for type-level naturals and
 symbols.

This is done with two built-in type families: `CmpNat and `CmpSymbol`.
Both of these return a promoted `Ordering` type (EQ, LT, or GT).

(cherry picked from commit 5e4bdb5fc5e741522cbb787731422da3f12aa398)
---
 compiler/prelude/PrelNames.lhs   |   3 +
 compiler/prelude/TysWiredIn.lhs  |  16 ++++
 compiler/typecheck/TcTypeNats.hs | 142 ++++++++++++++++++++++++++++++-
 3 files changed, 160 insertions(+), 1 deletion(-)

diff --git a/compiler/prelude/PrelNames.lhs b/compiler/prelude/PrelNames.lhs
index 0a440032449c..c0c1615daa82 100644
--- a/compiler/prelude/PrelNames.lhs
+++ b/compiler/prelude/PrelNames.lhs
@@ -1472,6 +1472,7 @@ rep1TyConKey = mkPreludeTyConUnique 156
 typeNatKindConNameKey, typeSymbolKindConNameKey,
   typeNatAddTyFamNameKey, typeNatMulTyFamNameKey, typeNatExpTyFamNameKey,
   typeNatLeqTyFamNameKey, typeNatSubTyFamNameKey
+  , typeSymbolCmpTyFamNameKey, typeNatCmpTyFamNameKey
   :: Unique
 typeNatKindConNameKey     = mkPreludeTyConUnique 160
 typeSymbolKindConNameKey  = mkPreludeTyConUnique 161
@@ -1480,6 +1481,8 @@ typeNatMulTyFamNameKey    = mkPreludeTyConUnique 163
 typeNatExpTyFamNameKey    = mkPreludeTyConUnique 164
 typeNatLeqTyFamNameKey    = mkPreludeTyConUnique 165
 typeNatSubTyFamNameKey    = mkPreludeTyConUnique 166
+typeSymbolCmpTyFamNameKey = mkPreludeTyConUnique 167
+typeNatCmpTyFamNameKey    = mkPreludeTyConUnique 168
 
 ntTyConKey:: Unique
 ntTyConKey = mkPreludeTyConUnique 174
diff --git a/compiler/prelude/TysWiredIn.lhs b/compiler/prelude/TysWiredIn.lhs
index bf1907d161e0..aa1b169c0c98 100644
--- a/compiler/prelude/TysWiredIn.lhs
+++ b/compiler/prelude/TysWiredIn.lhs
@@ -20,6 +20,8 @@ module TysWiredIn (
         ltDataCon, ltDataConId,
         eqDataCon, eqDataConId,
         gtDataCon, gtDataConId,
+        promotedOrderingTyCon,
+        promotedLTDataCon, promotedEQDataCon, promotedGTDataCon,
 
         -- * Char
         charTyCon, charDataCon, charTyCon_RDR,
@@ -831,5 +833,19 @@ promotedTrueDataCon   = promoteDataCon trueDataCon
 promotedFalseDataCon  = promoteDataCon falseDataCon
 \end{code}
 
+Promoted Ordering
+
+\begin{code}
+promotedOrderingTyCon
+  , promotedLTDataCon
+  , promotedEQDataCon
+  , promotedGTDataCon
+  :: TyCon
+promotedOrderingTyCon = promoteTyCon orderingTyCon
+promotedLTDataCon     = promoteDataCon ltDataCon
+promotedEQDataCon     = promoteDataCon eqDataCon
+promotedGTDataCon     = promoteDataCon gtDataCon
+\end{code}
+
 
 
diff --git a/compiler/typecheck/TcTypeNats.hs b/compiler/typecheck/TcTypeNats.hs
index c19164bf4b04..37fc6e0cdbcd 100644
--- a/compiler/typecheck/TcTypeNats.hs
+++ b/compiler/typecheck/TcTypeNats.hs
@@ -12,9 +12,14 @@ import Coercion   ( Role(..) )
 import TcRnTypes  ( Xi )
 import CoAxiom    ( CoAxiomRule(..), BuiltInSynFamily(..) )
 import Name       ( Name, BuiltInSyntax(..) )
-import TysWiredIn ( typeNatKind, mkWiredInTyConName
+import TysWiredIn ( typeNatKind, typeSymbolKind
+                  , mkWiredInTyConName
                   , promotedBoolTyCon
                   , promotedFalseDataCon, promotedTrueDataCon
+                  , promotedOrderingTyCon
+                  , promotedLTDataCon
+                  , promotedEQDataCon
+                  , promotedGTDataCon
                   )
 import TysPrim    ( tyVarList, mkArrowKinds )
 import PrelNames  ( gHC_TYPELITS
@@ -23,6 +28,8 @@ import PrelNames  ( gHC_TYPELITS
                   , typeNatExpTyFamNameKey
                   , typeNatLeqTyFamNameKey
                   , typeNatSubTyFamNameKey
+                  , typeNatCmpTyFamNameKey
+                  , typeSymbolCmpTyFamNameKey
                   )
 import FastString ( FastString, fsLit )
 import qualified Data.Map as Map
@@ -39,6 +46,8 @@ typeNatTyCons =
   , typeNatExpTyCon
   , typeNatLeqTyCon
   , typeNatSubTyCon
+  , typeNatCmpTyCon
+  , typeSymbolCmpTyCon
   ]
 
 typeNatAddTyCon :: TyCon
@@ -103,6 +112,45 @@ typeNatLeqTyCon =
     , sfInteractInert = interactInertLeq
     }
 
+typeNatCmpTyCon :: TyCon
+typeNatCmpTyCon =
+  mkSynTyCon name
+    (mkArrowKinds [ typeNatKind, typeNatKind ] orderingKind)
+    (take 2 $ tyVarList typeNatKind)
+    [Nominal,Nominal]
+    (BuiltInSynFamTyCon ops)
+    NoParentTyCon
+
+  where
+  name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "CmpNat")
+                typeNatCmpTyFamNameKey typeNatCmpTyCon
+  ops = BuiltInSynFamily
+    { sfMatchFam      = matchFamCmpNat
+    , sfInteractTop   = interactTopCmpNat
+    , sfInteractInert = \_ _ _ _ -> []
+    }
+
+typeSymbolCmpTyCon :: TyCon
+typeSymbolCmpTyCon =
+  mkSynTyCon name
+    (mkArrowKinds [ typeSymbolKind, typeSymbolKind ] orderingKind)
+    (take 2 $ tyVarList typeSymbolKind)
+    [Nominal,Nominal]
+    (BuiltInSynFamTyCon ops)
+    NoParentTyCon
+
+  where
+  name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "CmpSymbol")
+                typeSymbolCmpTyFamNameKey typeSymbolCmpTyCon
+  ops = BuiltInSynFamily
+    { sfMatchFam      = matchFamCmpSymbol
+    , sfInteractTop   = interactTopCmpSymbol
+    , sfInteractInert = \_ _ _ _ -> []
+    }
+
+
+
+
 
 -- Make a binary built-in constructor of kind: Nat -> Nat -> Nat
 mkTypeNatFunTyCon2 :: Name -> BuiltInSynFamily -> TyCon
@@ -127,6 +175,8 @@ axAddDef
   , axMulDef
   , axExpDef
   , axLeqDef
+  , axCmpNatDef
+  , axCmpSymbolDef
   , axAdd0L
   , axAdd0R
   , axMul0L
@@ -137,6 +187,8 @@ axAddDef
   , axExp0R
   , axExp1R
   , axLeqRefl
+  , axCmpNatRefl
+  , axCmpSymbolRefl
   , axLeq0L
   , axSubDef
   , axSub0R
@@ -154,6 +206,25 @@ axExpDef = mkBinAxiom "ExpDef" typeNatExpTyCon $
 axLeqDef = mkBinAxiom "LeqDef" typeNatLeqTyCon $
               \x y -> Just $ bool (x <= y)
 
+axCmpNatDef   = mkBinAxiom "CmpNatDef" typeNatCmpTyCon
+              $ \x y -> Just $ ordering (compare x y)
+
+axCmpSymbolDef =
+  CoAxiomRule
+    { coaxrName      = fsLit "CmpSymbolDef"
+    , coaxrTypeArity = 2
+    , coaxrAsmpRoles = []
+    , coaxrRole      = Nominal
+    , coaxrProves    = \ts cs ->
+        case (ts,cs) of
+          ([s,t],[]) ->
+            do x <- isStrLitTy s
+               y <- isStrLitTy t
+               return (mkTyConApp typeSymbolCmpTyCon [s,t] ===
+                      ordering (compare x y))
+          _ -> Nothing
+    }
+
 axSubDef = mkBinAxiom "SubDef" typeNatSubTyCon $
               \x y -> fmap num (minus x y)
 
@@ -168,6 +239,10 @@ axExp1L     = mkAxiom1 "Exp1L"    $ \t -> (num 1 .^. t) === num 1
 axExp0R     = mkAxiom1 "Exp0R"    $ \t -> (t .^. num 0) === num 1
 axExp1R     = mkAxiom1 "Exp1R"    $ \t -> (t .^. num 1) === t
 axLeqRefl   = mkAxiom1 "LeqRefl"  $ \t -> (t <== t) === bool True
+axCmpNatRefl    = mkAxiom1 "CmpNatRefl"
+                $ \t -> (cmpNat t t) === ordering EQ
+axCmpSymbolRefl = mkAxiom1 "CmpSymbolRefl"
+                $ \t -> (cmpSymbol t t) === ordering EQ
 axLeq0L     = mkAxiom1 "Leq0L"    $ \t -> (num 0 <== t) === bool True
 
 typeNatCoAxiomRules :: Map.Map FastString CoAxiomRule
@@ -176,6 +251,8 @@ typeNatCoAxiomRules = Map.fromList $ map (\x -> (coaxrName x, x))
   , axMulDef
   , axExpDef
   , axLeqDef
+  , axCmpNatDef
+  , axCmpSymbolDef
   , axAdd0L
   , axAdd0R
   , axMul0L
@@ -186,6 +263,8 @@ typeNatCoAxiomRules = Map.fromList $ map (\x -> (coaxrName x, x))
   , axExp0R
   , axExp1R
   , axLeqRefl
+  , axCmpNatRefl
+  , axCmpSymbolRefl
   , axLeq0L
   , axSubDef
   ]
@@ -211,6 +290,12 @@ s .^. t = mkTyConApp typeNatExpTyCon [s,t]
 (<==) :: Type -> Type -> Type
 s <== t = mkTyConApp typeNatLeqTyCon [s,t]
 
+cmpNat :: Type -> Type -> Type
+cmpNat s t = mkTyConApp typeNatCmpTyCon [s,t]
+
+cmpSymbol :: Type -> Type -> Type
+cmpSymbol s t = mkTyConApp typeSymbolCmpTyCon [s,t]
+
 (===) :: Type -> Type -> Pair Type
 x === y = Pair x y
 
@@ -232,6 +317,25 @@ isBoolLitTy tc =
          | tc == promotedTrueDataCon  -> return True
          | otherwise                   -> Nothing
 
+orderingKind :: Kind
+orderingKind = mkTyConApp promotedOrderingTyCon []
+
+ordering :: Ordering -> Type
+ordering o =
+  case o of
+    LT -> mkTyConApp promotedLTDataCon []
+    EQ -> mkTyConApp promotedEQDataCon []
+    GT -> mkTyConApp promotedGTDataCon []
+
+isOrderingLitTy :: Type -> Maybe Ordering
+isOrderingLitTy tc =
+  do (tc1,[]) <- splitTyConApp_maybe tc
+     case () of
+       _ | tc1 == promotedLTDataCon -> return LT
+         | tc1 == promotedEQDataCon -> return EQ
+         | tc1 == promotedGTDataCon -> return GT
+         | otherwise                -> Nothing
+
 known :: (Integer -> Bool) -> TcType -> Bool
 known p x = case isNumLitTy x of
               Just a  -> p a
@@ -258,6 +362,8 @@ mkBinAxiom str tc f =
           _ -> Nothing
     }
 
+
+
 mkAxiom1 :: String -> (Type -> Pair Type) -> CoAxiomRule
 mkAxiom1 str f =
   CoAxiomRule
@@ -328,6 +434,25 @@ matchFamLeq [s,t]
         mbY = isNumLitTy t
 matchFamLeq _ = Nothing
 
+matchFamCmpNat :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
+matchFamCmpNat [s,t]
+  | Just x <- mbX, Just y <- mbY =
+    Just (axCmpNatDef, [s,t], ordering (compare x y))
+  | tcEqType s t = Just (axCmpNatRefl, [s], ordering EQ)
+  where mbX = isNumLitTy s
+        mbY = isNumLitTy t
+matchFamCmpNat _ = Nothing
+
+matchFamCmpSymbol :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
+matchFamCmpSymbol [s,t]
+  | Just x <- mbX, Just y <- mbY =
+    Just (axCmpSymbolDef, [s,t], ordering (compare x y))
+  | tcEqType s t = Just (axCmpSymbolRefl, [s], ordering EQ)
+  where mbX = isStrLitTy s
+        mbY = isStrLitTy t
+matchFamCmpSymbol _ = Nothing
+
+
 {-------------------------------------------------------------------------------
 Interact with axioms
 -------------------------------------------------------------------------------}
@@ -415,6 +540,17 @@ interactTopLeq [s,t] r
   mbZ = isBoolLitTy r
 interactTopLeq _ _ = []
 
+interactTopCmpNat :: [Xi] -> Xi -> [Pair Type]
+interactTopCmpNat [s,t] r
+  | Just EQ <- isOrderingLitTy r = [ s === t ]
+interactTopCmpNat _ _ = []
+
+interactTopCmpSymbol :: [Xi] -> Xi -> [Pair Type]
+interactTopCmpSymbol [s,t] r
+  | Just EQ <- isOrderingLitTy r = [ s === t ]
+interactTopCmpSymbol _ _ = []
+
+
 
 
 {-------------------------------------------------------------------------------
@@ -466,6 +602,10 @@ interactInertLeq _ _ _ _ = []
 
 
 
+
+
+
+
 {- -----------------------------------------------------------------------------
 These inverse functions are used for simplifying propositions using
 concrete natural numbers.
-- 
GitLab