雑記帳
僕用勉強ノート 「レイトレーシング」の巻

Haskell でレイトレーシングのチュートリアルを追いかける その10 - アンチエイリアス (修正版)

小ネタとして4次方程式の解を求める関数を Haskell で書いてみた中で思ったことの1つとして「計算コストは確かに高いけど、実用として使える範囲内かも」というのがあった。
というわけで、「4次方程式の解の一般式を直接用いてトーラスの方程式を解く」というアプローチに切り替えたバージョンのHaskell製レイトレプログラムを、レイトレシリーズの続きとして書いてみることにした。
とはいえ、直近の「section 10-4 Schlick Approximation」からの続きとしてリニューアルしたコードを作るのはなんか違う気がするので、一旦よりシンプルな「section 7」まで戻った上でコードを書き換えた。
主な変更点は、
  • (先に述べた通り) 数値解析ではなく直接的に方程式の解を算出するようにした。
  • 座標軸の方向を、工学の人間にとって馴染み深そうなものに切り替えた。(xy-平面が地平面、z-軸が地平面の法線方向)
  • コードの意味が汲み取りにくいと思われる箇所を、わかりやすくなるように書き換えた。
  • 余積対象に相当する型を自前で定義するのではなく、ビルトインの Either a b 型を使用するようにした。
