I'm trying to port this program that computes the n
th derivative of x^x symbolically to Rust. It seems to be mostly easy:
use std::rc::Rc;
type Expr = Rc<Expr2>;
enum Expr2 {
Int(i32),
Var(String),
Add(Expr, Expr),
Mul(Expr, Expr),
Pow(Expr, Expr),
Ln(Expr),
}
use Expr2::*;
fn pown(a: i32, b: i32) -> i32 {
match b {
0 => 1,
1 => a,
n => {
let b = pown(a, b / 2);
let b2 = b * b;
if n % 2 == 0 {
b2
} else {
b2 * a
}
}
}
}
fn add(f: Expr, g: Expr) -> Expr {
match (f, g) {
(Int(m), Int(n)) => Int(m + n),
(Int(0), f) => f,
(f, Int(n)) => add(Int(n), f),
(f, Add(Int(n), g)) => add(Int(n), add(f, g)),
(Add(f, g), h) => add(f, add(g, h)),
(f, g) => Add(f, g),
}
}
fn mul(f: Expr, g: Expr) -> Expr {
match (f, g) {
(Int(m), Int(n)) => Int(m * n),
(Int(0), f) => Int(0),
(Int(1), f) => f,
(f, Int(n)) => mul(Int(n), f),
(f, Mul(Int(n), g)) => mul(Int(n), mul(f, g)),
(Mul(f, g), h) => mul(f, mul(g, h)),
(f, g) => Mul(f, g),
}
}
fn pow(f: Expr, g: Expr) -> Expr {
match (f, g) {
(Int(m), Int(n)) => Int(pown(m, n)),
(f, Int(0)) => Int(1),
(f, Int(1)) => f,
(Int(0), f) => Int(1),
(f, g) => Pow(f, g),
}
}
fn ln(f: Expr) -> Expr {
match f {
Int(1) => Int(0),
f => Ln(f),
}
}
fn d(x: String, f: Expr) -> Expr {
match f {
Int(_) => Int(0),
Var(y) => if x == y {
x
} else {
y
},
Add(f, g) => add(d(x, f), d(x, g)),
Mul(f, g) => add(mul(f, d(x, g)), mul(g, d(x, f))),
Pow(f, g) => mul(
pow(f, g),
add(mul(mul(g, d(x, f)), pow(f, Int(-1))), mul(ln(f), d(x, g))),
),
Ln(f) => mul(d(x, f), pow(f, Int(-1))),
}
}
fn count(f: Expr) -> i32 {
match f {
Int(_) | Var(_) => 1,
Add(f, g) | Mul(f, g) | Pow(f, g) => count(f) + count(g),
Ln(f) => count(f),
}
}
fn string_of_expr(f: Expr) -> String {
count(f).to_string();
}
fn nest(n: i32, f: Expr, x: Expr) -> Expr {
if n == 0 {
x
} else {
nest(n - 1, f, f(x))
}
}
fn deriv(f: Expr) -> Expr {
let df = d("x", f);
format!("D({}) = {}", string_of_expr(f), string_of_expr(df));
df
}
fn main() {
let x = "x";
let f = pow(x, x);
// FIXME: Read command-line argument
let df = nest(9, deriv, f);
format!("{}", count(df));
}
The type needs to be converted into a reference counted enum
in Rust and pattern matching makes for very similar code except... it doesn't work. From what I can gather, patterns in Rust cannot match upon the result of dereferencing an Rc
. So, no matter what I do, it fails on nested patterns like (f, Add(Int(n), g))
.
Am I missing something or is it really impossible for nested patterns to match over recursive datatypes in Rust? Apparently there is something called "box syntax" to dereference inside a pattern (amongst other things) that has been on the drawing board for four years.