mirror of
https://github.com/nim-lang/Nim.git
synced 2026-01-03 11:42:33 +00:00
Added count(*) support to sql parser. Fixed warnings in sql parser. (#7490)
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
## parser. It parses PostgreSQL syntax and the SQL ANSI standard.
|
||||
|
||||
import
|
||||
hashes, strutils, lexbase, streams
|
||||
hashes, strutils, lexbase
|
||||
|
||||
# ------------------- scanner -------------------------------------------------
|
||||
|
||||
@@ -62,10 +62,6 @@ const
|
||||
"count",
|
||||
]
|
||||
|
||||
proc open(L: var SqlLexer, input: Stream, filename: string) =
|
||||
lexbase.open(L, input)
|
||||
L.filename = filename
|
||||
|
||||
proc close(L: var SqlLexer) =
|
||||
lexbase.close(L)
|
||||
|
||||
@@ -496,6 +492,7 @@ type
|
||||
SqlNodeKind* = enum ## kind of SQL abstract syntax tree
|
||||
nkNone,
|
||||
nkIdent,
|
||||
nkQuotedIdent,
|
||||
nkStringLit,
|
||||
nkBitStringLit,
|
||||
nkHexStringLit,
|
||||
@@ -551,13 +548,18 @@ type
|
||||
nkCreateIndexIfNotExists,
|
||||
nkEnumDef
|
||||
|
||||
const
|
||||
LiteralNodes = {
|
||||
nkIdent, nkQuotedIdent, nkStringLit, nkBitStringLit, nkHexStringLit,
|
||||
nkIntegerLit, nkNumericLit
|
||||
}
|
||||
|
||||
type
|
||||
SqlParseError* = object of ValueError ## Invalid SQL encountered
|
||||
SqlNode* = ref SqlNodeObj ## an SQL abstract syntax tree node
|
||||
SqlNodeObj* = object ## an SQL abstract syntax tree node
|
||||
case kind*: SqlNodeKind ## kind of syntax tree
|
||||
of nkIdent, nkStringLit, nkBitStringLit, nkHexStringLit,
|
||||
nkIntegerLit, nkNumericLit:
|
||||
of LiteralNodes:
|
||||
strVal*: string ## AST leaf: the identifier, numeric literal
|
||||
## string literal, etc.
|
||||
else:
|
||||
@@ -566,21 +568,26 @@ type
|
||||
SqlParser* = object of SqlLexer ## SQL parser object
|
||||
tok: Token
|
||||
|
||||
|
||||
{.deprecated: [EInvalidSql: SqlParseError, PSqlNode: SqlNode,
|
||||
TSqlNode: SqlNodeObj, TSqlParser: SqlParser, TSqlNodeKind: SqlNodeKind].}
|
||||
|
||||
proc newNode(k: SqlNodeKind): SqlNode =
|
||||
proc newNode*(k: SqlNodeKind): SqlNode =
|
||||
new(result)
|
||||
result.kind = k
|
||||
|
||||
proc newNode(k: SqlNodeKind, s: string): SqlNode =
|
||||
proc newNode*(k: SqlNodeKind, s: string): SqlNode =
|
||||
new(result)
|
||||
result.kind = k
|
||||
result.strVal = s
|
||||
|
||||
proc newNode*(k: SqlNodeKind, sons: seq[SqlNode]): SqlNode =
|
||||
new(result)
|
||||
result.kind = k
|
||||
result.sons = sons
|
||||
|
||||
proc len*(n: SqlNode): int =
|
||||
if n.kind in {nkIdent, nkStringLit, nkBitStringLit, nkHexStringLit,
|
||||
nkIntegerLit, nkNumericLit}:
|
||||
if n.kind in LiteralNodes:
|
||||
result = 0
|
||||
else:
|
||||
result = n.sons.len
|
||||
@@ -630,7 +637,7 @@ proc eat(p: var SqlParser, keyw: string) =
|
||||
if isKeyw(p, keyw):
|
||||
getTok(p)
|
||||
else:
|
||||
sqlError(p, keyw.toUpper() & " expected")
|
||||
sqlError(p, keyw.toUpperAscii() & " expected")
|
||||
|
||||
proc opt(p: var SqlParser, kind: TokKind) =
|
||||
if p.tok.kind == kind: getTok(p)
|
||||
@@ -689,7 +696,10 @@ proc parseSelect(p: var SqlParser): SqlNode
|
||||
|
||||
proc identOrLiteral(p: var SqlParser): SqlNode =
|
||||
case p.tok.kind
|
||||
of tkIdentifier, tkQuotedIdentifier:
|
||||
of tkQuotedIdentifier:
|
||||
result = newNode(nkQuotedIdent, p.tok.literal)
|
||||
getTok(p)
|
||||
of tkIdentifier:
|
||||
result = newNode(nkIdent, p.tok.literal)
|
||||
getTok(p)
|
||||
of tkStringConstant, tkEscapeConstant, tkDollarQuotedConstant:
|
||||
@@ -713,11 +723,15 @@ proc identOrLiteral(p: var SqlParser): SqlNode =
|
||||
result.add(parseExpr(p))
|
||||
eat(p, tkParRi)
|
||||
else:
|
||||
sqlError(p, "expression expected")
|
||||
getTok(p) # we must consume a token here to prevend endless loops!
|
||||
if p.tok.literal == "*":
|
||||
result = newNode(nkIdent, p.tok.literal)
|
||||
getTok(p)
|
||||
else:
|
||||
sqlError(p, "expression expected")
|
||||
getTok(p) # we must consume a token here to prevend endless loops!
|
||||
|
||||
proc primary(p: var SqlParser): SqlNode =
|
||||
if p.tok.kind == tkOperator or isKeyw(p, "not"):
|
||||
if (p.tok.kind == tkOperator and (p.tok.literal == "+" or p.tok.literal == "-")) or isKeyw(p, "not"):
|
||||
result = newNode(nkPrefix)
|
||||
result.add(newNode(nkIdent, p.tok.literal))
|
||||
getTok(p)
|
||||
@@ -762,7 +776,7 @@ proc lowestExprAux(p: var SqlParser, v: var SqlNode, limit: int): int =
|
||||
result = opPred
|
||||
while opPred > limit:
|
||||
node = newNode(nkInfix)
|
||||
opNode = newNode(nkIdent, p.tok.literal.toLower())
|
||||
opNode = newNode(nkIdent, p.tok.literal.toLowerAscii())
|
||||
getTok(p)
|
||||
result = lowestExprAux(p, v2, opPred)
|
||||
node.add(opNode)
|
||||
@@ -1078,11 +1092,23 @@ proc parseSelect(p: var SqlParser): SqlNode =
|
||||
if p.tok.kind != tkComma: break
|
||||
getTok(p)
|
||||
result.add(g)
|
||||
if isKeyw(p, "limit"):
|
||||
if isKeyw(p, "order"):
|
||||
getTok(p)
|
||||
var l = newNode(nkLimit)
|
||||
l.add(parseExpr(p))
|
||||
result.add(l)
|
||||
eat(p, "by")
|
||||
var n = newNode(nkOrder)
|
||||
while true:
|
||||
var e = parseExpr(p)
|
||||
if isKeyw(p, "asc"):
|
||||
getTok(p) # is default
|
||||
elif isKeyw(p, "desc"):
|
||||
getTok(p)
|
||||
var x = newNode(nkDesc)
|
||||
x.add(e)
|
||||
e = x
|
||||
n.add(e)
|
||||
if p.tok.kind != tkComma: break
|
||||
getTok(p)
|
||||
result.add(n)
|
||||
if isKeyw(p, "having"):
|
||||
var h = newNode(nkHaving)
|
||||
while true:
|
||||
@@ -1099,22 +1125,6 @@ proc parseSelect(p: var SqlParser): SqlNode =
|
||||
elif isKeyw(p, "except"):
|
||||
result.add(newNode(nkExcept))
|
||||
getTok(p)
|
||||
if isKeyw(p, "order"):
|
||||
getTok(p)
|
||||
eat(p, "by")
|
||||
var n = newNode(nkOrder)
|
||||
while true:
|
||||
var e = parseExpr(p)
|
||||
if isKeyw(p, "asc"): getTok(p) # is default
|
||||
elif isKeyw(p, "desc"):
|
||||
getTok(p)
|
||||
var x = newNode(nkDesc)
|
||||
x.add(e)
|
||||
e = x
|
||||
n.add(e)
|
||||
if p.tok.kind != tkComma: break
|
||||
getTok(p)
|
||||
result.add(n)
|
||||
if isKeyw(p, "join") or isKeyw(p, "inner") or isKeyw(p, "outer") or isKeyw(p, "cross"):
|
||||
var join = newNode(nkJoin)
|
||||
result.add(join)
|
||||
@@ -1122,12 +1132,17 @@ proc parseSelect(p: var SqlParser): SqlNode =
|
||||
join.add(newNode(nkIdent, ""))
|
||||
getTok(p)
|
||||
else:
|
||||
join.add(newNode(nkIdent, p.tok.literal.toLower()))
|
||||
join.add(newNode(nkIdent, p.tok.literal.toLowerAscii()))
|
||||
getTok(p)
|
||||
eat(p, "join")
|
||||
join.add(parseFromItem(p))
|
||||
eat(p, "on")
|
||||
join.add(parseExpr(p))
|
||||
if isKeyw(p, "limit"):
|
||||
getTok(p)
|
||||
var l = newNode(nkLimit)
|
||||
l.add(parseExpr(p))
|
||||
result.add(l)
|
||||
|
||||
proc parseStmt(p: var SqlParser; parent: SqlNode) =
|
||||
if isKeyw(p, "create"):
|
||||
@@ -1161,14 +1176,6 @@ proc parseStmt(p: var SqlParser; parent: SqlNode) =
|
||||
else:
|
||||
sqlError(p, "SELECT, CREATE, UPDATE or DELETE expected")
|
||||
|
||||
proc open(p: var SqlParser, input: Stream, filename: string) =
|
||||
## opens the parser `p` and assigns the input stream `input` to it.
|
||||
## `filename` is only used for error messages.
|
||||
open(SqlLexer(p), input, filename)
|
||||
p.tok.kind = tkInvalid
|
||||
p.tok.literal = ""
|
||||
getTok(p)
|
||||
|
||||
proc parse(p: var SqlParser): SqlNode =
|
||||
## parses the content of `p`'s input stream and returns the SQL AST.
|
||||
## Syntax errors raise an `SqlParseError` exception.
|
||||
@@ -1183,24 +1190,6 @@ proc close(p: var SqlParser) =
|
||||
## closes the parser `p`. The associated input stream is closed too.
|
||||
close(SqlLexer(p))
|
||||
|
||||
proc parseSQL*(input: Stream, filename: string): SqlNode =
|
||||
## parses the SQL from `input` into an AST and returns the AST.
|
||||
## `filename` is only used for error messages.
|
||||
## Syntax errors raise an `SqlParseError` exception.
|
||||
var p: SqlParser
|
||||
open(p, input, filename)
|
||||
try:
|
||||
result = parse(p)
|
||||
finally:
|
||||
close(p)
|
||||
|
||||
proc parseSQL*(input: string, filename=""): SqlNode =
|
||||
## parses the SQL from `input` into an AST and returns the AST.
|
||||
## `filename` is only used for error messages.
|
||||
## Syntax errors raise an `SqlParseError` exception.
|
||||
parseSQL(newStringStream(input), "")
|
||||
|
||||
|
||||
type
|
||||
SqlWriter = object
|
||||
indent: int
|
||||
@@ -1218,12 +1207,12 @@ proc add(s: var SqlWriter, thing: string) =
|
||||
proc addKeyw(s: var SqlWriter, thing: string) =
|
||||
var keyw = thing
|
||||
if s.upperCase:
|
||||
keyw = keyw.toUpper()
|
||||
keyw = keyw.toUpperAscii()
|
||||
s.add(keyw)
|
||||
|
||||
proc addIden(s: var SqlWriter, thing: string) =
|
||||
var iden = thing
|
||||
if iden.toLower() in reservedKeywords:
|
||||
if iden.toLowerAscii() in reservedKeywords:
|
||||
iden = '"' & iden & '"'
|
||||
s.add(iden)
|
||||
|
||||
@@ -1251,15 +1240,20 @@ proc addMulti(s: var SqlWriter, n: SqlNode, sep = ',', prefix, suffix: char) =
|
||||
ra(n.sons[i], s)
|
||||
s.add(suffix)
|
||||
|
||||
proc quoted(s: string): string =
|
||||
"\"" & replace(s, "\"", "\"\"") & "\""
|
||||
|
||||
proc ra(n: SqlNode, s: var SqlWriter) =
|
||||
if n == nil: return
|
||||
case n.kind
|
||||
of nkNone: discard
|
||||
of nkIdent:
|
||||
if allCharsInSet(n.strVal, {'\33'..'\127'}) and n.strVal.toLower() notin reservedKeywords:
|
||||
if allCharsInSet(n.strVal, {'\33'..'\127'}):
|
||||
s.add(n.strVal)
|
||||
else:
|
||||
s.add("\"" & replace(n.strVal, "\"", "\"\"") & "\"")
|
||||
s.add(quoted(n.strVal))
|
||||
of nkQuotedIdent:
|
||||
s.add(quoted(n.strVal))
|
||||
of nkStringLit:
|
||||
s.add(escape(n.strVal, "'", "'"))
|
||||
of nkBitStringLit:
|
||||
@@ -1361,18 +1355,19 @@ proc ra(n: SqlNode, s: var SqlWriter) =
|
||||
s.addKeyw("select")
|
||||
if n.kind == nkSelectDistinct:
|
||||
s.addKeyw("distinct")
|
||||
s.addMulti(n.sons[0])
|
||||
for i in 1 .. n.len-1:
|
||||
for i in 0 ..< n.len:
|
||||
ra(n.sons[i], s)
|
||||
of nkSelectColumns:
|
||||
assert(false)
|
||||
for i, column in n.sons:
|
||||
if i > 0: s.add(',')
|
||||
ra(column, s)
|
||||
of nkSelectPair:
|
||||
ra(n.sons[0], s)
|
||||
if n.sons.len == 2:
|
||||
s.addKeyw("as")
|
||||
ra(n.sons[1], s)
|
||||
of nkFromItemPair:
|
||||
if n.sons[0].kind == nkIdent:
|
||||
if n.sons[0].kind in {nkIdent, nkQuotedIdent}:
|
||||
ra(n.sons[0], s)
|
||||
else:
|
||||
assert n.sons[0].kind == nkSelect
|
||||
@@ -1472,3 +1467,35 @@ proc renderSQL*(n: SqlNode, upperCase=false): string =
|
||||
proc `$`*(n: SqlNode): string =
|
||||
## an alias for `renderSQL`.
|
||||
renderSQL(n)
|
||||
|
||||
when not defined(js):
|
||||
import streams
|
||||
|
||||
proc open(L: var SqlLexer, input: Stream, filename: string) =
|
||||
lexbase.open(L, input)
|
||||
L.filename = filename
|
||||
|
||||
proc open(p: var SqlParser, input: Stream, filename: string) =
|
||||
## opens the parser `p` and assigns the input stream `input` to it.
|
||||
## `filename` is only used for error messages.
|
||||
open(SqlLexer(p), input, filename)
|
||||
p.tok.kind = tkInvalid
|
||||
p.tok.literal = ""
|
||||
getTok(p)
|
||||
|
||||
proc parseSQL*(input: Stream, filename: string): SqlNode =
|
||||
## parses the SQL from `input` into an AST and returns the AST.
|
||||
## `filename` is only used for error messages.
|
||||
## Syntax errors raise an `SqlParseError` exception.
|
||||
var p: SqlParser
|
||||
open(p, input, filename)
|
||||
try:
|
||||
result = parse(p)
|
||||
finally:
|
||||
close(p)
|
||||
|
||||
proc parseSQL*(input: string, filename=""): SqlNode =
|
||||
## parses the SQL from `input` into an AST and returns the AST.
|
||||
## `filename` is only used for error messages.
|
||||
## Syntax errors raise an `SqlParseError` exception.
|
||||
parseSQL(newStringStream(input), "")
|
||||
|
||||
@@ -26,10 +26,9 @@ doAssert $parseSQL("SELECT foo, bar, baz FROM table limit 10") == "select foo, b
|
||||
doAssert $parseSQL("SELECT foo AS bar FROM table") == "select foo as bar from table;"
|
||||
doAssert $parseSQL("SELECT foo AS foo_prime, bar AS bar_prime, baz AS baz_prime FROM table") == "select foo as foo_prime, bar as bar_prime, baz as baz_prime from table;"
|
||||
doAssert $parseSQL("SELECT * FROM table") == "select * from table;"
|
||||
|
||||
|
||||
#TODO add count(*)
|
||||
#doAssert $parseSQL("SELECT COUNT(*) FROM table"
|
||||
doAssert $parseSQL("SELECT count(*) FROM table") == "select count(*) from table;"
|
||||
doAssert $parseSQL("SELECT count(*) as 'Total' FROM table") == "select count(*) as 'Total' from table;"
|
||||
doAssert $parseSQL("SELECT count(*) as 'Total', sum(a) as 'Aggr' FROM table") == "select count(*) as 'Total', sum(a) as 'Aggr' from table;"
|
||||
|
||||
doAssert $parseSQL("""
|
||||
SELECT * FROM table
|
||||
@@ -50,6 +49,23 @@ WHERE
|
||||
a and not b
|
||||
""") == "select * from table where a and not b;"
|
||||
|
||||
doAssert $parseSQL("""
|
||||
SELECT * FROM table
|
||||
ORDER BY 1
|
||||
""") == "select * from table order by 1;"
|
||||
|
||||
doAssert $parseSQL("""
|
||||
SELECT * FROM table
|
||||
GROUP BY 1
|
||||
ORDER BY 1
|
||||
""") == "select * from table group by 1 order by 1;"
|
||||
|
||||
doAssert $parseSQL("""
|
||||
SELECT * FROM table
|
||||
ORDER BY 1
|
||||
LIMIT 100
|
||||
""") == "select * from table order by 1 limit 100;"
|
||||
|
||||
doAssert $parseSQL("""
|
||||
SELECT * FROM table
|
||||
WHERE a = b and c = d or n is null and not b + 1 = 3
|
||||
@@ -185,7 +201,10 @@ AND Country='USA'
|
||||
ORDER BY CustomerName;
|
||||
""") == "select * from Customers where(CustomerName like 'L%' or CustomerName like 'R%' or CustomerName like 'W%') and Country = 'USA' order by CustomerName;"
|
||||
|
||||
# parse keywords as identifires
|
||||
# parse quoted keywords as identifires
|
||||
doAssert $parseSQL("""
|
||||
SELECT `SELECT`, `FROM` as `GROUP` FROM `WHERE`;
|
||||
""") == """select "SELECT", "FROM" as "GROUP" from "WHERE";"""
|
||||
doAssert $parseSQL("""
|
||||
SELECT "SELECT", "FROM" as "GROUP" FROM "WHERE";
|
||||
""") == """select "SELECT", "FROM" as "GROUP" from "WHERE";"""
|
||||
|
||||
Reference in New Issue
Block a user