aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--db.py34
-rw-r--r--models.py5
2 files changed, 32 insertions, 7 deletions
diff --git a/db.py b/db.py
index df9ceda..e75adca 100644
--- a/db.py
+++ b/db.py
@@ -1,18 +1,40 @@
+from django.db import connections
from django.db.models import Manager
from django.db.models.query import QuerySet
+from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.query import Query
+from django.db.models.sql.where import WhereNode
+
+
+class GroupedCompiler(SQLCompiler):
+ def get_group_by(self, select, order_by):
+ result = super(GroupedCompiler, self).get_group_by(select, order_by)
+ expressions = []
+ for expr in self.query.real_group_by:
+ ref = expr if hasattr(expr, "as_sql") else self.query.resolve_ref(expr)
+ sql, params = self.compile(ref)
+ result.append((sql, params))
+
+ return result
class GroupedQuery(Query):
- def add_grouping(self, *grouping):
- if self.group_by is None:
- self.group_by = []
+ def __init__(self, model, where=WhereNode):
+ super(GroupedQuery, self).__init__(model, where)
+ self.real_group_by = []
- if isinstance(self.group_by, list):
- self.group_by.extend(grouping)
+ def add_grouping(self, *grouping):
+ self.real_group_by.extend(grouping)
def clear_grouping(self):
- self.group_by = None
+ self.real_group_by = []
+
+ def get_compiler(self, using=None, connection=None):
+ if using is None and connection is None:
+ raise ValueError("Need either using or connection")
+ if using:
+ connection = connections[using]
+ return GroupedCompiler(self, connection, using)
class GroupedQuerySet(QuerySet):
diff --git a/models.py b/models.py
index 64fb675..26bb63d 100644
--- a/models.py
+++ b/models.py
@@ -4,6 +4,8 @@ from django.db.models.expressions import RawSQL
from django.db.models.functions import Extract, ExtractYear
from django.utils.text import slugify
+from .db import GroupedManager
+
import hashlib
import os
@@ -111,7 +113,7 @@ class Room(models.Model):
verbose_name_plural = "salles"
-class CourseManager(models.Manager):
+class CourseManager(GroupedManager):
def __get_weeks(self, qs):
extractYear = ExtractYear("begin")
@@ -130,6 +132,7 @@ class CourseManager(models.Manager):
qs = self.get_courses_for_group(group, **filters)
return self.__get_weeks(qs)
+
class Course(models.Model):
objects = CourseManager()