コードの実行結果
実行結果
(地面の色やトーラスの色が修正前のものと異なるのは、座標軸の取り方の変更によるもの。)
ソースコード
{-# LANGUAGE TypeOperators #-}

module Main where

import Data.Complex
import Control.Lens
import System.Random
import Linear.Vector
import Linear.Metric
import Linear.V3

-- https://raytracing.github.io/books/RayTracingInOneWeekend.html section 7 --

main :: IO ()
main = do
  let
    -- Image
    aspect_ratio = 16.0 / 9
    image_width = 256
    image_height = round $ fromIntegral image_width / aspect_ratio
    samples_per_pixel = 100
    -- World
    world = []
      `add` RT_Torus{
        centerOfTorus = V3 0 (-1) 0,
        majorRadius = 0.35,
        minorRadius = 0.15,
        orientationOfTorus = normalize $ V3 (-0.2) 1.9 1.2
      }
      `add` RT_Sphere{center = V3 0 (-1) (-10000.5), radius = 10000}
    -- Camera
    camera = Camera {
      viewport_height = 2.0,
      viewport_width = aspect_ratio * viewport_height camera,
      focal_length = 1.0,
      origin = zero,
      horizontal = viewport_width camera *^ unit _x,
      vertical = viewport_height camera *^ unit _z,
      lower_left_corner =
        origin camera - horizontal camera ^/2 - vertical camera ^/2
        - focal_length camera *^ unit _y
      }
    img_data = "P3\n" ++ show image_width ++ " " ++ show image_height ++ "\n255\n"

  putStr $ img_data


  foldr (>>) (return ()) $ do
    let
      indices = [image_height - 1, image_height - 2 .. 0] `prod` [0 .. image_width - 1]
      seeds = (randomRs (0, 536870912) (mkStdGen 21) :: [Int])
    ((j,i), seed) <- zip indices seeds
    return $ do
        let
          rnds = take (2*samples_per_pixel) $ (randoms (mkStdGen seed) :: [Double])
          pixcel_color = foldr (+) 0 $ do
            s <- [0 .. samples_per_pixel - 1]
            let
              randNum1 = rnds !! (2*s + 0)
              randNum2 = rnds !! (2*s + 1)
              u = (fromIntegral i + randNum1) / (fromIntegral image_width - 1.0)
              v = (fromIntegral j + randNum2) / (fromIntegral image_height - 1.0)
              r = get_ray camera (u, v)
            return $ ray_color r world
        write_color pixcel_color samples_per_pixel


data Ray = Ray {
  orig :: V3 Double,
  dir :: V3 Double
} deriving (Show)


at' :: Ray -> Double -> V3 Double
at' r t = (orig r) + t *^ (dir r)


data Camera = Camera {
  viewport_height :: Double,
  viewport_width :: Double,
  focal_length :: Double,
  origin :: V3 Double,
  horizontal :: V3 Double,
  vertical :: V3 Double,
  lower_left_corner :: V3 Double
} deriving (Show)

get_ray :: Camera -> (Double, Double) -> Ray
get_ray this (u, v) =
  Ray {
    orig = origin this,
    dir = lower_left_corner this + u *^ horizontal this + v *^ vertical this - origin this
    }



type HittableData = (RT_Sphere + RT_Torus) + RT_Sphere

class Hittable a where
  toSum :: a -> HittableData
  hit :: a -> Ray -> Double -> Double -> Maybe HitRecord

instance (Hittable a, Hittable b) => Hittable (Either a b) where
  toSum = coPair(toSum, toSum)
  hit = coPair(hit, hit)

add :: Hittable a => [HittableData] -> a -> [HittableData]
add list obj = (toSum obj) : list

data HitRecord = HitRecord {
  p :: V3 Double,
  normal :: V3 Double,
  t :: Double,
  front_face :: Bool
} deriving (Show)


set_face_normal :: HitRecord -> Ray -> V3 Double -> HitRecord
set_face_normal this r outward_normal =
  HitRecord {
    p = p this,
    normal = if dir r `dot` outward_normal < 0 then outward_normal else -outward_normal,
    t = t this,
    front_face = (dir r `dot` outward_normal < 0)
  }

hitSomething :: [HittableData] -> Ray -> Double -> Double -> Maybe HitRecord
hitSomething list r t_min t_max =
  let
    f (list', r', closest_so_far, currentRecord) =
      case list' of
        x:xs ->
          let
            temp = hit x r' t_min t_max
          in
            case temp of
              Just a ->
                f $ (xs, r', t a, temp)
              Nothing ->
                f $ (xs, r', closest_so_far, currentRecord)
        [] ->
          currentRecord

  in
    f $ (list, r, t_max, Nothing)

-- Sphere
data RT_Sphere = RT_Sphere {
  center :: V3 Double,
  radius :: Double
} deriving (Show)

instance Hittable RT_Sphere where
  toSum = inj1 -: inj1
  hit obj r t_min t_max =
    let
      p0 = orig r
      c1 = center obj
      r1 = radius obj
      oc = p0 - c1
      a = quadrance (dir r)
      half_b = oc `dot` dir r
      c = quadrance oc - (radius obj) ^ 2
      discriminant = half_b ^ 2 - a*c in

      if discriminant > 0 then
        let
          root = sqrt discriminant
          f k =
            case k of
              x:xs ->
                if t_min < x && x < t_max then
                  return $ set_face_normal HitRecord {
                    p = at' r x,
                    normal = zero,
                    t = x,
                    front_face = False
                    } r ((at' r x - c1) ^/ r1)
                else
                  f $ xs

              [] ->
                Nothing

        in
          f $ [(-half_b - root) / a, (-half_b + root) / a]
      else
        Nothing

-- Torus
data RT_Torus = RT_Torus {
  centerOfTorus :: V3 Double,
  majorRadius :: Double,
  minorRadius :: Double,
  orientationOfTorus :: V3 Double
} deriving (Show)

instance Hittable RT_Torus where
  toSum = inj2 -: inj1
  hit obj r t_min t_max =
    let
      p0 = orig r
      a  = dir r
      c1 = centerOfTorus obj
      r1 = majorRadius obj
      r2 = minorRadius obj
      n = orientationOfTorus obj
      s = getIntersection_forTorus (p0,a,c1,r1,r2,n)
      oc = p0 - c1
      a_sq  = quadrance (dir r)
      half_b = oc `dot` dir r
      c = quadrance oc - (r1 + r2 + 0.01) ^ 2
      discriminant = half_b ^ 2 - a_sq*c
    in
      if discriminant > 0 then
        if null s then
          Nothing
        else
          let
            k = minimum s
            x = at' r k - c1
          in
            if t_min < k && k < t_max then
              return $ HitRecord {
                  p = c1 + x,
                  normal = (x - (r1 *^ (normalize $ x - (n `dot` x) *^ n))) ^/ r2,
                  t = k,
                  front_face = False
                  }
            else
              Nothing
      else
        Nothing


write_color :: RealFrac a => V3 a -> Int -> IO ()
write_color v samples_per_pixel =
  let
    v' = v ^/ fromIntegral samples_per_pixel
    f = show.floor.(255.999*)
  in
    putStr $ f(v' ^._x) ++ " " ++ f(v' ^._y) ++ " " ++ f(v' ^._z) ++ "\n"


ray_color :: Ray -> [HittableData] -> V3 Double
ray_color r objects =
  let
    record = hitSomething objects r 0 infinity
  in
    case record of
      Just a ->
        0.5 *^ ((normal a) + (V3 1 1 1))

      Nothing ->
        let
          unit_direction = normalize $ (dir r)
          s = 0.5 * (unit_direction ^._z + 1.0)
        in
          lerp s (V3 0.5 0.7 1.0) (V3 1.0 1.0 1.0)



infinity :: RealFloat a => a
infinity = encodeFloat (floatRadix 0 - 1) (snd $ floatRange 0)

deg2rad :: Floating a => a -> a
deg2rad degrees = degrees * pi / 180

clamp :: (Ord a, Num a) => a -> a -> a -> a
clamp x y val = (max x).(min y) $ val


getIntersection_forTorus :: (V3 Double, V3 Double, V3 Double, Double, Double, V3 Double) -> [Double]
getIntersection_forTorus = solveQuarticEq . genCoefficients

genCoefficients (x0,a,c,r1,r2,n) = (b4,b3,b2,b1,b0)
  where
    d0 = x0 - c
    k = (r1^2) - (r2^2)
    a_sq = quadrance a
    d0_sq = quadrance d0

    b4 = a_sq^2                                       
    b3 = 4*(d0 `dot` a)*a_sq                          
    b2 = 2*d0_sq*a_sq+4*((d0 `dot` a)^2) + 2*k*a_sq - 4*(r1^2)*a_sq         + 4*(r1^2)*(n `dot` a)^2
    b1 = 4*d0_sq*(d0 `dot` a)+4*k*(d0 `dot` a)      - 8*(r1^2)*(d0 `dot` a) + 8*(r1^2)*(n `dot` d0)*(n `dot` a)
    b0 = d0_sq*d0_sq+2*k*d0_sq+k^2                  - 4*(r1^2)*d0_sq        + 4*(r1^2)*(n `dot` d0)^2

solveQuarticEq (a4,a3,a2,a1,a0) =
  let
    sol = do
      (x_Re :+ x_Im) <- [x1,x2,x3,x4]
      if (abs(x_Im) < 1.0E-9) && (1.0E-9 <= x_Re) then
        return x_Re
      else
        []
  in
    sol
  where
    l1 = (toCmp $ k3/4)/sqrt(k4)
    l2 = (toCmp $ (cbrt(2)*k5)/(3*a4))/k8 + k8/(toCmp $ 3*cbrt(2)*a4)
    l3 = (toCmp $ (a3^2)/(2*a4^2) - (4*a2)/(3*a4)) - l2
    k1 = l1 + l3
    k2 = -l1 + l3
    k3 = -((a3/a4)^3) + (4*a2*a3)/(a4^2) - (8*a1)/a4
    k4 = (toCmp $ ((a3/(2*a4))^2) - (2*a2)/(3*a4)) + l2
    k5 = a2^2 - 3*a1*a3 + 12*a0*a4
    k6 = 2*a2^3 - 9*a1*a2*a3 + 27*a0*a3^2 + 27*a1^2*a4 - 72*a0*a2*a4
    k7 = -4*k5^3 + k6^2
    k8 = cbrt((toCmp $ k6) + sqrt(toCmp $ k7))

    l4 = toCmp $ -a3/(4*a4)
    l5 = sqrt(k2)/2
    l6 = sqrt(k1)/2
    l7 = sqrt(k4)/2

    x1 = l4 - l5 - l7
    x2 = l4 + l5 - l7
    x3 = l4 - l6 + l7
    x4 = l4 + l6 + l7

cbrt x = x ** (1/3)
toCmp x = x :+ 0
prod x y = x >>= (\u -> zip (repeat u) y)


(-:) = flip (.)

type (+)  a b = Either a b

inj1 :: a -> a + b
inj1 = Left

inj2 :: b -> a + b
inj2 = Right

coPair :: (a1 -> b, a2 -> b) -> (a1 + a2 -> b)
coPair = uncurry either