aboutsummaryrefslogtreecommitdiff
path: root/db.py
blob: 94df3765593035dfe4acb00889254c18d0112c93 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#    Copyright (C) 2017  Alban Gruin
#
#    celcatsanitizer is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    celcatsanitizer is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with celcatsanitizer; if not, write to the Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

from django.db import connections
from django.db.models import Manager
from django.db.models.functions import Extract
from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.query import Query
from django.db.models.sql.where import WhereNode
from django.db.models.query import QuerySet

class ExtractWeek(Extract):
    lookup_name = "week"


class GroupedCompiler(SQLCompiler):
    def get_group_by(self, select, order_by):
        result = super(GroupedCompiler, self).get_group_by(select, order_by)
        expressions = []
        for expr in self.query.real_group_by:
            ref = expr if hasattr(expr, "as_sql") else self.query.resolve_ref(expr)
            sql, params = self.compile(ref)
            result.append((sql, params))

        return result


class GroupedQuery(Query):
    def __init__(self, model, where=WhereNode):
        super(GroupedQuery, self).__init__(model, where)
        self.real_group_by = []

    def clone(self, klass=None, memo=None, **kwargs):
        obj = super(GroupedQuery, self).clone(klass, memo, **kwargs)
        obj.real_group_by = self.real_group_by[:]
        return obj

    def add_grouping(self, *grouping):
        self.real_group_by.extend(grouping)

    def clear_grouping(self):
        self.real_group_by = []

    def get_compiler(self, using=None, connection=None):
        if using is None and connection is None:
            raise ValueError("Need either using or connection")
        if using:
            connection = connections[using]
        return GroupedCompiler(self, connection, using)


class GroupedQuerySet(QuerySet):
    def __init__(self, model=None, query=None, using=None, hints=None):
        super(GroupedQuerySet, self).__init__(model, query, using, hints)
        self.query = query or GroupedQuery(self.model)

    def group_by(self, *field_names):
        obj = self._clone()
        obj.query.clear_grouping()
        obj.query.add_grouping(*field_names)
        return obj


class GroupedManager(Manager):
    def __init__(self):
        super(GroupedManager, self).__init__()
        self._queryset_class = GroupedQuerySet