module Main where import Data.Array import Test.HUnit import Control.Monad (replicateM) import Debug.Trace import Data.List (genericLength) import System.Random (RandomGen, getStdGen, randomR) data HatColor = Blue | Red deriving (Enum, Eq) instance Show HatColor where show Blue = "Blue" show Red = "Red" data HatStrategyBit = Guess HatColor | NoGuess instance Eq HatStrategyBit where Guess hc == Guess hc1 = (hc == hc1) NoGuess == NoGuess = True _ == _ = False instance Show HatStrategyBit where show (Guess Blue) = "B" show (Guess Red) = "R" show NoGuess = "N" type HatStrategy = [HatStrategyBit] type HatConfiguration = [HatColor] type HatStrategyScore = (HatStrategy, Double) --type HatGeneration = Array Int HatStrategyScore type HatGeneration = [HatStrategyScore] showHS :: HatStrategy -> String showHS = foldl (\prev cur -> prev ++ (show cur)) "" makeGuesses :: [HatColor] -> [HatStrategyBit] makeGuesses = map Guess removeElem :: Num b => b -> [a] -> [a] removeElem n l = removeElem' n 0 l removeElem' :: Num b => b -> b -> [a] -> [a] removeElem' n cur [] = [] removeElem' n cur (l1:lr) | n == cur = removeElem' n (cur + 1) lr | otherwise = l1 : (removeElem' n (cur + 1) lr) --breedGenes :: (HatStrategy, HatStrategy) -> Double -> HatStrategy --breedGenes (hs1, hs2) prob = mutate prob -- Prepares the generation for a call to selectWeighted. The first argument -- is a null element of the generation type prepareArray :: a -> [(a, Double)] -> (Array Int (a, Double), Double) prepareArray zero gen = (listArray (1, length modStratList) (zip modStratList modScoreList), foldl (\old listVal -> (old + snd listVal)) 0.0 gen) where modStratList = (map fst gen) ++ [zero] modScoreList = (scanl (+) 0.0 (map snd gen)) -- Selects a gene to breed. selectWeighted :: RandomGen g => Array Int (a, Double) -> Double -> g -> (a, g) selectWeighted hats total gen = (binarySelect hats total (fst rand) (bounds hats), snd rand) where rand = randomR (0.0, total) gen binarySelect :: Array Int (a, Double) -> Double -> Double -> (Int, Int) -> a binarySelect hats total rand (start, end) | (rand < highVal && lowVal <= rand) = fst (hats ! half) | rand < highVal = binarySelect hats total rand (start, half) | otherwise = binarySelect hats total rand (half, end) where half = start + ((end - start) `div` 2) lowVal = snd $ hats ! half highVal = snd $ hats ! (half + 1) strategyIndexFromHatConfiguration :: (Integral a, Integral b) => a -> HatConfiguration -> b strategyIndexFromHatConfiguration personIdx hc = strategyIndexFromHatConfiguration' (removeElem personIdx hc) ((length hc) - 2) 0 strategyIndexFromHatConfiguration' :: (Integral a, Integral b) => HatConfiguration -> a -> b -> b strategyIndexFromHatConfiguration' [] curIdx accum = accum strategyIndexFromHatConfiguration' (Blue:hcr) curIdx accum = strategyIndexFromHatConfiguration' hcr (curIdx - 1) accum strategyIndexFromHatConfiguration' (Red:hcr) curIdx accum = strategyIndexFromHatConfiguration' hcr (curIdx - 1) (accum + 2 ^ curIdx) strategyPass :: Int -> HatStrategy -> HatConfiguration -> Bool strategyPass numPpl hs hc = strategyPass' numPpl hs hc False 0 strategyPass' :: Int -> HatStrategy -> HatConfiguration -> Bool -> Int -> Bool strategyPass' numPpl hs hc accum idx | idx == (length hc) = accum | otherwise = strategyPass'' numPpl hs hc accum idx (hs !! ((2 ^ (numPpl - 1) * idx) + strategyIndexFromHatConfiguration idx hc)) strategyPass'' :: Int -> HatStrategy -> HatConfiguration -> Bool -> Int -> HatStrategyBit -> Bool strategyPass'' numPpl hs hc accum idx NoGuess = strategyPass' numPpl hs hc accum (idx + 1) strategyPass'' numPpl hs hc accum idx (Guess g) = strategyPassHelper numPpl hs hc (hc !! idx == g) (idx + 1) strategyPassHelper :: Int -> HatStrategy -> HatConfiguration -> Bool -> Int -> Bool strategyPassHelper _ _ _ False _ = False strategyPassHelper numPpl hs hc True idx = strategyPass' numPpl hs hc True idx scoreStrategy :: Int -> HatStrategy -> Double scoreStrategy numPpl hs = let cfgs = allConfigs numPpl in (genericLength $ filter (strategyPass numPpl hs) cfgs)/ (genericLength cfgs) allConfigs :: Int -> [HatConfiguration] allConfigs numPpl = replicateM numPpl [Blue, Red] allStrategies :: Int -> [HatStrategy] allStrategies numPpl = replicateM (2 ^ (numPpl - 1) * numPpl) [NoGuess, Guess Blue, Guess Red] findBestStrategyExhaustive :: Int -> HatStrategy findBestStrategyExhaustive numPpl = fst $ foldl (accumulateMax numPpl) ([], -1.0) (allStrategies numPpl) accumulateMaxR :: Int -> HatStrategy -> (HatStrategy, Double) -> (HatStrategy, Double) accumulateMaxR numPpl cur old@(oldBest, oldScore) | curScore > oldScore = (cur, curScore) | otherwise = old where curScore = scoreStrategy numPpl cur accumulateMax :: Int -> (HatStrategy, Double) -> HatStrategy -> (HatStrategy, Double) accumulateMax numPpl old@(oldBest, oldScore) cur | curScore > oldScore = (cur, curScore) | otherwise = old where curScore = scoreStrategy numPpl cur main = do runTestTT tests gen <- getStdGen putStrLn $ "Best is " ++ showHS best ++ " with score " ++ show (scoreStrategy numPpl best) where numPpl = 3 best = findBestStrategyExhaustive numPpl -- tests tests = TestList [TestLabel "stratIndexFromHatConfig" stratIndexFromHatConfigTests, TestLabel "strategyPass" strategyPassTests, TestLabel "scoreStrategy" scoreStrategyTests, TestLabel "selectWeighted" selectWeightedTests] stratIndexFromHatConfigTests = TestList [TestLabel "test1" test1, TestLabel "test1a" test1a, TestLabel "test2" test2, TestLabel "test3" test3] test1 = TestCase (assertEqual "" 0 (strategyIndexFromHatConfiguration 0 [Blue, Blue, Blue])) test1a = TestCase (assertEqual "" 2 (strategyIndexFromHatConfiguration 0 [Red, Red, Blue])) test2 = TestCase (2 @=? (strategyIndexFromHatConfiguration 1 [Red, Blue, Blue])) test3 = TestCase (assertEqual "" 1 (strategyIndexFromHatConfiguration 2 [Blue, Red, Blue])) strategyPassTests = TestList [TestLabel "test4" test4, TestLabel "test5" test5, TestLabel "test6" test6, TestLabel "test7" test7, TestLabel "test8" test8, TestLabel "test9" test9, TestLabel "test10" test10, TestLabel "test11" test11, TestLabel "test12" test12, TestLabel "test13" test13, TestLabel "test14" test14, TestLabel "test15" test15, TestLabel "test16" test16, TestLabel "test17" test17, TestLabel "test18" test18, TestLabel "test19" test19, TestLabel "test20" test20, TestLabel "test21" test21] test4 = TestCase (False @=? (strategyPass 3 ((replicateM 12 [NoGuess]) !! 0) [Blue, Blue, Red])) test5 = TestCase (True @=? (strategyPass 3 ((replicateM 12 [Guess Blue]) !! 0) [Blue, Blue, Blue])) test6 = TestCase (False @=? (strategyPass 3 ((replicateM 12 [Guess Blue]) !! 0) [Blue, Red, Blue])) test7 = TestCase (False @=? (strategyPass 3 ((replicateM 12 [Guess Blue]) !! 0) [Blue, Blue, Red])) test8 = TestCase (False @=? (strategyPass 3 ((replicateM 12 [Guess Blue]) !! 0) [Red, Blue, Blue])) strat1 = [Guess Blue, Guess Red, NoGuess, Guess Blue, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess] test9 = TestCase (False @=? (strategyPass 3 strat1 [Red, Blue, Blue])) test10 = TestCase (True @=? (strategyPass 3 strat1 [Blue, Blue, Blue])) test11 = TestCase (False @=? (strategyPass 3 strat1 [Red, Red, Blue])) test12 = TestCase (False @=? (strategyPass 3 strat1 [Blue, Red, Blue])) test13 = TestCase (True @=? (strategyPass 3 strat1 [Blue, Red, Red])) strat2 = [NoGuess, NoGuess, NoGuess, NoGuess, Guess Blue, Guess Red, NoGuess, Guess Blue, NoGuess, NoGuess, NoGuess, NoGuess] test14 = TestCase (False @=? (strategyPass 3 strat2 [Red, Red, Blue])) test15 = TestCase (False @=? (strategyPass 3 strat2 [Red, Blue, Blue])) test16 = TestCase (True @=? (strategyPass 3 strat2 [Blue, Blue, Blue])) test17 = TestCase (False @=? (strategyPass 3 strat2 [Red, Red, Red])) strat3 = [NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, Guess Blue, Guess Red, NoGuess, Guess Blue] test18 = TestCase (False @=? (strategyPass 3 strat3 [Red, Blue, Blue])) test19 = TestCase (False @=? (strategyPass 3 strat3 [Red, Blue, Red])) test20 = TestCase (True @=? (strategyPass 3 strat3 [Red, Red, Blue])) test21 = TestCase (False @=? (strategyPass 3 strat3 [Blue, Red, Blue])) scoreStrategyTests = TestList[TestLabel "test22" test22, TestLabel "test23" test23, TestLabel "test24" test24, TestLabel "test25" test25, TestLabel "test26" test26, TestLabel "test27" test27] test22 = TestCase (0.5 @=? (scoreStrategy 3 [Guess Blue, Guess Blue, Guess Blue, Guess Blue, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess])) test23 = TestCase (0.25 @=? (scoreStrategy 3 [Guess Blue, Guess Blue, Guess Blue, Guess Blue, NoGuess, NoGuess, NoGuess, NoGuess, Guess Red, Guess Red, Guess Red, Guess Red])) test24 = TestCase (0.25 @=? (scoreStrategy 3 [Guess Blue, NoGuess, NoGuess, Guess Red, Guess Blue, NoGuess, NoGuess, Guess Red, Guess Blue, NoGuess, NoGuess, Guess Red])) test25 = TestCase (0.75 @=? (scoreStrategy 3 [Guess Red, NoGuess, NoGuess, Guess Blue, Guess Red, NoGuess, NoGuess, Guess Blue, Guess Red, NoGuess, NoGuess, Guess Blue])) test26 = TestCase (0.75 @=? (scoreStrategy 3 [NoGuess, Guess Blue, Guess Red, NoGuess, Guess Blue, NoGuess, NoGuess, Guess Red, NoGuess, Guess Red, Guess Blue, NoGuess])) test27 = TestCase (0.125 @=? (scoreStrategy 3 [NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, NoGuess, Guess Blue])) selectWeightedTests = TestList[TestLabel "test28" test28, TestLabel "test29" test29, TestLabel "test30" test30, TestLabel "test31" test31, TestLabel "test32" test32, TestLabel "test33" test33, TestLabel "test34" test34] hats1 = array (1,4) [(1, ("a", 0.0)), (2, ("b", 0.5)), (3, ("c", 1.5)), (4, ("", 2.0))] hatsorig = [("a", 0.5), ("b", 1.0), ("c", 0.5)] test28 = TestCase ("a" @=? (binarySelect hats1 2.0 0.0 (1, 4))) test29 = TestCase ("a" @=? (binarySelect hats1 2.0 0.3 (1, 4))) test30 = TestCase ("b" @=? (binarySelect hats1 2.0 0.5 (1, 4))) test31 = TestCase ("b" @=? (binarySelect hats1 2.0 1.0 (1, 4))) test32 = TestCase ("c" @=? (binarySelect hats1 2.0 1.5 (1, 4))) test33 = TestCase ("c" @=? (binarySelect hats1 2.0 1.9 (1, 4))) test34 = TestCase ((hats1, 2.0) @=? (prepareArray ("") hatsorig))