aboutsummaryrefslogtreecommitdiff
path: root/db.py
diff options
context:
space:
mode:
Diffstat (limited to 'db.py')
-rw-r--r--db.py34
1 files changed, 28 insertions, 6 deletions
diff --git a/db.py b/db.py
index df9ceda..e75adca 100644
--- a/db.py
+++ b/db.py
@@ -1,18 +1,40 @@
+from django.db import connections
from django.db.models import Manager
from django.db.models.query import QuerySet
+from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.query import Query
+from django.db.models.sql.where import WhereNode
+
+
+class GroupedCompiler(SQLCompiler):
+ def get_group_by(self, select, order_by):
+ result = super(GroupedCompiler, self).get_group_by(select, order_by)
+ expressions = []
+ for expr in self.query.real_group_by:
+ ref = expr if hasattr(expr, "as_sql") else self.query.resolve_ref(expr)
+ sql, params = self.compile(ref)
+ result.append((sql, params))
+
+ return result
class GroupedQuery(Query):
- def add_grouping(self, *grouping):
- if self.group_by is None:
- self.group_by = []
+ def __init__(self, model, where=WhereNode):
+ super(GroupedQuery, self).__init__(model, where)
+ self.real_group_by = []
- if isinstance(self.group_by, list):
- self.group_by.extend(grouping)
+ def add_grouping(self, *grouping):
+ self.real_group_by.extend(grouping)
def clear_grouping(self):
- self.group_by = None
+ self.real_group_by = []
+
+ def get_compiler(self, using=None, connection=None):
+ if using is None and connection is None:
+ raise ValueError("Need either using or connection")
+ if using:
+ connection = connections[using]
+ return GroupedCompiler(self, connection, using)
class GroupedQuerySet(QuerySet):