(* A Staged Implementation of FFT Author: Walid Taha, Oleg Kiselyov, Kedar N. Swadi Date: Thu Jul 8 20:24:46 CDT 2004 Problem: See "A Methodology for Generating Verified Combinatorial Circuits" *) (* Step 1 *) let pi = 4.0 *. atan(1.0) (*elided*) let w dir n j = (* exp(-2PI dir/N)^j dir=1.0 for the forward transform *) let theta = dir *. ((float_of_int (-2 * j)) *. pi) /. (float_of_int n) in ((cos theta), (sin theta)) let add ((r1,i1), (r2, i2)) = ((r1 +. r2), (i1 +. i2)) (*elided*) let sub ((r1,i1), (r2, i2)) = ((r1 -. r2), (i1 -. i2)) (*elided*) let mult ((r1, i1), (r2, i2)) = let rp = (r1 *. r2) -. (i1 *. i2) in let ip = (r1 *. i2) +. (r2 *. i1) in (rp, ip) let rec split l = match l with [] -> ([], []) | x::y::xs -> let (a,b) = split xs in (x::a, y::b) let rec merge dir l1 l2 = let n = 2 * List.length l1 in let rec mg l1 l2 j = match (l1, l2) with (x::xs, y::ys) -> let z1 = mult (w dir n j, y) in let zx = add (x, z1) in let zy = sub (x, z1) in let (a,b) = (mg xs ys (j+1)) in (zx::a, zy::b) | _ -> ([], []) in let (a,b) = mg l1 l2 0 in (a @ b) let rec fft dir l = if (List.length l = 1) then l else let (e,o) = split l in let y0 = fft dir e in let y1 = fft dir o in merge dir y0 y1 (* Step 2a: The correctness should be obvious *) (* List of size n with (1,0) at position l *) let rec impulse n l = if n = 0 then [] else let t = if l = 0 then (1.0,0.0) else (0.0,0.0) in t::(impulse (n-1) (l-1)) let () = assert ([(0., 0.); (4., 0.); (0., 0.); (0., 0.)] = fft (-1.0) (fft 1.0 (impulse 4 1))) let () = assert ([(0., 0.); (8., 0.); (0., 0.); (0., 0.); (0., 0.); (0., 0.); (0., 0.); (0., 0.)] = fft (-1.0) (fft 1.0 (impulse 8 1))) let test3_u = fft (-1.0) (fft 1.0 (impulse 16 1)) (* -------------------------------------------------------------------- *) (* Step 2b: Converting the program into monadic form *) let (ret,bind,y_sm, run) = let ret a = fun s k -> k s a in let bind a f = fun s k -> a s (fun s' b -> f b s' k) in let rec y_sm f = f (fun x s k -> y_sm f x s k) in let run f size nums = (* convert list to array code *) let list2array l arr = let rec lfta l m = match l with [] -> (arr) | ((x,y)::z) -> let sm = m+1 in let ssm = m+2 in (((arr).(m) <- x; (arr).(sm) <- y; ((lfta z (ssm))))) in lfta l 0 in (* convert input array to list format *) let array2list arr n k = let rec gl m l = if m = n then k l arr else let sm = m+1 in (let x1 = (arr).(m) in let y1 = (arr).(sm) in (gl (m+2) (l @ [(x1, y1)]))) in gl 0 [] in let run1 f arr = f [] (fun s x -> list2array x arr) in (array2list nums size (fun l arr2 -> run1 (f l) arr2 )) in (ret,bind,y_sm,run) let merge_m dir l1 l2 = let n = 2 * List.length l1 in let rec mg l1 l2 j = match (l1, l2) with (x::xs, y::ys) -> bind (ret (mult (w dir n j, y))) (fun z1 -> bind (ret (add (x, z1))) (fun zx -> bind (ret (sub (x, z1))) (fun zy -> bind (mg xs ys (j+1)) (fun (a,b) -> ret (zx::a, zy::b))))) | _ -> ret ([], []) in bind (mg l1 l2 0) (fun (a,b) -> ret (a @ b)) let rec fft_m dir f l = if (List.length l = 1) then ret l else let (e,o) = split l in bind (f e) (fun y0 -> bind (f o) (fun y1 -> merge_m dir y0 y1)) (*tests *) let rec impulse n l = let arr = Array.make (2 * n) 0.0 in let () = arr.(2 * l) <- 1.0 in arr let () = assert ([|0.; 0.; 4.; 0.; 0.; 0.; 0.; 0.;|] = let s1 = run (y_sm (fft_m (1.0))) 8 (impulse 4 1) in let s2 = run (y_sm (fft_m (-1.0))) 8 s1 in s2) let () = assert ([|0.; 0.; 8.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.; 0.;|] = let s1 = run (y_sm (fft_m (1.0))) 16 (impulse 8 1) in let s2 = run (y_sm (fft_m (-1.0))) 16 s1 in s2) let test3_u = run (y_sm (fft_m (-1.0))) 32 (run (y_sm (fft_m 1.0)) 32 (impulse 16 1)) (* -------------------------------------------------------------------- *) (* Step 3,4: Staging the monadic program *) let (retS,retN,bind,y_sm,run) = let retS a = fun s k -> k s a in let retN a = fun s k -> ..)>. in let bind a f = fun s k -> a s (fun s' b -> f b s' k) in let rec y_sm f = f (fun x s k -> y_sm f x s k) in let run f size nums = (* convert list to array code *) let list2array_s l arr = let rec lfta l m = match l with [] -> (arr) | ((x,y)::z) -> let sm = m+1 in let ssm = m+2 in (.<((.~arr).(m) <- .~x; (.~arr).(sm) <- .~y; (.~(lfta z (ssm))))>.) in lfta l 0 in (* convert input array to list format *) let array2list_s arr n k = let rec gl m l = if m = n then k l arr else let sm = m+1 in (.., ..)]))>.) in gl 0 [] in let run1 f a = f [] (fun s x -> list2array_s x a) in (array2list_s nums size (fun l arr2 -> run1 (f l) arr2 )) in (retS,retN,bind,y_sm,run) (* Standard auxiliary monadic operators *) let liftM f = fun x -> retS (f x) let compM a b = fun x -> bind (b x) (fun nx -> a nx) let liftcM op (x,y) = bind (op x) (fun nx -> bind (op y) (fun ny -> retS (nx,ny))) let w_s dir n j = (* exp(-2PI dir/N)^j dir=1.0 for the forward transform *) let theta = dir *. ((float_of_int (-2 * j)) *. pi) /. (float_of_int n) in let ct = cos theta in let st = sin theta in (.., ..) let add_s ((r1,i1), (r2, i2)) = ((.<.~r1 +. .~r2>.), (.<.~i1 +. .~i2>.)) (*elided*) let sub_s ((r1,i1), (r2, i2)) = ((.<.~r1 -. .~r2>.), (.<.~i1 -. .~i2>.)) (*elided*) let mult_s ((r1, i1), (r2, i2)) = let rp = .<(.~r1 *. .~r2) -. (.~i1 *. .~i2)>. in let ip = .<(.~r1 *. .~i2) +. (.~r2 *. .~i1)>. in (rp, ip) let merge_ms dir l1 l2 = let n = 2 * List.length l1 in let rec mg l1 l2 j = match (l1, l2) with (x::xs, y::ys) -> bind (retS (mult_s (w_s dir n j, y))) (fun z1 -> bind (retS (add_s (x, z1))) (fun zx -> bind (retS (sub_s (x, z1))) (fun zy -> bind (mg xs ys (j+1)) (fun (a,b) -> retS (zx::a, zy::b))))) | _ -> retS ([], []) in bind (mg l1 l2 0) (fun (a,b) -> retS (a @ b)) let fft_ms dir f l = if (List.length l = 1) then retS l else let (e,o) = split l in bind (f e) (fun y0 -> bind (f o) (fun y1 -> merge_ms dir y0 y1)) (* Testing the generated code *) let s1m = . .~(run (y_sm (fft_ms 1.0)) (8 * 2) ..);0>. (* let _ = Trx.C.addToLibrary (Trx.C.toC s1m) "." "test";; *) let test_ms dir arr = let n = Array.length arr in let tranc = . .~(run (y_sm (fft_ms dir)) n ..)>. in (.!tranc) (arr) let () = assert ([|0.; 0.; 4.; 0.; 0.; 0.; 0.; 0.;|] = test_ms (-1.0) (test_ms 1.0 (impulse 4 1))) let () = assert ([|0.; 0.; 8.00000000000512; 0.; 0.; 0.; 0.; 0.; 0.; 0.; -5.11946041115152184e-12; 0.; 0.; 0.; 0.; 0.|] = test_ms (-1.0) (test_ms 1.0 (impulse 8 1))) let test3_ms = test_ms (-1.0) (test_ms 1.0 (impulse 16 1)) let merge_ms dir l1 l2 = let n = 2 * List.length l1 in let rec mg l1 l2 j = match (l1, l2) with (x::xs, y::ys) -> bind ((liftcM retN) (mult_s (w_s dir n j, y))) (fun z1 -> bind (retS (add_s (x, z1))) (fun zx -> bind (retS (sub_s (x, z1))) (fun zy -> bind (mg xs ys (j+1)) (fun (a,b) -> retS (zx::a, zy::b))))) | _ -> retS ([], []) in bind (mg l1 l2 0) (fun (a,b) -> retS (a @ b)) let fft_ms dir f l = if (List.length l = 1) then retS l else let (e,o) = split l in bind (f e) (fun y0 -> bind (f o) (fun y1 -> merge_ms dir y0 y1)) let s1m = . .~(run (y_sm (fft_ms 1.0)) 16 ..);0>. (* let _ = Trx.C.addToLibrary (Trx.C.toC s1m) "." "test2";; *) (* Step 4b: Changing some of the implicit retS -> retN *) (* We still have some code duplication. So, we notice that the code above had a few _implicit_ retS: (bind (retS x)) (fun x -> (bind (retS y)) (fun y -> Note that bind . retS is the identity... So, we replace those retS with RetN: *) let merge_ms1 dir l1 l2 = let n = 2 * List.length l1 in let rec mg l1 l2 j = match (l1, l2) with (x::xs, y::ys) -> bind ((liftcM retN) x) (fun x -> bind ((liftcM retN) y) (fun y -> bind ((liftcM retN) (mult_s (w_s dir n j, y))) (fun z1 -> bind (retS (add_s (x, z1))) (fun zx -> bind (retS (sub_s (x, z1))) (fun zy -> bind (mg xs ys (j+1)) (fun (a,b) -> retS (zx::a, zy::b))))))) | _ -> retS ([], []) in bind (mg l1 l2 0) (fun (a,b) -> retS (a @ b)) let fft_ms1 dir f l = if (List.length l = 1) then retS l else let (e,o) = split l in bind (f e) (fun y0 -> bind (f o) (fun y1 -> merge_ms1 dir y0 y1)) let s1m = . .~(run (y_sm (fft_ms1 1.0)) 16 ..);0>. (* let _ = Trx.C.addToLibrary (Trx.C.toC s1m) "." "test21";; *) let test_ms1 dir arr = let n = Array.length arr in let tranc = . .~(run (y_sm (fft_ms1 dir)) n ..)>. in (.!tranc) (arr) let () = assert ([|0.; 0.; 4.; 0.; 0.; 0.; 0.; 0.;|] = test_ms1 (-1.0) (test_ms1 1.0 (impulse 4 1))) let () = assert ([|0.; 0.; 8.00000000000512; 0.; 0.; 0.; 0.; 0.; 0.; 0.; -5.11946041115152184e-12; 0.; 0.; 0.; 0.; 0.|] = test_ms1 (-1.0) (test_ms1 1.0 (impulse 8 1))) let test3_ms1 = test_ms1 (-1.0) (test_ms1 1.0 (impulse 16 1)) (* -------------------------------------------------------------------- *) (* Step 5. Eliminating trivial bindings. If the code value is simple, *) (* don't generate the corresponding binding. *) (* An abstract domain that remembers if a value is cheap to duplicate *) type 'a maybeValue = Val of ('a, float) code | Exp of ('a, float) code (* The concretization function *) let mVconc = function (Val x) -> x | (Exp x) -> x let w_sv dir n j = (* exp(-2PI dir/N)^j dir=1.0 for the forward transform *) let theta = dir *. ((float_of_int (-2 * j)) *. pi) /. (float_of_int n) in let ct = cos theta in let st = sin theta in (Val .., Val ..) (* Helper functions for use in abstract operations *) let mV_add x y = let xc = mVconc x in let yc = mVconc y in Exp .<.~xc +. .~yc>. let mV_sub x y = let xc = mVconc x in let yc = mVconc y in Exp .<.~xc -. .~yc>. let mV_mul x y = let xc = mVconc x in let yc = mVconc y in Exp .<.~xc *. .~yc>. let add_sv ((r1,i1), (r2, i2)) = (mV_add r1 r2, mV_add i1 i2) (*elided*) let sub_sv ((r1,i1), (r2, i2)) = (mV_sub r1 r2, mV_sub i1 i2) (*elided*) let mult_sv ((r1, i1), (r2, i2)) = let rp = mV_sub (mV_mul r1 r2) (mV_mul i1 i2) in let ip = mV_add (mV_mul r1 i2) (mV_mul r2 i1) in (rp, ip) let retN_v = function Val _ as v -> retS v | Exp _ as x -> bind (retN (mVconc x)) (fun x -> retS (Val x)) let merge_mv dir l1 l2 = let n = 2 * List.length l1 in let rec mg l1 l2 j = match (l1, l2) with (x::xs, y::ys) -> bind ((liftcM retN_v) x) (fun x -> bind ((liftcM retN_v) y) (fun y -> bind ((liftcM retN_v) (mult_sv (w_sv dir n j, y))) (fun z1 -> bind (retS (add_sv (x, z1))) (fun zx -> bind (retS (sub_sv (x, z1))) (fun zy -> bind (mg xs ys (j+1)) (fun (a,b) -> retS (zx::a, zy::b))))))) | _ -> retS ([], []) in bind (mg l1 l2 0) (fun (a,b) -> retS (a @ b)) let fft_mv dir f l = if (List.length l = 1) then retS l else let (e,o) = split l in bind (f e) (fun y0 -> bind (f o) (fun y1 -> merge_mv dir y0 y1)) let run_v f = let unmap l = List.map (fun (x,y) -> (mVconc x,mVconc y)) l in let f1 l l' k = f (List.map (fun (x,y) -> (Val x,Val y)) l) l' (fun s l -> k s (unmap l)) in run f1 let s1v = . .~(run_v (y_sm (fft_mv 1.0)) 16 ..);0>. (* let _ = Trx.C.addToLibrary (Trx.C.toC s1m) "." "test5";; *) let test_mv dir arr = let n = Array.length arr in let tranc = . .~(run_v (y_sm (fft_mv dir)) n ..)>. in (.!tranc) (arr) let () = assert ([|0.; 0.; 4.; 0.; 0.; 0.; 0.; 0.;|] = test_mv (-1.0) (test_mv 1.0 (impulse 4 1))) let () = assert ([|0.; 0.; 8.00000000000512; 0.; 0.; 0.; 0.; 0.; 0.; 0.; -5.11946041115152184e-12; 0.; 0.; 0.; 0.; 0.|] = test_mv (-1.0) (test_mv 1.0 (impulse 8 1))) let test3_mv = test_mv (-1.0) (test_mv 1.0 (impulse 16 1)) (* Now we notice that all trivial bindings are gone. But multiplication by 1., 0. and 6.12303176911e-17 is still present. We need the last step -- true abstract interpretation of multiplication and addition *) (* -------------------------------------------------------------------- *) (* Step 6 : Improving accuracy and making complex operations smarter *) (* abstract domain *) type 'a abstract_code = Lit of float | Any of float * 'a maybeValue (* concretization function (from abstract_code to maybeValue) *) let conc = function (Lit x) -> Val .. | Any (1.0,x) -> x | Any (-1.0,x) -> Exp .<-. .~(mVconc x)>. | Any (factor,x) -> Exp .. let w_a dir n j = (* exp( dir* -2PI I j/n), where dir is +1 or -1 *) if j = 0 then (Lit 1.0, Lit 0.0) else if 2*j = n then (Lit (-1.0), Lit 0.0) else (* exp(dir* -PI I) *) if 4*j = n then (Lit 0.0, Lit (-. dir)) else (* exp( dir* -PI/2 I) *) if 4*j = 3*n then (Lit 0.0, Lit dir) else (* exp( dir* -3*PI/2 I) *) if 8*j mod n = 0 then (* 8j/n must be odd *) let quadrant = ((8*j / n) - 1)/2 and (* unnorm *) cos_signs = [| 1.0; -1.0; -1.0; 1.0 |] and sin_signs = [| 1.0; 1.0; -1.0;-1.0 |] and csh = cos (pi /. 4.0) in let quadrant = if dir = -1.0 then quadrant else 3 - quadrant in (Lit (csh *. cos_signs .(quadrant)), Lit (csh *. sin_signs .(quadrant))) else let theta = dir *. ((float_of_int (-2 * j)) *. pi) /. (float_of_int n) in (Lit (cos theta), Lit (sin theta)) (* Lifting from maybeValue to abstract_code *) let lifta x = Any (1.0, x) let rec add_a (n1, n2) = let signf f = if f > 0.0 then 1.0 else -1.0 in let setf f (Any (_,v)) = Any (f,v) in match (n1, n2) with (Lit 0.0,x) -> x | (x, Lit 0.0) -> x | (Lit x, Lit y) -> (Lit (x +. y)) | (Lit _, Any (1.0,y)) -> lifta (mV_add (conc n1) y) | (Lit _, Any (-1.0,y)) -> lifta (mV_sub (conc n1) y) | (Lit _, Any (factor,y)) -> if factor > 0.0 then add_a (n1,lifta (conc n2)) else let abs_factor = abs_float factor in add_a (n1, Any (-1.0,conc (setf abs_factor n2))) | (Any _, Lit _) -> add_a (n2,n1) | (Any (1.0,x), Any (-1.0,y)) -> lifta (mV_sub x y) | (Any (-1.0,x), Any (1.0,y)) -> lifta (mV_sub y x) | (Any (fx,x), Any (fy,y)) -> if fx = fy then Any (fx,mV_add x y) else if abs_float fx = abs_float fy then setf (abs_float fx) (add_a ((setf (signf fx) n1), (setf (signf fy) n2))) else add_a (Any ((signf fx),(conc (setf (abs_float fx) n1))), Any ((signf fy),(conc (setf (abs_float fy) n2)))) let neg_a n = match n with Lit x -> Lit (-. x) | Any (s,x) -> Any (-. s,x) let rec mul_a (n1,n2) = match (n1,n2) with (Lit 0.0,x) -> Lit 0.0 | (x,Lit 0.0) -> Lit 0.0 | (Lit x,Lit y) -> Lit (x *. y) | (Lit 1.0,x) -> x | (Lit (-1.0),x) -> neg_a x | (Lit x,Any (f,y)) -> Any (x *. f,y) | (_,Lit _) -> mul_a (n2,n1) | (Any (fx,x),Any(fy,y)) -> Any (fx *. fy,mV_mul x y) let mul_u_a ((r1, i1), (r2, i2)) = let rp = add_a(mul_a (r1,r2),neg_a(mul_a (i1,i2))) in let ip = add_a(mul_a (r1,i2),mul_a(r2,i1)) in (rp, ip) let sub_u_a ((r1, i1), (r2, i2)) = (add_a (r1, neg_a r2), add_a (i1, neg_a i2)) let add_u_a ((r1, i1), (r2, i2)) = (add_a (r1, r2), add_a (i1, i2)) let retN_va a = let take_sign a = match a with Any (f,x) -> if f > 0.0 then (1.0,a) else (-1.0,Any (abs_float f,x)) | _ -> (1.0,a) in let lifta_s s x = Any (s,x) in let (s,na) = take_sign a in (compM (liftM (lifta_s s)) (compM retN_v (liftM conc))) na let merge_a dir (l1, l2) = let n = 2 * List.length l1 in let rec mg l1 l2 j = match (l1, l2) with (x::xs, y::ys) -> bind ((liftcM retN_va) x) (fun x -> bind ((liftcM retN_va) y) (fun y -> bind ((liftcM retN_va) (mul_u_a (w_a dir n j, y))) (fun z1 -> bind (retS (add_u_a (x, z1))) (fun zx -> bind (retS (sub_u_a (x, z1))) (fun zy -> bind ((mg xs ys (j+1))) (fun (a,b) -> retS (zx::a, zy::b))))))) | _ -> retS ([], []) in bind (mg l1 l2 0) (fun (a,b) -> retS (a @ b)) let fft_a dir f l = if (List.length l = 1) then retS l else let (e,o) = split l in bind (f e) (fun y0 -> bind (f o) (fun y1 -> merge_a dir (y0, y1))) let run_a f = let unmap l = List.map (fun (x,y) -> (mVconc (conc x),mVconc (conc y))) l in let f1 l l' k = f (List.map (fun (x,y) -> (lifta (Val x), lifta (Val y))) l) l' (fun s l -> k s (unmap l)) in run f1 let s1a = . .~(run_a (y_sm (fft_a 1.0)) 8 ..);0>. let test_a dir arr = let n = Array.length arr in let tranc = . .~(run_a (y_sm (fft_a dir)) n ..)>. in (.!tranc) (arr) let () = assert ([|0.; 0.; 4.; 0.; 0.; 0.; 0.; 0.;|] = test_a (-1.0) (test_a 1.0 (impulse 4 1))) let () = assert ([|0.; 0.; 8.00000000000512; 0.; 0.; 0.; 0.; 0.; 0.; 0.; -5.11946041115152184e-12; 0.; 0.; 0.; 0.; 0.|] = test_a (-1.0) (test_a 1.0 (impulse 8 1))) let test3_a = test_a (-1.0) (test_a 1.0 (impulse 16 1)) (* -------------------------------------------------------------------- *) let () = print_string "\nDone\n" ;;