#!/usr/bin/env python3 ## # Copyright (c) 2013-2022, Roland Bock # All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, # are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. # IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, # INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED # OF THE POSSIBILITY OF SUCH DAMAGE. ## import pyparsing as pp import sys import re import os # error codes, we should refactor this later ERROR_BAD_ARGS = 1 ERROR_DATA_TYPE = 10 ERROR_STRANGE_PARSING = 20 # Rather crude SQL expression parser. # This is not geared at correctly interpreting SQL, but at identifying (and ignoring) expressions for instance in DEFAULT expressions ddlLeft, ddlRight = map(pp.Suppress, "()") ddlNumber = pp.Word(pp.nums + "+-.", pp.nums + "+-.Ee") ddlString = ( pp.QuotedString("'") | pp.QuotedString('"', escQuote='""') | pp.QuotedString("`") ) ddlTerm = pp.Word(pp.alphas + "_", pp.alphanums + "_.$") ddlName = pp.Or([ddlTerm, ddlString, pp.Combine(ddlString + "." + ddlString), pp.Combine(ddlTerm + ddlString)]) ddlOperator = pp.Or( map(pp.CaselessLiteral, ["+", "-", "*", "/", "<", "<=", ">", ">=", "=", "%"]), pp.CaselessKeyword("DIV") ) ddlBracedExpression = pp.Forward() ddlFunctionCall = pp.Forward() ddlCastEnd = "::" + ddlTerm ddlCast = ddlString + ddlCastEnd ddlBracedArguments = pp.Forward() ddlExpression = pp.OneOrMore( ddlBracedExpression | ddlFunctionCall | ddlCastEnd | ddlCast | ddlOperator | ddlString | ddlTerm | ddlNumber | ddlBracedArguments ) ddlBracedArguments << ddlLeft + pp.delimitedList(ddlExpression) + ddlRight ddlBracedExpression << ddlLeft + ddlExpression + ddlRight ddlArguments = pp.Suppress(pp.Group(pp.delimitedList(ddlExpression))) ddlFunctionCall << ddlName + ddlLeft + pp.Optional(ddlArguments) + ddlRight # Data types ddlBooleanTypes = [ "bool", "boolean", ] ddlIntegerTypes = [ "bigint", "int", "int2", # PostgreSQL "int4", # PostgreSQL "int8", # PostgreSQL "integer", "mediumint", "smallint", "tinyint", ] ddlSerialTypes = [ "bigserial", # PostgreSQL "serial", # PostgreSQL "serial2", # PostgreSQL "serial4", # PostgreSQL "serial8", # PostgreSQL "smallserial", # PostgreSQL ] ddlFloatingPointTypes = [ "decimal", # MYSQL "double", "float8", # PostgreSQL "float", "float4", # PostgreSQL "numeric", # PostgreSQL "real", ] ddlTextTypes = [ "char", "varchar", "character varying", # PostgreSQL "text", "clob", "enum", # MYSQL "set", "longtext", # MYSQL "jsonb", # PostgreSQL "json", # PostgreSQL "tinytext", # MYSQL "mediumtext", # MYSQL "rational", # PostgreSQL pg_rationale extension ] ddlBlobTypes = [ "bytea", "tinyblob", "blob", "mediumblob", "longblob", "binary", # MYSQL "varbinary", # MYSQL ] ddlDateTypes = [ "date", ] ddlDateTimeTypes = [ "datetime", "timestamp", "timestamp without time zone", # PostgreSQL "timestamp with time zone", # PostgreSQL "timestamptz", # PostgreSQL ] ddlTimeTypes = [ "time", "time without time zone", # PostgreSQL "time with time zone", # PostgreSQL ] parsedContent = "" def SetContent(text, loc, token) : global parsedContent parsedContent = text # Init the DLL parser def initDllParser(): global ddl global ddlType global ddlColumn global ddlConstraint global ddlCreateTable global parsedContent # Column and constraint parsers ddlBoolean = pp.Or( map(pp.CaselessKeyword, sorted(ddlBooleanTypes, reverse=True)) ).setParseAction(pp.replaceWith("boolean")) ddlInteger = pp.Or( map(pp.CaselessKeyword, sorted(ddlIntegerTypes, reverse=True)) ).setParseAction(pp.replaceWith("integral")) ddlSerial = ( pp.Or(map(pp.CaselessKeyword, sorted(ddlSerialTypes, reverse=True))) .setParseAction(pp.replaceWith("integral")) .setResultsName("hasAutoValue") ) ddlFloatingPoint = pp.Or( map(pp.CaselessKeyword, sorted(ddlFloatingPointTypes, reverse=True)) ).setParseAction(pp.replaceWith("floating_point")) ddlText = pp.Or( map(pp.CaselessKeyword, sorted(ddlTextTypes, reverse=True)) ).setParseAction(pp.replaceWith("text")) ddlBlob = pp.Or( map(pp.CaselessKeyword, sorted(ddlBlobTypes, reverse=True)) ).setParseAction(pp.replaceWith("blob")) ddlDate = ( pp.Or(map(pp.CaselessKeyword, sorted(ddlDateTypes, reverse=True))) .setParseAction(pp.replaceWith("day_point")) .setResultsName("warnTimezone") ) ddlDateTime = pp.Or( map(pp.CaselessKeyword, sorted(ddlDateTimeTypes, reverse=True)) ).setParseAction(pp.replaceWith("time_point")) ddlTime = pp.Or( map(pp.CaselessKeyword, sorted(ddlTimeTypes, reverse=True)) ).setParseAction(pp.replaceWith("time_of_day")) ddlUnknown = pp.Word(pp.alphanums).setParseAction(pp.replaceWith("UNKNOWN")) ddlType = ( ddlBoolean | ddlInteger | ddlSerial | ddlFloatingPoint | ddlText | ddlBlob | ddlDateTime | ddlDate | ddlTime | ddlUnknown ) ddlUnsigned = pp.CaselessKeyword("UNSIGNED").setResultsName("isUnsigned") ddlDigits = "," + pp.Word(pp.nums) ddlWidth = ddlLeft + pp.Word(pp.nums) + pp.Optional(ddlDigits) + ddlRight ddlTimezone = ( (pp.CaselessKeyword("with") | pp.CaselessKeyword("without")) + pp.CaselessKeyword("time") + pp.CaselessKeyword("zone") ) ddlNotNull = pp.Group( pp.CaselessKeyword("NOT") + pp.CaselessKeyword("NULL") ).setResultsName("notNull") ddlDefaultValue = pp.CaselessKeyword("DEFAULT").setResultsName("hasDefaultValue") ddlAutoKeywords = [ "AUTO_INCREMENT", "AUTOINCREMENT", "SMALLSERIAL", "SERIAL", "SERIAL2", "SERIAL4", "SERIAL8", "BIGSERIAL", "GENERATED", ] ddlAutoValue = pp.Or(map(pp.CaselessKeyword, sorted(ddlAutoKeywords, reverse=True))) ddlPrimaryKey = pp.Group( pp.CaselessKeyword("PRIMARY") + pp.CaselessKeyword("KEY") ).setResultsName("isPrimaryKey") ddlIgnoredKeywords = [ "CONSTRAINT", "FOREIGN", "KEY", "FULLTEXT", "INDEX", "UNIQUE", "CHECK", "PERIOD", ] ddlConstraint = pp.Group( pp.Or(map( pp.CaselessKeyword, sorted(ddlIgnoredKeywords + ["PRIMARY"], reverse=True) )) + ddlExpression ).setResultsName("isConstraint") ddlColumn = pp.Group( ddlName("name") + ddlType("type") + pp.Suppress(pp.Optional(ddlWidth)) + pp.Suppress(pp.Optional(ddlTimezone)) + pp.ZeroOrMore( ddlUnsigned("isUnsigned") | ddlNotNull("notNull") | pp.CaselessKeyword("null") | ddlAutoValue("hasAutoValue") | ddlDefaultValue("hasDefaultValue") | ddlPrimaryKey("isPrimaryKey") | pp.Suppress(pp.OneOrMore(pp.Or(map(pp.CaselessKeyword, sorted(ddlIgnoredKeywords, reverse=True))))) | pp.Suppress(ddlExpression) ) ) # CREATE TABLE parser ddlIfNotExists = pp.Group( pp.CaselessKeyword("IF") + pp.CaselessKeyword("NOT") + pp.CaselessKeyword("EXISTS") ).setResultsName("ifNotExists") ddlOrReplace = pp.Group( pp.CaselessKeyword("OR") + pp.CaselessKeyword("REPLACE") ).setResultsName("orReplace") ddlCreateTable = pp.Group( pp.CaselessKeyword("CREATE") + pp.Suppress(pp.Optional(ddlOrReplace)) + pp.CaselessKeyword("TABLE") + pp.Suppress(pp.Optional(ddlIfNotExists)) + ddlName.setResultsName("tableName") + ddlLeft + pp.Group(pp.delimitedList(pp.Suppress(ddlConstraint) | ddlColumn)).setResultsName( "columns" ) + ddlRight ).setResultsName("create") # ddlString.setDebug(True) #uncomment to debug pyparsing ddl = pp.OneOrMore(pp.Group(pp.Suppress(pp.SkipTo(ddlCreateTable, False)) + pp.Located(ddlCreateTable))).setResultsName("tables").setParseAction(SetContent) ddlComment = pp.oneOf(["--", "#"]) + pp.restOfLine ddl.ignore(ddlComment) def testBoolean(): for t in ddlBooleanTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "boolean" def testInteger(): for t in ddlIntegerTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "integer" def testSerial(): for t in ddlSerialTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "integer" assert result.hasAutoValue def testFloatingPoint(): for t in ddlFloatingPointTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "floating_point" def testText(): for t in ddlTextTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "text" def testBlob(): for t in ddlBlobTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "blob" def testDate(): for t in ddlDateTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "day_point" def testDateTime(): for t in ddlDateTimeTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "time_point" def testTime(): for t in ddlTimeTypes: result = ddlType.parseString(t, parseAll=True) assert result[0] == "time_of_day" def testUnknown(): for t in ["cheesecake", "blueberry"]: result = ddlType.parseString(t, parseAll=True) assert result[0] == "UNKNOWN" def testAutoValue(): def test(s, expected): results = ddlAutoValue.parseString(s, parseAll=True) print(results) def testColumn(): text = "\"id\" int(8) unsigned NOT NULL DEFAULT nextval('dk_id_seq'::regclass)" result = ddlColumn.parseString(text, parseAll=True) column = result[0] assert column.name == "id" assert column.type == "integer" assert column.isUnsigned assert column.notNull assert not column.hasAutoValue assert not column.isPrimaryKey def testConstraint(): for text in [ "CONSTRAINT unique_person UNIQUE (first_name, last_name)", "UNIQUE (id)", "UNIQUE (first_name,last_name)" ]: result = ddlConstraint.parseString(text, parseAll=True) assert result.isConstraint def testMathExpression(): text = "2 DIV 2" result = ddlExpression.parseString(text, parseAll=True) assert len(result) == 3 assert result[0] == "2" assert result[1] == "DIV" assert result[2] == "2" def testRational(): for text in [ "pos RATIONAL NOT NULL DEFAULT nextval('rational_seq')::integer", ]: result = ddlColumn.parseString(text, parseAll=True) column = result[0] assert column.name == "pos" assert column.type == "text" assert column.notNull def testTable(): text = """ CREATE TABLE "public"."dk" ( "id" int8 NOT NULL DEFAULT nextval('dk_id_seq'::regclass), "strange" NUMERIC(314, 15), "last_update" timestamp(6) DEFAULT now(), PRIMARY KEY (id) ) """ result = ddlCreateTable.parseString(text, parseAll=True) def testPrimaryKeyAutoIncrement(): for text in [ "CREATE TABLE tab (col INTEGER NOT NULL AUTO_INCREMENT PRIMARY KEY)", # mysql "CREATE TABLE tab (col INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT)", # mysql "CREATE TABLE tab (col INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT)", # sqlite ]: result = ddlCreateTable.parseString(text, parseAll=True) assert len(result) == 1 table = result[0] assert table.tableName == "tab" assert len(table.columns) == 1 column = table.columns[0] assert not column.isConstraint assert column.name == "col" assert column.type == "integer" assert column.notNull assert column.hasAutoValue assert column.isPrimaryKey def testParser(): initDllParser() testBoolean() testInteger() testSerial() testFloatingPoint() testText() testBlob() testDate() testTime() testUnknown() testDateTime() testColumn() testConstraint() testMathExpression() testRational() testTable() testPrimaryKeyAutoIncrement() # HELPERS def get_include_guard_name(namespace, inputfile): val = re.sub("[^A-Za-z0-9]+", "_", namespace + "_" + os.path.basename(inputfile)) return val.upper() def identity_naming_func(s): return s def repl_camel_case_func(m): if m.group(1) == "_": return m.group(2).upper() else: return m.group(1) + m.group(2).upper() def class_name_naming_func(s): s = s.replace(".", "_") return re.sub(r"(^|\s|[_0-9])(\S)", repl_camel_case_func, s) def member_name_naming_func(s): s = s.replace(".", "_") return re.sub(r"(\s|_|[0-9])(\S)", repl_camel_case_func, s) def repl_func_for_args(m): if m.group(1) == "-": return m.group(2).upper() else: return m.group(1) + m.group(2).upper() def setArgumentBool(s, bool_value): first_lower = ( lambda s: s[:1].lower() + s[1:] if s else "" ) # http://stackoverflow.com/a/3847369/5006740 var_name = first_lower(re.sub(r"(\s|-|[0-9])(\S)", repl_func_for_args, s)) globals()[var_name] = bool_value def loadExtendedTypesFile(filename): import csv with open(filename, newline='') as csvfile: reader = csv.DictReader(csvfile, fieldnames=["baseType"], restkey="extendedTypes", delimiter=',') for row in reader: var_values = [clean_val for value in row['extendedTypes'] if (clean_val := value.strip(" \"'"))] if var_values: var_name = f"ddl{row['baseType']}Types" globals()[var_name].extend(var_values) def escape_if_reserved(name): reserved_names = [ "BEGIN", "END", "GROUP", "ORDER", ] if name.upper() in reserved_names: return "!{}".format(name) return name def beginHeader(pathToHeader, namespace, nsList): header = open(pathToHeader, "w") print('#pragma once', file=header) print('', file=header) print("// generated by " + " ".join(sys.argv), file=header) print("", file=header) print("#include ", file=header) print("#include ", file=header) print("#include ", file=header) print("#include ", file=header) print("", file=header) for ns in nsList: print("namespace " + ns, file=header) print("{", file=header) return header def endHeader(header, nsList): for ns in reversed(nsList): print("} // namespace " + ns, file=header) header.close() def help_message(): arg_string = "" pad = 0 # The dataTypeFileArg is handled differently from the normal optionalArgs # and only added to the list here to make use of the formatting of the help. optionalArgs[dataTypeFileArg] = f"path to a csv that contains custom datatype mappings. The format is '{dataTypeFileArg}=path/to/file.csv' (See the README)." for argument in list(optionalArgs.keys()): if len(argument) > pad: pad = len(argument) for argument in list(optionalArgs.keys()): if len(argument) < pad: padding = " " * (pad - len(argument)) else: padding = "" arg_string = ( arg_string + argument + ": " + padding + optionalArgs[argument] + "\n" ) print( "Usage:\n" "ddl2cpp [optional args] \n\n" "OPTIONAL ARGUMENTS:\n" + arg_string + "\n" " path to your SQL database/table definitions (SHOW CREATE TABLE SomeTable) \n" " path to a generated C++ header file without extension (no *.h). \n" " namespace you want. Usually a project/database name\n" ) sys.exit(0) optionalArgs = { # if -some-key is present, it will set variable someKey to True "-no-timestamp-warning": "show warning about date / time data types", # noTimeStampWarning = True "-auto-id": "Assume column 'id' to have an automatic value as if AUTO_INCREMENT was specified (e.g. implicit for SQLite ROWID)", # autoId = True "-identity-naming": "Use table and column names from the ddl (defaults to UpperCamelCase for tables and lowerCamelCase for columns)", # identityNaming = True "-split-tables": "Make a header for each table name, using target as a directory", # splitTables = True "-with-table-creation-helper": "Create a helper function for each table that drops and creates the table", # withTableCreationHelper "--help": "show this help", "--test": "run parser self-test", } noTimestampWarning = False autoId = False identityNaming = False splitTables = False withTableCreationHelper = False dataTypeFileArg = "--datatype-file" def createHeader(): global noTimestampWarning # ARGUMENT PARSING if len(sys.argv) < (4): help_message() sys.exit(ERROR_BAD_ARGS) firstPositional = 1 if len(sys.argv) >= 4: for arg in sys.argv: if arg in list(optionalArgs.keys()): setArgumentBool(arg, True) firstPositional += 1 if dataTypeFileArg in arg: loadExtendedTypesFile(arg.split('=')[1]) firstPositional += 1 else: pass if identityNaming: toClassName = identity_naming_func toMemberName = identity_naming_func else: toClassName = class_name_naming_func toMemberName = member_name_naming_func pathToDdl = sys.argv[firstPositional] pathToHeader = sys.argv[firstPositional + 1] + ("/" if splitTables else ".h") namespace = sys.argv[firstPositional + 2] initDllParser() try: tableCreations = ddl.parseFile(pathToDdl) except pp.ParseException as e: print("ERROR: Could not parse any CREATE TABLE statement in " + pathToDdl) # print(pp.parseError) sys.exit(ERROR_STRANGE_PARSING) nsList = namespace.split("::") # PROCESS DDL tableCreations = ddl.parseFile(pathToDdl) header = 0 if not splitTables: header = beginHeader(pathToHeader, namespace, nsList) DataTypeError = False for table in tableCreations.tables: create = table.value.create sqlTableName = create.tableName if splitTables: header = beginHeader(pathToHeader + sqlTableName + ".h", namespace, nsList) tableClass = toClassName(sqlTableName) tableMember = toMemberName(sqlTableName) tableSpec = tableClass + "_" tableTemplateParameters = "" tableRequiredInsertColumns = "" if withTableCreationHelper: print(" template", file=header) print(" void create" + tableClass + "(Db& db)", file=header) print(" {", file=header) print(" db.execute(R\"+++(DROP TABLE IF EXISTS " + sqlTableName + ")+++\");", file=header) print(" db.execute(R\"+++(" + parsedContent[table.locn_start:table.locn_end] + ")+++\");", file=header) print(" }", file=header) print("", file=header) print(" struct " + tableSpec + " : public ::sqlpp::name_tag_base", file=header) print(" {", file=header) for column in create.columns: if column.isConstraint: continue sqlColumnName = column.name columnClass = toClassName(sqlColumnName) columnMember = toMemberName(sqlColumnName) columnType = column.type if columnType == "UNKNOWN": print( "Error: datatype of %s.%s is not supported." % (sqlTableName, sqlColumnName) ) DataTypeError = True if columnType == "integral" and column.isUnsigned: columnType = "unsigned_" + columnType if columnType == "time_point" and not noTimestampWarning: print( "Warning: date and time values are assumed to be without timezone." ) print( "Warning: If you are using types WITH timezones, your code has to deal with that." ) print("You can disable this warning using -no-timestamp-warning") noTimestampWarning = True print(" struct " + columnClass + " : public ::sqlpp::name_tag_base", file=header) print(" {", file=header) print(" struct _alias_t", file=header) print(" {", file=header) print( ' static constexpr const char _literal[] = "' + escape_if_reserved(sqlColumnName) + '";', file=header, ) print( " using _name_t = ::sqlpp::make_char_sequence;", file=header, ) print(" template", file=header) print(" struct _member_t", file=header) print(" {", file=header) print(" T " + columnMember + ";", file=header) print( " T& operator()() { return " + columnMember + "; }", file=header, ) print( " const T& operator()() const { return " + columnMember + "; }", file=header, ) print(" };", file=header) print(" };", file=header) columnCanBeNull = not column.notNull and not column.isPrimaryKey if columnCanBeNull: print(" using value_type = ::sqlpp::compat::optional<::sqlpp::" + columnType + ">;", file=header) else: print(" using value_type = ::sqlpp::" + columnType + ";", file=header) columnHasDefault = column.hasDefaultValue or \ columnCanBeNull or \ column.hasAutoValue or \ (autoId and sqlColumnName == "id") if columnHasDefault: print(" using has_default = std::true_type;", file=header) else: print(" using has_default = std::false_type;", file=header) print(" };", file=header) if tableTemplateParameters: tableTemplateParameters += "," tableTemplateParameters += "\n " + columnClass if not columnHasDefault: if tableRequiredInsertColumns: tableRequiredInsertColumns += "," tableRequiredInsertColumns += "\n sqlpp::column_t<" + tableSpec + ", " + columnClass + ">"; print(" struct _alias_t", file=header) print(" {", file=header) print( ' static constexpr const char _literal[] = "' + sqlTableName + '";', file=header, ) print( " using _name_t = ::sqlpp::make_char_sequence;", file=header, ) print(" template", file=header) print(" struct _member_t", file=header) print(" {", file=header) print(" T " + tableMember + ";", file=header) print(" T& operator()() { return " + tableMember + "; }", file=header) print( " const T& operator()() const { return " + tableMember + "; }", file=header, ) print(" };", file=header) print(" };", file=header) print(" template", file=header) print(" using _table_columns = sqlpp::table_columns;", file=header) print(" using _required_insert_columns = sqlpp::detail::type_set<" + tableRequiredInsertColumns + ">;", file=header) print(" };", file=header) print( " using " + tableClass + " = ::sqlpp::table_t<" + tableSpec + ">;", file=header) print("", file=header) if splitTables: endHeader(header, nsList) if not splitTables: endHeader(header, nsList) if DataTypeError: print("Error: unsupported datatypes.") print("Possible solutions:") print("A) Implement this datatype (examples: sqlpp11/data_types)") print(f"B) Use the '{dataTypeFileArg}' command line argument to map the type to a known type (example: README)") print("C) Extend/upgrade ddl2cpp (edit types map)") print("D) Raise an issue on github") sys.exit(10) # return non-zero error code, we might need it for automation if __name__ == "__main__": if "--help" in sys.argv: help_message() sys.exit() elif "--test" in sys.argv: testParser() sys.exit() else: createHeader()