django/django/db/backends/sqlite3/_functions.py
Adam Johnson deec15a9a6 Refs #33355 -- Made trunc functions raise ValueError on invalid lookups on SQLite.
Co-Authored-By: Nick Pope <nick@nickpope.me.uk>
2021-12-23 11:47:13 +01:00

307 lines
12 KiB
Python

"""
Implementations of SQL functions for SQLite.
"""
import functools
import operator
import random
import statistics
from datetime import timedelta
from hashlib import sha1, sha224, sha256, sha384, sha512
from math import (
acos, asin, atan, atan2, ceil, cos, degrees, exp, floor, fmod, log, pi,
radians, sin, sqrt, tan,
)
from re import search as re_search
from django.db.backends.base.base import timezone_constructor
from django.db.backends.utils import (
split_tzname_delta, typecast_time, typecast_timestamp,
)
from django.utils import timezone
from django.utils.crypto import md5
from django.utils.duration import duration_microseconds
def none_guard(func):
"""
Decorator that returns None if any of the arguments to the decorated
function are None. Many SQL functions return NULL if any of their arguments
are NULL. This decorator simplifies the implementation of this for the
custom functions registered below.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
return None if None in args else func(*args, **kwargs)
return wrapper
def list_aggregate(function):
"""
Return an aggregate class that accumulates values in a list and applies
the provided function to the data.
"""
return type('ListAggregate', (list,), {'finalize': function, 'step': list.append})
def register(connection):
create_deterministic_function = functools.partial(
connection.create_function,
deterministic=True,
)
create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract)
create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc)
create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date)
create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time)
create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract)
create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc)
create_deterministic_function('django_time_extract', 2, _sqlite_time_extract)
create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc)
create_deterministic_function('django_time_diff', 2, _sqlite_time_diff)
create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff)
create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta)
create_deterministic_function('regexp', 2, _sqlite_regexp)
create_deterministic_function('ACOS', 1, none_guard(acos))
create_deterministic_function('ASIN', 1, none_guard(asin))
create_deterministic_function('ATAN', 1, none_guard(atan))
create_deterministic_function('ATAN2', 2, none_guard(atan2))
create_deterministic_function('BITXOR', 2, none_guard(operator.xor))
create_deterministic_function('CEILING', 1, none_guard(ceil))
create_deterministic_function('COS', 1, none_guard(cos))
create_deterministic_function('COT', 1, none_guard(lambda x: 1 / tan(x)))
create_deterministic_function('DEGREES', 1, none_guard(degrees))
create_deterministic_function('EXP', 1, none_guard(exp))
create_deterministic_function('FLOOR', 1, none_guard(floor))
create_deterministic_function('LN', 1, none_guard(log))
create_deterministic_function('LOG', 2, none_guard(lambda x, y: log(y, x)))
create_deterministic_function('LPAD', 3, _sqlite_lpad)
create_deterministic_function('MD5', 1, none_guard(lambda x: md5(x.encode()).hexdigest()))
create_deterministic_function('MOD', 2, none_guard(fmod))
create_deterministic_function('PI', 0, lambda: pi)
create_deterministic_function('POWER', 2, none_guard(operator.pow))
create_deterministic_function('RADIANS', 1, none_guard(radians))
create_deterministic_function('REPEAT', 2, none_guard(operator.mul))
create_deterministic_function('REVERSE', 1, none_guard(lambda x: x[::-1]))
create_deterministic_function('RPAD', 3, _sqlite_rpad)
create_deterministic_function('SHA1', 1, none_guard(lambda x: sha1(x.encode()).hexdigest()))
create_deterministic_function('SHA224', 1, none_guard(lambda x: sha224(x.encode()).hexdigest()))
create_deterministic_function('SHA256', 1, none_guard(lambda x: sha256(x.encode()).hexdigest()))
create_deterministic_function('SHA384', 1, none_guard(lambda x: sha384(x.encode()).hexdigest()))
create_deterministic_function('SHA512', 1, none_guard(lambda x: sha512(x.encode()).hexdigest()))
create_deterministic_function('SIGN', 1, none_guard(lambda x: (x > 0) - (x < 0)))
create_deterministic_function('SIN', 1, none_guard(sin))
create_deterministic_function('SQRT', 1, none_guard(sqrt))
create_deterministic_function('TAN', 1, none_guard(tan))
# Don't use the built-in RANDOM() function because it returns a value
# in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
connection.create_function('RAND', 0, random.random)
connection.create_aggregate('STDDEV_POP', 1, list_aggregate(statistics.pstdev))
connection.create_aggregate('STDDEV_SAMP', 1, list_aggregate(statistics.stdev))
connection.create_aggregate('VAR_POP', 1, list_aggregate(statistics.pvariance))
connection.create_aggregate('VAR_SAMP', 1, list_aggregate(statistics.variance))
def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
if dt is None:
return None
try:
dt = typecast_timestamp(dt)
except (TypeError, ValueError):
return None
if conn_tzname:
dt = dt.replace(tzinfo=timezone_constructor(conn_tzname))
if tzname is not None and tzname != conn_tzname:
tzname, sign, offset = split_tzname_delta(tzname)
if offset:
hours, minutes = offset.split(':')
offset_delta = timedelta(hours=int(hours), minutes=int(minutes))
dt += offset_delta if sign == '+' else -offset_delta
dt = timezone.localtime(dt, timezone_constructor(tzname))
return dt
def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == 'year':
return f'{dt.year:04d}-01-01'
elif lookup_type == 'quarter':
month_in_quarter = dt.month - (dt.month - 1) % 3
return f'{dt.year:04d}-{month_in_quarter:02d}-01'
elif lookup_type == 'month':
return f'{dt.year:04d}-{dt.month:02d}-01'
elif lookup_type == 'week':
dt = dt - timedelta(days=dt.weekday())
return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d}'
elif lookup_type == 'day':
return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d}'
raise ValueError(f'Unsupported lookup type: {lookup_type!r}')
def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
if dt is None:
return None
dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt_parsed is None:
try:
dt = typecast_time(dt)
except (ValueError, TypeError):
return None
else:
dt = dt_parsed
if lookup_type == 'hour':
return f'{dt.hour:02d}:00:00'
elif lookup_type == 'minute':
return f'{dt.hour:02d}:{dt.minute:02d}:00'
elif lookup_type == 'second':
return f'{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}'
raise ValueError(f'Unsupported lookup type: {lookup_type!r}')
def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
return dt.date().isoformat()
def _sqlite_datetime_cast_time(dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
return dt.time().isoformat()
def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == 'week_day':
return (dt.isoweekday() % 7) + 1
elif lookup_type == 'iso_week_day':
return dt.isoweekday()
elif lookup_type == 'week':
return dt.isocalendar()[1]
elif lookup_type == 'quarter':
return ceil(dt.month / 3)
elif lookup_type == 'iso_year':
return dt.isocalendar()[0]
else:
return getattr(dt, lookup_type)
def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == 'year':
return f'{dt.year:04d}-01-01 00:00:00'
elif lookup_type == 'quarter':
month_in_quarter = dt.month - (dt.month - 1) % 3
return f'{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00'
elif lookup_type == 'month':
return f'{dt.year:04d}-{dt.month:02d}-01 00:00:00'
elif lookup_type == 'week':
dt = dt - timedelta(days=dt.weekday())
return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00'
elif lookup_type == 'day':
return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00'
elif lookup_type == 'hour':
return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00'
elif lookup_type == 'minute':
return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:00'
elif lookup_type == 'second':
return f'{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}'
raise ValueError(f'Unsupported lookup type: {lookup_type!r}')
def _sqlite_time_extract(lookup_type, dt):
if dt is None:
return None
try:
dt = typecast_time(dt)
except (ValueError, TypeError):
return None
return getattr(dt, lookup_type)
def _sqlite_prepare_dtdelta_param(conn, param):
if conn in ['+', '-']:
if isinstance(param, int):
return timedelta(0, 0, param)
else:
return typecast_timestamp(param)
return param
@none_guard
def _sqlite_format_dtdelta(connector, lhs, rhs):
"""
LHS and RHS can be either:
- An integer number of microseconds
- A string representing a datetime
- A scalar value, e.g. float
"""
connector = connector.strip()
try:
real_lhs = _sqlite_prepare_dtdelta_param(connector, lhs)
real_rhs = _sqlite_prepare_dtdelta_param(connector, rhs)
except (ValueError, TypeError):
return None
if connector == '+':
# typecast_timestamp() returns a date or a datetime without timezone.
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
out = str(real_lhs + real_rhs)
elif connector == '-':
out = str(real_lhs - real_rhs)
elif connector == '*':
out = real_lhs * real_rhs
else:
out = real_lhs / real_rhs
return out
@none_guard
def _sqlite_time_diff(lhs, rhs):
left = typecast_time(lhs)
right = typecast_time(rhs)
return (
(left.hour * 60 * 60 * 1000000) +
(left.minute * 60 * 1000000) +
(left.second * 1000000) +
(left.microsecond) -
(right.hour * 60 * 60 * 1000000) -
(right.minute * 60 * 1000000) -
(right.second * 1000000) -
(right.microsecond)
)
@none_guard
def _sqlite_timestamp_diff(lhs, rhs):
left = typecast_timestamp(lhs)
right = typecast_timestamp(rhs)
return duration_microseconds(left - right)
@none_guard
def _sqlite_regexp(pattern, string):
if not isinstance(string, str):
string = str(string)
return bool(re_search(pattern, string))
@none_guard
def _sqlite_lpad(text, length, fill_text):
delta = length - len(text)
if delta <= 0:
return text[:length]
return (fill_text * length)[:delta] + text
@none_guard
def _sqlite_rpad(text, length, fill_text):
if text is None or length is None or fill_text is None:
return None
return (text + fill_text * length)[:length]