aboutsummaryrefslogtreecommitdiff
path: root/db.py
diff options
context:
space:
mode:
Diffstat (limited to 'db.py')
-rw-r--r--db.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/db.py b/db.py
new file mode 100644
index 0000000..2d4d320
--- /dev/null
+++ b/db.py
@@ -0,0 +1,60 @@
+from django.db import connections
+from django.db.models import Manager
+from django.db.models.query import QuerySet
+from django.db.models.sql.compiler import SQLCompiler
+from django.db.models.sql.query import Query
+from django.db.models.sql.where import WhereNode
+
+
+class GroupedCompiler(SQLCompiler):
+ def get_group_by(self, select, order_by):
+ result = super(GroupedCompiler, self).get_group_by(select, order_by)
+ expressions = []
+ for expr in self.query.real_group_by:
+ ref = expr if hasattr(expr, "as_sql") else self.query.resolve_ref(expr)
+ sql, params = self.compile(ref)
+ result.append((sql, params))
+
+ return result
+
+
+class GroupedQuery(Query):
+ def __init__(self, model, where=WhereNode):
+ super(GroupedQuery, self).__init__(model, where)
+ self.real_group_by = []
+
+ def clone(self, klass=None, memo=None, **kwargs):
+ obj = super(GroupedQuery, self).clone(klass, memo, **kwargs)
+ obj.real_group_by = self.real_group_by[:]
+ return obj
+
+ def add_grouping(self, *grouping):
+ self.real_group_by.extend(grouping)
+
+ def clear_grouping(self):
+ self.real_group_by = []
+
+ def get_compiler(self, using=None, connection=None):
+ if using is None and connection is None:
+ raise ValueError("Need either using or connection")
+ if using:
+ connection = connections[using]
+ return GroupedCompiler(self, connection, using)
+
+
+class GroupedQuerySet(QuerySet):
+ def __init__(self, model=None, query=None, using=None, hints=None):
+ super(GroupedQuerySet, self).__init__(model, query, using, hints)
+ self.query = query or GroupedQuery(self.model)
+
+ def group_by(self, *field_names):
+ obj = self._clone()
+ obj.query.clear_grouping()
+ obj.query.add_grouping(*field_names)
+ return obj
+
+
+class GroupedManager(Manager):
+ def __init__(self):
+ super(GroupedManager, self).__init__()
+ self._queryset_class = GroupedQuerySet