mirror of
https://github.com/rbock/sqlpp11.git
synced 2024-11-15 20:31:16 +08:00
Partial rewrite of ddl2cpp
Initially setting out to fix #418, this change - accepts more SQL expressions - uses slightly more idiomatic pyparsing, I believe - uses black formatter - comes with some unit tests for the parser - simplifies options Tested with all SQL files in the repo.
This commit is contained in:
parent
d8a76fa282
commit
5b3abca4b1
816
scripts/ddl2cpp
816
scripts/ddl2cpp
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
#!/usr/bin/env python3
|
||||
|
||||
##
|
||||
# Copyright (c) 2013-2015, Roland Bock
|
||||
# Copyright (c) 2013-2022, Roland Bock
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
@ -25,406 +25,674 @@
|
||||
# OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
##
|
||||
|
||||
from __future__ import print_function
|
||||
import pyparsing as pp
|
||||
import sys
|
||||
import re
|
||||
import os
|
||||
|
||||
|
||||
# error codes, we should refactor this later
|
||||
ERROR_BAD_ARGS = 1
|
||||
ERROR_DATA_TYPE = 10
|
||||
ERROR_STRANGE_PARSING = 20
|
||||
|
||||
# Rather crude SQL expression parser.
|
||||
# This is not geared at correctly interpreting SQL, but at identifying (and ignoring) expressions for instance in DEFAULT expressions
|
||||
ddlLeft, ddlRight = map(pp.Suppress, "()")
|
||||
ddlNumber = pp.Word(pp.nums + "+-.", pp.nums + "+-.Ee")
|
||||
ddlString = (
|
||||
pp.QuotedString("'") | pp.QuotedString('"', escQuote='""') | pp.QuotedString("`")
|
||||
)
|
||||
ddlTerm = pp.Word(pp.alphas + "_", pp.alphanums + "_.$")
|
||||
ddlName = pp.Or([ddlTerm, ddlString, pp.Combine(ddlString + "." + ddlString)])
|
||||
ddlOperator = pp.Or(
|
||||
map(pp.CaselessLiteral, ["+", "-", "*", "/", "<", "<=", ">", ">=", "=", "%"])
|
||||
)
|
||||
|
||||
from pyparsing import CaselessLiteral, Literal, SkipTo, restOfLine, oneOf, ZeroOrMore, Optional, Combine, Suppress, \
|
||||
WordStart, WordEnd, Word, alphas, alphanums, nums, QuotedString, nestedExpr, MatchFirst, OneOrMore, delimitedList, \
|
||||
Or, Group, ParseException
|
||||
ddlBracedExpression = pp.Forward()
|
||||
ddlFunctionCall = pp.Forward()
|
||||
ddlCast = ddlString + "::" + ddlTerm
|
||||
ddlExpression = pp.OneOrMore(
|
||||
ddlBracedExpression
|
||||
| ddlFunctionCall
|
||||
| ddlCast
|
||||
| ddlOperator
|
||||
| ddlString
|
||||
| ddlTerm
|
||||
| ddlNumber
|
||||
)
|
||||
|
||||
ddlBracedExpression << ddlLeft + ddlExpression + ddlRight
|
||||
|
||||
ddlArguments = pp.Suppress(pp.Group(pp.delimitedList(ddlExpression)))
|
||||
ddlFunctionCall << ddlName + ddlLeft + pp.Optional(ddlArguments) + ddlRight
|
||||
|
||||
# Column and constraint parsers
|
||||
ddlBooleanTypes = [
|
||||
"bool",
|
||||
"boolean",
|
||||
]
|
||||
|
||||
ddlBoolean = pp.Or(
|
||||
map(pp.CaselessLiteral, sorted(ddlBooleanTypes, reverse=True))
|
||||
).setParseAction(pp.replaceWith("boolean"))
|
||||
|
||||
ddlIntegerTypes = [
|
||||
"bigint",
|
||||
"int",
|
||||
"int2", # PostgreSQL
|
||||
"int4", # PostgreSQL
|
||||
"int8", # PostgreSQL
|
||||
"integer",
|
||||
"mediumint",
|
||||
"smallint",
|
||||
"tinyint",
|
||||
]
|
||||
ddlInteger = pp.Or(
|
||||
map(pp.CaselessLiteral, sorted(ddlIntegerTypes, reverse=True))
|
||||
).setParseAction(pp.replaceWith("integer"))
|
||||
|
||||
ddlSerialTypes = [
|
||||
"bigserial", # PostgreSQL
|
||||
"serial", # PostgreSQL
|
||||
"smallserial", # PostgreSQL
|
||||
]
|
||||
ddlSerial = (
|
||||
pp.Or(map(pp.CaselessLiteral, sorted(ddlSerialTypes, reverse=True)))
|
||||
.setParseAction(pp.replaceWith("integer"))
|
||||
.setResultsName("hasAutoValue")
|
||||
)
|
||||
|
||||
ddlFloatingPointTypes = [
|
||||
"decimal", # MYSQL
|
||||
"double",
|
||||
"float8", # PostgreSQL
|
||||
"float",
|
||||
"float4", # PostgreSQL
|
||||
"numeric", # PostgreSQL
|
||||
"real",
|
||||
]
|
||||
ddlFloatingPoint = pp.Or(
|
||||
map(pp.CaselessLiteral, sorted(ddlFloatingPointTypes, reverse=True))
|
||||
).setParseAction(pp.replaceWith("floating_point"))
|
||||
|
||||
ddlTextTypes = [
|
||||
"char",
|
||||
"varchar",
|
||||
"character varying", # PostgreSQL
|
||||
"text",
|
||||
"clob",
|
||||
"enum", # MYSQL
|
||||
"set",
|
||||
"longtext", # MYSQL
|
||||
"jsonb", # PostgreSQL
|
||||
"json", # PostgreSQL
|
||||
"tinytext", # MYSQL
|
||||
]
|
||||
|
||||
ddlText = pp.Or(
|
||||
map(pp.CaselessLiteral, sorted(ddlTextTypes, reverse=True))
|
||||
).setParseAction(pp.replaceWith("text"))
|
||||
|
||||
ddlBlobTypes = [
|
||||
"bytea",
|
||||
"tinyblob",
|
||||
"blob",
|
||||
"mediumblob",
|
||||
"longblob",
|
||||
"binary", # MYSQL
|
||||
"varbinary", # MYSQL
|
||||
]
|
||||
|
||||
ddlBlob = pp.Or(
|
||||
map(pp.CaselessLiteral, sorted(ddlBlobTypes, reverse=True))
|
||||
).setParseAction(pp.replaceWith("blob"))
|
||||
|
||||
|
||||
ddlDateTypes = [
|
||||
"date",
|
||||
]
|
||||
|
||||
ddlDate = (
|
||||
pp.Or(map(pp.CaselessLiteral, sorted(ddlDateTypes, reverse=True)))
|
||||
.setParseAction(pp.replaceWith("day_point"))
|
||||
.setResultsName("warnTimezone")
|
||||
)
|
||||
|
||||
ddlDateTimeTypes = [
|
||||
"datetime",
|
||||
"timestamp",
|
||||
"timestamp without time zone", # PostgreSQL
|
||||
"timestamp with time zone", # PostgreSQL
|
||||
"timestamptz", # PostgreSQL
|
||||
]
|
||||
|
||||
ddlDateTime = pp.Or(
|
||||
map(pp.CaselessLiteral, sorted(ddlDateTimeTypes, reverse=True))
|
||||
).setParseAction(pp.replaceWith("time_point"))
|
||||
|
||||
ddlTimeTypes = [
|
||||
"time",
|
||||
"time without time zone", # PostgreSQL
|
||||
"time with time zone", # PostgreSQL
|
||||
]
|
||||
|
||||
ddlTime = pp.Or(
|
||||
map(pp.CaselessLiteral, sorted(ddlTimeTypes, reverse=True))
|
||||
).setParseAction(pp.replaceWith("time_of_day"))
|
||||
|
||||
|
||||
ddlUnknown = pp.Word(pp.alphanums).setParseAction(pp.replaceWith("UNKNOWN"))
|
||||
|
||||
ddlType = (
|
||||
ddlBoolean
|
||||
| ddlInteger
|
||||
| ddlSerial
|
||||
| ddlFloatingPoint
|
||||
| ddlText
|
||||
| ddlBlob
|
||||
| ddlDateTime
|
||||
| ddlDate
|
||||
| ddlTime
|
||||
| ddlUnknown
|
||||
)
|
||||
|
||||
ddlUnsigned = pp.CaselessLiteral("UNSIGNED").setResultsName("isUnsigned")
|
||||
ddlWidth = ddlLeft + pp.Word(pp.nums) + ddlRight
|
||||
ddlTimezone = (
|
||||
(pp.CaselessLiteral("with") | pp.CaselessLiteral("without"))
|
||||
+ pp.CaselessLiteral("time")
|
||||
+ pp.CaselessLiteral("zone")
|
||||
)
|
||||
|
||||
ddlNotNull = pp.Group(
|
||||
pp.CaselessLiteral("NOT") + pp.CaselessLiteral("NULL")
|
||||
).setResultsName("notNull")
|
||||
ddlDefaultValue = pp.CaselessLiteral("DEFAULT").setResultsName("hasDefaultValue")
|
||||
|
||||
ddlAutoKeywords = [
|
||||
"AUTO_INCREMENT",
|
||||
"AUTOINCREMENT",
|
||||
"SMALLSERIAL",
|
||||
"SERIAL",
|
||||
"BIGSERIAL",
|
||||
]
|
||||
ddlAutoValue = pp.Or(map(pp.CaselessLiteral, sorted(ddlAutoKeywords, reverse=True)))
|
||||
|
||||
ddlColumn = pp.Group(
|
||||
ddlName("name")
|
||||
+ ddlType("type")
|
||||
+ pp.Suppress(pp.Optional(ddlWidth))
|
||||
+ pp.Suppress(pp.Optional(ddlTimezone))
|
||||
+ pp.ZeroOrMore(
|
||||
ddlUnsigned("isUnsigned")
|
||||
| ddlNotNull("notNull")
|
||||
| pp.CaselessLiteral("null")
|
||||
| ddlAutoValue("hasAutoValue")
|
||||
| ddlDefaultValue("hasDefaultValue")
|
||||
| pp.Suppress(ddlExpression)
|
||||
)
|
||||
)
|
||||
|
||||
ddlConstraintKeywords = [
|
||||
"CONSTRAINT",
|
||||
"PRIMARY",
|
||||
"FOREIGN",
|
||||
"KEY",
|
||||
"FULLTEXT",
|
||||
"INDEX",
|
||||
"UNIQUE",
|
||||
"CHECK",
|
||||
]
|
||||
ddlConstraint = pp.Group(
|
||||
pp.Or(map(pp.CaselessLiteral, sorted(ddlConstraintKeywords, reverse=True)))
|
||||
+ ddlExpression
|
||||
).setResultsName("isConstraint")
|
||||
|
||||
# CREATE TABLE parser
|
||||
ddlIfNotExists = pp.Group(
|
||||
pp.CaselessLiteral("IF") + pp.CaselessLiteral("NOT") + pp.CaselessLiteral("EXISTS")
|
||||
).setResultsName("ifNotExists")
|
||||
ddlCreateTable = pp.Group(
|
||||
pp.CaselessLiteral("CREATE")
|
||||
+ pp.CaselessLiteral("TABLE")
|
||||
+ pp.Suppress(pp.Optional(ddlIfNotExists))
|
||||
+ ddlName.setResultsName("tableName")
|
||||
+ ddlLeft
|
||||
+ pp.Group(pp.delimitedList(ddlColumn | pp.Suppress(ddlConstraint))).setResultsName(
|
||||
"columns"
|
||||
)
|
||||
+ ddlRight
|
||||
).setResultsName("create")
|
||||
# ddlString.setDebug(True) #uncomment to debug pyparsing
|
||||
|
||||
ddl = pp.OneOrMore(pp.Suppress(pp.SkipTo(ddlCreateTable, False)) + ddlCreateTable)
|
||||
|
||||
ddlComment = pp.oneOf(["--", "#"]) + pp.restOfLine
|
||||
ddl.ignore(ddlComment)
|
||||
|
||||
|
||||
def testBoolean():
|
||||
for t in ddlBooleanTypes:
|
||||
result = ddlType.parseString(t, parseAll=True)
|
||||
assert result[0] == "boolean"
|
||||
|
||||
|
||||
def testInteger():
|
||||
for t in ddlIntegerTypes:
|
||||
result = ddlType.parseString(t, parseAll=True)
|
||||
assert result[0] == "integer"
|
||||
|
||||
|
||||
def testSerial():
|
||||
for t in ddlSerialTypes:
|
||||
result = ddlType.parseString(t, parseAll=True)
|
||||
assert result[0] == "integer"
|
||||
assert result.hasAutoValue
|
||||
|
||||
|
||||
def testFloatingPoint():
|
||||
for t in ddlFloatingPointTypes:
|
||||
result = ddlType.parseString(t, parseAll=True)
|
||||
assert result[0] == "floating_point"
|
||||
|
||||
|
||||
def testText():
|
||||
for t in ddlTextTypes:
|
||||
result = ddlType.parseString(t, parseAll=True)
|
||||
assert result[0] == "text"
|
||||
|
||||
|
||||
def testBlob():
|
||||
for t in ddlBlobTypes:
|
||||
result = ddlType.parseString(t, parseAll=True)
|
||||
assert result[0] == "blob"
|
||||
|
||||
|
||||
def testDate():
|
||||
for t in ddlDateTypes:
|
||||
result = ddlType.parseString(t, parseAll=True)
|
||||
assert result[0] == "day_point"
|
||||
|
||||
|
||||
def testDateTime():
|
||||
for t in ddlDateTimeTypes:
|
||||
result = ddlType.parseString(t, parseAll=True)
|
||||
assert result[0] == "time_point"
|
||||
|
||||
|
||||
def testTime():
|
||||
for t in ddlTimeTypes:
|
||||
result = ddlType.parseString(t, parseAll=True)
|
||||
assert result[0] == "time_of_day"
|
||||
|
||||
|
||||
def testUnknown():
|
||||
for t in ["cheesecake", "blueberry"]:
|
||||
result = ddlType.parseString(t, parseAll=True)
|
||||
assert result[0] == "UNKNOWN"
|
||||
|
||||
|
||||
def testAutoValue():
|
||||
def test(s, expected):
|
||||
results = ddlAutoValue.parseString(s, parseAll=True)
|
||||
print(results)
|
||||
|
||||
|
||||
def testColumn():
|
||||
text = "\"id\" int(8) unsigned NOT NULL DEFAULT nextval('dk_id_seq'::regclass)"
|
||||
result = ddlColumn.parseString(text, parseAll=True)
|
||||
column = result[0]
|
||||
assert column.name == "id"
|
||||
assert column.type == "integer"
|
||||
assert column.isUnsigned
|
||||
assert column.notNull
|
||||
assert not column.hasAutoValue
|
||||
|
||||
|
||||
def testConstraint():
|
||||
for text in [
|
||||
"CONSTRAINT unique_person UNIQUE (first_name, last_name)",
|
||||
"UNIQUE (id)",
|
||||
]:
|
||||
result = ddlConstraint.parseString(text, parseAll=True)
|
||||
assert result.isConstraint
|
||||
|
||||
|
||||
def testTable():
|
||||
text = """
|
||||
CREATE TABLE "public"."dk" (
|
||||
"id" int8 NOT NULL DEFAULT nextval('dk_id_seq'::regclass),
|
||||
"last_update" timestamp(6) DEFAULT now(),
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
"""
|
||||
result = ddlCreateTable.parseString(text, parseAll=True)
|
||||
|
||||
|
||||
def testParser():
|
||||
testBoolean()
|
||||
testInteger()
|
||||
testSerial()
|
||||
testFloatingPoint()
|
||||
testText()
|
||||
testBlob()
|
||||
testDate()
|
||||
testTime()
|
||||
testUnknown()
|
||||
testDateTime()
|
||||
testColumn()
|
||||
testConstraint()
|
||||
testTable()
|
||||
|
||||
|
||||
# CODE GENERATOR
|
||||
# HELPERS
|
||||
|
||||
|
||||
def get_include_guard_name(namespace, inputfile):
|
||||
val = re.sub("[^A-Za-z0-9]+", "_", namespace + '_' + os.path.basename(inputfile))
|
||||
val = re.sub("[^A-Za-z0-9]+", "_", namespace + "_" + os.path.basename(inputfile))
|
||||
return val.upper()
|
||||
|
||||
|
||||
def identity_naming_func(s):
|
||||
return s
|
||||
|
||||
|
||||
def repl_camel_case_func(m):
|
||||
if m.group(1) == '_':
|
||||
if m.group(1) == "_":
|
||||
return m.group(2).upper()
|
||||
else:
|
||||
return m.group(1) + m.group(2).upper()
|
||||
|
||||
|
||||
def class_name_naming_func(s):
|
||||
s = s.replace('.', '_')
|
||||
s = s.replace(".", "_")
|
||||
return re.sub("(^|\s|[_0-9])(\S)", repl_camel_case_func, s)
|
||||
|
||||
|
||||
def member_name_naming_func(s):
|
||||
s = s.replace('.', '_')
|
||||
s = s.replace(".", "_")
|
||||
return re.sub("(\s|_|[0-9])(\S)", repl_camel_case_func, s)
|
||||
|
||||
toClassName = class_name_naming_func
|
||||
toMemberName = member_name_naming_func
|
||||
|
||||
def repl_func_for_args(m):
|
||||
if m.group(1) == '-':
|
||||
if m.group(1) == "-":
|
||||
return m.group(2).upper()
|
||||
else:
|
||||
return m.group(1) + m.group(2).upper()
|
||||
|
||||
|
||||
def setArgumentBool(s, bool_value):
|
||||
first_lower = lambda s: s[:1].lower() + s[1:] if s else '' # http://stackoverflow.com/a/3847369/5006740
|
||||
first_lower = (
|
||||
lambda s: s[:1].lower() + s[1:] if s else ""
|
||||
) # http://stackoverflow.com/a/3847369/5006740
|
||||
var_name = first_lower(re.sub("(\s|-|[0-9])(\S)", repl_func_for_args, s))
|
||||
globals()[var_name] = bool_value
|
||||
|
||||
|
||||
def usage(optionalArgs = {}):
|
||||
print('\
|
||||
Usage: ddl2cpp <path to ddl> <path to target (without extension, e.g. /tmp/MyTable)> <namespace>\n\
|
||||
ddl2cpp -help')
|
||||
def escape_if_reserved(name):
|
||||
reserved_names = [
|
||||
"BEGIN",
|
||||
"END",
|
||||
"GROUP",
|
||||
"ORDER",
|
||||
]
|
||||
if name.upper() in reserved_names:
|
||||
return "!{}".format(name)
|
||||
return name
|
||||
|
||||
def beginHeader(pathToHeader, nsList):
|
||||
header = open(pathToHeader, 'w')
|
||||
print('// generated by ' + ' '.join(sys.argv), file=header)
|
||||
print('#ifndef '+get_include_guard_name(namespace, pathToHeader), file=header)
|
||||
print('#define '+get_include_guard_name(namespace, pathToHeader), file=header)
|
||||
print('', file=header)
|
||||
print('#include <' + INCLUDE + '/table.h>', file=header)
|
||||
print('#include <' + INCLUDE + '/data_types.h>', file=header)
|
||||
print('#include <' + INCLUDE + '/char_sequence.h>', file=header)
|
||||
print('', file=header)
|
||||
|
||||
def beginHeader(pathToHeader, namespace, nsList):
|
||||
header = open(pathToHeader, "w")
|
||||
print("// generated by " + " ".join(sys.argv), file=header)
|
||||
print("#ifndef " + get_include_guard_name(namespace, pathToHeader), file=header)
|
||||
print("#define " + get_include_guard_name(namespace, pathToHeader), file=header)
|
||||
print("", file=header)
|
||||
print("#include <sqlpp11/table.h>", file=header)
|
||||
print("#include <sqlpp11/data_types.h>", file=header)
|
||||
print("#include <sqlpp11/char_sequence.h>", file=header)
|
||||
print("", file=header)
|
||||
for ns in nsList:
|
||||
print('namespace ' + ns, file=header)
|
||||
print('{', file=header)
|
||||
print("namespace " + namespace, file=header)
|
||||
print("{", file=header)
|
||||
return header
|
||||
|
||||
|
||||
def endHeader(header, nsList):
|
||||
for ns in nsList:
|
||||
print('} // namespace ' + ns, file=header)
|
||||
print('#endif', file=header)
|
||||
print("} // namespace " + ns, file=header)
|
||||
print("#endif", file=header)
|
||||
header.close()
|
||||
|
||||
|
||||
def help_message():
|
||||
arg_string = '\n'
|
||||
arg_string = ""
|
||||
pad = 0
|
||||
padding = 0
|
||||
for argument in list(optionalArgs.keys()):
|
||||
if len(argument) > pad:
|
||||
pad = len(argument)
|
||||
for argument in list(optionalArgs.keys()):
|
||||
if argument == '-help':
|
||||
continue
|
||||
if len(argument) < pad:
|
||||
padding = " " * (pad - len(argument))
|
||||
else:
|
||||
padding = ''
|
||||
arg_string = arg_string + ' [-[no]'+argument+']: ' + padding + optionalArgs[argument] + '\n'
|
||||
print('Usage: ddl2cpp [-help]\n\n OPTIONAL ARGUMENTS:\n' + arg_string +' \n \
|
||||
<path to ddl> <path to target> <namespace>\n\
|
||||
\n\
|
||||
<path to ddl> path to your SQL database/table definitions (SHOW CREATE TABLE SomeTable) \n\
|
||||
<path to target> path to a generated C++ header file. Without extension (no *.h). \n\
|
||||
<namespace> namespace you want. Usually a project/database name\n')
|
||||
padding = ""
|
||||
arg_string = (
|
||||
arg_string + argument + ": " + padding + optionalArgs[argument] + "\n"
|
||||
)
|
||||
print(
|
||||
"Usage:\n"
|
||||
"ddl2cpp [optional args] <path to ddl> <path to target> <namespace>\n\n"
|
||||
"OPTIONAL ARGUMENTS:\n" + arg_string + "\n"
|
||||
"<path to ddl> path to your SQL database/table definitions (SHOW CREATE TABLE SomeTable) \n"
|
||||
"<path to target> path to a generated C++ header file without extension (no *.h). \n"
|
||||
"<namespace> namespace you want. Usually a project/database name\n"
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
optionalArgs = {
|
||||
# if -some-key is present, it will set variable someKey to True
|
||||
# if -no-some-key is present, it will set variable someKey to False
|
||||
'-timestamp-warning': "show warning about mysql timestamp data type", # timeStampWarning = True
|
||||
# '-no-time-stamp-warning' # timeStampWarning = False
|
||||
'-fail-on-parse': "abort instead of silent genereation of unusable headers", # failOnParse = True
|
||||
'-warn-on-parse': "warn about unusable headers, but continue", # warnOnParse = True
|
||||
'-auto-id': "Assume column 'id' to have an automatic value as if AUTO_INCREMENT was specified (e.g. implicit for SQLite ROWID)", # autoId = True
|
||||
'-identity-naming': "Use table and column names from the ddl (defaults to UpperCamelCase for tables and lowerCamelCase for columns)", # identityNaming = True
|
||||
'-split-tables': "Make a header for each table name, using target as a directory", # splitTables = True
|
||||
'-help': "show this help"
|
||||
"-no-timestamp-warning": "show warning about date / time data types", # noTimeStampWarning = True
|
||||
"-auto-id": "Assume column 'id' to have an automatic value as if AUTO_INCREMENT was specified (e.g. implicit for SQLite ROWID)", # autoId = True
|
||||
"-identity-naming": "Use table and column names from the ddl (defaults to UpperCamelCase for tables and lowerCamelCase for columns)", # identityNaming = True
|
||||
"-split-tables": "Make a header for each table name, using target as a directory", # splitTables = True
|
||||
"--help": "show this help",
|
||||
"--test": "run parser self-test",
|
||||
}
|
||||
|
||||
if '-help' in sys.argv:
|
||||
help_message()
|
||||
# ARGUMENT PARSING
|
||||
if len(sys.argv) < (4):
|
||||
usage(optionalArgs)
|
||||
sys.exit(ERROR_BAD_ARGS)
|
||||
|
||||
firstPositional = 1
|
||||
timestampWarning = True
|
||||
failOnParse = False
|
||||
warnOnParse = False
|
||||
parseError = "Parsing error, possible reason: can't parse default value for a field"
|
||||
noTimestampWarning = False
|
||||
autoId = False
|
||||
identityNaming = False
|
||||
splitTables = False
|
||||
|
||||
|
||||
def createHeader():
|
||||
global noTimestampWarning
|
||||
# ARGUMENT PARSING
|
||||
if len(sys.argv) < (4):
|
||||
help_message()
|
||||
sys.exit(ERROR_BAD_ARGS)
|
||||
|
||||
firstPositional = 1
|
||||
if len(sys.argv) >= 4:
|
||||
for arg in sys.argv:
|
||||
noArg = arg.replace('-no-', '-')
|
||||
if arg in list(optionalArgs.keys()):
|
||||
setArgumentBool(arg, True)
|
||||
firstPositional += 1
|
||||
elif noArg in optionalArgs:
|
||||
setArgumentBool(noArg, False)
|
||||
firstPositional += 1
|
||||
else:
|
||||
pass
|
||||
|
||||
if identityNaming:
|
||||
toClassName = identity_naming_func
|
||||
toMemberName = identity_naming_func
|
||||
else:
|
||||
toClassName = class_name_naming_func
|
||||
toMemberName = member_name_naming_func
|
||||
|
||||
pathToDdl = sys.argv[firstPositional]
|
||||
|
||||
pathToHeader = sys.argv[firstPositional + 1] + ('/' if splitTables else '.h')
|
||||
pathToHeader = sys.argv[firstPositional + 1] + ("/" if splitTables else ".h")
|
||||
namespace = sys.argv[firstPositional + 2]
|
||||
|
||||
|
||||
INCLUDE = 'sqlpp11'
|
||||
NAMESPACE = 'sqlpp'
|
||||
|
||||
|
||||
|
||||
# PARSER
|
||||
def ddlWord(string):
|
||||
return WordStart(alphanums + "_") + CaselessLiteral(string) + WordEnd(alphanums + "_")
|
||||
|
||||
# This function should be refactored if we find some database function which needs parameters
|
||||
# Right now it works only for something like NOW() in MySQL default field value
|
||||
def ddlFunctionWord(string):
|
||||
return CaselessLiteral(string) + OneOrMore("(") + ZeroOrMore(" ") + OneOrMore(")")
|
||||
|
||||
ddlString = Or([QuotedString("'"), QuotedString("\"", escQuote='""'), QuotedString("`")])
|
||||
negativeSign = Literal('-')
|
||||
ddlNum = Combine(Optional(negativeSign) + Word(nums + "."))
|
||||
ddlTerm = Word(alphanums + "_$")
|
||||
ddlName = Or([ddlTerm, ddlString, Combine(ddlString + "." + ddlString)])
|
||||
ddlMathOp = Word("+><=-")
|
||||
ddlBoolean = Or([ddlWord("AND"), ddlWord("OR"), ddlWord("NOT")])
|
||||
ddlArguments = "(" + delimitedList(Or([ddlString, ddlTerm, ddlNum])) + ")"
|
||||
ddlMathCond = "(" + delimitedList(
|
||||
Or([
|
||||
Group(ddlName + ddlMathOp + ddlName),
|
||||
Group(ddlName + ddlWord("NOT") + ddlWord("NULL")),
|
||||
]),
|
||||
delim=ddlBoolean) + ")"
|
||||
|
||||
ddlUnsigned = ddlWord("unsigned").setResultsName("isUnsigned")
|
||||
ddlNotNull = Group(ddlWord("NOT") + ddlWord("NULL")).setResultsName("notNull")
|
||||
ddlDefaultValue = ddlWord("DEFAULT").setResultsName("hasDefaultValue")
|
||||
ddlAutoValue = Or([
|
||||
ddlWord("AUTO_INCREMENT"),
|
||||
ddlWord("AUTOINCREMENT"),
|
||||
ddlWord("SMALLSERIAL"),
|
||||
ddlWord("SERIAL"),
|
||||
ddlWord("BIGSERIAL"),
|
||||
]).setResultsName("hasAutoValue")
|
||||
ddlColumnComment = Group(ddlWord("COMMENT") + ddlString).setResultsName("comment")
|
||||
ddlConstraint = Or([
|
||||
ddlWord("CONSTRAINT"),
|
||||
ddlWord("PRIMARY"),
|
||||
ddlWord("FOREIGN"),
|
||||
ddlWord("KEY"),
|
||||
ddlWord("FULLTEXT"),
|
||||
ddlWord("INDEX"),
|
||||
ddlWord("UNIQUE"),
|
||||
ddlWord("CHECK")
|
||||
])
|
||||
ddlColumn = Group(Optional(ddlConstraint).setResultsName("isConstraint") + OneOrMore(MatchFirst([ddlUnsigned, ddlNotNull, ddlAutoValue, ddlDefaultValue, ddlFunctionWord("NOW"), ddlFunctionWord("current_timestamp"), ddlTerm, ddlNum, ddlColumnComment, ddlString, ddlArguments, ddlMathCond])))
|
||||
ddlIfNotExists = Optional(Group(ddlWord("IF") + ddlWord("NOT") + ddlWord("EXISTS")).setResultsName("ifNotExists"))
|
||||
createTable = Group(ddlWord("CREATE") + ddlWord("TABLE") + ddlIfNotExists + ddlName.setResultsName("tableName") + "(" + Group(delimitedList(ddlColumn)).setResultsName("columns") + ")").setResultsName("create")
|
||||
#ddlString.setDebug(True) #uncomment to debug pyparsing
|
||||
|
||||
ddl = ZeroOrMore(Suppress(SkipTo(createTable, False)) + createTable)
|
||||
|
||||
ddlComment = oneOf(["--", "#"]) + restOfLine
|
||||
ddl.ignore(ddlComment)
|
||||
|
||||
# MAP SQL TYPES
|
||||
types = {
|
||||
'tinyint': 'tinyint',
|
||||
'smallint': 'smallint',
|
||||
'smallserial': 'smallint', # PostgreSQL
|
||||
'int2': 'smallint', #PostgreSQL
|
||||
'integer': 'integer',
|
||||
'int': 'integer',
|
||||
'serial': 'integer', # PostgreSQL
|
||||
'int4': 'integer', #PostgreSQL
|
||||
'mediumint' : 'integer',
|
||||
'bigint': 'bigint',
|
||||
'bigserial': 'bigint', # PostgreSQL
|
||||
'int8': 'bigint', #PostgreSQL
|
||||
'char': 'char_',
|
||||
'varchar': 'varchar',
|
||||
'character varying': 'varchar', #PostgreSQL
|
||||
'text': 'text',
|
||||
'clob': 'text',
|
||||
'bytea': 'blob',
|
||||
'tinyblob': 'blob',
|
||||
'blob': 'blob',
|
||||
'mediumblob': 'blob',
|
||||
'longblob': 'blob',
|
||||
'bool': 'boolean',
|
||||
'boolean': 'boolean',
|
||||
'double': 'floating_point',
|
||||
'float8': 'floating_point', # PostgreSQL
|
||||
'float': 'floating_point',
|
||||
'float4': 'floating_point', # PostgreSQL
|
||||
'real': 'floating_point',
|
||||
'numeric': 'floating_point', # PostgreSQL
|
||||
'decimal' : 'floating_point', # MYSQL
|
||||
'date': 'day_point',
|
||||
'datetime': 'time_point',
|
||||
'time': 'time_of_day',
|
||||
'time without time zone': 'time_point', # PostgreSQL
|
||||
'time with time zone': 'time_point', # PostgreSQL
|
||||
'timestamp': 'time_point',
|
||||
'timestamp without time zone': 'time_point', # PostgreSQL
|
||||
'timestamp with time zone': 'time_point', # PostgreSQL
|
||||
'timestamptz': 'time_point', # PostgreSQL
|
||||
'enum': 'text', # MYSQL
|
||||
'set': 'text', # MYSQL,
|
||||
'longtext' : 'text', #MYSQL
|
||||
'json' : 'text', # PostgreSQL
|
||||
'jsonb' : 'text', # PostgreSQL
|
||||
'tinyint unsigned': 'tinyint_unsigned', #MYSQL
|
||||
'smallint unsigned': 'smallint_unsigned', #MYSQL
|
||||
'integer unsigned': 'integer_unsigned', #MYSQL
|
||||
'int unsigned': 'integer_unsigned', #MYSQL
|
||||
'bigint unsigned': 'bigint_unsigned', #MYSQL
|
||||
'mediumint unsigned' : 'integer', #MYSQL
|
||||
'tinytext' : 'text', #MYSQL
|
||||
'binary' : 'blob', #MYSQL
|
||||
'varbinary' : 'blob', #MYSQL
|
||||
'float unsigned' : 'floating_point', #MYSQL
|
||||
}
|
||||
|
||||
if failOnParse:
|
||||
ddl = OneOrMore(Suppress(SkipTo(createTable, False)) + createTable)
|
||||
ddl.ignore(ddlComment)
|
||||
try:
|
||||
tableCreations = ddl.parseFile(pathToDdl)
|
||||
except ParseException as e:
|
||||
print(parseError + '. Exiting [-no-fail-on-parse]')
|
||||
except pp.ParseException as e:
|
||||
print("ERROR: Could not parse any CREATE TABLE statement in " + pathToDdl)
|
||||
# print(pp.parseError)
|
||||
sys.exit(ERROR_STRANGE_PARSING)
|
||||
else:
|
||||
ddl = ZeroOrMore(Suppress(SkipTo(createTable, False)) + createTable)
|
||||
ddl.ignore(ddlComment)
|
||||
tableCreations = ddl.parseFile(pathToDdl)
|
||||
|
||||
if warnOnParse:
|
||||
print(parseError + '. Continuing [-no-warn-on-parse]')
|
||||
|
||||
nsList = namespace.split('::')
|
||||
|
||||
def escape_if_reserved(name):
|
||||
reserved_names = [
|
||||
'BEGIN',
|
||||
'END',
|
||||
'GROUP',
|
||||
'ORDER',
|
||||
]
|
||||
if name.upper() in reserved_names:
|
||||
return '!{}'.format(name)
|
||||
return name
|
||||
|
||||
nsList = namespace.split("::")
|
||||
|
||||
# PROCESS DDL
|
||||
tableCreations = ddl.parseFile(pathToDdl)
|
||||
|
||||
header = 0
|
||||
if not splitTables:
|
||||
header = beginHeader(pathToHeader, nsList)
|
||||
header = beginHeader(pathToHeader, namespace, nsList)
|
||||
DataTypeError = False
|
||||
for create in tableCreations:
|
||||
sqlTableName = create.tableName
|
||||
if splitTables:
|
||||
header = beginHeader(pathToHeader + sqlTableName + '.h', nsList)
|
||||
header = beginHeader(pathToHeader + sqlTableName + ".h", namespace, nsList)
|
||||
tableClass = toClassName(sqlTableName)
|
||||
tableMember = toMemberName(sqlTableName)
|
||||
tableNamespace = tableClass + '_'
|
||||
tableNamespace = tableClass + "_"
|
||||
tableTemplateParameters = tableClass
|
||||
print(' namespace ' + tableNamespace, file=header)
|
||||
print(' {', file=header)
|
||||
print(" namespace " + tableNamespace, file=header)
|
||||
print(" {", file=header)
|
||||
for column in create.columns:
|
||||
if column.isConstraint:
|
||||
continue
|
||||
sqlColumnName = column[0]
|
||||
sqlColumnName = column.name
|
||||
columnClass = toClassName(sqlColumnName)
|
||||
tableTemplateParameters += ',\n ' + tableNamespace + '::' + columnClass
|
||||
tableTemplateParameters += (
|
||||
",\n " + tableNamespace + "::" + columnClass
|
||||
)
|
||||
columnMember = toMemberName(sqlColumnName)
|
||||
sqlColumnType = column[1].lower()
|
||||
if column.isUnsigned:
|
||||
sqlColumnType = sqlColumnType + ' unsigned';
|
||||
if sqlColumnType == 'timestamp' and timestampWarning:
|
||||
print("Warning: timestamp is mapped to sqlpp::time_point like datetime")
|
||||
print("Warning: You have to take care of timezones yourself")
|
||||
print("You can disable this warning using -no-timestamp-warning")
|
||||
columnCanBeNull = not column.notNull
|
||||
print(' struct ' + columnClass, file=header)
|
||||
print(' {', file=header)
|
||||
print(' struct _alias_t', file=header)
|
||||
print(' {', file=header)
|
||||
print(' static constexpr const char _literal[] = "' + escape_if_reserved(sqlColumnName) + '";', file=header)
|
||||
print(' using _name_t = sqlpp::make_char_sequence<sizeof(_literal), _literal>;', file=header)
|
||||
print(' template<typename T>', file=header)
|
||||
print(' struct _member_t', file=header)
|
||||
print(' {', file=header)
|
||||
print(' T ' + columnMember + ';', file=header)
|
||||
print(' T& operator()() { return ' + columnMember + '; }', file=header)
|
||||
print(' const T& operator()() const { return ' + columnMember + '; }', file=header)
|
||||
print(' };', file=header)
|
||||
print(' };', file=header)
|
||||
try:
|
||||
traitslist = [NAMESPACE + '::' + types[sqlColumnType]]
|
||||
except KeyError as e:
|
||||
print ('Error: datatype "' + sqlColumnType + '"" is not supported.')
|
||||
columnType = column.type
|
||||
if columnType == "UNKNOWN":
|
||||
print(
|
||||
"Error: datatype of %s.%s is not supported."
|
||||
% (sqlTableName, sqlColumnName)
|
||||
)
|
||||
DataTypeError = True
|
||||
if columnType == "integer" and column.isUnsigned:
|
||||
columnType = columnType + "_unsigned"
|
||||
if columnType == "time_point" and not noTimestampWarning:
|
||||
print(
|
||||
"Warning: date and time values are assumed to be without timezone."
|
||||
)
|
||||
print(
|
||||
"Warning: If you are using types WITH timezones, your code has to deal with that."
|
||||
)
|
||||
print("You can disable this warning using -no-timestamp-warning")
|
||||
noTimestampWarning = True
|
||||
traitslist = ["sqlpp::" + columnType]
|
||||
columnCanBeNull = not column.notNull
|
||||
print(" struct " + columnClass, file=header)
|
||||
print(" {", file=header)
|
||||
print(" struct _alias_t", file=header)
|
||||
print(" {", file=header)
|
||||
print(
|
||||
' static constexpr const char _literal[] = "'
|
||||
+ escape_if_reserved(sqlColumnName)
|
||||
+ '";',
|
||||
file=header,
|
||||
)
|
||||
print(
|
||||
" using _name_t = sqlpp::make_char_sequence<sizeof(_literal), _literal>;",
|
||||
file=header,
|
||||
)
|
||||
print(" template<typename T>", file=header)
|
||||
print(" struct _member_t", file=header)
|
||||
print(" {", file=header)
|
||||
print(" T " + columnMember + ";", file=header)
|
||||
print(
|
||||
" T& operator()() { return " + columnMember + "; }",
|
||||
file=header,
|
||||
)
|
||||
print(
|
||||
" const T& operator()() const { return "
|
||||
+ columnMember
|
||||
+ "; }",
|
||||
file=header,
|
||||
)
|
||||
print(" };", file=header)
|
||||
print(" };", file=header)
|
||||
requireInsert = True
|
||||
hasAutoValue = column.hasAutoValue or (autoId and sqlColumnName == 'id')
|
||||
hasAutoValue = column.hasAutoValue or (autoId and sqlColumnName == "id")
|
||||
if hasAutoValue:
|
||||
traitslist.append(NAMESPACE + '::tag::must_not_insert')
|
||||
traitslist.append(NAMESPACE + '::tag::must_not_update')
|
||||
traitslist.append("sqlpp::tag::must_not_insert")
|
||||
traitslist.append("sqlpp::tag::must_not_update")
|
||||
requireInsert = False
|
||||
if not column.notNull:
|
||||
traitslist.append(NAMESPACE + '::tag::can_be_null')
|
||||
traitslist.append("sqlpp::tag::can_be_null")
|
||||
requireInsert = False
|
||||
if column.hasDefaultValue:
|
||||
requireInsert = False
|
||||
if requireInsert:
|
||||
traitslist.append(NAMESPACE + '::tag::require_insert')
|
||||
print(' using _traits = ' + NAMESPACE + '::make_traits<' + ', '.join(traitslist) + '>;', file=header)
|
||||
print(' };', file=header)
|
||||
print(' } // namespace ' + tableNamespace, file=header)
|
||||
print('', file=header)
|
||||
traitslist.append("sqlpp::tag::require_insert")
|
||||
print(
|
||||
" using _traits = sqlpp::make_traits<"
|
||||
+ ", ".join(traitslist)
|
||||
+ ">;",
|
||||
file=header,
|
||||
)
|
||||
print(" };", file=header)
|
||||
print(" } // namespace " + tableNamespace, file=header)
|
||||
print("", file=header)
|
||||
|
||||
print(' struct ' + tableClass + ': ' + NAMESPACE + '::table_t<' + tableTemplateParameters + '>', file=header)
|
||||
print(' {', file=header)
|
||||
print(' struct _alias_t', file=header)
|
||||
print(' {', file=header)
|
||||
print(' static constexpr const char _literal[] = "' + sqlTableName + '";', file=header)
|
||||
print(' using _name_t = sqlpp::make_char_sequence<sizeof(_literal), _literal>;', file=header)
|
||||
print(' template<typename T>', file=header)
|
||||
print(' struct _member_t', file=header)
|
||||
print(' {', file=header)
|
||||
print(' T ' + tableMember + ';', file=header)
|
||||
print(' T& operator()() { return ' + tableMember + '; }', file=header)
|
||||
print(' const T& operator()() const { return ' + tableMember + '; }', file=header)
|
||||
print(' };', file=header)
|
||||
print(' };', file=header)
|
||||
print(' };', file=header)
|
||||
print(
|
||||
" struct "
|
||||
+ tableClass
|
||||
+ ": sqlpp::table_t<"
|
||||
+ tableTemplateParameters
|
||||
+ ">",
|
||||
file=header,
|
||||
)
|
||||
print(" {", file=header)
|
||||
print(" struct _alias_t", file=header)
|
||||
print(" {", file=header)
|
||||
print(
|
||||
' static constexpr const char _literal[] = "' + sqlTableName + '";',
|
||||
file=header,
|
||||
)
|
||||
print(
|
||||
" using _name_t = sqlpp::make_char_sequence<sizeof(_literal), _literal>;",
|
||||
file=header,
|
||||
)
|
||||
print(" template<typename T>", file=header)
|
||||
print(" struct _member_t", file=header)
|
||||
print(" {", file=header)
|
||||
print(" T " + tableMember + ";", file=header)
|
||||
print(" T& operator()() { return " + tableMember + "; }", file=header)
|
||||
print(
|
||||
" const T& operator()() const { return " + tableMember + "; }",
|
||||
file=header,
|
||||
)
|
||||
print(" };", file=header)
|
||||
print(" };", file=header)
|
||||
print(" };", file=header)
|
||||
if splitTables:
|
||||
endHeader(header, nsList)
|
||||
|
||||
if not splitTables:
|
||||
endHeader(header, nsList)
|
||||
if (DataTypeError):
|
||||
if DataTypeError:
|
||||
print("Error: unsupported datatypes.")
|
||||
print("Possible solutions:")
|
||||
print("A) Implement this datatype (examples: sqlpp11/data_types)")
|
||||
print("B) Extend/upgrade ddl2cpp (edit types map)")
|
||||
print("C) Raise an issue on github")
|
||||
sys.exit(10) # return non-zero error code, we might need it for automation
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if "--help" in sys.argv:
|
||||
help_message()
|
||||
sys.exit()
|
||||
elif "--test" in sys.argv:
|
||||
testParser()
|
||||
sys.exit()
|
||||
else:
|
||||
createHeader()
|
||||
|
Loading…
Reference in New Issue
Block a user