Commit 1c15bee5 authored by simonpj@microsoft.com's avatar simonpj@microsoft.com

Add the ability to derive instances of Functor, Foldable, Traversable

This patch is a straightforward extension of the 'deriving' mechanism.
The ability to derive classes Functor, Foldable, Traverable is controlled
by a single flag  -XDeriveFunctor.  (Maybe that's a poor name.)

Still to come: documentation

Thanks to twanvl for developing the patch
parent 97d5e75a
......@@ -220,6 +220,7 @@ data DynFlag
| Opt_RelaxedPolyRec
| Opt_StandaloneDeriving
| Opt_DeriveDataTypeable
| Opt_DeriveFunctor
| Opt_TypeSynonymInstances
| Opt_FlexibleContexts
| Opt_FlexibleInstances
......@@ -1771,6 +1772,7 @@ xFlags = [
( "UnboxedTuples", Opt_UnboxedTuples, const Supported ),
( "StandaloneDeriving", Opt_StandaloneDeriving, const Supported ),
( "DeriveDataTypeable", Opt_DeriveDataTypeable, const Supported ),
( "DeriveFunctor", Opt_DeriveFunctor, const Supported ),
( "TypeSynonymInstances", Opt_TypeSynonymInstances, const Supported ),
( "FlexibleContexts", Opt_FlexibleContexts, const Supported ),
( "FlexibleInstances", Opt_FlexibleInstances, const Supported ),
......@@ -1809,6 +1811,7 @@ glasgowExtsFlags = [
, Opt_TypeSynonymInstances
, Opt_StandaloneDeriving
, Opt_DeriveDataTypeable
, Opt_DeriveFunctor
, Opt_FlexibleContexts
, Opt_FlexibleInstances
, Opt_ConstrainedClassMethods
......
......@@ -131,6 +131,9 @@ basicKnownKeyNames
realFloatClassName, -- numeric
dataClassName,
isStringClassName,
applicativeClassName,
foldableClassName,
traversableClassName,
-- Numeric stuff
negateName, minusName,
......@@ -164,7 +167,7 @@ basicKnownKeyNames
-- Read stuff
readClassName,
-- Stable pointers
newStablePtrName,
......@@ -204,11 +207,8 @@ basicKnownKeyNames
randomClassName, randomGenClassName, monadPlusClassName,
-- Annotation type checking
toAnnotationWrapperName,
toAnnotationWrapperName
-- Booleans
andName, orName
-- The Either type
, eitherTyConName, leftDataConName, rightDataConName
......@@ -236,10 +236,11 @@ pRELUDE = mkBaseModule_ pRELUDE_NAME
gHC_PRIM, gHC_TYPES, gHC_BOOL, gHC_UNIT, gHC_ORDERING, gHC_GENERICS, gHC_CLASSES, gHC_BASE, gHC_ENUM,
gHC_SHOW, gHC_READ, gHC_NUM, gHC_INTEGER, gHC_INTEGER_INTERNALS, gHC_LIST, gHC_PARR,
gHC_TUPLE, dATA_TUPLE, dATA_EITHER, dATA_STRING, gHC_PACK, gHC_CONC, gHC_IO_BASE,
gHC_TUPLE, dATA_TUPLE, dATA_EITHER, dATA_STRING, dATA_FOLDABLE, dATA_TRAVERSABLE,
gHC_PACK, gHC_CONC, gHC_IO_BASE,
gHC_ST, gHC_ARR, gHC_STABLE, gHC_ADDR, gHC_PTR, gHC_ERR, gHC_REAL,
gHC_FLOAT, gHC_TOP_HANDLER, sYSTEM_IO, dYNAMIC, tYPEABLE, gENERICS,
dOTNET, rEAD_PREC, lEX, gHC_INT, gHC_WORD, mONAD, mONAD_FIX, aRROW,
dOTNET, rEAD_PREC, lEX, gHC_INT, gHC_WORD, mONAD, mONAD_FIX, aRROW, cONTROL_APPLICATIVE,
gHC_DESUGAR, rANDOM, gHC_EXTS, cONTROL_EXCEPTION_BASE :: Module
gHC_PRIM = mkPrimModule (fsLit "GHC.Prim") -- Primitive types and values
gHC_TYPES = mkPrimModule (fsLit "GHC.Types")
......@@ -261,6 +262,8 @@ gHC_TUPLE = mkPrimModule (fsLit "GHC.Tuple")
dATA_TUPLE = mkBaseModule (fsLit "Data.Tuple")
dATA_EITHER = mkBaseModule (fsLit "Data.Either")
dATA_STRING = mkBaseModule (fsLit "Data.String")
dATA_FOLDABLE = mkBaseModule (fsLit "Data.Foldable")
dATA_TRAVERSABLE= mkBaseModule (fsLit "Data.Traversable")
gHC_PACK = mkBaseModule (fsLit "GHC.Pack")
gHC_CONC = mkBaseModule (fsLit "GHC.Conc")
gHC_IO_BASE = mkBaseModule (fsLit "GHC.IOBase")
......@@ -285,6 +288,7 @@ gHC_WORD = mkBaseModule (fsLit "GHC.Word")
mONAD = mkBaseModule (fsLit "Control.Monad")
mONAD_FIX = mkBaseModule (fsLit "Control.Monad.Fix")
aRROW = mkBaseModule (fsLit "Control.Arrow")
cONTROL_APPLICATIVE = mkBaseModule (fsLit "Control.Applicative")
gHC_DESUGAR = mkBaseModule (fsLit "GHC.Desugar")
rANDOM = mkBaseModule (fsLit "System.Random")
gHC_EXTS = mkBaseModule (fsLit "GHC.Exts")
......@@ -389,9 +393,6 @@ returnM_RDR = nameRdrName returnMName
bindM_RDR = nameRdrName bindMName
failM_RDR = nameRdrName failMName
and_RDR :: RdrName
and_RDR = nameRdrName andName
left_RDR, right_RDR :: RdrName
left_RDR = nameRdrName leftDataConName
right_RDR = nameRdrName rightDataConName
......@@ -443,8 +444,9 @@ compose_RDR :: RdrName
compose_RDR = varQual_RDR gHC_BASE (fsLit ".")
not_RDR, getTag_RDR, succ_RDR, pred_RDR, minBound_RDR, maxBound_RDR,
range_RDR, inRange_RDR, index_RDR,
and_RDR, range_RDR, inRange_RDR, index_RDR,
unsafeIndex_RDR, unsafeRangeSize_RDR :: RdrName
and_RDR = varQual_RDR gHC_CLASSES (fsLit "&&")
not_RDR = varQual_RDR gHC_CLASSES (fsLit "not")
getTag_RDR = varQual_RDR gHC_BASE (fsLit "getTag")
succ_RDR = varQual_RDR gHC_ENUM (fsLit "succ")
......@@ -502,6 +504,13 @@ inlDataCon_RDR = dataQual_RDR gHC_GENERICS (fsLit "Inl")
inrDataCon_RDR = dataQual_RDR gHC_GENERICS (fsLit "Inr")
genUnitDataCon_RDR = dataQual_RDR gHC_GENERICS (fsLit "Unit")
fmap_RDR, pure_RDR, ap_RDR, foldable_foldr_RDR, traverse_RDR :: RdrName
fmap_RDR = varQual_RDR gHC_BASE (fsLit "fmap")
pure_RDR = varQual_RDR cONTROL_APPLICATIVE (fsLit "pure")
ap_RDR = varQual_RDR cONTROL_APPLICATIVE (fsLit "<*>")
foldable_foldr_RDR = varQual_RDR dATA_FOLDABLE (fsLit "foldr")
traverse_RDR = varQual_RDR dATA_TRAVERSABLE (fsLit "traverse")
----------------------
varQual_RDR, tcQual_RDR, clsQual_RDR, dataQual_RDR
:: Module -> FastString -> RdrName
......@@ -573,13 +582,19 @@ bindMName = methName gHC_BASE (fsLit ">>=") bindMClassOpKey
returnMName = methName gHC_BASE (fsLit "return") returnMClassOpKey
failMName = methName gHC_BASE (fsLit "fail") failMClassOpKey
-- Classes (Applicative, Foldable, Traversable)
applicativeClassName, foldableClassName, traversableClassName :: Name
applicativeClassName = clsQual cONTROL_APPLICATIVE (fsLit "Applicative") applicativeClassKey
foldableClassName = clsQual dATA_FOLDABLE (fsLit "Foldable") foldableClassKey
traversableClassName = clsQual dATA_TRAVERSABLE (fsLit "Traversable") traversableClassKey
-- Functions for GHC extensions
groupWithName :: Name
groupWithName = varQual gHC_EXTS (fsLit "groupWith") groupWithIdKey
-- Random PrelBase functions
fromStringName, otherwiseIdName, foldrName, buildName, augmentName,
mapName, appendName, andName, orName, assertName,
mapName, appendName, assertName,
breakpointName, breakpointCondName, breakpointAutoName,
opaqueTyConName :: Name
fromStringName = methName dATA_STRING (fsLit "fromString") fromStringClassOpKey
......@@ -589,8 +604,6 @@ buildName = varQual gHC_BASE (fsLit "build") buildIdKey
augmentName = varQual gHC_BASE (fsLit "augment") augmentIdKey
mapName = varQual gHC_BASE (fsLit "map") mapIdKey
appendName = varQual gHC_BASE (fsLit "++") appendIdKey
andName = varQual gHC_CLASSES (fsLit "&&") andIdKey
orName = varQual gHC_CLASSES (fsLit "||") orIdKey
assertName = varQual gHC_BASE (fsLit "assert") assertIdKey
breakpointName = varQual gHC_BASE (fsLit "breakpoint") breakpointIdKey
breakpointCondName= varQual gHC_BASE (fsLit "breakpointCond") breakpointCondIdKey
......@@ -889,6 +902,11 @@ randomGenClassKey = mkPreludeClassUnique 32
isStringClassKey :: Unique
isStringClassKey = mkPreludeClassUnique 33
applicativeClassKey, foldableClassKey, traversableClassKey :: Unique
applicativeClassKey = mkPreludeClassUnique 34
foldableClassKey = mkPreludeClassUnique 35
traversableClassKey = mkPreludeClassUnique 36
\end{code}
%************************************************************************
......@@ -1156,9 +1174,7 @@ rootMainKey, runMainKey :: Unique
rootMainKey = mkPreludeMiscIdUnique 55
runMainKey = mkPreludeMiscIdUnique 56
andIdKey, orIdKey, thenIOIdKey, lazyIdKey, assertErrorIdKey :: Unique
andIdKey = mkPreludeMiscIdUnique 57
orIdKey = mkPreludeMiscIdUnique 58
thenIOIdKey, lazyIdKey, assertErrorIdKey :: Unique
thenIOIdKey = mkPreludeMiscIdUnique 59
lazyIdKey = mkPreludeMiscIdUnique 60
assertErrorIdKey = mkPreludeMiscIdUnique 61
......@@ -1260,6 +1276,7 @@ fromStringClassOpKey = mkPreludeMiscIdUnique 125
toAnnotationWrapperIdKey :: Unique
toAnnotationWrapperIdKey = mkPreludeMiscIdUnique 126
---------------- Template Haskell -------------------
-- USES IdUniques 200-399
-----------------------------------------------------
......@@ -1325,7 +1342,8 @@ standardClassKeys = derivableClassKeys ++ numericClassKeys
++ [randomClassKey, randomGenClassKey,
functorClassKey,
monadClassKey, monadPlusClassKey,
isStringClassKey
isStringClassKey,
applicativeClassKey, foldableClassKey, traversableClassKey
]
\end{code}
......
......@@ -49,6 +49,8 @@ import ListSetOps
import Outputable
import FastString
import Bag
import Control.Monad
\end{code}
%************************************************************************
......@@ -566,15 +568,12 @@ mkEqnHelp orig tvs cls cls_tys tc_app mtheta
className cls `elem` typeableClassNames)
(derivingHiddenErr tycon)
; mayDeriveDataTypeable <- doptM Opt_DeriveDataTypeable
; newtype_deriving <- doptM Opt_GeneralizedNewtypeDeriving
; dflags <- getDOpts
; if isDataTyCon rep_tc then
mkDataTypeEqn orig mayDeriveDataTypeable tvs cls cls_tys
mkDataTypeEqn orig dflags tvs cls cls_tys
tycon tc_args rep_tc rep_tc_args mtheta
else
mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving
tvs cls cls_tys
mkNewTypeEqn orig dflags tvs cls cls_tys
tycon tc_args rep_tc rep_tc_args mtheta }
| otherwise
= failWithTc (derivingThingErr cls cls_tys tc_app
......@@ -631,13 +630,21 @@ famInstNotFound tycon tys
%************************************************************************
\begin{code}
mkDataTypeEqn :: InstOrigin -> Bool -> [Var] -> Class -> [Type]
-> TyCon -> [Type] -> TyCon -> [Type] -> Maybe ThetaType
-> TcRn EarlyDerivSpec -- Return 'Nothing' if error
mkDataTypeEqn orig mayDeriveDataTypeable tvs cls cls_tys
mkDataTypeEqn :: InstOrigin
-> DynFlags
-> [Var] -- Universally quantified type variables in the instance
-> Class -- Class for which we need to derive an instance
-> [Type] -- Other parameters to the class except the last
-> TyCon -- Type constructor for which the instance is requested (last parameter to the type class)
-> [Type] -- Parameters to the type constructor
-> TyCon -- rep of the above (for type families)
-> [Type] -- rep of the above
-> Maybe ThetaType -- Context of the instance, for standalone deriving
-> TcRn EarlyDerivSpec -- Return 'Nothing' if error
mkDataTypeEqn orig dflags tvs cls cls_tys
tycon tc_args rep_tc rep_tc_args mtheta
= case checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tc of
= case checkSideConditions dflags cls cls_tys rep_tc of
-- NB: pass the *representation* tycon to checkSideConditions
CanDerive -> mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
NonDerivableClass -> bale_out (nonStdErr cls)
......@@ -656,7 +663,7 @@ mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
| otherwise
= do { dfun_name <- new_dfun_name cls tycon
; loc <- getSrcSpanM
; let ordinary_constraints
; let ordinary_constraints_simple
= [ mkClassPred cls [arg_ty]
| data_con <- tyConDataCons rep_tc,
arg_ty <- ASSERT( isVanillaDataCon data_con )
......@@ -665,13 +672,31 @@ mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
-- No constraints for unlifted types
-- Where they are legal we generate specilised function calls
-- constraints on all subtypes for classes like Functor
ordinary_constraints_deep
= [ mkClassPred cls [deept_ty]
| data_con <- tyConDataCons rep_tc,
arg_ty <- ASSERT( isVanillaDataCon data_con )
dataConInstOrigArgTys data_con (rep_tc_args++[mkTyVarTy dummy_ty]),
deept_ty <- deepSubtypesContaining dummy_ty arg_ty,
not (isUnLiftedType deept_ty) ]
where dummy_ty = last (tyConTyVars tycon) -- don't substitute the last var, this might not be a good idea
ordinary_constraints
| getUnique cls == functorClassKey = ordinary_constraints_deep
| getUnique cls == foldableClassKey = ordinary_constraints_deep
| getUnique cls == traversableClassKey = ordinary_constraints_deep
| otherwise = ordinary_constraints_simple
-- See Note [Superclasses of derived instance]
sc_constraints = substTheta (zipOpenTvSubst (classTyVars cls) inst_tys)
(classSCTheta cls)
inst_tys = [mkTyConApp tycon tc_args]
stupid_subst = zipTopTvSubst (tyConTyVars rep_tc) rep_tc_args
nonfree_tycon_vars = dropTail (classArity cls) (tyConTyVars rep_tc)
stupid_subst = zipTopTvSubst nonfree_tycon_vars rep_tc_args
stupid_constraints = substTheta stupid_subst (tyConStupidTheta rep_tc)
all_constraints = stupid_constraints ++ sc_constraints ++ ordinary_constraints
spec = DS { ds_loc = loc, ds_orig = orig
......@@ -712,6 +737,7 @@ mk_typeable_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
, ds_tc = rep_tc, ds_tc_args = rep_tc_args
, ds_theta = mtheta `orElse` [], ds_newtype = False }) }
------------------------------------------------------------------
-- Check side conditions that dis-allow derivability for particular classes
-- This is *apart* from the newtype-deriving mechanism
......@@ -724,10 +750,10 @@ data DerivStatus = CanDerive
| DerivableClassError SDoc -- Standard class, but can't do it
| NonDerivableClass -- Non-standard class
checkSideConditions :: Bool -> Class -> [TcType] -> TyCon -> DerivStatus
checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tc
checkSideConditions :: DynFlags -> Class -> [TcType] -> TyCon -> DerivStatus
checkSideConditions dflags cls cls_tys rep_tc
| Just cond <- sideConditions cls
= case (cond (mayDeriveDataTypeable, rep_tc)) of
= case (cond (dflags, rep_tc)) of
Just err -> DerivableClassError err -- Class-specific error
Nothing | null cls_tys -> CanDerive
| otherwise -> DerivableClassError ty_args_why -- e.g. deriving( Eq s )
......@@ -748,13 +774,17 @@ sideConditions cls
| cls_key == ixClassKey = Just (cond_std `andCond` cond_enumOrProduct)
| cls_key == boundedClassKey = Just (cond_std `andCond` cond_enumOrProduct)
| cls_key == dataClassKey = Just (cond_mayDeriveDataTypeable `andCond` cond_std `andCond` cond_noUnliftedArgs)
| cls_key == functorClassKey = Just (cond_mayDeriveFunctor `andCond` cond_std `andCond` cond_functorOK True)
| cls_key == foldableClassKey = Just (cond_mayDeriveFunctor `andCond` cond_std `andCond` cond_functorOK False)
| cls_key == traversableClassKey = Just (cond_mayDeriveFunctor `andCond` cond_std `andCond` cond_functorOK False)
| getName cls `elem` typeableClassNames = Just (cond_mayDeriveDataTypeable `andCond` cond_typeableOK)
| otherwise = Nothing
where
cls_key = getUnique cls
type Condition = (Bool, TyCon) -> Maybe SDoc
-- Bool is whether or not we are allowed to derive Data and Typeable
type Condition = (DynFlags, TyCon) -> Maybe SDoc
-- first Bool is whether or not we are allowed to derive Data and Typeable
-- second Bool is whether or not we are allowed to derive Functor
-- TyCon is the *representation* tycon if the
-- data type is an indexed one
-- Nothing => OK
......@@ -835,13 +865,47 @@ cond_typeableOK (_, rep_tc)
fam_inst = quotes (pprSourceTyCon rep_tc) <+>
ptext (sLit "is a type family")
cond_functorOK :: Bool -> Condition
-- OK for Functor class
-- Currently: (a) at least one argument
-- (b) don't use argument contravariantly
-- (c) don't use argument in the wrong place, e.g. data T a = T (X a a)
-- (d) optionally: don't use function types
cond_functorOK allowFunctions (_, rep_tc) = msum (map check con_types)
where
data_cons = tyConDataCons rep_tc
con_types = concatMap dataConOrigArgTys data_cons
check = functorLikeTraverse
Nothing
Nothing
(Just covariant)
(\x y -> if allowFunctions then x `mplus` y else Just functions)
(\_ xs -> msum xs)
(\_ x -> x)
(Just wrong_arg)
(\_ x -> x)
(last (tyConTyVars rep_tc))
covariant = quotes (pprSourceTyCon rep_tc) <+>
ptext (sLit "uses the type variable in a function argument")
functions = quotes (pprSourceTyCon rep_tc) <+>
ptext (sLit "contains function types")
wrong_arg = quotes (pprSourceTyCon rep_tc) <+>
ptext (sLit "uses the type variable in an argument other than the last")
cond_mayDeriveDataTypeable :: Condition
cond_mayDeriveDataTypeable (mayDeriveDataTypeable, _)
| mayDeriveDataTypeable = Nothing
cond_mayDeriveDataTypeable (dflags, _)
| dopt Opt_DeriveDataTypeable dflags = Nothing
| otherwise = Just why
where
why = ptext (sLit "You need -XDeriveDataTypeable to derive an instance for this class")
cond_mayDeriveFunctor :: Condition
cond_mayDeriveFunctor (dflags, _)
| dopt Opt_DeriveFunctor dflags = Nothing
| otherwise = Just why
where
why = ptext (sLit "You need -XDeriveFunctor to derive an instance for this class")
std_class_via_iso :: Class -> Bool
std_class_via_iso clas -- These standard classes can be derived for a newtype
-- using the isomorphism trick *even if no -fglasgow-exts*
......@@ -890,11 +954,11 @@ a context for the Data instances:
%************************************************************************
\begin{code}
mkNewTypeEqn :: InstOrigin -> Bool -> Bool -> [Var] -> Class
mkNewTypeEqn :: InstOrigin -> DynFlags -> [Var] -> Class
-> [Type] -> TyCon -> [Type] -> TyCon -> [Type]
-> Maybe ThetaType
-> TcRn EarlyDerivSpec
mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
mkNewTypeEqn orig dflags tvs
cls cls_tys tycon tc_args rep_tycon rep_tc_args mtheta
-- Want: instance (...) => cls (cls_tys ++ [tycon tc_args]) where ...
| can_derive_via_isomorphism && (newtype_deriving || std_class_via_iso cls)
......@@ -919,7 +983,8 @@ mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
| newtype_deriving -> bale_out cant_derive_err -- Too hard, even with newtype deriving
| otherwise -> bale_out non_std_err -- Try newtype deriving!
where
check_conditions = checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tycon
newtype_deriving = dopt Opt_GeneralizedNewtypeDeriving dflags
check_conditions = checkSideConditions dflags cls cls_tys rep_tycon
bale_out msg = failWithTc (derivingThingErr cls cls_tys inst_ty msg)
non_std_err = nonStdErr cls $$
......@@ -1292,6 +1357,9 @@ genDerivBinds loc fix_env clas tycon
,(showClassKey, gen_Show_binds fix_env)
,(readClassKey, gen_Read_binds fix_env)
,(dataClassKey, gen_Data_binds)
,(functorClassKey, gen_Functor_binds)
,(foldableClassKey, gen_Foldable_binds)
,(traversableClassKey, gen_Traversable_binds)
]
\end{code}
......
......@@ -23,6 +23,9 @@ module TcGenDeriv (
gen_Show_binds,
gen_Data_binds,
gen_Typeable_binds,
gen_Functor_binds, functorLikeTraverse, deepSubtypesContaining,
gen_Foldable_binds,
gen_Traversable_binds,
genAuxBind
) where
......@@ -44,7 +47,12 @@ import TyCon
import TcType
import TysPrim
import TysWiredIn
import Type
import TypeRep
import VarSet
import State
import Util
import MonadUtils
import Outputable
import FastString
import OccName
......@@ -1203,6 +1211,302 @@ prefix_RDR = dataQual_RDR gENERICS (fsLit "Prefix")
infix_RDR = dataQual_RDR gENERICS (fsLit "Infix")
\end{code}
%************************************************************************
%* *
Functor instances
%* *
%************************************************************************
For the data type:
data T a = T1 Int a | T2 (T a)
We generate the instance:
instance Functor T where
fmap f (T1 b1 a) = T1 b1 (f a)
fmap f (T2 ta) = T2 (fmap f ta)
Notice that we don't simply apply 'fmap' to the constructor arguments.
Rather
- Do nothing to an argument whose type doesn't mention 'a'
- Apply 'f' to an argument of type 'a'
- Apply 'fmap f' to other arguments
That's why we have to recurse deeply into the constructor argument types,
rather than just one level, as we typically do.
What about types with more than one type parameter? In general, we only
derive Functor for the last position:
data S a b = S1 [b] | S2 a
instance Functor (S a) where
fmap f (S1 bs) = S1 (fmap f bs)
fmap f (S2 a) = S2 a
However, we have special cases for
- tuples
- functions
More formally, we write the derivation of fmap code over type variable
'a for type 'b as ($fmap 'a 'b). In this general notation the derived
instance for T is:
instance Functor T where
fmap f (T1 x1 x2) = T1 ($(fmap 'a 'b1) x1) ($(fmap 'a 'a) x2)
fmap f (T2 x1) = T2 ($(fmap 'a '(T a)) x1)
$(fmap 'a 'b) x = x -- when b does not contain a
$(fmap 'a 'a) x = f x
$(fmap 'a '(b1,b2)) x = case x of (x1,x2) -> ($(fmap 'a 'b1) x1, $(fmap 'a 'b2) x2)
$(fmap 'a '(T b1 b2)) x = fmap $(fmap 'a 'b2) x -- when a only occurs in the last parameter, b2
$(fmap 'a '(b -> c)) x = \b -> $(fmap 'a' 'c) (x ($(cofmap 'a 'b) b))
For functions, the type parameter 'a can occur in a contravariant position,
which means we need to derive a function like:
cofmap :: (a -> b) -> (f b -> f a)
This is pretty much the same as $fmap, only without the $(cofmap 'a 'a) case:
$(cofmap 'a 'b) x = x -- when b does not contain a
$(cofmap 'a 'a) x = error "type variable in contravariant position"
$(cofmap 'a '(b1,b2)) x = case x of (x1,x2) -> ($(cofmap 'a 'b1) x1, $(cofmap 'a 'b2) x2)
$(cofmap 'a '[b]) x = map $(cofmap 'a 'b) x
$(cofmap 'a '(T b1 b2)) x = fmap $(cofmap 'a 'b2) x -- when a only occurs in the last parameter, b2
$(cofmap 'a '(b -> c)) x = \b -> $(cofmap 'a' 'c) (x ($(fmap 'a 'c) b))
\begin{code}
gen_Functor_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
gen_Functor_binds loc tycon
= (listToBag [fmap_bind], [])
where
data_cons = tyConDataCons tycon
arg = last (tyConTyVars tycon) -- argument to derive for, 'a in the above description
fmap_bind = L loc $ mkFunBind (L loc fmap_RDR) (map fmap_eqn data_cons)
fmap_eqn con = evalState (match_for_con [f_Pat] con parts) bs_RDRs
where parts = map derive_fmap_type (dataConOrigArgTys con)
derive_fmap_type :: Type -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName)
derive_fmap_type = functorLikeTraverse
(\ x -> return x) -- fmap f x = x
(\ x -> return (nlHsApp f_Expr x)) -- fmap f x = f x
(panic "contravariant")
(\g h x -> mkSimpleLam (\b -> h =<< (nlHsApp x `fmap` g b))) -- fmap f x = \b -> h (x (g b))
(mkSimpleTupleCase match_for_con) -- fmap f x = case x of (a1,a2,..) -> (g1 a1,g2 a2,..)
(\_ g x -> do gg <- mkSimpleLam g
return $ nlHsApps fmap_RDR [gg,x]) -- fmap f x = fmap g x
(panic "in other argument")
(\_ g x -> g x)
arg
match_for_con = mkSimpleConMatch $
\con_name xsM -> do xs <- sequence xsM
return (nlHsApps con_name xs) -- Con (g1 v1) (g2 v2) ..
\end{code}
Utility functions related to Functor deriving.
Since several things use the same pattern of traversal, this is abstracted into functorLikeTraverse.
This function works like a fold: it makes a value of type 'a' in a bottom up way.
\begin{code}
-- Generic traversal for Functor deriving
functorLikeTraverse :: a -- ^ Case: does not contain variable
-> a -- ^ Case: the variable itself
-> a -- ^ Case: the variable itself, contravariantly
-> (a -> a -> a) -- ^ Case: function type
-> (Boxity -> [a] -> a) -- ^ Case: tuple type
-> (Type -> a -> a) -- ^ Case: other tycon, variable only in last argument
-> a -- ^ Case: other tycon, variable only in last argument
-> (TcTyVar -> a -> a) -- ^ Case: forall type
-> TcTyVar -- ^ Variable to look for
-> Type -- ^ Type to process
-> a
functorLikeTraverse caseTrivial caseVar caseCoVar caseFun caseTuple caseTyApp caseWrongArg caseForAll var ty
= fst (go False ty)
where -- go returns (result of type a, does type contain var)
go co ty | Just ty' <- coreView ty = go co ty'
go co (TyVarTy v) | v == var = (if co then caseCoVar else caseVar,True)
go co (FunTy (PredTy _) b) = go co b
go co (FunTy x y) | xc || yc = (caseFun xr yr,True)
where (xr,xc) = go (not co) x
(yr,yc) = go co y
go co (AppTy x y) | xc = (caseWrongArg,True)
| yc = (caseTyApp x yr,True)
where (_, xc) = go co x
(yr,yc) = go co y
go co ty@(TyConApp con args)
| isTupleTyCon con = (caseTuple (tupleTyConBoxity con) xrs,True)
| null args = (caseTrivial,False)
| or (init xcs) = (caseWrongArg,True)
| (last xcs) = (caseTyApp (fst (splitAppTy ty)) (last xrs),True)
where (xrs,xcs) = unzip (map (go co) args)
go co (ForAllTy v x) | v /= var && xc = (caseForAll v xr,True)
where (xr,xc) = go co x
go _ _ = (caseTrivial,False)
-- return all subtypes of ty that contain var somewhere
-- these are the things that should appear in instance constraints
deepSubtypesContaining :: TcTyVar -> TcType -> [TcType]
deepSubtypesContaining = functorLikeTraverse
[]
[]
(panic "contravariant")
(\x y -> x ++ y) -- function
(\_ xs -> concat xs) -- tuple
(\ty x -> ty : x) -- tyapp
(panic "in other argument")
(\v x -> filter (not . (v `elemVarSet`) . tyVarsOfType) x) -- forall v
-- Make a HsLam using a fresh variable from a State monad
mkSimpleLam :: (LHsExpr id -> State [id] (LHsExpr id)) -> State [id] (LHsExpr id)
mkSimpleLam lam = do
(n:names) <- get
put names
body <- lam (nlHsVar n)
return (mkHsLam [nlVarPat n] body)
mkSimpleLam2 :: (LHsExpr id -> LHsExpr id -> State [id] (LHsExpr id)) -> State [id] (LHsExpr id)
mkSimpleLam2 lam = do
(n1:n2:names) <- get
put names
body <- lam (nlHsVar n1) (nlHsVar n2)
return (mkHsLam [nlVarPat n1,nlVarPat n2] body)
-- "Con a1 a2 a3 -> fold [x1 a1, x2 a2, x3 a3]"
mkSimpleConMatch :: Monad m => (RdrName -> [a] -> m (LHsExpr RdrName)) -> [LPat RdrName] -> DataCon -> [LHsExpr RdrName -> a] -> m (LMatch RdrName)
mkSimpleConMatch fold extra_pats con insides = do
let con_name = getRdrName con
let vars_needed = takeList insides as_RDRs
let pat = nlConVarPat con_name vars_needed
rhs <- fold con_name (zipWith ($) insides (map nlHsVar vars_needed))
return $ mkMatch (extra_pats ++ [pat]) rhs emptyLocalBinds
-- "case x of (a1,a2,a3) -> fold [x1 a1, x2 a2, x3 a3]"
mkSimpleTupleCase :: Monad m => ([LPat RdrName] -> DataCon -> [LHsExpr RdrName -> a] -> m (LMatch RdrName))
-> Boxity -> [LHsExpr RdrName -> a] -> LHsExpr RdrName -> m (LHsExpr RdrName)
mkSimpleTupleCase match_for_con boxity insides x = do
let con = tupleCon boxity (length insides)
match <- match_for_con [] con insides
return $ nlHsCase x [match]
\end{code}
%************************************************************************
%* *
Foldable instances
%* *
%************************************************************************
Deriving Foldable instances works the same way as Functor instances,
only Foldable instances are not possible for function types at all.
Here the derived instance for the type T above is:
instance Foldable T where
foldr f z (T1 x1 x2 x3) = $(foldr 'a 'b1) x1 ( $(foldr 'a 'a) x2 ( $(foldr 'a 'b2) x3 z ) )
The cases are:
$(foldr 'a 'b) x z = z -- when b does not contain a
$(foldr 'a 'a) x z = f x z
$(foldr 'a '(b1,b2)) x z = case x of (x1,x2) -> $(foldr 'a 'b1) x1 ( $(foldr 'a 'b2) x2 z )
$(foldr 'a '(T b1 b2)) x z = foldr $(foldr 'a 'b2) x z -- when a only occurs in the last parameter, b2
Note that the arguments to the real foldr function are the wrong way around,
since (f :: a -> b -> b), while (foldr f :: b -> t a -> b).
\begin{code}
gen_Foldable_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
gen_Foldable_binds loc tycon
= (listToBag [foldr_bind], [])
where
data_cons = tyConDataCons tycon
arg = last (tyConTyVars tycon) -- argument to derive for, 'a in the above description
foldr_bind = L loc $ mkFunBind (L loc foldr_RDR) (map foldr_eqn data_cons)
foldr_eqn con = evalState (match_for_con z_Expr [f_Pat,z_Pat] con parts) bs_RDRs
where parts = map derive_foldr_type (dataConOrigArgTys con)
derive_foldr_type :: Type -> LHsExpr RdrName -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName)
derive_foldr_type = functorLikeTraverse
(\ _ z -> return z) -- foldr f z x = z
(\ x z -> return (nlHsApps f_RDR [x,z])) -- foldr f z x = f x z
(panic "function")
(panic "function")
(\b gs x z -> mkSimpleTupleCase (match_for_con z) b gs x)
(\_ g x z -> do gg <- mkSimpleLam2 g -- foldr f z x = foldr (\xx zz -> g xx zz) z x
return $ nlHsApps foldable_foldr_RDR [gg,z,x])
(panic "in other argument")
(\_ g x z -> g x z)
arg
match_for_con z = mkSimpleConMatch (\_con_name -> foldrM ($) z) -- g1 v1 (g2 v2 (.. z))
\end{code}
%************************************************************************
%* *
Traversable instances