Fixed #17785 -- Preferred column names in get_relations introspection

Thanks Thomas Güttler for the report and the initial patch, and
Tim Graham for the review.
This commit is contained in:
Claude Paroz 2015-01-10 20:27:30 +01:00
parent b75c707943
commit 4c413e231c
7 changed files with 48 additions and 44 deletions

View File

@ -73,11 +73,11 @@ class Command(BaseCommand):
except NotImplementedError: except NotImplementedError:
constraints = {} constraints = {}
used_column_names = [] # Holds column names used in the table so far used_column_names = [] # Holds column names used in the table so far
for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)): for row in connection.introspection.get_table_description(cursor, table_name):
comment_notes = [] # Holds Field notes, to be displayed in a Python comment. comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = OrderedDict() # Holds Field parameters such as 'db_column'. extra_params = OrderedDict() # Holds Field parameters such as 'db_column'.
column_name = row[0] column_name = row[0]
is_relation = i in relations is_relation = column_name in relations
att_name, params, notes = self.normalize_col_name( att_name, params, notes = self.normalize_col_name(
column_name, used_column_names, is_relation) column_name, used_column_names, is_relation)
@ -94,7 +94,7 @@ class Command(BaseCommand):
extra_params['unique'] = True extra_params['unique'] = True
if is_relation: if is_relation:
rel_to = "self" if relations[i][1] == table_name else table2model(relations[i][1]) rel_to = "self" if relations[column_name][1] == table_name else table2model(relations[column_name][1])
if rel_to in known_models: if rel_to in known_models:
field_type = 'ForeignKey(%s' % rel_to field_type = 'ForeignKey(%s' % rel_to
else: else:

View File

@ -80,25 +80,15 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
) )
return fields return fields
def _name_to_index(self, cursor, table_name):
"""
Returns a dictionary of {field_name: field_index} for the given table.
Indexes are 0-based.
"""
return {d[0]: i for i, d in enumerate(self.get_table_description(cursor, table_name))}
def get_relations(self, cursor, table_name): def get_relations(self, cursor, table_name):
""" """
Returns a dictionary of {field_index: (field_index_other_table, other_table)} Returns a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based. representing all relationships to the given table.
""" """
my_field_dict = self._name_to_index(cursor, table_name)
constraints = self.get_key_columns(cursor, table_name) constraints = self.get_key_columns(cursor, table_name)
relations = {} relations = {}
for my_fieldname, other_table, other_field in constraints: for my_fieldname, other_table, other_field in constraints:
other_field_index = self._name_to_index(cursor, other_table)[other_field] relations[my_fieldname] = (other_field, other_table)
my_field_index = my_field_dict[my_fieldname]
relations[my_field_index] = (other_field_index, other_table)
return relations return relations
def get_key_columns(self, cursor, table_name): def get_key_columns(self, cursor, table_name):

View File

@ -78,12 +78,12 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def get_relations(self, cursor, table_name): def get_relations(self, cursor, table_name):
""" """
Returns a dictionary of {field_index: (field_index_other_table, other_table)} Returns a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based. representing all relationships to the given table.
""" """
table_name = table_name.upper() table_name = table_name.upper()
cursor.execute(""" cursor.execute("""
SELECT ta.column_id - 1, tb.table_name, tb.column_id - 1 SELECT ta.column_name, tb.table_name, tb.column_name
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb, FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb,
user_tab_cols ta, user_tab_cols tb user_tab_cols ta, user_tab_cols tb
WHERE user_constraints.table_name = %s AND WHERE user_constraints.table_name = %s AND

View File

@ -69,20 +69,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def get_relations(self, cursor, table_name): def get_relations(self, cursor, table_name):
""" """
Returns a dictionary of {field_index: (field_index_other_table, other_table)} Returns a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based. representing all relationships to the given table.
""" """
cursor.execute(""" cursor.execute("""
SELECT con.conkey, con.confkey, c2.relname SELECT c2.relname, a1.attname, a2.attname
FROM pg_constraint con, pg_class c1, pg_class c2 FROM pg_constraint con
WHERE c1.oid = con.conrelid LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
AND c2.oid = con.confrelid LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
AND c1.relname = %s LEFT JOIN pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
LEFT JOIN pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
WHERE c1.relname = %s
AND con.contype = 'f'""", [table_name]) AND con.contype = 'f'""", [table_name])
relations = {} relations = {}
for row in cursor.fetchall(): for row in cursor.fetchall():
# row[0] and row[1] are single-item lists, so grab the single item. relations[row[1]] = (row[2], row[0])
relations[row[0][0] - 1] = (row[1][0] - 1, row[2])
return relations return relations
def get_key_columns(self, cursor, table_name): def get_key_columns(self, cursor, table_name):

View File

@ -106,23 +106,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Walk through and look for references to other tables. SQLite doesn't # Walk through and look for references to other tables. SQLite doesn't
# really have enforced references, but since it echoes out the SQL used # really have enforced references, but since it echoes out the SQL used
# to create the table we can look for REFERENCES statements used there. # to create the table we can look for REFERENCES statements used there.
field_names = [] for field_desc in results.split(','):
for field_index, field_desc in enumerate(results.split(',')):
field_desc = field_desc.strip() field_desc = field_desc.strip()
if field_desc.startswith("UNIQUE"): if field_desc.startswith("UNIQUE"):
continue continue
field_names.append(field_desc.split()[0].strip('"'))
m = re.search('references (\S*) ?\(["|]?(.*)["|]?\)', field_desc, re.I) m = re.search('references (\S*) ?\(["|]?(.*)["|]?\)', field_desc, re.I)
if not m: if not m:
continue continue
table, column = [s.strip('"') for s in m.groups()] table, column = [s.strip('"') for s in m.groups()]
if field_desc.startswith("FOREIGN KEY"): if field_desc.startswith("FOREIGN KEY"):
# Find index of the target FK field # Find name of the target FK field
m = re.match('FOREIGN KEY\(([^\)]*)\).*', field_desc, re.I) m = re.match('FOREIGN KEY\(([^\)]*)\).*', field_desc, re.I)
fkey_field = m.groups()[0].strip('"') field_name = m.groups()[0].strip('"')
field_index = field_names.index(fkey_field) else:
field_name = field_desc.split()[0].strip('"')
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s", [table]) cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s", [table])
result = cursor.fetchall()[0] result = cursor.fetchall()[0]
@ -130,14 +129,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
li, ri = other_table_results.index('('), other_table_results.rindex(')') li, ri = other_table_results.index('('), other_table_results.rindex(')')
other_table_results = other_table_results[li + 1:ri] other_table_results = other_table_results[li + 1:ri]
for other_index, other_desc in enumerate(other_table_results.split(',')): for other_desc in other_table_results.split(','):
other_desc = other_desc.strip() other_desc = other_desc.strip()
if other_desc.startswith('UNIQUE'): if other_desc.startswith('UNIQUE'):
continue continue
name = other_desc.split(' ', 1)[0].strip('"') other_name = other_desc.split(' ', 1)[0].strip('"')
if name == column: if other_name == column:
relations[field_index] = (other_index, table) relations[field_name] = (other_name, table)
break break
return relations return relations

