0
0
mirror of https://github.com/rbock/sqlpp11.git synced 2024-11-16 04:47:18 +08:00

Support schema in ddl #444

This commit is contained in:
Roland Bock 2022-07-03 08:19:28 +02:00
parent 648183fd64
commit 159f4be66f
2 changed files with 33 additions and 18 deletions

View File

@ -42,8 +42,8 @@ ddlNumber = pp.Word(pp.nums + "+-.", pp.nums + "+-.Ee")
ddlString = ( ddlString = (
pp.QuotedString("'") | pp.QuotedString('"', escQuote='""') | pp.QuotedString("`") pp.QuotedString("'") | pp.QuotedString('"', escQuote='""') | pp.QuotedString("`")
) )
ddlTerm = pp.Word(pp.alphas + "_", pp.alphanums + "_.$") ddlTerm = pp.Word(pp.alphas + "_", pp.alphanums + "_$")
ddlName = pp.Or([ddlTerm, ddlString, pp.Combine(ddlString + "." + ddlString)]) ddlName = ddlTerm | ddlString
ddlOperator = pp.Or( ddlOperator = pp.Or(
map(pp.CaselessLiteral, ["+", "-", "*", "/", "<", "<=", ">", ">=", "=", "%"]) map(pp.CaselessLiteral, ["+", "-", "*", "/", "<", "<=", ">", ">=", "=", "%"])
) )
@ -69,7 +69,7 @@ ddlBracedArguments << ddlLeft + pp.delimitedList(ddlExpression) + ddlRight
ddlBracedExpression << ddlLeft + ddlExpression + ddlRight ddlBracedExpression << ddlLeft + ddlExpression + ddlRight
ddlArguments = pp.Suppress(pp.Group(pp.delimitedList(ddlExpression))) ddlArguments = pp.Suppress(pp.Group(pp.delimitedList(ddlExpression)))
ddlFunctionCall << ddlName + ddlLeft + pp.Optional(ddlArguments) + ddlRight ddlFunctionCall << pp.Optional(ddlName + ".") + ddlName + ddlLeft + pp.Optional(ddlArguments) + ddlRight
# Column and constraint parsers # Column and constraint parsers
ddlBooleanTypes = [ ddlBooleanTypes = [
@ -270,6 +270,7 @@ ddlCreateTable = pp.Group(
+ pp.Suppress(pp.Optional(ddlOrReplace)) + pp.Suppress(pp.Optional(ddlOrReplace))
+ pp.CaselessLiteral("TABLE") + pp.CaselessLiteral("TABLE")
+ pp.Suppress(pp.Optional(ddlIfNotExists)) + pp.Suppress(pp.Optional(ddlIfNotExists))
+ pp.Optional(ddlName.setResultsName("schema") + pp.Suppress('.'))
+ ddlName.setResultsName("tableName") + ddlName.setResultsName("tableName")
+ ddlLeft + ddlLeft
+ pp.Group(pp.delimitedList(pp.Suppress(ddlConstraint) | ddlColumn)).setResultsName( + pp.Group(pp.delimitedList(pp.Suppress(ddlConstraint) | ddlColumn)).setResultsName(
@ -386,14 +387,17 @@ def testRational():
def testTable(): def testTable():
text = """ text = """
CREATE TABLE "public"."dk" ( CREATE TABLE 'public'.'dk' (
"id" int8 NOT NULL DEFAULT nextval('dk_id_seq'::regclass), "id" int8 NOT NULL DEFAULT nextval('dk_id_seq'::regclass),
"strange" NUMERIC(314, 15), "strange" NUMERIC(314, 15),
"last_update" timestamp(6) DEFAULT now(), "last_update" timestamp(6) DEFAULT now(),
PRIMARY KEY (id) PRIMARY KEY (id)
) )
""" """
result = ddlCreateTable.parseString(text, parseAll=True) result = ddlCreateTable.parseString(text, parseAll=True)
table = result[0]
assert table.schema == "public"
assert table.tableName == "dk"
def testParser(): def testParser():
@ -422,8 +426,10 @@ def get_include_guard_name(namespace, inputfile):
return val.upper() return val.upper()
def identity_naming_func(s): def identity_naming_func(name, schema = None):
return s if schema:
return schema + '__' + name;
return name
def repl_camel_case_func(m): def repl_camel_case_func(m):
@ -433,14 +439,18 @@ def repl_camel_case_func(m):
return m.group(1) + m.group(2).upper() return m.group(1) + m.group(2).upper()
def class_name_naming_func(s): def class_name_naming_func(name, schema = None):
s = s.replace(".", "_") if schema:
return re.sub("(^|\s|[_0-9])(\S)", repl_camel_case_func, s) name = schema + "__" + name
name = name.replace(".", "_")
return re.sub("(^|\s|[_0-9])(\S)", repl_camel_case_func, name)
def member_name_naming_func(s): def member_name_naming_func(name, schema = None):
s = s.replace(".", "_") if schema:
return re.sub("(\s|_|[0-9])(\S)", repl_camel_case_func, s) name = schema + "_" + name
name = name.replace(".", "_")
return re.sub("(\s|_|[0-9])(\S)", repl_camel_case_func, name)
def repl_func_for_args(m): def repl_func_for_args(m):
@ -579,11 +589,16 @@ def createHeader():
header = beginHeader(pathToHeader, namespace, nsList) header = beginHeader(pathToHeader, namespace, nsList)
DataTypeError = False DataTypeError = False
for create in tableCreations: for create in tableCreations:
sqlTableName = create.tableName if identityNaming:
sqlSchema = create.schema
sqlTableName = create.tableName
else:
sqlSchema = None
sqlTableName = create.schema + '.' + create.tableName
if splitTables: if splitTables:
header = beginHeader(pathToHeader + sqlTableName + ".h", namespace, nsList) header = beginHeader(pathToHeader + sqlTableName + ".h", namespace, nsList)
tableClass = toClassName(sqlTableName) tableClass = toClassName(sqlTableName, sqlSchema)
tableMember = toMemberName(sqlTableName) tableMember = toMemberName(sqlTableName, sqlSchema)
tableNamespace = tableClass + "_" tableNamespace = tableClass + "_"
tableTemplateParameters = tableClass tableTemplateParameters = tableClass
print(" namespace " + tableNamespace, file=header) print(" namespace " + tableNamespace, file=header)
@ -682,7 +697,7 @@ def createHeader():
print(" struct _alias_t", file=header) print(" struct _alias_t", file=header)
print(" {", file=header) print(" {", file=header)
print( print(
' static constexpr const char _literal[] = "' + sqlTableName + '";', ' static constexpr const char _literal[] = "' + (sqlSchema + '.' if sqlSchema else '') + sqlTableName + '";',
file=header, file=header,
) )
print( print(

View File

@ -1,4 +1,4 @@
CREATE TABLE tab_sample ( CREATE TABLE public.tab_sample (
alpha bigint(20) DEFAULT NULL AUTO_INCREMENT, alpha bigint(20) DEFAULT NULL AUTO_INCREMENT,
beta tinyint(1) DEFAULT NULL, beta tinyint(1) DEFAULT NULL,
gamma varchar(255) DEFAULT NULL gamma varchar(255) DEFAULT NULL