Generating SQL From LINQ Expression Trees
Generating SQL From LINQ Expression Trees
An expression tree is na abstract representation of code as data. In .NET they are primarily use for LINQ-
style code. In C#, lambda expressions can be decomposed into expression trees. Here is an example of a
lambda expression and its expression tree:
MemberExpression: Accessing a property, field, or method of an object or a variable. (Variable references in lambda
expressions are implemented as fields on a classe generated by the compiler.)
The following code recursively walks an expression tree and generates the equivalent where clause in SQL, for
sufficiently simple expressions. One of the areas that was a bit tricky is SQL’s handling of NULL. I have to check the
right side of a binary expression for NULL so I can generate “x IS NULL” instead of “x = NULL”. I used parentheses
liberally so ease composing the expressions. Handling negation was done naively. It could be cleaned up by
propagating the negation into the child node.
Input Output
x => x.PosId == 1 ([PosId] = 1)
x => x.IsAborted ([IsAborted] = 1)
x => !x.IsAborted (NOT ([IsAborted] = 1))
x => x.Name == null ([Name] IS NULL)
x => x.PosId == posId (where posId = 2) ([PosId] = 2)
x => x.PosId == 1 && x.Name == “Main” (([PosId] = 1) AND ([Name] = ‘Main’))
x => x.Name.Contains(“Main”) ([Name] LIKE ‘%Main%’)
x => x.Name.StartsWith(“R”) ([Name] LIKE ‘R%’)
x => list.Contains(x.PosId) ([PosId] IN (1, 2, 3))
private string Recurse(Expression expression, bool isUnary = false, bool quote = true)
{
if (expression is UnaryExpression)
{
var unary = (UnaryExpression)expression;
var right = Recurse(unary.Operand, true);
return "(" + NodeTypeToString(unary.NodeType, right == "NULL") + " " + right + ")";
}
if (expression is BinaryExpression)
{
var body = (BinaryExpression)expression;
var right = Recurse(body.Right);
return "(" + Recurse(body.Left) + " " + NodeTypeToString(body.NodeType, right == "NULL") + " " +
right + ")";
}
if (expression is ConstantExpression)
{
var constant = (ConstantExpression)expression;
return ValueToString(constant.Value, isUnary, quote);
}
if (expression is MemberExpression)
{
var member = (MemberExpression)expression;
if (member.Member is PropertyInfo)
{
var property = (PropertyInfo)member.Member;
var colName = _tableDef.GetColumnNameFor(property.Name);
if (isUnary && member.Type == typeof(bool))
{
return "([" + colName + "] = 1)";
}
return "[" + colName + "]";
}
if (member.Member is FieldInfo)
{
return ValueToString(GetValue(member), isUnary, quote);
}
throw new Exception($"Expression does not refer to a property or field: {expression}");
}
if (expression is MethodCallExpression)
{
var methodCall = (MethodCallExpression)expression;
// LIKE queries:
if (methodCall.Method == typeof(string).GetMethod("Contains", new[] { typeof(string) }))
{
return "(" + Recurse(methodCall.Object) + " LIKE '%" + Recurse(methodCall.Arguments[0],
quote: false) + "%')";
}
if (methodCall.Method == typeof(string).GetMethod("StartsWith", new[] { typeof(string) }))
{
return "(" + Recurse(methodCall.Object) + " LIKE '" + Recurse(methodCall.Arguments[0], quote:
false) + "%')";
}
if (methodCall.Method == typeof(string).GetMethod("EndsWith", new [] {typeof(string)}))
{
return "(" + Recurse(methodCall.Object) + " LIKE '%" + Recurse(methodCall.Arguments[0],
quote: false) + "')";
}
// IN queries:
if (methodCall.Method.Name == "Contains")
{
Expression collection;
Expression property;
if (methodCall.Method.IsDefined(typeof (ExtensionAttribute)) && methodCall.Arguments.Count ==
2)
{
collection = methodCall.Arguments[0];
property = methodCall.Arguments[1];
}
else if (!methodCall.Method.IsDefined(typeof (ExtensionAttribute)) &&
methodCall.Arguments.Count == 1)
{
collection = methodCall.Object;
property = methodCall.Arguments[0];
}
else
{
throw new Exception("Unsupported method call: " + methodCall.Method.Name);
}
var values = (IEnumerable)GetValue(collection);
var concated = "";
foreach (var e in values)
{
concated += ValueToString(e, false, true) + ", ";
}
if (concated == "")
{
return ValueToString(false, true, false);
}
return "(" + Recurse(property) + " IN (" + concated.Substring(0, concated.Length - 2) + "))";
}
throw new Exception("Unsupported method call: " + methodCall.Method.Name);
}
throw new Exception("Unsupported expression: " + expression.GetType().Name);
}
The following is a much improved version of the where clause builder. This version generates parameterized queries
so it isn’t vulnerable to SQL injection. Using parameters also allowed me to simply the logic since I don’t need to
worry about stringifying the values or using “IS” instead of “=” for null checking.
I moved all the string concatenation into a separate class called WherePart. These objects are composable in a
structure similar to the source expression tree. Extracting this class is my favorite part of the refactoring.
I’m still not happy with how I’m handling the LIKE queries. I have to pass a prefix and postfix parameter down to the
next level of recursion which clutters up the method signature. It might be better to just build the string in place.
public class WhereBuilder
{
private readonly IProvider _provider;
private TableDefinition _tableDef;
WherePart Recurse(
ref int i, Expression expression, bool isUnary = false, string prefix = null, string postfix = null)
{
if (expression is UnaryExpression)
{
var unary = (UnaryExpression)expression;
return WherePart.Concat(NodeTypeToString(unary.NodeType), Recurse(ref i, unary.Operand, true));
}
if (expression is BinaryExpression)
{
var body = (BinaryExpression)expression;
return WherePart.Concat(
Recurse(ref i, body.Left), NodeTypeToString(body.NodeType), Recurse(ref i, body.Right));
}
if (expression is ConstantExpression)
{
var constant = (ConstantExpression)expression;
var value = constant.Value;
if (value is int)
return WherePart.IsSql(value.ToString());
if (value is string)
value = prefix + (string)value + postfix;
if (expression is MemberExpression)
{
var member = (MemberExpression)expression;
if (member.Member is PropertyInfo)
{
var property = (PropertyInfo)member.Member;
var colName = _tableDef.GetColumnNameFor(property.Name);
if (isUnary && member.Type == typeof(bool))
return WherePart.Concat(Recurse(ref i, expression), "=", WherePart.IsParameter(i++,
true));
return WherePart.IsSql("[" + colName + "]");
}
if (member.Member is FieldInfo)
{
var value = GetValue(member);
if (value is string)
value = prefix + (string)value + postfix;
return WherePart.IsParameter(i++, value);
}
if (expression is MethodCallExpression)
{
var methodCall = (MethodCallExpression)expression;
// LIKE queries:
if (methodCall.Method == typeof(string).GetMethod("Contains", new[] { typeof(string) }))
return WherePart.Concat(
Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0],
prefix: "%", postfix: "%"));
// IN queries:
if (methodCall.Method.Name == "Contains")
{
Expression collection;
Expression property;
if (methodCall.Method.IsDefined(typeof(ExtensionAttribute)) && methodCall.Arguments.Count ==
2)
{
collection = methodCall.Arguments[0];
property = methodCall.Arguments[1];
}
else
throw new Exception("Unsupported method call: " + methodCall.Method.Name);