mirror of
https://github.com/rbock/sqlpp11.git
synced 2024-11-16 04:47:18 +08:00
796 lines
25 KiB
Python
Executable File
796 lines
25 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
##
|
|
# Copyright (c) 2013-2022, Roland Bock
|
|
# All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without modification,
|
|
# are permitted provided that the following conditions are met:
|
|
#
|
|
# * Redistributions of source code must retain the above copyright notice,
|
|
# this list of conditions and the following disclaimer.
|
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
|
|
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
|
|
# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
|
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
|
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
|
|
# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
|
|
# OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
##
|
|
|
|
import pyparsing as pp
|
|
import sys
|
|
import re
|
|
import os
|
|
|
|
# error codes, we should refactor this later
|
|
ERROR_BAD_ARGS = 1
|
|
ERROR_DATA_TYPE = 10
|
|
ERROR_STRANGE_PARSING = 20
|
|
|
|
# Rather crude SQL expression parser.
|
|
# This is not geared at correctly interpreting SQL, but at identifying (and ignoring) expressions for instance in DEFAULT expressions
|
|
ddlLeft, ddlRight = map(pp.Suppress, "()")
|
|
ddlNumber = pp.Word(pp.nums + "+-.", pp.nums + "+-.Ee")
|
|
ddlString = (
|
|
pp.QuotedString("'") | pp.QuotedString('"', escQuote='""') | pp.QuotedString("`")
|
|
)
|
|
ddlTerm = pp.Word(pp.alphas + "_", pp.alphanums + "_.$")
|
|
ddlName = pp.Or([ddlTerm, ddlString, pp.Combine(ddlString + "." + ddlString), pp.Combine(ddlTerm + ddlString)])
|
|
ddlOperator = pp.Or(
|
|
map(pp.CaselessLiteral, ["+", "-", "*", "/", "<", "<=", ">", ">=", "=", "%"]),
|
|
pp.CaselessKeyword("DIV")
|
|
)
|
|
|
|
ddlBracedExpression = pp.Forward()
|
|
ddlFunctionCall = pp.Forward()
|
|
ddlCastEnd = "::" + ddlTerm
|
|
ddlCast = ddlString + ddlCastEnd
|
|
ddlBracedArguments = pp.Forward()
|
|
ddlExpression = pp.OneOrMore(
|
|
ddlBracedExpression
|
|
| ddlFunctionCall
|
|
| ddlCastEnd
|
|
| ddlCast
|
|
| ddlOperator
|
|
| ddlString
|
|
| ddlTerm
|
|
| ddlNumber
|
|
| ddlBracedArguments
|
|
)
|
|
|
|
ddlBracedArguments << ddlLeft + pp.delimitedList(ddlExpression) + ddlRight
|
|
ddlBracedExpression << ddlLeft + ddlExpression + ddlRight
|
|
|
|
ddlArguments = pp.Suppress(pp.Group(pp.delimitedList(ddlExpression)))
|
|
ddlFunctionCall << ddlName + ddlLeft + pp.Optional(ddlArguments) + ddlRight
|
|
|
|
# Data types
|
|
ddlBooleanTypes = [
|
|
"bool",
|
|
"boolean",
|
|
]
|
|
|
|
ddlIntegerTypes = [
|
|
"bigint",
|
|
"int",
|
|
"int2", # PostgreSQL
|
|
"int4", # PostgreSQL
|
|
"int8", # PostgreSQL
|
|
"integer",
|
|
"mediumint",
|
|
"smallint",
|
|
"tinyint",
|
|
]
|
|
|
|
ddlSerialTypes = [
|
|
"bigserial", # PostgreSQL
|
|
"serial", # PostgreSQL
|
|
"serial2", # PostgreSQL
|
|
"serial4", # PostgreSQL
|
|
"serial8", # PostgreSQL
|
|
"smallserial", # PostgreSQL
|
|
]
|
|
|
|
ddlFloatingPointTypes = [
|
|
"decimal", # MYSQL
|
|
"double",
|
|
"float8", # PostgreSQL
|
|
"float",
|
|
"float4", # PostgreSQL
|
|
"numeric", # PostgreSQL
|
|
"real",
|
|
]
|
|
|
|
ddlTextTypes = [
|
|
"char",
|
|
"varchar",
|
|
"character varying", # PostgreSQL
|
|
"text",
|
|
"clob",
|
|
"enum", # MYSQL
|
|
"set",
|
|
"longtext", # MYSQL
|
|
"jsonb", # PostgreSQL
|
|
"json", # PostgreSQL
|
|
"tinytext", # MYSQL
|
|
"mediumtext", # MYSQL
|
|
"rational", # PostgreSQL pg_rationale extension
|
|
]
|
|
|
|
ddlBlobTypes = [
|
|
"bytea",
|
|
"tinyblob",
|
|
"blob",
|
|
"mediumblob",
|
|
"longblob",
|
|
"binary", # MYSQL
|
|
"varbinary", # MYSQL
|
|
]
|
|
|
|
ddlDateTypes = [
|
|
"date",
|
|
]
|
|
|
|
ddlDateTimeTypes = [
|
|
"datetime",
|
|
"timestamp",
|
|
"timestamp without time zone", # PostgreSQL
|
|
"timestamp with time zone", # PostgreSQL
|
|
"timestamptz", # PostgreSQL
|
|
]
|
|
|
|
ddlTimeTypes = [
|
|
"time",
|
|
"time without time zone", # PostgreSQL
|
|
"time with time zone", # PostgreSQL
|
|
]
|
|
|
|
# Init the DLL parser
|
|
def initDllParser():
|
|
global ddl
|
|
global ddlType
|
|
global ddlColumn
|
|
global ddlConstraint
|
|
global ddlCreateTable
|
|
# Column and constraint parsers
|
|
|
|
ddlBoolean = pp.Or(
|
|
map(pp.CaselessKeyword, sorted(ddlBooleanTypes, reverse=True))
|
|
).setParseAction(pp.replaceWith("boolean"))
|
|
|
|
ddlInteger = pp.Or(
|
|
map(pp.CaselessKeyword, sorted(ddlIntegerTypes, reverse=True))
|
|
).setParseAction(pp.replaceWith("integer"))
|
|
|
|
ddlSerial = (
|
|
pp.Or(map(pp.CaselessKeyword, sorted(ddlSerialTypes, reverse=True)))
|
|
.setParseAction(pp.replaceWith("integer"))
|
|
.setResultsName("hasAutoValue")
|
|
)
|
|
|
|
ddlFloatingPoint = pp.Or(
|
|
map(pp.CaselessKeyword, sorted(ddlFloatingPointTypes, reverse=True))
|
|
).setParseAction(pp.replaceWith("floating_point"))
|
|
|
|
ddlText = pp.Or(
|
|
map(pp.CaselessKeyword, sorted(ddlTextTypes, reverse=True))
|
|
).setParseAction(pp.replaceWith("text"))
|
|
|
|
|
|
ddlBlob = pp.Or(
|
|
map(pp.CaselessKeyword, sorted(ddlBlobTypes, reverse=True))
|
|
).setParseAction(pp.replaceWith("blob"))
|
|
|
|
ddlDate = (
|
|
pp.Or(map(pp.CaselessKeyword, sorted(ddlDateTypes, reverse=True)))
|
|
.setParseAction(pp.replaceWith("day_point"))
|
|
.setResultsName("warnTimezone")
|
|
)
|
|
|
|
ddlDateTime = pp.Or(
|
|
map(pp.CaselessKeyword, sorted(ddlDateTimeTypes, reverse=True))
|
|
).setParseAction(pp.replaceWith("time_point"))
|
|
|
|
ddlTime = pp.Or(
|
|
map(pp.CaselessKeyword, sorted(ddlTimeTypes, reverse=True))
|
|
).setParseAction(pp.replaceWith("time_of_day"))
|
|
|
|
ddlUnknown = pp.Word(pp.alphanums).setParseAction(pp.replaceWith("UNKNOWN"))
|
|
|
|
ddlType = (
|
|
ddlBoolean
|
|
| ddlInteger
|
|
| ddlSerial
|
|
| ddlFloatingPoint
|
|
| ddlText
|
|
| ddlBlob
|
|
| ddlDateTime
|
|
| ddlDate
|
|
| ddlTime
|
|
| ddlUnknown
|
|
)
|
|
|
|
ddlUnsigned = pp.CaselessKeyword("UNSIGNED").setResultsName("isUnsigned")
|
|
ddlDigits = "," + pp.Word(pp.nums)
|
|
ddlWidth = ddlLeft + pp.Word(pp.nums) + pp.Optional(ddlDigits) + ddlRight
|
|
ddlTimezone = (
|
|
(pp.CaselessKeyword("with") | pp.CaselessKeyword("without"))
|
|
+ pp.CaselessKeyword("time")
|
|
+ pp.CaselessKeyword("zone")
|
|
)
|
|
|
|
ddlNotNull = pp.Group(
|
|
pp.CaselessKeyword("NOT") + pp.CaselessKeyword("NULL")
|
|
).setResultsName("notNull")
|
|
ddlDefaultValue = pp.CaselessKeyword("DEFAULT").setResultsName("hasDefaultValue")
|
|
|
|
ddlAutoKeywords = [
|
|
"AUTO_INCREMENT",
|
|
"AUTOINCREMENT",
|
|
"SMALLSERIAL",
|
|
"SERIAL",
|
|
"SERIAL2",
|
|
"SERIAL4",
|
|
"SERIAL8",
|
|
"BIGSERIAL",
|
|
"GENERATED",
|
|
]
|
|
ddlAutoValue = pp.Or(map(pp.CaselessKeyword, sorted(ddlAutoKeywords, reverse=True)))
|
|
|
|
ddlPrimaryKey = pp.Group(
|
|
pp.CaselessKeyword("PRIMARY") + pp.CaselessKeyword("KEY")
|
|
).setResultsName("isPrimaryKey")
|
|
|
|
ddlIgnoredKeywords = [
|
|
"CONSTRAINT",
|
|
"FOREIGN",
|
|
"KEY",
|
|
"FULLTEXT",
|
|
"INDEX",
|
|
"UNIQUE",
|
|
"CHECK",
|
|
"PERIOD",
|
|
]
|
|
ddlConstraint = pp.Group(
|
|
pp.Or(map(
|
|
pp.CaselessKeyword,
|
|
sorted(ddlIgnoredKeywords + ["PRIMARY"], reverse=True)
|
|
))
|
|
+ ddlExpression
|
|
).setResultsName("isConstraint")
|
|
|
|
ddlColumn = pp.Group(
|
|
ddlName("name")
|
|
+ ddlType("type")
|
|
+ pp.Suppress(pp.Optional(ddlWidth))
|
|
+ pp.Suppress(pp.Optional(ddlTimezone))
|
|
+ pp.ZeroOrMore(
|
|
ddlUnsigned("isUnsigned")
|
|
| ddlNotNull("notNull")
|
|
| pp.CaselessKeyword("null")
|
|
| ddlAutoValue("hasAutoValue")
|
|
| ddlDefaultValue("hasDefaultValue")
|
|
| ddlPrimaryKey("isPrimaryKey")
|
|
| pp.Suppress(pp.OneOrMore(pp.Or(map(pp.CaselessKeyword, sorted(ddlIgnoredKeywords, reverse=True)))))
|
|
| pp.Suppress(ddlExpression)
|
|
)
|
|
)
|
|
|
|
# CREATE TABLE parser
|
|
ddlIfNotExists = pp.Group(
|
|
pp.CaselessKeyword("IF") + pp.CaselessKeyword("NOT") + pp.CaselessKeyword("EXISTS")
|
|
).setResultsName("ifNotExists")
|
|
ddlOrReplace = pp.Group(
|
|
pp.CaselessKeyword("OR") + pp.CaselessKeyword("REPLACE")
|
|
).setResultsName("orReplace")
|
|
ddlCreateTable = pp.Group(
|
|
pp.CaselessKeyword("CREATE")
|
|
+ pp.Suppress(pp.Optional(ddlOrReplace))
|
|
+ pp.CaselessKeyword("TABLE")
|
|
+ pp.Suppress(pp.Optional(ddlIfNotExists))
|
|
+ ddlName.setResultsName("tableName")
|
|
+ ddlLeft
|
|
+ pp.Group(pp.delimitedList(pp.Suppress(ddlConstraint) | ddlColumn)).setResultsName(
|
|
"columns"
|
|
)
|
|
+ ddlRight
|
|
).setResultsName("create")
|
|
# ddlString.setDebug(True) #uncomment to debug pyparsing
|
|
|
|
ddl = pp.OneOrMore(pp.Suppress(pp.SkipTo(ddlCreateTable, False)) + ddlCreateTable)
|
|
|
|
ddlComment = pp.oneOf(["--", "#"]) + pp.restOfLine
|
|
ddl.ignore(ddlComment)
|
|
|
|
def testBoolean():
|
|
for t in ddlBooleanTypes:
|
|
result = ddlType.parseString(t, parseAll=True)
|
|
assert result[0] == "boolean"
|
|
|
|
|
|
def testInteger():
|
|
for t in ddlIntegerTypes:
|
|
result = ddlType.parseString(t, parseAll=True)
|
|
assert result[0] == "integer"
|
|
|
|
|
|
def testSerial():
|
|
for t in ddlSerialTypes:
|
|
result = ddlType.parseString(t, parseAll=True)
|
|
assert result[0] == "integer"
|
|
assert result.hasAutoValue
|
|
|
|
|
|
def testFloatingPoint():
|
|
for t in ddlFloatingPointTypes:
|
|
result = ddlType.parseString(t, parseAll=True)
|
|
assert result[0] == "floating_point"
|
|
|
|
|
|
def testText():
|
|
for t in ddlTextTypes:
|
|
result = ddlType.parseString(t, parseAll=True)
|
|
assert result[0] == "text"
|
|
|
|
|
|
def testBlob():
|
|
for t in ddlBlobTypes:
|
|
result = ddlType.parseString(t, parseAll=True)
|
|
assert result[0] == "blob"
|
|
|
|
|
|
def testDate():
|
|
for t in ddlDateTypes:
|
|
result = ddlType.parseString(t, parseAll=True)
|
|
assert result[0] == "day_point"
|
|
|
|
|
|
def testDateTime():
|
|
for t in ddlDateTimeTypes:
|
|
result = ddlType.parseString(t, parseAll=True)
|
|
assert result[0] == "time_point"
|
|
|
|
|
|
def testTime():
|
|
for t in ddlTimeTypes:
|
|
result = ddlType.parseString(t, parseAll=True)
|
|
assert result[0] == "time_of_day"
|
|
|
|
|
|
def testUnknown():
|
|
for t in ["cheesecake", "blueberry"]:
|
|
result = ddlType.parseString(t, parseAll=True)
|
|
assert result[0] == "UNKNOWN"
|
|
|
|
|
|
def testAutoValue():
|
|
def test(s, expected):
|
|
results = ddlAutoValue.parseString(s, parseAll=True)
|
|
print(results)
|
|
|
|
|
|
def testColumn():
|
|
text = "\"id\" int(8) unsigned NOT NULL DEFAULT nextval('dk_id_seq'::regclass)"
|
|
result = ddlColumn.parseString(text, parseAll=True)
|
|
column = result[0]
|
|
assert column.name == "id"
|
|
assert column.type == "integer"
|
|
assert column.isUnsigned
|
|
assert column.notNull
|
|
assert not column.hasAutoValue
|
|
|
|
|
|
def testConstraint():
|
|
for text in [
|
|
"CONSTRAINT unique_person UNIQUE (first_name, last_name)",
|
|
"UNIQUE (id)",
|
|
"UNIQUE (first_name,last_name)"
|
|
]:
|
|
result = ddlConstraint.parseString(text, parseAll=True)
|
|
assert result.isConstraint
|
|
|
|
def testMathExpression():
|
|
text = "2 DIV 2"
|
|
result = ddlExpression.parseString(text, parseAll=True)
|
|
assert len(result) == 3
|
|
assert result[0] == "2"
|
|
assert result[1] == "DIV"
|
|
assert result[2] == "2"
|
|
|
|
|
|
def testRational():
|
|
for text in [
|
|
"pos RATIONAL NOT NULL DEFAULT nextval('rational_seq')::integer",
|
|
]:
|
|
result = ddlColumn.parseString(text, parseAll=True)
|
|
column = result[0]
|
|
assert column.name == "pos"
|
|
assert column.type == "text"
|
|
assert column.notNull
|
|
|
|
|
|
def testTable():
|
|
text = """
|
|
CREATE TABLE "public"."dk" (
|
|
"id" int8 NOT NULL DEFAULT nextval('dk_id_seq'::regclass),
|
|
"strange" NUMERIC(314, 15),
|
|
"last_update" timestamp(6) DEFAULT now(),
|
|
PRIMARY KEY (id)
|
|
)
|
|
"""
|
|
result = ddlCreateTable.parseString(text, parseAll=True)
|
|
|
|
def testPrimaryKeyAutoIncrement():
|
|
for text in [
|
|
"CREATE TABLE tab (col INTEGER NOT NULL AUTO_INCREMENT PRIMARY KEY)", # mysql
|
|
"CREATE TABLE tab (col INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT)", # mysql
|
|
"CREATE TABLE tab (col INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT)", # sqlite
|
|
]:
|
|
result = ddlCreateTable.parseString(text, parseAll=True)
|
|
assert len(result) == 1
|
|
table = result[0]
|
|
assert table.tableName == "tab"
|
|
assert len(table.columns) == 1
|
|
column = table.columns[0]
|
|
assert not column.isConstraint
|
|
assert column.name == "col"
|
|
assert column.type == "integer"
|
|
assert column.notNull
|
|
assert column.hasAutoValue
|
|
|
|
def testParser():
|
|
initDllParser()
|
|
testBoolean()
|
|
testInteger()
|
|
testSerial()
|
|
testFloatingPoint()
|
|
testText()
|
|
testBlob()
|
|
testDate()
|
|
testTime()
|
|
testUnknown()
|
|
testDateTime()
|
|
testColumn()
|
|
testConstraint()
|
|
testMathExpression()
|
|
testRational()
|
|
testTable()
|
|
testPrimaryKeyAutoIncrement()
|
|
|
|
|
|
|
|
# HELPERS
|
|
def get_include_guard_name(namespace, inputfile):
|
|
val = re.sub("[^A-Za-z0-9]+", "_", namespace + "_" + os.path.basename(inputfile))
|
|
return val.upper()
|
|
|
|
|
|
def identity_naming_func(s):
|
|
return s
|
|
|
|
|
|
def repl_camel_case_func(m):
|
|
if m.group(1) == "_":
|
|
return m.group(2).upper()
|
|
else:
|
|
return m.group(1) + m.group(2).upper()
|
|
|
|
|
|
def class_name_naming_func(s):
|
|
s = s.replace(".", "_")
|
|
return re.sub(r"(^|\s|[_0-9])(\S)", repl_camel_case_func, s)
|
|
|
|
|
|
def member_name_naming_func(s):
|
|
s = s.replace(".", "_")
|
|
return re.sub(r"(\s|_|[0-9])(\S)", repl_camel_case_func, s)
|
|
|
|
|
|
def repl_func_for_args(m):
|
|
if m.group(1) == "-":
|
|
return m.group(2).upper()
|
|
else:
|
|
return m.group(1) + m.group(2).upper()
|
|
|
|
|
|
def setArgumentBool(s, bool_value):
|
|
first_lower = (
|
|
lambda s: s[:1].lower() + s[1:] if s else ""
|
|
) # http://stackoverflow.com/a/3847369/5006740
|
|
var_name = first_lower(re.sub(r"(\s|-|[0-9])(\S)", repl_func_for_args, s))
|
|
globals()[var_name] = bool_value
|
|
|
|
def loadExtendedTypesFile(filename):
|
|
import csv
|
|
with open(filename, newline='') as csvfile:
|
|
reader = csv.DictReader(csvfile, fieldnames=["baseType"], restkey="extendedTypes", delimiter=',')
|
|
for row in reader:
|
|
var_values = [clean_val for value in row['extendedTypes'] if (clean_val := value.strip(" \"'"))]
|
|
if var_values:
|
|
var_name = f"ddl{row['baseType']}Types"
|
|
globals()[var_name].extend(var_values)
|
|
|
|
def escape_if_reserved(name):
|
|
reserved_names = [
|
|
"BEGIN",
|
|
"END",
|
|
"GROUP",
|
|
"ORDER",
|
|
]
|
|
if name.upper() in reserved_names:
|
|
return "!{}".format(name)
|
|
return name
|
|
|
|
|
|
def beginHeader(pathToHeader, namespace, nsList):
|
|
header = open(pathToHeader, "w")
|
|
print('#pragma once', file=header)
|
|
print('', file=header)
|
|
print("// generated by " + " ".join(sys.argv), file=header)
|
|
print("", file=header)
|
|
print("#include <sqlpp11/table.h>", file=header)
|
|
print("#include <sqlpp11/data_types.h>", file=header)
|
|
print("#include <sqlpp11/char_sequence.h>", file=header)
|
|
print("", file=header)
|
|
for ns in nsList:
|
|
print("namespace " + ns, file=header)
|
|
print("{", file=header)
|
|
return header
|
|
|
|
|
|
def endHeader(header, nsList):
|
|
for ns in reversed(nsList):
|
|
print("} // namespace " + ns, file=header)
|
|
header.close()
|
|
|
|
|
|
def help_message():
|
|
arg_string = ""
|
|
pad = 0
|
|
|
|
# The dataTypeFileArg is handled differently from the normal optionalArgs
|
|
# and only added to the list here to make use of the formatting of the help.
|
|
optionalArgs[dataTypeFileArg] = f"path to a csv that contains custom datatype mappings. The format is '{dataTypeFileArg}=path/to/file.csv' (See the README)."
|
|
for argument in list(optionalArgs.keys()):
|
|
if len(argument) > pad:
|
|
pad = len(argument)
|
|
for argument in list(optionalArgs.keys()):
|
|
if len(argument) < pad:
|
|
padding = " " * (pad - len(argument))
|
|
else:
|
|
padding = ""
|
|
arg_string = (
|
|
arg_string + argument + ": " + padding + optionalArgs[argument] + "\n"
|
|
)
|
|
print(
|
|
"Usage:\n"
|
|
"ddl2cpp [optional args] <path to ddl> <path to target> <namespace>\n\n"
|
|
"OPTIONAL ARGUMENTS:\n" + arg_string + "\n"
|
|
"<path to ddl> path to your SQL database/table definitions (SHOW CREATE TABLE SomeTable) \n"
|
|
"<path to target> path to a generated C++ header file without extension (no *.h). \n"
|
|
"<namespace> namespace you want. Usually a project/database name\n"
|
|
)
|
|
sys.exit(0)
|
|
|
|
|
|
optionalArgs = {
|
|
# if -some-key is present, it will set variable someKey to True
|
|
"-no-timestamp-warning": "show warning about date / time data types", # noTimeStampWarning = True
|
|
"-auto-id": "Assume column 'id' to have an automatic value as if AUTO_INCREMENT was specified (e.g. implicit for SQLite ROWID)", # autoId = True
|
|
"-identity-naming": "Use table and column names from the ddl (defaults to UpperCamelCase for tables and lowerCamelCase for columns)", # identityNaming = True
|
|
"-split-tables": "Make a header for each table name, using target as a directory", # splitTables = True
|
|
"--help": "show this help",
|
|
"--test": "run parser self-test",
|
|
}
|
|
|
|
noTimestampWarning = False
|
|
autoId = False
|
|
identityNaming = False
|
|
splitTables = False
|
|
dataTypeFileArg = "--datatype-file"
|
|
|
|
def createHeader():
|
|
global noTimestampWarning
|
|
# ARGUMENT PARSING
|
|
if len(sys.argv) < (4):
|
|
help_message()
|
|
sys.exit(ERROR_BAD_ARGS)
|
|
|
|
firstPositional = 1
|
|
if len(sys.argv) >= 4:
|
|
for arg in sys.argv:
|
|
if arg in list(optionalArgs.keys()):
|
|
setArgumentBool(arg, True)
|
|
firstPositional += 1
|
|
if dataTypeFileArg in arg:
|
|
loadExtendedTypesFile(arg.split('=')[1])
|
|
firstPositional += 1
|
|
else:
|
|
pass
|
|
|
|
if identityNaming:
|
|
toClassName = identity_naming_func
|
|
toMemberName = identity_naming_func
|
|
else:
|
|
toClassName = class_name_naming_func
|
|
toMemberName = member_name_naming_func
|
|
|
|
pathToDdl = sys.argv[firstPositional]
|
|
|
|
pathToHeader = sys.argv[firstPositional + 1] + ("/" if splitTables else ".h")
|
|
namespace = sys.argv[firstPositional + 2]
|
|
|
|
initDllParser()
|
|
|
|
try:
|
|
tableCreations = ddl.parseFile(pathToDdl)
|
|
except pp.ParseException as e:
|
|
print("ERROR: Could not parse any CREATE TABLE statement in " + pathToDdl)
|
|
# print(pp.parseError)
|
|
sys.exit(ERROR_STRANGE_PARSING)
|
|
|
|
nsList = namespace.split("::")
|
|
|
|
# PROCESS DDL
|
|
tableCreations = ddl.parseFile(pathToDdl)
|
|
|
|
header = 0
|
|
if not splitTables:
|
|
header = beginHeader(pathToHeader, namespace, nsList)
|
|
DataTypeError = False
|
|
for create in tableCreations:
|
|
sqlTableName = create.tableName
|
|
if splitTables:
|
|
header = beginHeader(pathToHeader + sqlTableName + ".h", namespace, nsList)
|
|
tableClass = toClassName(sqlTableName)
|
|
tableMember = toMemberName(sqlTableName)
|
|
tableNamespace = tableClass + "_"
|
|
tableTemplateParameters = tableClass
|
|
print(" namespace " + tableNamespace, file=header)
|
|
print(" {", file=header)
|
|
for column in create.columns:
|
|
if column.isConstraint:
|
|
continue
|
|
sqlColumnName = column.name
|
|
columnClass = toClassName(sqlColumnName)
|
|
tableTemplateParameters += (
|
|
",\n " + tableNamespace + "::" + columnClass
|
|
)
|
|
columnMember = toMemberName(sqlColumnName)
|
|
columnType = column.type
|
|
if columnType == "UNKNOWN":
|
|
print(
|
|
"Error: datatype of %s.%s is not supported."
|
|
% (sqlTableName, sqlColumnName)
|
|
)
|
|
DataTypeError = True
|
|
if columnType == "integer" and column.isUnsigned:
|
|
columnType = columnType + "_unsigned"
|
|
if columnType == "time_point" and not noTimestampWarning:
|
|
print(
|
|
"Warning: date and time values are assumed to be without timezone."
|
|
)
|
|
print(
|
|
"Warning: If you are using types WITH timezones, your code has to deal with that."
|
|
)
|
|
print("You can disable this warning using -no-timestamp-warning")
|
|
noTimestampWarning = True
|
|
traitslist = ["sqlpp::" + columnType]
|
|
columnCanBeNull = not column.notNull
|
|
print(" struct " + columnClass, file=header)
|
|
print(" {", file=header)
|
|
print(" struct _alias_t", file=header)
|
|
print(" {", file=header)
|
|
print(
|
|
' static constexpr const char _literal[] = "'
|
|
+ escape_if_reserved(sqlColumnName)
|
|
+ '";',
|
|
file=header,
|
|
)
|
|
print(
|
|
" using _name_t = sqlpp::make_char_sequence<sizeof(_literal), _literal>;",
|
|
file=header,
|
|
)
|
|
print(" template<typename T>", file=header)
|
|
print(" struct _member_t", file=header)
|
|
print(" {", file=header)
|
|
print(" T " + columnMember + ";", file=header)
|
|
print(
|
|
" T& operator()() { return " + columnMember + "; }",
|
|
file=header,
|
|
)
|
|
print(
|
|
" const T& operator()() const { return "
|
|
+ columnMember
|
|
+ "; }",
|
|
file=header,
|
|
)
|
|
print(" };", file=header)
|
|
print(" };", file=header)
|
|
requireInsert = True
|
|
hasAutoValue = column.hasAutoValue or (autoId and sqlColumnName == "id")
|
|
if hasAutoValue:
|
|
traitslist.append("sqlpp::tag::must_not_insert")
|
|
traitslist.append("sqlpp::tag::must_not_update")
|
|
requireInsert = False
|
|
if not column.notNull and not column.isPrimaryKey:
|
|
traitslist.append("sqlpp::tag::can_be_null")
|
|
requireInsert = False
|
|
if column.hasDefaultValue:
|
|
requireInsert = False
|
|
if requireInsert:
|
|
traitslist.append("sqlpp::tag::require_insert")
|
|
print(
|
|
" using _traits = sqlpp::make_traits<"
|
|
+ ", ".join(traitslist)
|
|
+ ">;",
|
|
file=header,
|
|
)
|
|
print(" };", file=header)
|
|
print(" } // namespace " + tableNamespace, file=header)
|
|
print("", file=header)
|
|
|
|
print(
|
|
" struct "
|
|
+ tableClass
|
|
+ ": sqlpp::table_t<"
|
|
+ tableTemplateParameters
|
|
+ ">",
|
|
file=header,
|
|
)
|
|
print(" {", file=header)
|
|
print(" struct _alias_t", file=header)
|
|
print(" {", file=header)
|
|
print(
|
|
' static constexpr const char _literal[] = "' + sqlTableName + '";',
|
|
file=header,
|
|
)
|
|
print(
|
|
" using _name_t = sqlpp::make_char_sequence<sizeof(_literal), _literal>;",
|
|
file=header,
|
|
)
|
|
print(" template<typename T>", file=header)
|
|
print(" struct _member_t", file=header)
|
|
print(" {", file=header)
|
|
print(" T " + tableMember + ";", file=header)
|
|
print(" T& operator()() { return " + tableMember + "; }", file=header)
|
|
print(
|
|
" const T& operator()() const { return " + tableMember + "; }",
|
|
file=header,
|
|
)
|
|
print(" };", file=header)
|
|
print(" };", file=header)
|
|
print(" };", file=header)
|
|
if splitTables:
|
|
endHeader(header, nsList)
|
|
|
|
if not splitTables:
|
|
endHeader(header, nsList)
|
|
if DataTypeError:
|
|
print("Error: unsupported datatypes.")
|
|
print("Possible solutions:")
|
|
print("A) Implement this datatype (examples: sqlpp11/data_types)")
|
|
print(f"B) Use the '{dataTypeFileArg}' command line argument to map the type to a known type (example: README)")
|
|
print("C) Extend/upgrade ddl2cpp (edit types map)")
|
|
print("D) Raise an issue on github")
|
|
sys.exit(10) # return non-zero error code, we might need it for automation
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if "--help" in sys.argv:
|
|
help_message()
|
|
sys.exit()
|
|
elif "--test" in sys.argv:
|
|
testParser()
|
|
sys.exit()
|
|
else:
|
|
createHeader()
|