2009-10-27 8 views
5

Estoy tratando de encontrar una forma elegante de manejar algunos polinomios generados. Esta es la situación que nos centraremos en (exclusivamente) para esta pregunta:Métodos generados para la evaluación de polinomios

  1. fin es un parámetro en la generación de una n º polinomio de orden, donde n: = + 1. Para
  2. i es un parámetro entero en el rango 0..n
  3. El polinomio tiene ceros en x_j, donde j = 1..n y j ≠ i (debe estar claro en este punto que StackOverflow necesita una nueva característica o está presente y no lo sé)
  4. Evaluación polinomial es a 1 en x_i.

Dado que este ejemplo de código en particular genera x_1 .. x_n, explicaré cómo se encuentran en el código. Los puntos están espaciados uniformemente x_j = j * elementSize/order aparte, donde n = order + 1.

Genero un Func<double, double> para evaluar este polinomio¹.

private static Func<double, double> GeneratePsi(double elementSize, int order, int i) 
{ 
    if (order < 1) 
     throw new ArgumentOutOfRangeException("order", "order must be greater than 0."); 

    if (i < 0) 
     throw new ArgumentOutOfRangeException("i", "i cannot be less than zero."); 
    if (i > order) 
     throw new ArgumentException("i", "i cannot be greater than order"); 

    ParameterExpression xp = Expression.Parameter(typeof(double), "x"); 

    // generate the terms of the factored polynomial in form (x_j - x) 
    List<Expression> factors = new List<Expression>(); 
    for (int j = 0; j <= order; j++) 
    { 
     if (j == i) 
      continue; 

     double p = j * elementSize/order; 
     factors.Add(Expression.Subtract(Expression.Constant(p), xp)); 
    } 

    // evaluate the result at the point x_i to get scaleInv=1.0/scale. 
    double xi = i * elementSize/order; 
    double scaleInv = Enumerable.Range(0, order + 1).Aggregate(0.0, (product, j) => product * (j == i ? 1.0 : (j * elementSize/order - xi))); 

    /* generate an expression to evaluate 
    * (x_0 - x) * (x_1 - x) .. (x_n - x)/(x_i - x) 
    * obviously the term (x_i - x) is cancelled in this result, but included here to make the result clear 
    */ 
    Expression expr = factors.Skip(1).Aggregate(factors[0], Expression.Multiply); 
    // multiplying by scale forces the condition f(x_i)=1 
    expr = Expression.Multiply(Expression.Constant(1.0/scaleInv), expr); 

    Expression<Func<double, double>> lambdaMethod = Expression.Lambda<Func<double, double>>(expr, xp); 
    return lambdaMethod.Compile(); 
} 

El problema: que también tienen que evaluar ψ '= dψ/dx. Para hacer esto, puedo reescribir ψ = escala × (x_0 - x) (x_1 - x) × .. × (x_n - x)/(x_i - x) en la forma ψ = α_n × x^n + α_n × x^(n-1) + .. + α_1 × x + α_0. Esto da ψ '= n × α_n × x^(n-1) + (n-1) × α_n × x^(n-2) + .. + 1 × α_1.

Por razones de cálculo, podemos reescribir la respuesta final sin llamadas a Math.Pow escribiendo ψ '= x × (x × (x × (..) - β_2) - β_1) - β_0.

a hacer todo este "engaño" (todo el álgebra muy básico), necesito una manera limpia de:

  1. expandir una factorizado Expression contiene ConstantExpression y ParameterExpression hojas y las operaciones matemáticas básicas (terminan ya sea BinaryExpression con el NodeType configurado para la operación) - el resultado aquí puede incluir elementos InvocationExpression al MethodInfo para Math.Pow que manejaremos de manera especial en todo momento.
  2. Luego tomo la derivada con respecto a algunos ParameterExpression especificados. Los términos en el resultado donde el parámetro del lado derecho de una invocación de Math.Pow era la constante 2 son reemplazados por el ConstantExpression(2) multiplicado por lo que era el lado izquierdo (se elimina la invocación de Math.Pow(x,1)). Los términos en el resultado que se vuelven cero porque fueron constantes con respecto a x se eliminan.
  3. A continuación, factorizar las instancias de algunos ParameterExpression específicos donde ocurren como el parámetro del lado izquierdo de una invocación de Math.Pow. Cuando el lado derecho de la invocación se convierte en ConstantExpression con el valor 1, reemplazamos la invocación con solo el ParameterExpression.