View File

@ -24,6 +24,7 @@ class Reporter(models.Model):
class Article(models.Model): class Article(models.Model):
headline = models.CharField(max_length=100) headline = models.CharField(max_length=100)
pub_date = models.DateField() pub_date = models.DateField()
body = models.TextField(default='')
reporter = models.ForeignKey(Reporter) reporter = models.ForeignKey(Reporter)
response_to = models.ForeignKey('self', null=True) response_to = models.ForeignKey('self', null=True)

View File

@ -117,19 +117,32 @@ class IntrospectionTests(TransactionTestCase):
with connection.cursor() as cursor: with connection.cursor() as cursor:
relations = connection.introspection.get_relations(cursor, Article._meta.db_table) relations = connection.introspection.get_relations(cursor, Article._meta.db_table)
# That's {field_index: (field_index_other_table, other_table)} # That's {field_name: (field_name_other_table, other_table)}
self.assertEqual(relations, {3: (0, Reporter._meta.db_table), expected_relations = {
4: (0, Article._meta.db_table)}) 'reporter_id': ('id', Reporter._meta.db_table),
'response_to_id': ('id', Article._meta.db_table),
}
self.assertEqual(relations, expected_relations)
# Removing a field shouldn't disturb get_relations (#17785)
body = Article._meta.get_field('body')
with connection.schema_editor() as editor:
editor.remove_field(Article, body)
with connection.cursor() as cursor:
relations = connection.introspection.get_relations(cursor, Article._meta.db_table)
with connection.schema_editor() as editor:
editor.add_field(Article, body)
self.assertEqual(relations, expected_relations)
@skipUnless(connection.vendor == 'sqlite', "This is an sqlite-specific issue") @skipUnless(connection.vendor == 'sqlite', "This is an sqlite-specific issue")
def test_get_relations_alt_format(self): def test_get_relations_alt_format(self):
"""With SQLite, foreign keys can be added with different syntaxes.""" """With SQLite, foreign keys can be added with different syntaxes."""
with connection.cursor() as cursor: with connection.cursor() as cursor:
cursor.fetchone = mock.Mock(return_value=[ cursor.fetchone = mock.Mock(return_value=[
"CREATE TABLE track(id, art INTEGER, FOREIGN KEY(art) REFERENCES %s(id));" % Article._meta.db_table "CREATE TABLE track(id, art_id INTEGER, FOREIGN KEY(art_id) REFERENCES %s(id));" % Article._meta.db_table
]) ])
relations = connection.introspection.get_relations(cursor, 'mocked_table') relations = connection.introspection.get_relations(cursor, 'mocked_table')
self.assertEqual(relations, {1: (0, Article._meta.db_table)}) self.assertEqual(relations, {'art_id': ('id', Article._meta.db_table)})
@skipUnlessDBFeature('can_introspect_foreign_keys') @skipUnlessDBFeature('can_introspect_foreign_keys')
def test_get_key_columns(self): def test_get_key_columns(self):