diff options
author | Alban Gruin | 2017-01-28 13:27:32 +0100 |
---|---|---|
committer | Alban Gruin | 2017-01-28 13:27:32 +0100 |
commit | d5c846ac3214efcbd2b3cb8d562b58006c3a5eb5 (patch) | |
tree | 4e4721c0cdec315471379692f2f540b7a6c24851 /db.py | |
parent | cfd2969cd9fa18e6148e97c1066341b1c1add6cf (diff) | |
parent | ebb7c3bf0dc3eef2efa3f4add60ccfc7dc063248 (diff) |
Merge branch 'dev/db-groupby' into dev/dbv0.2.0
Diffstat (limited to 'db.py')
-rw-r--r-- | db.py | 60 |
1 files changed, 60 insertions, 0 deletions
@@ -0,0 +1,60 @@ +from django.db import connections +from django.db.models import Manager +from django.db.models.query import QuerySet +from django.db.models.sql.compiler import SQLCompiler +from django.db.models.sql.query import Query +from django.db.models.sql.where import WhereNode + + +class GroupedCompiler(SQLCompiler): + def get_group_by(self, select, order_by): + result = super(GroupedCompiler, self).get_group_by(select, order_by) + expressions = [] + for expr in self.query.real_group_by: + ref = expr if hasattr(expr, "as_sql") else self.query.resolve_ref(expr) + sql, params = self.compile(ref) + result.append((sql, params)) + + return result + + +class GroupedQuery(Query): + def __init__(self, model, where=WhereNode): + super(GroupedQuery, self).__init__(model, where) + self.real_group_by = [] + + def clone(self, klass=None, memo=None, **kwargs): + obj = super(GroupedQuery, self).clone(klass, memo, **kwargs) + obj.real_group_by = self.real_group_by[:] + return obj + + def add_grouping(self, *grouping): + self.real_group_by.extend(grouping) + + def clear_grouping(self): + self.real_group_by = [] + + def get_compiler(self, using=None, connection=None): + if using is None and connection is None: + raise ValueError("Need either using or connection") + if using: + connection = connections[using] + return GroupedCompiler(self, connection, using) + + +class GroupedQuerySet(QuerySet): + def __init__(self, model=None, query=None, using=None, hints=None): + super(GroupedQuerySet, self).__init__(model, query, using, hints) + self.query = query or GroupedQuery(self.model) + + def group_by(self, *field_names): + obj = self._clone() + obj.query.clear_grouping() + obj.query.add_grouping(*field_names) + return obj + + +class GroupedManager(Manager): + def __init__(self): + super(GroupedManager, self).__init__() + self._queryset_class = GroupedQuerySet |