¹ En el futuro, me gustaría que el método para tomar una ParameterExpression y devuelve un Expression que evalúa basa en ese parámetro. De esa forma puedo agregar funciones generadas. No estoy allí todavía. ² En el futuro, espero lanzar una biblioteca general para trabajar con LINQ Expressions como matemática simbólica.

+4

+1 para mí perder después de 5 líneas ... debe ser muy una pregunta inteligente;) –

+0

Yo, por otro lado, entiendo todas las matemáticas y no sé nada sobre LINQ! Parece que ya tienes tus algoritmos prácticamente resueltos. ¡Y buena suerte con esa biblioteca! – Cascabel

+0

@Jefromi: puedo generar un árbol de expresiones bien. Lo que quiero construir es una forma elegante de transformar los árboles, tratándolos como expresiones de matemática simbólica. :) –

Respuesta

6

Escribí los conceptos básicos de varias características simbólicas de matemáticas usando el tipo ExpressionVisitor en .NET 4. No es perfecto, pero parece la base de una solución viable.

  • Symbolic es una clase estática pública exponiendo métodos como Expand, Simplify y PartialDerivative
  • ExpandVisitor es un tipo de ayuda interna que se expande expresiones
  • SimplifyVisitor es un tipo de ayuda interna que simplifica expresiones
  • DerivativeVisitor es un tipo de ayuda interna que toma la derivada de una expresión
  • ListPrintVisitor es un helper de tipo interno que convierte una Expression a una notación de prefijo con una sintaxis Lisp

Symbolic

public static class Symbolic 
{ 
    public static Expression Expand(Expression expression) 
    { 
     return new ExpandVisitor().Visit(expression); 
    } 

    public static Expression Simplify(Expression expression) 
    { 
     return new SimplifyVisitor().Visit(expression); 
    } 

    public static Expression PartialDerivative(Expression expression, ParameterExpression parameter) 
    { 
     bool totalDerivative = false; 
     return new DerivativeVisitor(parameter, totalDerivative).Visit(expression); 
    } 

    public static string ToString(Expression expression) 
    { 
     ConstantExpression result = (ConstantExpression)new ListPrintVisitor().Visit(expression); 
     return result.Value.ToString(); 
    } 
} 

expresiones expansión con ExpandVisitor

internal class ExpandVisitor : ExpressionVisitor 
{ 
    protected override Expression VisitBinary(BinaryExpression node) 
    { 
     var left = Visit(node.Left); 
     var right = Visit(node.Right); 

     if (node.NodeType == ExpressionType.Multiply) 
     { 
      Expression[] leftNodes = GetAddedNodes(left).ToArray(); 
      Expression[] rightNodes = GetAddedNodes(right).ToArray(); 
      var result = 
       leftNodes 
       .SelectMany(x => rightNodes.Select(y => Expression.Multiply(x, y))) 
       .Aggregate((sum, term) => Expression.Add(sum, term)); 

      return result; 
     } 

     if (node.Left == left && node.Right == right) 
      return node; 

     return Expression.MakeBinary(node.NodeType, left, right, node.IsLiftedToNull, node.Method, node.Conversion); 
    } 

    /// <summary> 
    /// Treats the <paramref name="node"/> as the sum (or difference) of one or more child nodes and returns the 
    /// the individual addends in the sum. 
    /// </summary> 
    private static IEnumerable<Expression> GetAddedNodes(Expression node) 
    { 
     BinaryExpression binary = node as BinaryExpression; 
     if (binary != null) 
     { 
      switch (binary.NodeType) 
      { 
      case ExpressionType.Add: 
       foreach (var n in GetAddedNodes(binary.Left)) 
        yield return n; 

       foreach (var n in GetAddedNodes(binary.Right)) 
        yield return n; 

       yield break; 

      case ExpressionType.Subtract: 
       foreach (var n in GetAddedNodes(binary.Left)) 
        yield return n; 

       foreach (var n in GetAddedNodes(binary.Right)) 
        yield return Expression.Negate(n); 

       yield break; 

      default: 
       break; 
      } 
     } 

     yield return node; 
    } 
} 

Tomando un derivado con DerivativeVisitor

internal class DerivativeVisitor : ExpressionVisitor 
{ 
    private ParameterExpression _parameter; 
    private bool _totalDerivative; 

    public DerivativeVisitor(ParameterExpression parameter, bool totalDerivative) 
    { 
     if (_totalDerivative) 
      throw new NotImplementedException(); 

     _parameter = parameter; 
     _totalDerivative = totalDerivative; 
    } 

