Added count(*) support to sql parser. Fixed warnings in sql parser. (#7490)

This commit is contained in:
treeform
2018-04-12 08:49:24 -07:00
committed by Andreas Rumpf
parent 63160855aa
commit f3db632b1d
2 changed files with 123 additions and 77 deletions

View File

@@ -11,7 +11,7 @@
## parser. It parses PostgreSQL syntax and the SQL ANSI standard.
import
hashes, strutils, lexbase, streams
hashes, strutils, lexbase
# ------------------- scanner -------------------------------------------------
@@ -62,10 +62,6 @@ const
"count",
]
proc open(L: var SqlLexer, input: Stream, filename: string) =
lexbase.open(L, input)
L.filename = filename
proc close(L: var SqlLexer) =
lexbase.close(L)
@@ -496,6 +492,7 @@ type
SqlNodeKind* = enum ## kind of SQL abstract syntax tree
nkNone,
nkIdent,
nkQuotedIdent,
nkStringLit,
nkBitStringLit,
nkHexStringLit,
@@ -551,13 +548,18 @@ type
nkCreateIndexIfNotExists,
nkEnumDef
const
LiteralNodes = {
nkIdent, nkQuotedIdent, nkStringLit, nkBitStringLit, nkHexStringLit,
nkIntegerLit, nkNumericLit
}
type
SqlParseError* = object of ValueError ## Invalid SQL encountered
SqlNode* = ref SqlNodeObj ## an SQL abstract syntax tree node
SqlNodeObj* = object ## an SQL abstract syntax tree node
case kind*: SqlNodeKind ## kind of syntax tree
of nkIdent, nkStringLit, nkBitStringLit, nkHexStringLit,
nkIntegerLit, nkNumericLit:
of LiteralNodes:
strVal*: string ## AST leaf: the identifier, numeric literal
## string literal, etc.
else:
@@ -566,21 +568,26 @@ type
SqlParser* = object of SqlLexer ## SQL parser object
tok: Token
{.deprecated: [EInvalidSql: SqlParseError, PSqlNode: SqlNode,
TSqlNode: SqlNodeObj, TSqlParser: SqlParser, TSqlNodeKind: SqlNodeKind].}
proc newNode(k: SqlNodeKind): SqlNode =
proc newNode*(k: SqlNodeKind): SqlNode =
new(result)
result.kind = k
proc newNode(k: SqlNodeKind, s: string): SqlNode =
proc newNode*(k: SqlNodeKind, s: string): SqlNode =
new(result)
result.kind = k
result.strVal = s
proc newNode*(k: SqlNodeKind, sons: seq[SqlNode]): SqlNode =
new(result)
result.kind = k
result.sons = sons
proc len*(n: SqlNode): int =
if n.kind in {nkIdent, nkStringLit, nkBitStringLit, nkHexStringLit,
nkIntegerLit, nkNumericLit}:
if n.kind in LiteralNodes:
result = 0
else:
result = n.sons.len
@@ -630,7 +637,7 @@ proc eat(p: var SqlParser, keyw: string) =
if isKeyw(p, keyw):
getTok(p)
else:
sqlError(p, keyw.toUpper() & " expected")
sqlError(p, keyw.toUpperAscii() & " expected")
proc opt(p: var SqlParser, kind: TokKind) =
if p.tok.kind == kind: getTok(p)
@@ -689,7 +696,10 @@ proc parseSelect(p: var SqlParser): SqlNode
proc identOrLiteral(p: var SqlParser): SqlNode =
case p.tok.kind
of tkIdentifier, tkQuotedIdentifier:
of tkQuotedIdentifier:
result = newNode(nkQuotedIdent, p.tok.literal)
getTok(p)
of tkIdentifier:
result = newNode(nkIdent, p.tok.literal)
getTok(p)
of tkStringConstant, tkEscapeConstant, tkDollarQuotedConstant:
@@ -713,11 +723,15 @@ proc identOrLiteral(p: var SqlParser): SqlNode =
result.add(parseExpr(p))
eat(p, tkParRi)
else:
sqlError(p, "expression expected")
getTok(p) # we must consume a token here to prevend endless loops!
if p.tok.literal == "*":
result = newNode(nkIdent, p.tok.literal)
getTok(p)
else:
sqlError(p, "expression expected")
getTok(p) # we must consume a token here to prevend endless loops!
proc primary(p: var SqlParser): SqlNode =
if p.tok.kind == tkOperator or isKeyw(p, "not"):
if (p.tok.kind == tkOperator and (p.tok.literal == "+" or p.tok.literal == "-")) or isKeyw(p, "not"):
result = newNode(nkPrefix)
result.add(newNode(nkIdent, p.tok.literal))
getTok(p)
@@ -762,7 +776,7 @@ proc lowestExprAux(p: var SqlParser, v: var SqlNode, limit: int): int =
result = opPred
while opPred > limit:
node = newNode(nkInfix)
opNode = newNode(nkIdent, p.tok.literal.toLower())
opNode = newNode(nkIdent, p.tok.literal.toLowerAscii())
getTok(p)
result = lowestExprAux(p, v2, opPred)
node.add(opNode)
@@ -1078,11 +1092,23 @@ proc parseSelect(p: var SqlParser): SqlNode =
if p.tok.kind != tkComma: break
getTok(p)
result.add(g)
if isKeyw(p, "limit"):
if isKeyw(p, "order"):
getTok(p)
var l = newNode(nkLimit)
l.add(parseExpr(p))
result.add(l)
eat(p, "by")
var n = newNode(nkOrder)
while true:
var e = parseExpr(p)
if isKeyw(p, "asc"):
getTok(p) # is default
elif isKeyw(p, "desc"):
getTok(p)
var x = newNode(nkDesc)
x.add(e)
e = x
n.add(e)
if p.tok.kind != tkComma: break
getTok(p)
result.add(n)
if isKeyw(p, "having"):
var h = newNode(nkHaving)
while true:
@@ -1099,22 +1125,6 @@ proc parseSelect(p: var SqlParser): SqlNode =
elif isKeyw(p, "except"):
result.add(newNode(nkExcept))
getTok(p)
if isKeyw(p, "order"):
getTok(p)
eat(p, "by")
var n = newNode(nkOrder)
while true:
var e = parseExpr(p)
if isKeyw(p, "asc"): getTok(p) # is default
elif isKeyw(p, "desc"):
getTok(p)
var x = newNode(nkDesc)
x.add(e)
e = x
n.add(e)
if p.tok.kind != tkComma: break
getTok(p)
result.add(n)
if isKeyw(p, "join") or isKeyw(p, "inner") or isKeyw(p, "outer") or isKeyw(p, "cross"):
var join = newNode(nkJoin)
result.add(join)
@@ -1122,12 +1132,17 @@ proc parseSelect(p: var SqlParser): SqlNode =
join.add(newNode(nkIdent, ""))
getTok(p)
else:
join.add(newNode(nkIdent, p.tok.literal.toLower()))
join.add(newNode(nkIdent, p.tok.literal.toLowerAscii()))
getTok(p)
eat(p, "join")
join.add(parseFromItem(p))
eat(p, "on")
join.add(parseExpr(p))
if isKeyw(p, "limit"):
getTok(p)
var l = newNode(nkLimit)
l.add(parseExpr(p))
result.add(l)
proc parseStmt(p: var SqlParser; parent: SqlNode) =
if isKeyw(p, "create"):
@@ -1161,14 +1176,6 @@ proc parseStmt(p: var SqlParser; parent: SqlNode) =
else:
sqlError(p, "SELECT, CREATE, UPDATE or DELETE expected")
proc open(p: var SqlParser, input: Stream, filename: string) =
## opens the parser `p` and assigns the input stream `input` to it.
## `filename` is only used for error messages.
open(SqlLexer(p), input, filename)
p.tok.kind = tkInvalid
p.tok.literal = ""
getTok(p)
proc parse(p: var SqlParser): SqlNode =
## parses the content of `p`'s input stream and returns the SQL AST.
## Syntax errors raise an `SqlParseError` exception.
@@ -1183,24 +1190,6 @@ proc close(p: var SqlParser) =
## closes the parser `p`. The associated input stream is closed too.
close(SqlLexer(p))
proc parseSQL*(input: Stream, filename: string): SqlNode =
## parses the SQL from `input` into an AST and returns the AST.
## `filename` is only used for error messages.
## Syntax errors raise an `SqlParseError` exception.
var p: SqlParser
open(p, input, filename)
try:
result = parse(p)
finally:
close(p)
proc parseSQL*(input: string, filename=""): SqlNode =
## parses the SQL from `input` into an AST and returns the AST.
## `filename` is only used for error messages.
## Syntax errors raise an `SqlParseError` exception.
parseSQL(newStringStream(input), "")
type
SqlWriter = object
indent: int
@@ -1218,12 +1207,12 @@ proc add(s: var SqlWriter, thing: string) =
proc addKeyw(s: var SqlWriter, thing: string) =
var keyw = thing
if s.upperCase:
keyw = keyw.toUpper()
keyw = keyw.toUpperAscii()
s.add(keyw)
proc addIden(s: var SqlWriter, thing: string) =
var iden = thing
if iden.toLower() in reservedKeywords:
if iden.toLowerAscii() in reservedKeywords:
iden = '"' & iden & '"'
s.add(iden)
@@ -1251,15 +1240,20 @@ proc addMulti(s: var SqlWriter, n: SqlNode, sep = ',', prefix, suffix: char) =
ra(n.sons[i], s)
s.add(suffix)
proc quoted(s: string): string =
"\"" & replace(s, "\"", "\"\"") & "\""
proc ra(n: SqlNode, s: var SqlWriter) =
if n == nil: return
case n.kind
of nkNone: discard
of nkIdent:
if allCharsInSet(n.strVal, {'\33'..'\127'}) and n.strVal.toLower() notin reservedKeywords:
if allCharsInSet(n.strVal, {'\33'..'\127'}):
s.add(n.strVal)
else:
s.add("\"" & replace(n.strVal, "\"", "\"\"") & "\"")
s.add(quoted(n.strVal))
of nkQuotedIdent:
s.add(quoted(n.strVal))
of nkStringLit:
s.add(escape(n.strVal, "'", "'"))
of nkBitStringLit:
@@ -1361,18 +1355,19 @@ proc ra(n: SqlNode, s: var SqlWriter) =
s.addKeyw("select")
if n.kind == nkSelectDistinct:
s.addKeyw("distinct")
s.addMulti(n.sons[0])
for i in 1 .. n.len-1:
for i in 0 ..< n.len:
ra(n.sons[i], s)
of nkSelectColumns:
assert(false)
for i, column in n.sons:
if i > 0: s.add(',')
ra(column, s)
of nkSelectPair:
ra(n.sons[0], s)
if n.sons.len == 2:
s.addKeyw("as")
ra(n.sons[1], s)
of nkFromItemPair:
if n.sons[0].kind == nkIdent:
if n.sons[0].kind in {nkIdent, nkQuotedIdent}:
ra(n.sons[0], s)
else:
assert n.sons[0].kind == nkSelect
@@ -1472,3 +1467,35 @@ proc renderSQL*(n: SqlNode, upperCase=false): string =
proc `$`*(n: SqlNode): string =
## an alias for `renderSQL`.
renderSQL(n)
when not defined(js):
import streams
proc open(L: var SqlLexer, input: Stream, filename: string) =
lexbase.open(L, input)
L.filename = filename
proc open(p: var SqlParser, input: Stream, filename: string) =
## opens the parser `p` and assigns the input stream `input` to it.
## `filename` is only used for error messages.
open(SqlLexer(p), input, filename)
p.tok.kind = tkInvalid
p.tok.literal = ""
getTok(p)
proc parseSQL*(input: Stream, filename: string): SqlNode =
## parses the SQL from `input` into an AST and returns the AST.
## `filename` is only used for error messages.
## Syntax errors raise an `SqlParseError` exception.
var p: SqlParser
open(p, input, filename)
try:
result = parse(p)
finally:
close(p)
proc parseSQL*(input: string, filename=""): SqlNode =
## parses the SQL from `input` into an AST and returns the AST.
## `filename` is only used for error messages.
## Syntax errors raise an `SqlParseError` exception.
parseSQL(newStringStream(input), "")

