雑記帳
僕用勉強ノート 「圏論」の巻

継続渡しスタイルでの階乗関数の構成を純粋圏論的にやってみた。【小ネタ】

(圏論シリーズロゴ)
継続渡しスタイルでの階乗関数の構成
導入
(殴り書き)
まず、継続渡しスタイルで書かれたコードに対する圏論的な意味付けが行いやすくなるように、コーディングに対して以下の縛りを課す。
  • 自己参照あるいは不動点コンビネータの使用を回避する。
  • 不必要なブラックボックス化は避ける
上の2点を加味したうえで、一旦構文ベースに「継続渡しスタイルでの階乗関数の特徴付け」を考えてみると、1例としては以下が得られると思う。
import Control.Monad.Cont

f1 = \rec n -> callCC $ \exit -> do
  t <- cont $ \k -> k $ n == 0
  if t then
    exit 1
  else
    return ()
  nm1 <- cont $ \k -> k $ n - 1
  f   <- rec(nm1)
  r   <- cont $ \k -> k $ n * f
  return $ r

fact n = runCont ((foldr (.) id (replicate (n+1) f1) return) n) id
上のコードを見れば大体推測できるように、この手続きであれば実際「自然数対象と全ての有限余積を持つカルテシアン閉圏」の一般論の中で意味付けすることができる。
先に結果を与えてしまうと、上の構文の意味に当たる射の構成は次のようになる。
f2 =
  trans(trans(pair(pair(prj1-:prj2,const'(nat(0)))-:trans(tw-:(id***eq')-:
  ev),trans(pair(pair(prj2,pair(pair(prj1-:prj2,const'(nat(1)))-:ev,const'(point-:return_Cont)))-:
  if_N,trans(pair(pair(prj1-:prj1-:prj1-:prj2,const'(nat(1)))-:trans(tw-:(id***sub_N)-:
  ev),trans(pair(pair(prj1-:prj1-:prj1-:prj1-:prj1,prj2)-:ev,trans(pair(pair(prj1-:prj1-:prj1-:
  prj1-:prj1-:prj2,prj2)-:trans(tw-:(id***mul_N)-:ev),trans(prj2-:return_Cont))-:bind_Cont))-:
  bind_Cont))-:bind_Cont))-:bind_Cont))-:bind_Cont)-:callCC')

fact' = pair(pair(succ'-:rec_N(arrToEl(return_Cont), f2),id)-:ev,const'(arrToEl(id)))-:ev
一方で、これだけ見ても 「bind_Cont って何?」とか「callCC' の純粋圏論的な構成は?」などなど色々と疑問は残ると思うので、そういった点についてを次節で説明する。
継続モナドに関係するパーツを構成する
モナドであるために必要な構造
  • return
  • bind
// 数学的な定義を書く
(..)
Haskell で書くと
return_Cont :: t -> r ^ (r ^ t)
return_Cont = trans(tw-:ev)

bind_Cont :: ((r ^ (r ^ t1)) *** ((r ^ (r ^ t2)) ^ t1)) -> r ^ (r ^ t2)
bind_Cont = trans(pair(prj1-:prj1,pair(prj1-:prj2,prj2))-:(id***trans(pair(pair(prj1-:prj1,prj2)-:ev,prj1-:prj2)-:ev))-:ev)
余談
純粋圏論的な構成を考えると、特に「関数への値の適用」の部分が非常に煩雑にはなるが、それまで区別できなかったことをより細分化して捉えることができるようになる利点もある。
継続モナドに固有の構造
  • callCC
