#!/usr/bin/env python3 ## # Copyright (c) 2013-2022, Roland Bock # All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, # are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. # IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, # INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED # OF THE POSSIBILITY OF SUCH DAMAGE. ## import pyparsing as pp import sys import re import os # error codes, we should refactor this later ERROR_BAD_ARGS = 1 ERROR_DATA_TYPE = 10 ERROR_STRANGE_PARSING = 20 # Rather crude SQL expression parser. # This is not geared at correctly interpreting SQL, but at identifying (and ignoring) expressions for instance in DEFAULT expressions ddlLeft, ddlRight = map(pp.Suppress, "()") ddlNumber = pp.Word(pp.nums + "+-.", pp.nums + "+-.Ee") ddlString = ( pp.QuotedString("'") | pp.QuotedString('"', escQuote='""') | pp.QuotedString("`") ) ddlTerm = pp.Word(pp.alphas + "_", pp.alphanums + "_$") ddlName = ddlTerm | ddlString ddlOperator = pp.Or( map(pp.CaselessLiteral, ["+", "-", "*", "/", "<", "<=", ">", ">=", "=", "%"]) ) ddlBracedExpression = pp.Forward() ddlFunctionCall = pp.Forward() ddlCastEnd = "::" + ddlTerm ddlCast = ddlString + ddlCastEnd ddlBracedArguments = pp.Forward() ddlExpression = pp.OneOrMore( ddlBracedExpression | ddlFunctionCall | ddlCastEnd | ddlCast | ddlOperator | ddlString | ddlTerm | ddlNumber | ddlBracedArguments ) ddlBracedArguments << ddlLeft + pp.delimitedList(ddlExpression) + ddlRight ddlBracedExpression << ddlLeft + ddlExpression + ddlRight ddlArguments = pp.Suppress(pp.Group(pp.delimitedList(ddlExpression))) ddlFunctionCall << pp.Optional(ddlName + ".") + ddlName + ddlLeft + pp.Optional(ddlArguments) + ddlRight # Column and constraint parsers ddlBooleanTypes = [ "bool", "boolean", ] ddlBoolean = pp.Or( map(pp.CaselessLiteral, sorted(ddlBooleanTypes, reverse=True)) ).setParseAction(pp.replaceWith("boolean")) ddlIntegerTypes = [ "bigint", "int", "int2", # PostgreSQL "int4", # PostgreSQL "int8", # PostgreSQL "integer", "mediumint", "smallint", "tinyint", ] ddlInteger = pp.Or( map(pp.CaselessLiteral, sorted(ddlIntegerTypes, reverse=True)) ).setParseAction(pp.replaceWith("integer")) ddlSerialTypes = [ "bigserial", # PostgreSQL "serial", # PostgreSQL "smallserial", # PostgreSQL ] ddlSerial = ( pp.Or(map(pp.CaselessLiteral, sorted(ddlSerialTypes, reverse=True))) .setParseAction(pp.replaceWith("integer")) .setResultsName("hasAutoValue") ) ddlFloatingPointTypes = [ "decimal", # MYSQL "double", "float8", # PostgreSQL "float", "float4", # PostgreSQL "numeric", # PostgreSQL "real", ] ddlFloatingPoint = pp.Or( map(pp.CaselessLiteral, sorted(ddlFloatingPointTypes, reverse=True)) ).setParseAction(pp.replaceWith("floating_point")) ddlTextTypes = [ "char", "varchar", "character varying", # PostgreSQL "text", "clob", "enum", # MYSQL "set", "longtext", # MYSQL "jsonb", # PostgreSQL "json", # PostgreSQL "tinytext", # MYSQL "mediumtext", # MYSQL "rational", # PostgreSQL pg_rationale extension ] ddlText = pp.Or( map(pp.CaselessLiteral, sorted(ddlTextTypes, reverse=True)) ).setParseAction(pp.replaceWith("text")) ddlBlobTypes = [ "bytea", "tinyblob", "blob", "mediumblob", "longblob", "binary", # MYSQL "varbinary", # MYSQL ] ddlBlob = pp.Or( map(pp.CaselessLiteral, sorted(ddlBlobTypes, reverse=True)) ).setParseAction(pp.replaceWith("blob")) ddlDateTypes = [ "date", ] ddlDate = ( pp.Or(map(pp.CaselessLiteral, sorted(ddlDateTypes, reverse=True))) .setParseAction(pp.replaceWith("day_point")) .setResultsName("warnTimezone") ) ddlDateTimeTypes = [ "datetime", "timestamp", "timestamp without time zone", # PostgreSQL "timestamp with time zone", # PostgreSQL "timestamptz", # PostgreSQL ] ddlDateTime = pp.Or( map(pp.CaselessLiteral, sorted(ddlDateTimeTypes, reverse=True)) ).setParseAction(pp.replaceWith("time_point")) ddlTimeTypes = [ "time", "time without time zone", # PostgreSQL "time with time zone", # PostgreSQL ] ddlTime = pp.Or( map(pp.CaselessLiteral, sorted(ddlTimeTypes, reverse=True)) ).setParseAction(pp.replaceWith("time_of_day")) ddlUnknown = pp.Word(pp.alphanums).setParseAction(pp.replaceWith("UNKNOWN")) ddlType = ( ddlBoolean | ddlInteger | ddlSerial | ddlFloatingPoint | ddlText | ddlBlob | ddlDateTime | ddlDate | ddlTime | ddlUnknown ) ddlUnsigned = pp.CaselessLiteral("UNSIGNED").setResultsName("isUnsigned") ddlDigits = "," + pp.Word(pp.nums) ddlWidth = ddlLeft + pp.Word(pp.nums) + pp.Optional(ddlDigits) + ddlRight ddlTimezone = ( (pp.CaselessLiteral("with") | pp.CaselessLiteral("without")) + pp.CaselessLiteral("time") + pp.CaselessLiteral("zone") ) ddlNotNull = pp.Group( pp.CaselessLiteral("NOT") + pp.CaselessLiteral("NULL") ).setResultsName("notNull") ddlDefaultValue = pp.CaselessLiteral("DEFAULT").setResultsName("hasDefaultValue") ddlAutoKeywords = [ "AUTO_INCREMENT", "AUTOINCREMENT", "SMALLSERIAL", "SERIAL", "BIGSERIAL", "GENERATED", ] ddlAutoValue = pp.Or(map(pp.CaselessLiteral, sorted(ddlAutoKeywords, reverse=True))) ddlColumn = pp.Group( ddlName("name") + ddlType("type") + pp.Suppress(pp.Optional(ddlWidth)) + pp.Suppress(pp.Optional(ddlTimezone)) + pp.ZeroOrMore( ddlUnsigned("isUnsigned") | ddlNotNull("notNull") | pp.CaselessLiteral("null") | ddlAutoValue("hasAutoValue") | ddlDefaultValue("hasDefaultValue") | pp.Suppress(ddlExpression) ) ) ddlConstraintKeywords = [ "CONSTRAINT", "PRIMARY", "FOREIGN", "KEY", "FULLTEXT", "INDEX", "UNIQUE", "CHECK", "PERIOD", ] ddlConstraint = pp.Group( pp.Or(map(pp.CaselessLiteral, sorted(ddlConstraintKeywords, reverse=True))) + ddlExpression ).setResultsName("isConstraint") # CREATE TABLE parser ddlIfNotExists = pp.Group( pp.CaselessLiteral("IF") + pp.CaselessLiteral("NOT") + pp.CaselessLiteral("EXISTS") ).setResultsName("ifNotExists") ddlOrReplace = pp.Group( pp.CaselessLiteral("OR") + pp.CaselessLiteral("REPLACE") ).setResultsName("orReplace") ddlCreateTable = pp.Group( pp.CaselessLiteral("CREATE") + pp.Suppress(pp.Optional(ddlOrReplace)) + pp.CaselessLiteral("TABLE") + pp.Suppress(pp.Optional(ddlIfNotExists)) + pp.Optional(ddlName.setResultsName("schema") + pp.Suppress('.')) + ddlName.setResultsName("tableName") + ddlLeft + pp.Group(pp.delimitedList(pp.Suppress(ddlConstraint) | ddlColumn)).setResultsName( "columns" ) + ddlRight ).setResultsName("create") # ddlString.setDebug(True) #uncomment to debug pyparsing ddl = pp.OneOrMore(pp.Suppress(pp.SkipTo(ddlCreateTable, False)) + ddlCreateTable) ddlComment = pp.oneOf(["--", "#"]) + pp.restOfLine ddl.ignore(ddlComment) def testBoolean(): for t in ddlBooleanTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "boolean" def testInteger(): for t in ddlIntegerTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "integer" def testSerial(): for t in ddlSerialTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "integer" assert result.hasAutoValue def testFloatingPoint(): for t in ddlFloatingPointTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "floating_point" def testText(): for t in ddlTextTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "text" def testBlob(): for t in ddlBlobTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "blob" def testDate(): for t in ddlDateTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "day_point" def testDateTime(): for t in ddlDateTimeTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "time_point" def testTime(): for t in ddlTimeTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "time_of_day" def testUnknown(): for t in ["cheesecake", "blueberry"]: result = ddlType.parseString(t, parseAll=True) assert result[0] == "UNKNOWN" def testAutoValue(): def test(s, expected): results = ddlAutoValue.parseString(s, parseAll=True) print(results) def testColumn(): text = "\"id\" int(8) unsigned NOT NULL DEFAULT nextval('dk_id_seq'::regclass)" result = ddlColumn.parseString(text, parseAll=True) column = result[0] assert column.name == "id" assert column.type == "integer" assert column.isUnsigned assert column.notNull assert not column.hasAutoValue def testConstraint(): for text in [ "CONSTRAINT unique_person UNIQUE (first_name, last_name)", "UNIQUE (id)", "UNIQUE (first_name,last_name)" ]: result = ddlConstraint.parseString(text, parseAll=True) assert result.isConstraint def testRational(): for text in [ "pos RATIONAL NOT NULL DEFAULT nextval('rational_seq')::integer", ]: result = ddlColumn.parseString(text, parseAll=True) column = result[0] assert column.name == "pos" assert column.type == "text" assert column.notNull def testTable(): text = """ CREATE TABLE 'public'.'dk' ( "id" int8 NOT NULL DEFAULT nextval('dk_id_seq'::regclass), "strange" NUMERIC(314, 15), "last_update" timestamp(6) DEFAULT now(), PRIMARY KEY (id) ) """ result = ddlCreateTable.parseString(text, parseAll=True) table = result[0] assert table.schema == "public" assert table.tableName == "dk" def testParser(): testBoolean() testInteger() testSerial() testFloatingPoint() testText() testBlob() testDate() testTime() testUnknown() testDateTime() testColumn() testConstraint() testRational() testTable() # CODE GENERATOR # HELPERS def get_include_guard_name(namespace, inputfile): val = re.sub("[^A-Za-z0-9]+", "_", namespace + "_" + os.path.basename(inputfile)) return val.upper() def identity_naming_func(name, schema = None): if schema: return schema + '__' + name; return name def repl_camel_case_func(m): if m.group(1) == "_": return m.group(2).upper() else: return m.group(1) + m.group(2).upper() def class_name_naming_func(name, schema = None): if schema: name = schema + "__" + name name = name.replace(".", "_") return re.sub("(^|\s|[_0-9])(\S)", repl_camel_case_func, name) def member_name_naming_func(name, schema = None): if schema: name = schema + "_" + name name = name.replace(".", "_") return re.sub("(\s|_|[0-9])(\S)", repl_camel_case_func, name) def repl_func_for_args(m): if m.group(1) == "-": return m.group(2).upper() else: return m.group(1) + m.group(2).upper() def setArgumentBool(s, bool_value): first_lower = ( lambda s: s[:1].lower() + s[1:] if s else "" ) # http://stackoverflow.com/a/3847369/5006740 var_name = first_lower(re.sub("(\s|-|[0-9])(\S)", repl_func_for_args, s)) globals()[var_name] = bool_value def escape_if_reserved(name): reserved_names = [ "BEGIN", "END", "GROUP", "ORDER", ] if name.upper() in reserved_names: return "!{}".format(name) return name def beginHeader(pathToHeader, namespace, nsList): header = open(pathToHeader, "w") print("// generated by " + " ".join(sys.argv), file=header) print("#ifndef " + get_include_guard_name(namespace, pathToHeader), file=header) print("#define " + get_include_guard_name(namespace, pathToHeader), file=header) print("", file=header) print("#include ", file=header) print("#include ", file=header) print("#include ", file=header) print("", file=header) for ns in nsList: print("namespace " + namespace, file=header) print("{", file=header) return header def endHeader(header, nsList): for ns in reversed(nsList): print("} // namespace " + ns, file=header) print("#endif", file=header) header.close() def help_message(): arg_string = "" pad = 0 for argument in list(optionalArgs.keys()): if len(argument) > pad: pad = len(argument) for argument in list(optionalArgs.keys()): if len(argument) < pad: padding = " " * (pad - len(argument)) else: padding = "" arg_string = ( arg_string + argument + ": " + padding + optionalArgs[argument] + "\n" ) print( "Usage:\n" "ddl2cpp [optional args] \n\n" "OPTIONAL ARGUMENTS:\n" + arg_string + "\n" " path to your SQL database/table definitions (SHOW CREATE TABLE SomeTable) \n" " path to a generated C++ header file without extension (no *.h). \n" " namespace you want. Usually a project/database name\n" ) sys.exit(0) optionalArgs = { # if -some-key is present, it will set variable someKey to True "-no-timestamp-warning": "show warning about date / time data types", # noTimeStampWarning = True "-auto-id": "Assume column 'id' to have an automatic value as if AUTO_INCREMENT was specified (e.g. implicit for SQLite ROWID)", # autoId = True "-identity-naming": "Use table and column names from the ddl (defaults to UpperCamelCase for tables and lowerCamelCase for columns)", # identityNaming = True "-split-tables": "Make a header for each table name, using target as a directory", # splitTables = True "--help": "show this help", "--test": "run parser self-test", } noTimestampWarning = False autoId = False identityNaming = False splitTables = False def createHeader(): global noTimestampWarning # ARGUMENT PARSING if len(sys.argv) < (4): help_message() sys.exit(ERROR_BAD_ARGS) firstPositional = 1 if len(sys.argv) >= 4: for arg in sys.argv: if arg in list(optionalArgs.keys()): setArgumentBool(arg, True) firstPositional += 1 else: pass if identityNaming: toClassName = identity_naming_func toMemberName = identity_naming_func else: toClassName = class_name_naming_func toMemberName = member_name_naming_func pathToDdl = sys.argv[firstPositional] pathToHeader = sys.argv[firstPositional + 1] + ("/" if splitTables else ".h") namespace = sys.argv[firstPositional + 2] try: tableCreations = ddl.parseFile(pathToDdl) except pp.ParseException as e: print("ERROR: Could not parse any CREATE TABLE statement in " + pathToDdl) # print(pp.parseError) sys.exit(ERROR_STRANGE_PARSING) nsList = namespace.split("::") # PROCESS DDL tableCreations = ddl.parseFile(pathToDdl) header = 0 if not splitTables: header = beginHeader(pathToHeader, namespace, nsList) DataTypeError = False for create in tableCreations: if identityNaming: sqlSchema = create.schema sqlTableName = create.tableName else: sqlSchema = None sqlTableName = create.schema + '.' + create.tableName if splitTables: header = beginHeader(pathToHeader + sqlTableName + ".h", namespace, nsList) tableClass = toClassName(sqlTableName, sqlSchema) tableMember = toMemberName(sqlTableName, sqlSchema) tableNamespace = tableClass + "_" tableTemplateParameters = tableClass print(" namespace " + tableNamespace, file=header) print(" {", file=header) for column in create.columns: if column.isConstraint: continue sqlColumnName = column.name columnClass = toClassName(sqlColumnName) tableTemplateParameters += ( ",\n " + tableNamespace + "::" + columnClass ) columnMember = toMemberName(sqlColumnName) columnType = column.type if columnType == "UNKNOWN": print( "Error: datatype of %s.%s is not supported." % (sqlTableName, sqlColumnName) ) DataTypeError = True if columnType == "integer" and column.isUnsigned: columnType = columnType + "_unsigned" if columnType == "time_point" and not noTimestampWarning: print( "Warning: date and time values are assumed to be without timezone." ) print( "Warning: If you are using types WITH timezones, your code has to deal with that." ) print("You can disable this warning using -no-timestamp-warning") noTimestampWarning = True traitslist = ["sqlpp::" + columnType] columnCanBeNull = not column.notNull print(" struct " + columnClass, file=header) print(" {", file=header) print(" struct _alias_t", file=header) print(" {", file=header) print( ' static constexpr const char _literal[] = "' + escape_if_reserved(sqlColumnName) + '";', file=header, ) print( " using _name_t = sqlpp::make_char_sequence;", file=header, ) print(" template", file=header) print(" struct _member_t", file=header) print(" {", file=header) print(" T " + columnMember + ";", file=header) print( " T& operator()() { return " + columnMember + "; }", file=header, ) print( " const T& operator()() const { return " + columnMember + "; }", file=header, ) print(" };", file=header) print(" };", file=header) requireInsert = True hasAutoValue = column.hasAutoValue or (autoId and sqlColumnName == "id") if hasAutoValue: traitslist.append("sqlpp::tag::must_not_insert") traitslist.append("sqlpp::tag::must_not_update") requireInsert = False if not column.notNull: traitslist.append("sqlpp::tag::can_be_null") requireInsert = False if column.hasDefaultValue: requireInsert = False if requireInsert: traitslist.append("sqlpp::tag::require_insert") print( " using _traits = sqlpp::make_traits<" + ", ".join(traitslist) + ">;", file=header, ) print(" };", file=header) print(" } // namespace " + tableNamespace, file=header) print("", file=header) print( " struct " + tableClass + ": sqlpp::table_t<" + tableTemplateParameters + ">", file=header, ) print(" {", file=header) print(" struct _alias_t", file=header) print(" {", file=header) print( ' static constexpr const char _literal[] = "' + (sqlSchema + '.' if sqlSchema else '') + sqlTableName + '";', file=header, ) print( " using _name_t = sqlpp::make_char_sequence;", file=header, ) print(" template", file=header) print(" struct _member_t", file=header) print(" {", file=header) print(" T " + tableMember + ";", file=header) print(" T& operator()() { return " + tableMember + "; }", file=header) print( " const T& operator()() const { return " + tableMember + "; }", file=header, ) print(" };", file=header) print(" };", file=header) print(" };", file=header) if splitTables: endHeader(header, nsList) if not splitTables: endHeader(header, nsList) if DataTypeError: print("Error: unsupported datatypes.") print("Possible solutions:") print("A) Implement this datatype (examples: sqlpp11/data_types)") print("B) Extend/upgrade ddl2cpp (edit types map)") print("C) Raise an issue on github") sys.exit(10) # return non-zero error code, we might need it for automation if __name__ == "__main__": if "--help" in sys.argv: help_message() sys.exit() elif "--test" in sys.argv: testParser() sys.exit() else: createHeader()