    protected override Expression VisitBinary(BinaryExpression node) 
    { 
     switch (node.NodeType) 
     { 
     case ExpressionType.Add: 
     case ExpressionType.Subtract: 
      return Expression.MakeBinary(node.NodeType, Visit(node.Left), Visit(node.Right)); 

     case ExpressionType.Multiply: 
      return Expression.Add(Expression.Multiply(node.Left, Visit(node.Right)), Expression.Multiply(Visit(node.Left), node.Right)); 

     case ExpressionType.Divide: 
      return Expression.Divide(Expression.Subtract(Expression.Multiply(Visit(node.Left), node.Right), Expression.Multiply(node.Left, Visit(node.Right))), Expression.Power(node.Right, Expression.Constant(2))); 

     case ExpressionType.Power: 
      if (node.Right is ConstantExpression) 
      { 
       return Expression.Multiply(node.Right, Expression.Multiply(Visit(node.Left), Expression.Subtract(node.Right, Expression.Constant(1)))); 
      } 
      else if (node.Left is ConstantExpression) 
      { 
       return Expression.Multiply(node, MathExpressions.Log(node.Left)); 
      } 
      else 
      { 
       return Expression.Multiply(node, Expression.Add(
        Expression.Multiply(Visit(node.Left), Expression.Divide(node.Right, node.Left)), 
        Expression.Multiply(Visit(node.Right), MathExpressions.Log(node.Left)) 
        )); 
      } 

     default: 
      throw new NotImplementedException(); 
     } 
    } 

    protected override Expression VisitConstant(ConstantExpression node) 
    { 
     return MathExpressions.Zero; 
    } 

    protected override Expression VisitInvocation(InvocationExpression node) 
    { 
     MemberExpression memberExpression = node.Expression as MemberExpression; 
     if (memberExpression != null) 
     { 
      var member = memberExpression.Member; 
      if (member.DeclaringType != typeof(Math)) 
       throw new NotImplementedException(); 

      switch (member.Name) 
      { 
      case "Log": 
       return Expression.Divide(Visit(node.Expression), node.Expression); 

      case "Log10": 
       return Expression.Divide(Visit(node.Expression), Expression.Multiply(Expression.Constant(Math.Log(10)), node.Expression)); 

      case "Exp": 
      case "Sin": 
      case "Cos": 
      default: 
       throw new NotImplementedException(); 
      } 
     } 

     throw new NotImplementedException(); 
    } 

    protected override Expression VisitParameter(ParameterExpression node) 
    { 
     if (node == _parameter) 
      return MathExpressions.One; 

     return MathExpressions.Zero; 
    } 
} 

Simplificar expresiones con SimplifyVisitor

internal class SimplifyVisitor : ExpressionVisitor 
{ 
    protected override Expression VisitBinary(BinaryExpression node) 
    { 
     var left = Visit(node.Left); 
     var right = Visit(node.Right); 

     ConstantExpression leftConstant = left as ConstantExpression; 
     ConstantExpression rightConstant = right as ConstantExpression; 
     if (leftConstant != null && rightConstant != null 
      && (leftConstant.Value is double) && (rightConstant.Value is double)) 
     { 
      double leftValue = (double)leftConstant.Value; 
      double rightValue = (double)rightConstant.Value; 

      switch (node.NodeType) 
      { 
      case ExpressionType.Add: 
       return Expression.Constant(leftValue + rightValue); 
      case ExpressionType.Subtract: 
       return Expression.Constant(leftValue - rightValue); 
      case ExpressionType.Multiply: 
       return Expression.Constant(leftValue * rightValue); 
      case ExpressionType.Divide: 
       return Expression.Constant(leftValue/rightValue); 
      default: 
       throw new NotImplementedException(); 
      } 
     } 

     switch (node.NodeType) 
     { 
     case ExpressionType.Add: 
      if (IsZero(left)) 
       return right; 
      if (IsZero(right)) 
       return left; 
      break; 

     case ExpressionType.Subtract: 
      if (IsZero(left)) 
       return Expression.Negate(right); 
      if (IsZero(right)) 
       return left; 
      break; 

     case ExpressionType.Multiply: 
      if (IsZero(left) || IsZero(right)) 
       return MathExpressions.Zero; 
      if (IsOne(left)) 
       return right; 
      if (IsOne(right)) 
       return left; 
      break; 

     case ExpressionType.Divide: 
      if (IsZero(right)) 
       throw new DivideByZeroException(); 
      if (IsZero(left)) 
       return MathExpressions.Zero; 
      if (IsOne(right)) 
       return left; 
      break; 

     default: 
      throw new NotImplementedException(); 
     } 

     return Expression.MakeBinary(node.NodeType, left, right); 
    } 

