module Sudoku(solution) where
type Row a = [a]
type Matrix a = [Row a]
type Digit = Char
type Grid = Matrix Digit
digits :: [Char]
digits = ['1'..'9']
blank :: Digit -> Bool
blank = (== '0')
solution :: Grid -> Grid
solution = head . solve
solve :: Grid -> [Grid]
solve = search . choices
choices :: Grid -> Matrix [Digit]
choices = map (map choice)
where
choice d = if blank d then digits else [d]
cp :: [[a]] -> [[a]]
cp [] = [[]]
cp (xs:xss) = [x:ys | x <- xs, ys <- yss]
where
yss = cp xss
expand :: Matrix [Digit] -> [Grid]
expand = cp . map cp
nodups :: Eq a => [a] -> Bool
nodups [] = True
nodups (x:xs) = all (/=x) xs && nodups xs
rows :: Matrix a -> Matrix a
rows = id
cols :: Matrix a -> Matrix a
cols [xs] = [[x] | x <- xs]
cols (xs:xss) = zipWith (:) xs (cols xss)
group :: [a] -> [[a]]
group [] = []
group xs = take 3 xs : group (drop 3 xs)
ungroup :: [[a]] -> [a]
ungroup = concat
boxs :: Matrix a -> Matrix a
boxs = map ungroup . ungroup . map cols . group . map group
valid :: Grid -> Bool
valid g = all nodups (rows g) &&
all nodups (cols g) &&
all nodups (boxs g)
pruneRow :: Row [Digit] -> Row [Digit]
pruneRow row = map (remove fixed) row
where
fixed = [d | [d] <- row]
remove ds [x] = [x]
remove ds xs = filter (`notElem` ds) xs
notElem x xs = all (/=x) xs
prune :: Matrix [Digit] -> Matrix [Digit]
prune = pruneBy boxs . pruneBy cols . pruneBy rows
where
pruneBy f = f . map pruneRow . f
many :: Eq a => (a -> a) -> a -> a
many f x = if x == y then x else many f y
where
y = f x
search :: Matrix [Digit] -> [Grid]
search cm
| not (safe pm) = []
| complete pm = [extract pm]
| otherwise = concat (map search (expand1 pm))
where
pm = prune cm
complete :: Matrix [Digit] -> Bool
complete = all (all single)
single :: [a] -> Bool
single [_] = True
single _ = False
extract :: Matrix [Digit] -> Grid
extract = map (map head)
safe :: Matrix [Digit] -> Bool
safe cm = all ok (rows cm) &&
all ok (cols cm) &&
all ok (boxs cm)
where
ok row = nodups [x | [x] <- row]
expand1 :: Matrix [Digit] -> [Matrix [Digit]]
expand1 rows = [rows1 ++ [row1 ++ [c]:row2] ++ rows2 | c <- cs]
where
(rows1, row:rows2) = break (any smallest) rows
(row1, cs:row2) = break smallest row
smallest cs = length cs == n
n = minimum (counts rows)
counts = filter (/= 1) . map length . concat