Main.hs 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
{-# OPTIONS -cpp #-}
{-# LANGUAGE BangPatterns #-}

{- Matrix multiplication using a torus (gentleman algorithm) -- FR10 --    -}

{-
RL/JB ParCo2005: eliminate result communication (Maybe-Type)

JB PhD2008: adapt for simple PhD skeleton tests

JB MSR07/2008: modified to use all available toroid skeletons.

JB MSR07/2008: derived a straight-forward GpH program using identical
               helpers and strategies

JB optimised prodEscalar

JB for ghc-6.9: replaced Control.Parallel.Strategies by a workaround
   (reexporting what should work)

-}
22

23 24
module Main(main) where

25
import System.Environment
Simon Marlow's avatar
Simon Marlow committed
26
import Data.List hiding (foldl', foldl1')
27 28

import ListAux
29
import Control.DeepSeq
30 31 32 33 34 35 36 37 38 39 40 41 42 43

-- replaced by StratWorkaround, excluding what does not work with
-- ghc-6.9
#ifdef WORKAROUND
import StratWorkaround
-- workaround hacks needed for ghc-6.9:
parList :: Strategy a -> Strategy [a]
parList = parListStrict -- as name suggests: it is strict!
parListChunk :: Int -> Strategy a -> Strategy [a]
parListChunk c strat l = let subLs = splitAtN c l
                             evaluateMe = (map (seqList strat) subLs)::[()]
                         in parListStrict rnf evaluateMe

#else
44
import Control.Parallel.Strategies
45 46 47 48 49 50 51
#endif

-----------  matrix strategies here:
strats :: [ Int -> Strategy Matrix {- == Int -> [[Int]] -> Done -} ]
strats = [ undefined, -- do not use it!
            lineStrat, blockStrat, columnStrat]
names    = ["sequential",
52
            "linewise", "blockwise", "columnwise"]
53

54 55 56
lineStrat c          = parListChunk c rdeepseq -- OK?
columnStrat c matrix = parListChunk c rdeepseq (transpose matrix) -- bad ?
blockStrat c  matrix -- best?
57 58 59 60
    = let blocks = concat (splitIntoClusters numB matrix) -- result splitted
                                                    -- in numB * numB blocks
          numB  = round (sqrt (fromIntegral (length matrix) / fromIntegral c))
                  -- approx. same num/granularity of sparks as in others...
61
      in fmap concat $ parList rdeepseq blocks
62 63 64 65 66 67 68 69 70 71

undef _ _ = error "undefined strategy"

-------------------------------------

type Vector = [Int]
type Matrix = [Vector]

-- main computation, different versions:
mult :: Int -> Matrix -> Matrix -> Int -> [[Maybe Matrix]]
72
mult 0 m1 m2 _ =
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
#ifdef OUTPUT
        [[Just $ multMatricesTr m1 (transpose m2)]]
#else
        rnf (multMatricesTr m1 (transpose m2)) `seq` [[Nothing]]
#endif
mult v m1 m2 c = results
 where results ::  [[Maybe Matrix]]
#ifdef OUTPUT
       results = [[Just computed]]
#else
       results = (rnf computed `seq` [[Nothing]])
#endif
       computed = multMatricesTr m1 m2Tr `using` (strats'!!v) c
       strats' = strats ++ repeat undef
       m2Tr = transpose m2

prMM' :: (Matrix,Matrix) -> Matrix
prMM' (c,mt) = [[prVV f c | c <- mt]|f <-c]
prVV :: Vector -> Vector -> Int
prVV f c = sum (zipWith (*) f c)
93

94 95 96 97 98 99 100 101 102 103
shiftRight c [] = []
shiftRight c (xs:xss) = (xs2++xs1):shiftRight (c-1) xss
 where (xs1,xs2) = splitAt c xs

shiftDown c xss = transpose (shiftRight c (transpose xss))

join2 :: Matrix -> Matrix -> Matrix
join2 xs ys = zipWith (++) xs ys
join :: [Matrix] -> Matrix
join xss = foldr join2 (repeat []) xss
104

105 106 107 108 109 110 111
splitIntoClusters :: Int -> Matrix -> [[Matrix]]
splitIntoClusters c m | c < 1 = splitIntoClusters 1 m
splitIntoClusters c m1 = mss
  where bh  = kPartition (length m1) c
        bhsplit [] [] = []
        bhsplit [] _  = error "some elements left over"
	bhsplit (t:ts) xs = hs : (bhsplit ts rest)
112 113
	  	  where (hs,rest) = splitAt t xs
        ms  = bhsplit bh m1 -- blocks of rows
114 115 116 117 118 119
        mss = map (colsplit bh) ms
        colsplit [] _  = []
        colsplit (t:ts) rs
         | head rs == [] = []
         | otherwise = (cab:colsplit ts resto)
          where  (cab,resto) = unzip (map (splitAt t) rs)
120

121 122 123 124 125 126 127 128 129 130 131
--        mss = map (repartir (length m1 `div` c)) ms
--        repartir c xs
--         | head xs == [] = []
--         | otherwise = (cab:repartir c resto)
--          where  (cab,resto) = unzip (map (splitAt c) xs)

-- helper for splitIntoClusters (formerly bresenham)
kPartition :: Int -> Int -> [Int]
kPartition n k = zipWith (+) ((replicate (n `mod` k) 1) ++ repeat 0)
                             (replicate k (n `div` k))

132

133
mult' :: Int -> Int -> ((Matrix,Matrix),[Matrix],[Matrix]) -> (Maybe Matrix,[Matrix],[Matrix])
134
mult' nc nr ((sm1,sm2),sm1s,sm2s)
135 136 137 138 139 140 141 142 143 144 145
#ifdef OUTPUT
    =  (Just result,toRight,toDown)
#else
    =  (rnf result `seq` Nothing ,toRight,toDown)
#endif
  where toRight  = take (nc-1) (sm1:sm1s)
        toDown   = take (nr-1) (sm2':sm2s)
        sm2'     = transpose sm2
        sms      = zipWith multMatricesTr (sm1:sm1s) (sm2':sm2s)
        result = foldl1' addMatrices sms  -- foldr1: not enough demand??

146

147 148 149 150 151
addMatrices :: Matrix -> Matrix -> Matrix
addMatrices m1 m2 = zipWith addVectors m1 m2
  where addVectors :: Vector -> Vector -> Vector
        addVectors v1 v2 = zipWith (+) v1 v2

152
-- Assumes the second matrix has already been transposed
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
multMatricesTr :: Matrix -> Matrix -> Matrix
multMatricesTr m1 m2 = [[prodEscalar2 row col | col <- m2] | row <- m1]

-- JB 2008: a lot faster, directly consuming lists, and tail-recursive (optimised with -O2)
prodEscalar2JB :: Vector -> Vector -> Int
prodEscalar2JB v1 v2 = addProd v1 v2 0
    where addProd :: Vector -> Vector -> Int -> Int
          addProd (v:vs) (w:ws) acc = addProd vs ws (acc + v*w)
          addProd [] [] n = n
          addProd _  _  _ = error "addProd: length does not match"

-- JB 2008: identical when using ghc-6.8.3, avoids bug in ghc-HEAD. Version suggested by SM
prodEscalar2 :: Vector -> Vector -> Int
prodEscalar2 v1 v2 = addProd v1 v2 0
addProd :: Vector -> Vector -> Int -> Int
addProd (v:vs) (w:ws) !acc = addProd vs ws (acc + v*w)
addProd _ _ !n = n


prodEscalar :: Vector -> Vector -> Int
prodEscalar v1 v2 = sum (zipWith (*) v1 v2)


------- foldl, strict in head element
foldl1' :: NFData a => (a->a->a) -> [a] -> a
foldl1' f (x:xs) = foldl' f x xs

foldl'           ::  NFData a => (a -> b -> a) -> a -> [b] -> a
foldl' f a []     = a
foldl' f a (x:xs) = -- whnf, not enough( (foldl' f) $! (f a x)) xs
183
		    let first = f a x
184 185 186 187 188 189 190 191
		    in rnf first `seq` foldl' f first xs


usage :: String -> String
usage name = "Cannon's algorithm: Usage:\n\t "++
	     name ++ " <matrix size> <version> <blocksPerRow> \n" ++
             "Version selects from " ++ show (zip [0..] names)

192
main = do
193 194 195 196 197 198 199 200 201
       args <- getArgs
       let l = length args
       if l == 0 then do n <- getProgName
                         putStrLn (usage n)
                         putStrLn "\n *** defaults: size 100, seq. computation ***"
                 else return () --putStrLn "Cannon's algorithm"

       let    size  = if null args then 100 else read (head args)
	      opt   = if length args < 2 then 0 else read (args!!1)
202
	      chunk = if length args < 3 then 1
203
                                         else read (args!!2)
204
	      a = "Matrices of size " ++ show size ++
205 206 207 208 209 210 211 212 213 214
                  " with skeleton " ++ ((names++repeat "UNDEF")!!opt) ++
                  " using chunk parameter " ++ show chunk ++ "\n"
	      res = mult opt (mA size) (mB size) chunk
	      b = multMatricesTr (mA size) (transpose (mB size))
       -- putStrLn a
#ifdef OUTPUT
       putStrLn "Output wanted, checking result for correctness..."
       let computed = map (map fromJust) res
           computed' = concat (map join computed)
       printMat computed'
215
       if (b == computed')
216 217 218 219 220 221
                 then putStrLn "Correct!"
                 else do putStrLn "WRONG RESULT! Should be"
                         printMat b
#else
       -- putStrLn "No Output, matrix stays distributed."
       putStrLn (show res)
222 223
#endif

224 225 226 227 228 229


m1 size = replicate size [1..size]
m2 size = listToListList size [1..size*size]
mA size = if size <= 4000 then m1 size else listToListList size (concat (take 20 (repeat [1..(size*size `div` 20)])))
mB size = if size <= 4000 then m1 size else listToListList size (concat (take 20 (repeat [0,2.. ((size*size) `div` 20)-2])))
230
listToListList c m
231 232 233 234 235 236
 | length m <= c = [m]
 | otherwise = c1 : listToListList c resto
  where (c1,resto) = splitAt c m


printMat :: Matrix -> IO ()
237 238
printMat m = putStrLn ("Matrix: " ++ (show (length (head m)))
                       ++ " x " ++ (show $ length m) ++ "\n"
239 240 241 242 243 244 245 246
                       ++ (showMat m))

-- instance Show a => Show (Matrix a) where
showMat m_ = "<<" ++ unlines (map (concatMap (\ x -> show x ++ " ")) m_) ++ ">>"

fromJust :: Maybe a -> a
fromJust (Just x) = x
fromJust Nothing  = error "fromJust"