    protected override Expression VisitUnary(UnaryExpression node) 
    { 
     var operand = Visit(node.Operand); 

     ConstantExpression operandConstant = operand as ConstantExpression; 
     if (operandConstant != null && (operandConstant.Value is double)) 
     { 
      double operandValue = (double)operandConstant.Value; 

      switch (node.NodeType) 
      { 
      case ExpressionType.Negate: 
       if (operandValue == 0.0) 
        return MathExpressions.Zero; 

       return Expression.Constant(-operandValue); 

      default: 
       throw new NotImplementedException(); 
      } 
     } 

     switch (node.NodeType) 
     { 
     case ExpressionType.Negate: 
      if (operand.NodeType == ExpressionType.Negate) 
      { 
       return ((UnaryExpression)operand).Operand; 
      } 

      break; 

     default: 
      throw new NotImplementedException(); 
     } 

     return Expression.MakeUnary(node.NodeType, operand, node.Type); 
    } 

    private static bool IsZero(Expression expression) 
    { 
     ConstantExpression constant = expression as ConstantExpression; 
     if (constant != null) 
     { 
      if (constant.Value.Equals(0.0)) 
       return true; 
     } 

     return false; 
    } 

    private static bool IsOne(Expression expression) 
    { 
     ConstantExpression constant = expression as ConstantExpression; 
     if (constant != null) 
     { 
      if (constant.Value.Equals(1.0)) 
       return true; 
     } 

     return false; 
    } 
} 

expresiones de formato para pantalla con ListPrintVisitor

internal class ListPrintVisitor : ExpressionVisitor 
{ 
    protected override Expression VisitBinary(BinaryExpression node) 
    { 
     string op = null; 

     switch (node.NodeType) 
     { 
     case ExpressionType.Add: 
      op = "+"; 
      break; 
     case ExpressionType.Subtract: 
      op = "-"; 
      break; 
     case ExpressionType.Multiply: 
      op = "*"; 
      break; 
     case ExpressionType.Divide: 
      op = "/"; 
      break; 
     default: 
      throw new NotImplementedException(); 
     } 

     var left = Visit(node.Left); 
     var right = Visit(node.Right); 
     string result = string.Format("({0} {1} {2})", op, ((ConstantExpression)left).Value, ((ConstantExpression)right).Value); 
     return Expression.Constant(result); 
    } 

    protected override Expression VisitConstant(ConstantExpression node) 
    { 
     if (node.Value is string) 
      return node; 

     return Expression.Constant(node.Value.ToString()); 
    } 

    protected override Expression VisitParameter(ParameterExpression node) 
    { 
     return Expression.Constant(node.Name); 
    } 
} 

Prueba de los resultados

[TestMethod] 
public void BasicSymbolicTest() 
{ 
    ParameterExpression x = Expression.Parameter(typeof(double), "x"); 
    Expression linear = Expression.Add(Expression.Constant(3.0), x); 
    Assert.AreEqual("(+ 3 x)", Symbolic.ToString(linear)); 

    Expression quadratic = Expression.Multiply(linear, Expression.Add(Expression.Constant(2.0), x)); 
    Assert.AreEqual("(* (+ 3 x) (+ 2 x))", Symbolic.ToString(quadratic)); 

    Expression expanded = Symbolic.Expand(quadratic); 
    Assert.AreEqual("(+ (+ (+ (* 3 2) (* 3 x)) (* x 2)) (* x x))", Symbolic.ToString(expanded)); 
    Assert.AreEqual("(+ (+ (+ 6 (* 3 x)) (* x 2)) (* x x))", Symbolic.ToString(Symbolic.Simplify(expanded))); 

    Expression derivative = Symbolic.PartialDerivative(expanded, x); 
    Assert.AreEqual("(+ (+ (+ (+ (* 3 0) (* 0 2)) (+ (* 3 1) (* 0 x))) (+ (* x 0) (* 1 2))) (+ (* x 1) (* 1 x)))", Symbolic.ToString(derivative)); 

    Expression simplified = Symbolic.Simplify(derivative); 
    Assert.AreEqual("(+ 5 (+ x x))", Symbolic.ToString(simplified)); 
} 
Cuestiones relacionadas