From 5b3abca4b1785a0ee952d6f5f0ee081b9f614848 Mon Sep 17 00:00:00 2001 From: Roland Bock Date: Sun, 6 Feb 2022 17:25:24 +0100 Subject: [PATCH] Partial rewrite of ddl2cpp Initially setting out to fix #418, this change - accepts more SQL expressions - uses slightly more idiomatic pyparsing, I believe - uses black formatter - comes with some unit tests for the parser - simplifies options Tested with all SQL files in the repo. --- scripts/ddl2cpp | 970 ++++++++++++++++++++++++++++++------------------ 1 file changed, 619 insertions(+), 351 deletions(-) diff --git a/scripts/ddl2cpp b/scripts/ddl2cpp index 976b8289..db8b676b 100755 --- a/scripts/ddl2cpp +++ b/scripts/ddl2cpp @@ -1,430 +1,698 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 ## - # Copyright (c) 2013-2015, Roland Bock - # All rights reserved. - # - # Redistribution and use in source and binary forms, with or without modification, - # are permitted provided that the following conditions are met: - # - # * Redistributions of source code must retain the above copyright notice, - # this list of conditions and the following disclaimer. - # * Redistributions in binary form must reproduce the above copyright notice, - # this list of conditions and the following disclaimer in the documentation - # and/or other materials provided with the distribution. - # - # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. - # IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, - # INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE - # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED - # OF THE POSSIBILITY OF SUCH DAMAGE. - ## +# Copyright (c) 2013-2022, Roland Bock +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, +# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED +# OF THE POSSIBILITY OF SUCH DAMAGE. +## -from __future__ import print_function +import pyparsing as pp import sys import re import os - # error codes, we should refactor this later ERROR_BAD_ARGS = 1 ERROR_DATA_TYPE = 10 ERROR_STRANGE_PARSING = 20 +# Rather crude SQL expression parser. +# This is not geared at correctly interpreting SQL, but at identifying (and ignoring) expressions for instance in DEFAULT expressions +ddlLeft, ddlRight = map(pp.Suppress, "()") +ddlNumber = pp.Word(pp.nums + "+-.", pp.nums + "+-.Ee") +ddlString = ( + pp.QuotedString("'") | pp.QuotedString('"', escQuote='""') | pp.QuotedString("`") +) +ddlTerm = pp.Word(pp.alphas + "_", pp.alphanums + "_.$") +ddlName = pp.Or([ddlTerm, ddlString, pp.Combine(ddlString + "." + ddlString)]) +ddlOperator = pp.Or( + map(pp.CaselessLiteral, ["+", "-", "*", "/", "<", "<=", ">", ">=", "=", "%"]) +) -from pyparsing import CaselessLiteral, Literal, SkipTo, restOfLine, oneOf, ZeroOrMore, Optional, Combine, Suppress, \ - WordStart, WordEnd, Word, alphas, alphanums, nums, QuotedString, nestedExpr, MatchFirst, OneOrMore, delimitedList, \ - Or, Group, ParseException +ddlBracedExpression = pp.Forward() +ddlFunctionCall = pp.Forward() +ddlCast = ddlString + "::" + ddlTerm +ddlExpression = pp.OneOrMore( + ddlBracedExpression + | ddlFunctionCall + | ddlCast + | ddlOperator + | ddlString + | ddlTerm + | ddlNumber +) + +ddlBracedExpression << ddlLeft + ddlExpression + ddlRight + +ddlArguments = pp.Suppress(pp.Group(pp.delimitedList(ddlExpression))) +ddlFunctionCall << ddlName + ddlLeft + pp.Optional(ddlArguments) + ddlRight + +# Column and constraint parsers +ddlBooleanTypes = [ + "bool", + "boolean", +] + +ddlBoolean = pp.Or( + map(pp.CaselessLiteral, sorted(ddlBooleanTypes, reverse=True)) +).setParseAction(pp.replaceWith("boolean")) + +ddlIntegerTypes = [ + "bigint", + "int", + "int2", # PostgreSQL + "int4", # PostgreSQL + "int8", # PostgreSQL + "integer", + "mediumint", + "smallint", + "tinyint", +] +ddlInteger = pp.Or( + map(pp.CaselessLiteral, sorted(ddlIntegerTypes, reverse=True)) +).setParseAction(pp.replaceWith("integer")) + +ddlSerialTypes = [ + "bigserial", # PostgreSQL + "serial", # PostgreSQL + "smallserial", # PostgreSQL +] +ddlSerial = ( + pp.Or(map(pp.CaselessLiteral, sorted(ddlSerialTypes, reverse=True))) + .setParseAction(pp.replaceWith("integer")) + .setResultsName("hasAutoValue") +) + +ddlFloatingPointTypes = [ + "decimal", # MYSQL + "double", + "float8", # PostgreSQL + "float", + "float4", # PostgreSQL + "numeric", # PostgreSQL + "real", +] +ddlFloatingPoint = pp.Or( + map(pp.CaselessLiteral, sorted(ddlFloatingPointTypes, reverse=True)) +).setParseAction(pp.replaceWith("floating_point")) + +ddlTextTypes = [ + "char", + "varchar", + "character varying", # PostgreSQL + "text", + "clob", + "enum", # MYSQL + "set", + "longtext", # MYSQL + "jsonb", # PostgreSQL + "json", # PostgreSQL + "tinytext", # MYSQL +] + +ddlText = pp.Or( + map(pp.CaselessLiteral, sorted(ddlTextTypes, reverse=True)) +).setParseAction(pp.replaceWith("text")) + +ddlBlobTypes = [ + "bytea", + "tinyblob", + "blob", + "mediumblob", + "longblob", + "binary", # MYSQL + "varbinary", # MYSQL +] + +ddlBlob = pp.Or( + map(pp.CaselessLiteral, sorted(ddlBlobTypes, reverse=True)) +).setParseAction(pp.replaceWith("blob")) +ddlDateTypes = [ + "date", +] + +ddlDate = ( + pp.Or(map(pp.CaselessLiteral, sorted(ddlDateTypes, reverse=True))) + .setParseAction(pp.replaceWith("day_point")) + .setResultsName("warnTimezone") +) + +ddlDateTimeTypes = [ + "datetime", + "timestamp", + "timestamp without time zone", # PostgreSQL + "timestamp with time zone", # PostgreSQL + "timestamptz", # PostgreSQL +] + +ddlDateTime = pp.Or( + map(pp.CaselessLiteral, sorted(ddlDateTimeTypes, reverse=True)) +).setParseAction(pp.replaceWith("time_point")) + +ddlTimeTypes = [ + "time", + "time without time zone", # PostgreSQL + "time with time zone", # PostgreSQL +] + +ddlTime = pp.Or( + map(pp.CaselessLiteral, sorted(ddlTimeTypes, reverse=True)) +).setParseAction(pp.replaceWith("time_of_day")) + + +ddlUnknown = pp.Word(pp.alphanums).setParseAction(pp.replaceWith("UNKNOWN")) + +ddlType = ( + ddlBoolean + | ddlInteger + | ddlSerial + | ddlFloatingPoint + | ddlText + | ddlBlob + | ddlDateTime + | ddlDate + | ddlTime + | ddlUnknown +) + +ddlUnsigned = pp.CaselessLiteral("UNSIGNED").setResultsName("isUnsigned") +ddlWidth = ddlLeft + pp.Word(pp.nums) + ddlRight +ddlTimezone = ( + (pp.CaselessLiteral("with") | pp.CaselessLiteral("without")) + + pp.CaselessLiteral("time") + + pp.CaselessLiteral("zone") +) + +ddlNotNull = pp.Group( + pp.CaselessLiteral("NOT") + pp.CaselessLiteral("NULL") +).setResultsName("notNull") +ddlDefaultValue = pp.CaselessLiteral("DEFAULT").setResultsName("hasDefaultValue") + +ddlAutoKeywords = [ + "AUTO_INCREMENT", + "AUTOINCREMENT", + "SMALLSERIAL", + "SERIAL", + "BIGSERIAL", +] +ddlAutoValue = pp.Or(map(pp.CaselessLiteral, sorted(ddlAutoKeywords, reverse=True))) + +ddlColumn = pp.Group( + ddlName("name") + + ddlType("type") + + pp.Suppress(pp.Optional(ddlWidth)) + + pp.Suppress(pp.Optional(ddlTimezone)) + + pp.ZeroOrMore( + ddlUnsigned("isUnsigned") + | ddlNotNull("notNull") + | pp.CaselessLiteral("null") + | ddlAutoValue("hasAutoValue") + | ddlDefaultValue("hasDefaultValue") + | pp.Suppress(ddlExpression) + ) +) + +ddlConstraintKeywords = [ + "CONSTRAINT", + "PRIMARY", + "FOREIGN", + "KEY", + "FULLTEXT", + "INDEX", + "UNIQUE", + "CHECK", +] +ddlConstraint = pp.Group( + pp.Or(map(pp.CaselessLiteral, sorted(ddlConstraintKeywords, reverse=True))) + + ddlExpression +).setResultsName("isConstraint") + +# CREATE TABLE parser +ddlIfNotExists = pp.Group( + pp.CaselessLiteral("IF") + pp.CaselessLiteral("NOT") + pp.CaselessLiteral("EXISTS") +).setResultsName("ifNotExists") +ddlCreateTable = pp.Group( + pp.CaselessLiteral("CREATE") + + pp.CaselessLiteral("TABLE") + + pp.Suppress(pp.Optional(ddlIfNotExists)) + + ddlName.setResultsName("tableName") + + ddlLeft + + pp.Group(pp.delimitedList(ddlColumn | pp.Suppress(ddlConstraint))).setResultsName( + "columns" + ) + + ddlRight +).setResultsName("create") +# ddlString.setDebug(True) #uncomment to debug pyparsing + +ddl = pp.OneOrMore(pp.Suppress(pp.SkipTo(ddlCreateTable, False)) + ddlCreateTable) + +ddlComment = pp.oneOf(["--", "#"]) + pp.restOfLine +ddl.ignore(ddlComment) + + +def testBoolean(): + for t in ddlBooleanTypes: + result = ddlType.parseString(t, parseAll=True) + assert result[0] == "boolean" + + +def testInteger(): + for t in ddlIntegerTypes: + result = ddlType.parseString(t, parseAll=True) + assert result[0] == "integer" + + +def testSerial(): + for t in ddlSerialTypes: + result = ddlType.parseString(t, parseAll=True) + assert result[0] == "integer" + assert result.hasAutoValue + + +def testFloatingPoint(): + for t in ddlFloatingPointTypes: + result = ddlType.parseString(t, parseAll=True) + assert result[0] == "floating_point" + + +def testText(): + for t in ddlTextTypes: + result = ddlType.parseString(t, parseAll=True) + assert result[0] == "text" + + +def testBlob(): + for t in ddlBlobTypes: + result = ddlType.parseString(t, parseAll=True) + assert result[0] == "blob" + + +def testDate(): + for t in ddlDateTypes: + result = ddlType.parseString(t, parseAll=True) + assert result[0] == "day_point" + + +def testDateTime(): + for t in ddlDateTimeTypes: + result = ddlType.parseString(t, parseAll=True) + assert result[0] == "time_point" + + +def testTime(): + for t in ddlTimeTypes: + result = ddlType.parseString(t, parseAll=True) + assert result[0] == "time_of_day" + + +def testUnknown(): + for t in ["cheesecake", "blueberry"]: + result = ddlType.parseString(t, parseAll=True) + assert result[0] == "UNKNOWN" + + +def testAutoValue(): + def test(s, expected): + results = ddlAutoValue.parseString(s, parseAll=True) + print(results) + + +def testColumn(): + text = "\"id\" int(8) unsigned NOT NULL DEFAULT nextval('dk_id_seq'::regclass)" + result = ddlColumn.parseString(text, parseAll=True) + column = result[0] + assert column.name == "id" + assert column.type == "integer" + assert column.isUnsigned + assert column.notNull + assert not column.hasAutoValue + + +def testConstraint(): + for text in [ + "CONSTRAINT unique_person UNIQUE (first_name, last_name)", + "UNIQUE (id)", + ]: + result = ddlConstraint.parseString(text, parseAll=True) + assert result.isConstraint + + +def testTable(): + text = """ + CREATE TABLE "public"."dk" ( + "id" int8 NOT NULL DEFAULT nextval('dk_id_seq'::regclass), + "last_update" timestamp(6) DEFAULT now(), + PRIMARY KEY (id) +) +""" + result = ddlCreateTable.parseString(text, parseAll=True) + + +def testParser(): + testBoolean() + testInteger() + testSerial() + testFloatingPoint() + testText() + testBlob() + testDate() + testTime() + testUnknown() + testDateTime() + testColumn() + testConstraint() + testTable() + + +# CODE GENERATOR # HELPERS + def get_include_guard_name(namespace, inputfile): - val = re.sub("[^A-Za-z0-9]+", "_", namespace + '_' + os.path.basename(inputfile)) - return val.upper() + val = re.sub("[^A-Za-z0-9]+", "_", namespace + "_" + os.path.basename(inputfile)) + return val.upper() + def identity_naming_func(s): - return s + return s + def repl_camel_case_func(m): - if m.group(1) == '_': - return m.group(2).upper() - else: - return m.group(1) + m.group(2).upper() + if m.group(1) == "_": + return m.group(2).upper() + else: + return m.group(1) + m.group(2).upper() def class_name_naming_func(s): - s = s.replace('.', '_') - return re.sub("(^|\s|[_0-9])(\S)", repl_camel_case_func, s) + s = s.replace(".", "_") + return re.sub("(^|\s|[_0-9])(\S)", repl_camel_case_func, s) def member_name_naming_func(s): - s = s.replace('.', '_') - return re.sub("(\s|_|[0-9])(\S)", repl_camel_case_func, s) + s = s.replace(".", "_") + return re.sub("(\s|_|[0-9])(\S)", repl_camel_case_func, s) -toClassName = class_name_naming_func -toMemberName = member_name_naming_func def repl_func_for_args(m): - if m.group(1) == '-': - return m.group(2).upper() - else: - return m.group(1) + m.group(2).upper() + if m.group(1) == "-": + return m.group(2).upper() + else: + return m.group(1) + m.group(2).upper() + def setArgumentBool(s, bool_value): - first_lower = lambda s: s[:1].lower() + s[1:] if s else '' # http://stackoverflow.com/a/3847369/5006740 + first_lower = ( + lambda s: s[:1].lower() + s[1:] if s else "" + ) # http://stackoverflow.com/a/3847369/5006740 var_name = first_lower(re.sub("(\s|-|[0-9])(\S)", repl_func_for_args, s)) globals()[var_name] = bool_value -def usage(optionalArgs = {}): - print('\ - Usage: ddl2cpp \n\ - ddl2cpp -help') +def escape_if_reserved(name): + reserved_names = [ + "BEGIN", + "END", + "GROUP", + "ORDER", + ] + if name.upper() in reserved_names: + return "!{}".format(name) + return name -def beginHeader(pathToHeader, nsList): - header = open(pathToHeader, 'w') - print('// generated by ' + ' '.join(sys.argv), file=header) - print('#ifndef '+get_include_guard_name(namespace, pathToHeader), file=header) - print('#define '+get_include_guard_name(namespace, pathToHeader), file=header) - print('', file=header) - print('#include <' + INCLUDE + '/table.h>', file=header) - print('#include <' + INCLUDE + '/data_types.h>', file=header) - print('#include <' + INCLUDE + '/char_sequence.h>', file=header) - print('', file=header) + +def beginHeader(pathToHeader, namespace, nsList): + header = open(pathToHeader, "w") + print("// generated by " + " ".join(sys.argv), file=header) + print("#ifndef " + get_include_guard_name(namespace, pathToHeader), file=header) + print("#define " + get_include_guard_name(namespace, pathToHeader), file=header) + print("", file=header) + print("#include ", file=header) + print("#include ", file=header) + print("#include ", file=header) + print("", file=header) for ns in nsList: - print('namespace ' + ns, file=header) - print('{', file=header) + print("namespace " + namespace, file=header) + print("{", file=header) return header + def endHeader(header, nsList): for ns in nsList: - print('} // namespace ' + ns, file=header) - print('#endif', file=header) + print("} // namespace " + ns, file=header) + print("#endif", file=header) header.close() + def help_message(): - arg_string = '\n' + arg_string = "" pad = 0 - padding = 0 for argument in list(optionalArgs.keys()): if len(argument) > pad: pad = len(argument) for argument in list(optionalArgs.keys()): - if argument == '-help': - continue if len(argument) < pad: padding = " " * (pad - len(argument)) else: - padding = '' - arg_string = arg_string + ' [-[no]'+argument+']: ' + padding + optionalArgs[argument] + '\n' - print('Usage: ddl2cpp [-help]\n\n OPTIONAL ARGUMENTS:\n' + arg_string +' \n \ - \n\ -\n\ - path to your SQL database/table definitions (SHOW CREATE TABLE SomeTable) \n\ - path to a generated C++ header file. Without extension (no *.h). \n\ - namespace you want. Usually a project/database name\n') + padding = "" + arg_string = ( + arg_string + argument + ": " + padding + optionalArgs[argument] + "\n" + ) + print( + "Usage:\n" + "ddl2cpp [optional args] \n\n" + "OPTIONAL ARGUMENTS:\n" + arg_string + "\n" + " path to your SQL database/table definitions (SHOW CREATE TABLE SomeTable) \n" + " path to a generated C++ header file without extension (no *.h). \n" + " namespace you want. Usually a project/database name\n" + ) sys.exit(0) + optionalArgs = { # if -some-key is present, it will set variable someKey to True - # if -no-some-key is present, it will set variable someKey to False - '-timestamp-warning': "show warning about mysql timestamp data type", # timeStampWarning = True - # '-no-time-stamp-warning' # timeStampWarning = False - '-fail-on-parse': "abort instead of silent genereation of unusable headers", # failOnParse = True - '-warn-on-parse': "warn about unusable headers, but continue", # warnOnParse = True - '-auto-id': "Assume column 'id' to have an automatic value as if AUTO_INCREMENT was specified (e.g. implicit for SQLite ROWID)", # autoId = True - '-identity-naming': "Use table and column names from the ddl (defaults to UpperCamelCase for tables and lowerCamelCase for columns)", # identityNaming = True - '-split-tables': "Make a header for each table name, using target as a directory", # splitTables = True - '-help': "show this help" + "-no-timestamp-warning": "show warning about date / time data types", # noTimeStampWarning = True + "-auto-id": "Assume column 'id' to have an automatic value as if AUTO_INCREMENT was specified (e.g. implicit for SQLite ROWID)", # autoId = True + "-identity-naming": "Use table and column names from the ddl (defaults to UpperCamelCase for tables and lowerCamelCase for columns)", # identityNaming = True + "-split-tables": "Make a header for each table name, using target as a directory", # splitTables = True + "--help": "show this help", + "--test": "run parser self-test", } -if '-help' in sys.argv: - help_message() -# ARGUMENT PARSING -if len(sys.argv) < (4): - usage(optionalArgs) - sys.exit(ERROR_BAD_ARGS) - -firstPositional = 1 -timestampWarning = True -failOnParse = False -warnOnParse = False -parseError = "Parsing error, possible reason: can't parse default value for a field" +noTimestampWarning = False autoId = False identityNaming = False splitTables = False -if len(sys.argv) >= 4: - for arg in sys.argv: - noArg = arg.replace('-no-', '-') - if arg in list(optionalArgs.keys()): - setArgumentBool(arg, True) - firstPositional += 1 - elif noArg in optionalArgs: - setArgumentBool(noArg, False) - firstPositional += 1 - else: - pass +def createHeader(): + global noTimestampWarning + # ARGUMENT PARSING + if len(sys.argv) < (4): + help_message() + sys.exit(ERROR_BAD_ARGS) -if identityNaming: - toClassName = identity_naming_func - toMemberName = identity_naming_func + firstPositional = 1 + if len(sys.argv) >= 4: + for arg in sys.argv: + if arg in list(optionalArgs.keys()): + setArgumentBool(arg, True) + firstPositional += 1 + else: + pass -pathToDdl = sys.argv[firstPositional] + if identityNaming: + toClassName = identity_naming_func + toMemberName = identity_naming_func + else: + toClassName = class_name_naming_func + toMemberName = member_name_naming_func -pathToHeader = sys.argv[firstPositional + 1] + ('/' if splitTables else '.h') -namespace = sys.argv[firstPositional + 2] + pathToDdl = sys.argv[firstPositional] + pathToHeader = sys.argv[firstPositional + 1] + ("/" if splitTables else ".h") + namespace = sys.argv[firstPositional + 2] -INCLUDE = 'sqlpp11' -NAMESPACE = 'sqlpp' - - - -# PARSER -def ddlWord(string): - return WordStart(alphanums + "_") + CaselessLiteral(string) + WordEnd(alphanums + "_") - -# This function should be refactored if we find some database function which needs parameters -# Right now it works only for something like NOW() in MySQL default field value -def ddlFunctionWord(string): - return CaselessLiteral(string) + OneOrMore("(") + ZeroOrMore(" ") + OneOrMore(")") - -ddlString = Or([QuotedString("'"), QuotedString("\"", escQuote='""'), QuotedString("`")]) -negativeSign = Literal('-') -ddlNum = Combine(Optional(negativeSign) + Word(nums + ".")) -ddlTerm = Word(alphanums + "_$") -ddlName = Or([ddlTerm, ddlString, Combine(ddlString + "." + ddlString)]) -ddlMathOp = Word("+><=-") -ddlBoolean = Or([ddlWord("AND"), ddlWord("OR"), ddlWord("NOT")]) -ddlArguments = "(" + delimitedList(Or([ddlString, ddlTerm, ddlNum])) + ")" -ddlMathCond = "(" + delimitedList( - Or([ - Group(ddlName + ddlMathOp + ddlName), - Group(ddlName + ddlWord("NOT") + ddlWord("NULL")), - ]), - delim=ddlBoolean) + ")" - -ddlUnsigned = ddlWord("unsigned").setResultsName("isUnsigned") -ddlNotNull = Group(ddlWord("NOT") + ddlWord("NULL")).setResultsName("notNull") -ddlDefaultValue = ddlWord("DEFAULT").setResultsName("hasDefaultValue") -ddlAutoValue = Or([ - ddlWord("AUTO_INCREMENT"), - ddlWord("AUTOINCREMENT"), - ddlWord("SMALLSERIAL"), - ddlWord("SERIAL"), - ddlWord("BIGSERIAL"), - ]).setResultsName("hasAutoValue") -ddlColumnComment = Group(ddlWord("COMMENT") + ddlString).setResultsName("comment") -ddlConstraint = Or([ - ddlWord("CONSTRAINT"), - ddlWord("PRIMARY"), - ddlWord("FOREIGN"), - ddlWord("KEY"), - ddlWord("FULLTEXT"), - ddlWord("INDEX"), - ddlWord("UNIQUE"), - ddlWord("CHECK") - ]) -ddlColumn = Group(Optional(ddlConstraint).setResultsName("isConstraint") + OneOrMore(MatchFirst([ddlUnsigned, ddlNotNull, ddlAutoValue, ddlDefaultValue, ddlFunctionWord("NOW"), ddlFunctionWord("current_timestamp"), ddlTerm, ddlNum, ddlColumnComment, ddlString, ddlArguments, ddlMathCond]))) -ddlIfNotExists = Optional(Group(ddlWord("IF") + ddlWord("NOT") + ddlWord("EXISTS")).setResultsName("ifNotExists")) -createTable = Group(ddlWord("CREATE") + ddlWord("TABLE") + ddlIfNotExists + ddlName.setResultsName("tableName") + "(" + Group(delimitedList(ddlColumn)).setResultsName("columns") + ")").setResultsName("create") -#ddlString.setDebug(True) #uncomment to debug pyparsing - -ddl = ZeroOrMore(Suppress(SkipTo(createTable, False)) + createTable) - -ddlComment = oneOf(["--", "#"]) + restOfLine -ddl.ignore(ddlComment) - -# MAP SQL TYPES -types = { - 'tinyint': 'tinyint', - 'smallint': 'smallint', - 'smallserial': 'smallint', # PostgreSQL - 'int2': 'smallint', #PostgreSQL - 'integer': 'integer', - 'int': 'integer', - 'serial': 'integer', # PostgreSQL - 'int4': 'integer', #PostgreSQL - 'mediumint' : 'integer', - 'bigint': 'bigint', - 'bigserial': 'bigint', # PostgreSQL - 'int8': 'bigint', #PostgreSQL - 'char': 'char_', - 'varchar': 'varchar', - 'character varying': 'varchar', #PostgreSQL - 'text': 'text', - 'clob': 'text', - 'bytea': 'blob', - 'tinyblob': 'blob', - 'blob': 'blob', - 'mediumblob': 'blob', - 'longblob': 'blob', - 'bool': 'boolean', - 'boolean': 'boolean', - 'double': 'floating_point', - 'float8': 'floating_point', # PostgreSQL - 'float': 'floating_point', - 'float4': 'floating_point', # PostgreSQL - 'real': 'floating_point', - 'numeric': 'floating_point', # PostgreSQL - 'decimal' : 'floating_point', # MYSQL - 'date': 'day_point', - 'datetime': 'time_point', - 'time': 'time_of_day', - 'time without time zone': 'time_point', # PostgreSQL - 'time with time zone': 'time_point', # PostgreSQL - 'timestamp': 'time_point', - 'timestamp without time zone': 'time_point', # PostgreSQL - 'timestamp with time zone': 'time_point', # PostgreSQL - 'timestamptz': 'time_point', # PostgreSQL - 'enum': 'text', # MYSQL - 'set': 'text', # MYSQL, - 'longtext' : 'text', #MYSQL - 'json' : 'text', # PostgreSQL - 'jsonb' : 'text', # PostgreSQL - 'tinyint unsigned': 'tinyint_unsigned', #MYSQL - 'smallint unsigned': 'smallint_unsigned', #MYSQL - 'integer unsigned': 'integer_unsigned', #MYSQL - 'int unsigned': 'integer_unsigned', #MYSQL - 'bigint unsigned': 'bigint_unsigned', #MYSQL - 'mediumint unsigned' : 'integer', #MYSQL - 'tinytext' : 'text', #MYSQL - 'binary' : 'blob', #MYSQL - 'varbinary' : 'blob', #MYSQL - 'float unsigned' : 'floating_point', #MYSQL - } - -if failOnParse: - ddl = OneOrMore(Suppress(SkipTo(createTable, False)) + createTable) - ddl.ignore(ddlComment) try: tableCreations = ddl.parseFile(pathToDdl) - except ParseException as e: - print(parseError + '. Exiting [-no-fail-on-parse]') + except pp.ParseException as e: + print("ERROR: Could not parse any CREATE TABLE statement in " + pathToDdl) + # print(pp.parseError) sys.exit(ERROR_STRANGE_PARSING) -else: - ddl = ZeroOrMore(Suppress(SkipTo(createTable, False)) + createTable) - ddl.ignore(ddlComment) + + nsList = namespace.split("::") + + # PROCESS DDL tableCreations = ddl.parseFile(pathToDdl) -if warnOnParse: - print(parseError + '. Continuing [-no-warn-on-parse]') + header = 0 + if not splitTables: + header = beginHeader(pathToHeader, namespace, nsList) + DataTypeError = False + for create in tableCreations: + sqlTableName = create.tableName + if splitTables: + header = beginHeader(pathToHeader + sqlTableName + ".h", namespace, nsList) + tableClass = toClassName(sqlTableName) + tableMember = toMemberName(sqlTableName) + tableNamespace = tableClass + "_" + tableTemplateParameters = tableClass + print(" namespace " + tableNamespace, file=header) + print(" {", file=header) + for column in create.columns: + if column.isConstraint: + continue + sqlColumnName = column.name + columnClass = toClassName(sqlColumnName) + tableTemplateParameters += ( + ",\n " + tableNamespace + "::" + columnClass + ) + columnMember = toMemberName(sqlColumnName) + columnType = column.type + if columnType == "UNKNOWN": + print( + "Error: datatype of %s.%s is not supported." + % (sqlTableName, sqlColumnName) + ) + DataTypeError = True + if columnType == "integer" and column.isUnsigned: + columnType = columnType + "_unsigned" + if columnType == "time_point" and not noTimestampWarning: + print( + "Warning: date and time values are assumed to be without timezone." + ) + print( + "Warning: If you are using types WITH timezones, your code has to deal with that." + ) + print("You can disable this warning using -no-timestamp-warning") + noTimestampWarning = True + traitslist = ["sqlpp::" + columnType] + columnCanBeNull = not column.notNull + print(" struct " + columnClass, file=header) + print(" {", file=header) + print(" struct _alias_t", file=header) + print(" {", file=header) + print( + ' static constexpr const char _literal[] = "' + + escape_if_reserved(sqlColumnName) + + '";', + file=header, + ) + print( + " using _name_t = sqlpp::make_char_sequence;", + file=header, + ) + print(" template", file=header) + print(" struct _member_t", file=header) + print(" {", file=header) + print(" T " + columnMember + ";", file=header) + print( + " T& operator()() { return " + columnMember + "; }", + file=header, + ) + print( + " const T& operator()() const { return " + + columnMember + + "; }", + file=header, + ) + print(" };", file=header) + print(" };", file=header) + requireInsert = True + hasAutoValue = column.hasAutoValue or (autoId and sqlColumnName == "id") + if hasAutoValue: + traitslist.append("sqlpp::tag::must_not_insert") + traitslist.append("sqlpp::tag::must_not_update") + requireInsert = False + if not column.notNull: + traitslist.append("sqlpp::tag::can_be_null") + requireInsert = False + if column.hasDefaultValue: + requireInsert = False + if requireInsert: + traitslist.append("sqlpp::tag::require_insert") + print( + " using _traits = sqlpp::make_traits<" + + ", ".join(traitslist) + + ">;", + file=header, + ) + print(" };", file=header) + print(" } // namespace " + tableNamespace, file=header) + print("", file=header) -nsList = namespace.split('::') + print( + " struct " + + tableClass + + ": sqlpp::table_t<" + + tableTemplateParameters + + ">", + file=header, + ) + print(" {", file=header) + print(" struct _alias_t", file=header) + print(" {", file=header) + print( + ' static constexpr const char _literal[] = "' + sqlTableName + '";', + file=header, + ) + print( + " using _name_t = sqlpp::make_char_sequence;", + file=header, + ) + print(" template", file=header) + print(" struct _member_t", file=header) + print(" {", file=header) + print(" T " + tableMember + ";", file=header) + print(" T& operator()() { return " + tableMember + "; }", file=header) + print( + " const T& operator()() const { return " + tableMember + "; }", + file=header, + ) + print(" };", file=header) + print(" };", file=header) + print(" };", file=header) + if splitTables: + endHeader(header, nsList) -def escape_if_reserved(name): - reserved_names = [ - 'BEGIN', - 'END', - 'GROUP', - 'ORDER', - ] - if name.upper() in reserved_names: - return '!{}'.format(name) - return name - - -# PROCESS DDL -tableCreations = ddl.parseFile(pathToDdl) - -header = 0 -if not splitTables: - header = beginHeader(pathToHeader, nsList) -DataTypeError = False -for create in tableCreations: - sqlTableName = create.tableName - if splitTables: - header = beginHeader(pathToHeader + sqlTableName + '.h', nsList) - tableClass = toClassName(sqlTableName) - tableMember = toMemberName(sqlTableName) - tableNamespace = tableClass + '_' - tableTemplateParameters = tableClass - print(' namespace ' + tableNamespace, file=header) - print(' {', file=header) - for column in create.columns: - if column.isConstraint: - continue - sqlColumnName = column[0] - columnClass = toClassName(sqlColumnName) - tableTemplateParameters += ',\n ' + tableNamespace + '::' + columnClass - columnMember = toMemberName(sqlColumnName) - sqlColumnType = column[1].lower() - if column.isUnsigned: - sqlColumnType = sqlColumnType + ' unsigned'; - if sqlColumnType == 'timestamp' and timestampWarning: - print("Warning: timestamp is mapped to sqlpp::time_point like datetime") - print("Warning: You have to take care of timezones yourself") - print("You can disable this warning using -no-timestamp-warning") - columnCanBeNull = not column.notNull - print(' struct ' + columnClass, file=header) - print(' {', file=header) - print(' struct _alias_t', file=header) - print(' {', file=header) - print(' static constexpr const char _literal[] = "' + escape_if_reserved(sqlColumnName) + '";', file=header) - print(' using _name_t = sqlpp::make_char_sequence;', file=header) - print(' template', file=header) - print(' struct _member_t', file=header) - print(' {', file=header) - print(' T ' + columnMember + ';', file=header) - print(' T& operator()() { return ' + columnMember + '; }', file=header) - print(' const T& operator()() const { return ' + columnMember + '; }', file=header) - print(' };', file=header) - print(' };', file=header) - try: - traitslist = [NAMESPACE + '::' + types[sqlColumnType]] - except KeyError as e: - print ('Error: datatype "' + sqlColumnType + '"" is not supported.') - DataTypeError = True - requireInsert = True - hasAutoValue = column.hasAutoValue or (autoId and sqlColumnName == 'id') - if hasAutoValue: - traitslist.append(NAMESPACE + '::tag::must_not_insert') - traitslist.append(NAMESPACE + '::tag::must_not_update') - requireInsert = False - if not column.notNull: - traitslist.append(NAMESPACE + '::tag::can_be_null') - requireInsert = False - if column.hasDefaultValue: - requireInsert = False - if requireInsert: - traitslist.append(NAMESPACE + '::tag::require_insert') - print(' using _traits = ' + NAMESPACE + '::make_traits<' + ', '.join(traitslist) + '>;', file=header) - print(' };', file=header) - print(' } // namespace ' + tableNamespace, file=header) - print('', file=header) - - print(' struct ' + tableClass + ': ' + NAMESPACE + '::table_t<' + tableTemplateParameters + '>', file=header) - print(' {', file=header) - print(' struct _alias_t', file=header) - print(' {', file=header) - print(' static constexpr const char _literal[] = "' + sqlTableName + '";', file=header) - print(' using _name_t = sqlpp::make_char_sequence;', file=header) - print(' template', file=header) - print(' struct _member_t', file=header) - print(' {', file=header) - print(' T ' + tableMember + ';', file=header) - print(' T& operator()() { return ' + tableMember + '; }', file=header) - print(' const T& operator()() const { return ' + tableMember + '; }', file=header) - print(' };', file=header) - print(' };', file=header) - print(' };', file=header) - if splitTables: + if not splitTables: endHeader(header, nsList) + if DataTypeError: + print("Error: unsupported datatypes.") + print("Possible solutions:") + print("A) Implement this datatype (examples: sqlpp11/data_types)") + print("B) Extend/upgrade ddl2cpp (edit types map)") + print("C) Raise an issue on github") + sys.exit(10) # return non-zero error code, we might need it for automation -if not splitTables: - endHeader(header, nsList) -if (DataTypeError): - print("Error: unsupported datatypes." ) - print("Possible solutions:") - print("A) Implement this datatype (examples: sqlpp11/data_types)" ) - print("B) Extend/upgrade ddl2cpp (edit types map)" ) - print("C) Raise an issue on github" ) - sys.exit(10) #return non-zero error code, we might need it for automation + +if __name__ == "__main__": + if "--help" in sys.argv: + help_message() + sys.exit() + elif "--test" in sys.argv: + testParser() + sys.exit() + else: + createHeader()