From 9415c9c20cdc7ee702a780f4b1cb05b1e4835993 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=AD=A6=E5=B3=B0?= Date: Fri, 22 Jul 2022 14:41:00 +0800 Subject: [PATCH] feature: filter by field --- models/__init__.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/models/__init__.py b/models/__init__.py index 29d5caa..8ed2b4b 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -28,6 +28,11 @@ class BaseModel(Base): res[i.name] = datetime_toString(v) return res + def fields_to_dict(self, fields): + for field in fields: + v = getattr(self, field) + print(v) + def to_obj(self, data): return self.autoset(data) @@ -73,6 +78,18 @@ class BaseModel(Base): result = (await session.execute(select(cls).where(*condition))).scalars() return [i for i in result] + @classmethod + async def query_obj_all_by_fields(cls, fields=None, desc=None, *condition): + data_orm = select(cls) + if fields: + data_orm = select(fields) + data_orm = data_orm.where(*condition) + if desc: + data_orm = data_orm.order_by(desc.desc()) + async with db.conn() as session: + result = await session.execute(data_orm) + return [dict(zip(i.keys(), i)) for i in result] + @classmethod async def query_dict_all(cls, *condition): async with db.conn() as session: @@ -123,7 +140,7 @@ class CommonModel(BaseModel): return res @classmethod - async def query_page(cls, page_num, page_size, search=None, match=None, desc=False): + async def query_page(cls, page_num, page_size, search=None, match=None, desc=False, fields=None, *args): if match is None: match = dict() if search is None: @@ -131,6 +148,8 @@ class CommonModel(BaseModel): condition_list = [] total_orm = select(func.count(cls.id)) data_orm = select(cls) + if fields: + data_orm = select(fields) search_set = set(search) match_set = set(match) table_set = set([i.name for i in cls.__table__.columns]) @@ -156,7 +175,6 @@ class CommonModel(BaseModel): condition_list.append(getattr(cls, i).in_(v.split(','))) else: condition_list.append(getattr(cls, i) == v) - if match: query_list = match_set & table_set for i in query_list: @@ -164,6 +182,8 @@ class CommonModel(BaseModel): if v: condition_list.append(getattr(cls, i).contains(v)) + if len(args) > 0: + condition_list.extend(args) if len(condition_list) > 0: total_orm = total_orm.where(*condition_list) data_orm = data_orm.where(*condition_list) @@ -173,8 +193,12 @@ class CommonModel(BaseModel): async with db.conn() as session: total = (await session.execute(total_orm)).scalar() - res = (await session.execute(data_orm)).scalars() - data = [i.to_dict() for i in res] + if fields: + res = await session.execute(data_orm) + data = [dict(zip(i.keys(), i)) for i in res] + else: + res = (await session.execute(data_orm)).scalars() + data = [i.to_dict() for i in res] if total % page_size == 0: total_page = total // page_size -- Gitee