// 数学的な定義を書く
(..)
Haskell で書くと
callCC' :: ((r ^ (r ^ t1)) ^ ((r ^ (r ^ t2)) ^ t1)) -> r ^ (r ^ t1)
callCC' = trans(pair(pair(prj1,trans(trans(pair(prj1-:prj1-:prj2,prj1-:prj2)-:ev)))-:ev,prj2)-:ev)
ソースコード
{-# LANGUAGE TypeOperators #-}

import Data.Void
import Control.Monad.Fix
import Control.Monad.Cont

main :: IO ()
main = do
  print   $ fact(5)
  printEl $ nat(5) -: fact'


f1 = \rec n -> callCC $ \exit -> do
  t <- cont $ \k -> k $ n == 0
  if t then
    exit 1
  else
    return ()
  nm1 <- cont $ \k -> k $ n - 1
  f   <- rec(nm1)
  r   <- cont $ \k -> k $ n * f
  return $ r

fact n = runCont ((foldr (.) id (replicate (n+1) f1) return) n) id


return_Cont :: t -> r ^ (r ^ t)
return_Cont =
  trans(tw-:ev)

bind_Cont :: ((r ^ (r ^ t1)) *** ((r ^ (r ^ t2)) ^ t1)) -> r ^ (r ^ t2)
bind_Cont =
  trans(pair(prj1-:prj1,pair(prj1-:prj2,prj2))-:(id***trans(pair(pair(prj1-:prj1,prj2)-:ev,prj1-:prj2)-:ev))-:ev)

callCC' :: ((r ^ (r ^ t1)) ^ ((r ^ (r ^ t2)) ^ t1)) -> r ^ (r ^ t1)
callCC' =
  trans(pair(pair(prj1,trans(trans(pair(prj1-:prj1-:prj2,prj1-:prj2)-:ev)))-:ev,prj2)-:ev)


f2 =
  trans(trans(pair(pair(prj1-:prj2,const'(nat(0)))-:trans(tw-:(id***eq')-:
  ev),trans(pair(pair(prj2,pair(pair(prj1-:prj2,const'(nat(1)))-:ev,const'(point-:return_Cont)))-:
  if_N,trans(pair(pair(prj1-:prj1-:prj1-:prj2,const'(nat(1)))-:trans(tw-:(id***sub_N)-:
  ev),trans(pair(pair(prj1-:prj1-:prj1-:prj1-:prj1,prj2)-:ev,trans(pair(pair(prj1-:prj1-:prj1-:
  prj1-:prj1-:prj2,prj2)-:trans(tw-:(id***mul_N)-:ev),trans(prj2-:return_Cont))-:bind_Cont))-:
  bind_Cont))-:bind_Cont))-:bind_Cont))-:bind_Cont)-:callCC')

fact' = pair(pair(succ'-:rec_N(arrToEl(return_Cont), f2),id)-:ev,const'(arrToEl(id)))-:ev


class (MyShow a) where
  myShow :: a -> String

instance MyShow () where
  myShow = const "*"

instance (MyShow a, MyShow b) => MyShow (Either a b) where
  myShow = either
    (\z -> if (myShow z == "*") then "inj1" else (myShow z) ++ ";inj1")
      (\z -> if (myShow z == "*") then "inj2" else (myShow z) ++ ";inj2")

instance (MyShow a, MyShow b) => MyShow (a,b) where
  myShow (x,y) = "(" ++ myShow x ++ "," ++ myShow y ++ ")"

instance MyShow Int where
  myShow = show

instance MyShow Nat where
  myShow (Nat i) = myShow (length i)
  --myShow (Nat i) = "zero" ++ (foldr ((++).(const ";succ")) [] i)

instance MyShow (a -> b) where
  myShow = const "(AN ARROW)"


-- X の要素を圏論に倣って終対象から X への射(Global element)として扱うための関数
el :: a -> (Pt -> a)
el = (const::a -> (Pt -> a))


-- Global elements 用 ユーティリティ
(===) :: Eq a =>  (Pt -> a) -> (Pt -> a) -> Bool
(===) x y = (x() == y())

showEl :: MyShow a => (Pt -> a) -> String
showEl x = (myShow $ x())

printEl :: MyShow a => (Pt -> a) -> IO ()
printEl = putStrLn . showEl


-- Diagrammatic-order な射の合成演算
(-:) = flip (.)

-- # 始対象と終対象
type Empty = Void
type Pt = ()

initArr :: Empty -> a
initArr = absurd

termArr :: a -> Pt
termArr = const ()

point :: Pt -> Pt
point = id

const' x = termArr -: x

-- # 余積対象と積対象
type (+++)  a b = Either a b
type (***) a b = (a,b)

-- 入射
inj1 :: a -> a +++ b
inj1 = Left

inj2 :: b -> a +++ b
inj2 = Right

-- 射影
prj1 :: a *** b -> a
prj1 = fst

prj2 :: a *** b -> b
prj2 = snd

-- 余積対象の仲介射
coPair :: (a -> c, b -> c) -> (a +++ b -> c)
coPair = uncurry either

-- 積対象の仲介射
pair   :: (c -> a, c -> b) -> (c -> a *** b)
pair = uncurry $ (<*>) . fmap (,)

-- 畳み込み
fol = coPair(id, id)

-- 対角射
dup = pair(id, id)

-- 射同士の余積
(+++) :: (a1 -> b1) -> (a2 -> b2) -> (a1 +++ a2 -> b1 +++ b2)
(+++) f g = coPair(f -: inj1 , g -: inj2)

-- 射同士の積
(***) :: (a1 -> b1) -> (a2 -> b2) -> (a1 *** a2 -> b1 *** b2)
(***) f g =   pair(prj1 -: f, prj2 -: g)

-- Twist の形式的双対
coTw :: a +++ b -> b +++ a
coTw = coPair(inj2, inj1)

-- Twist
tw :: a *** b -> b *** a
tw = pair(prj2, prj1)

-- # Exponential 対象
type (^) b a = a -> b

-- 評価射
ev :: (b ^ a) *** a -> b
ev = uncurry id

-- 射の転置 (transpose) の構成
trans :: (c *** a -> b) -> (c -> b ^ a)
trans = curry

-- 射 h:a->b の Exponential 対象 (Exp b a) の要素への変換
arrToEl :: (a -> b) -> (Pt -> b ^ a)
arrToEl h = trans(prj2-:h)

elToArr x = pair(termArr, id) -: (x *** id) -: ev

-- # 自然数対象 (NNO)
data Nat = Nat{imp::[()]} deriving Eq

_Nat :: Nat -> Nat
_Nat = id

zero :: Pt -> Nat
zero = el (Nat [])

succ' :: Nat -> Nat
succ' (Nat i) = Nat (():i)

-- 整数リテラルを使って NNO の Global elements としての自然数を得るための小細工
nat :: Int -> (Pt -> Nat)
-- nat i = zero -: (foldr (.) id (replicate i succ'))
nat i = el (Nat (replicate i ()))

-- recursion data x_0:1->X と f:X->X から rec_N(x_0, f):Nat->X を構成する関数
rec_N :: (Pt -> a, a -> a) -> (Nat -> a)
rec_N = ((flip ($) ())***id)-:((curry((id***(length.imp))-:uncurry(!!))).(uncurry.flip $ iterate))

add_N = (rec_N(arrToEl(_Nat), trans(ev-:succ'))***_Nat)-:ev
mul_N = (rec_N(arrToEl(termArr-:nat(0)), trans(pair(ev,prj2)-:add_N))***_Nat)-:ev
mul_N_fast (Nat x, Nat y) = Nat (replicate (length x * length y) ()) -- デバッグ用
sq_N  = dup-:mul_N

pred' = rec_N(pair(nat(0),nat(0)), pair(prj2, prj2 -:succ')) -: prj1
sub_N = tw-:(rec_N(arrToEl(_Nat), trans(ev-:pred'))***_Nat)-:ev

true :: Pt -> Nat
true = nat(1)

false :: Pt -> Nat
false = nat(0)

le' :: Nat *** Nat -> Nat
le' = sub_N -: not'

eq' :: Nat *** Nat -> Nat
eq' = pair(le',tw-:le')-:and'

not' :: Nat -> Nat
not' = rec_N(pair(nat(1),nat(0)), pair(prj2, prj2)) -: prj1

and' :: Nat *** Nat -> Nat
and' = mul_N

or' :: Nat *** Nat -> Nat
or' = add_N

if_N :: (Nat *** (a *** a)) -> a
if_N = ((rec_N(inj2,fol-:inj1) -: coPair(arrToEl(prj1),arrToEl(prj2)))***id)-:ev