View File

@@ -26,10 +26,9 @@ doAssert $parseSQL("SELECT foo, bar, baz FROM table limit 10") == "select foo, b
doAssert $parseSQL("SELECT foo AS bar FROM table") == "select foo as bar from table;"
doAssert $parseSQL("SELECT foo AS foo_prime, bar AS bar_prime, baz AS baz_prime FROM table") == "select foo as foo_prime, bar as bar_prime, baz as baz_prime from table;"
doAssert $parseSQL("SELECT * FROM table") == "select * from table;"
#TODO add count(*)
#doAssert $parseSQL("SELECT COUNT(*) FROM table"
doAssert $parseSQL("SELECT count(*) FROM table") == "select count(*) from table;"
doAssert $parseSQL("SELECT count(*) as 'Total' FROM table") == "select count(*) as 'Total' from table;"
doAssert $parseSQL("SELECT count(*) as 'Total', sum(a) as 'Aggr' FROM table") == "select count(*) as 'Total', sum(a) as 'Aggr' from table;"
doAssert $parseSQL("""
SELECT * FROM table
@@ -50,6 +49,23 @@ WHERE
a and not b
""") == "select * from table where a and not b;"
doAssert $parseSQL("""
SELECT * FROM table
ORDER BY 1
""") == "select * from table order by 1;"
doAssert $parseSQL("""
SELECT * FROM table
GROUP BY 1
ORDER BY 1
""") == "select * from table group by 1 order by 1;"
doAssert $parseSQL("""
SELECT * FROM table
ORDER BY 1
LIMIT 100
""") == "select * from table order by 1 limit 100;"
doAssert $parseSQL("""
SELECT * FROM table
WHERE a = b and c = d or n is null and not b + 1 = 3
@@ -185,7 +201,10 @@ AND Country='USA'
ORDER BY CustomerName;
""") == "select * from Customers where(CustomerName like 'L%' or CustomerName like 'R%' or CustomerName like 'W%') and Country = 'USA' order by CustomerName;"
# parse keywords as identifires
# parse quoted keywords as identifires
doAssert $parseSQL("""
SELECT `SELECT`, `FROM` as `GROUP` FROM `WHERE`;
""") == """select "SELECT", "FROM" as "GROUP" from "WHERE";"""
doAssert $parseSQL("""
SELECT "SELECT", "FROM" as "GROUP" FROM "WHERE";
""") == """select "SELECT", "FROM" as "GROUP" from "WHERE";"""