odps.df.expr.groupby 源代码

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#      http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
import random

from ...models import TableSchema
from .expressions import Expr, CollectionExpr, BooleanSequenceExpr, \
    Column, SequenceExpr, Scalar, BooleanScalar, repr_obj
from .collections import SortedExpr, ReshuffledCollectionExpr
from .errors import ExpressionError
from . import utils
from ...compat import reduce, six
from .. import types
from ..utils import is_constant_scalar
from ...utils import object_getattr, camel_to_underline

class BaseGroupBy(Expr):
    __slots__ = '_to_agg', '_by_names'
    _args = '_input', '_by'

    def _init(self, *args, **kwargs):
        self._init_attr('_to_agg', None)
        super(BaseGroupBy, self)._init(*args, **kwargs)
        if isinstance(self._by, list):
            self._by = self._input._get_fields(self._by)
            self._by = [self._input._get_field(self._by)]
        for idx, by_field in enumerate(self._by):
            if by_field.name is None:
                new_field_name = '_%s_%d' % (camel_to_underline(type(by_field).__name__),
                                             random.randint(10000, 99999))
                self._by[idx] = by_field.rename(new_field_name)
        if self._to_agg is None:
            self._to_agg = self._input.schema

    def __getitem__(self, item):
        if isinstance(item, six.string_types):
            if item in self._to_agg:
                return SequenceGroupBy(_input=self, _name=item,
                raise KeyError('Fail to get group by field, unknown field: %s' % repr_obj(item))

        is_field = lambda it: isinstance(it, six.string_types) or isinstance(it, Column)
        if not all(is_field(it) for it in item):
            raise TypeError('Fail to get group by fields, unknown type: %s' % type(item))
        if any(col.is_renamed() for col in item if isinstance(col, Column)):
            raise ValueError('Fail to get group by fields, column cannot be renamed')

        get_name = lambda it: it if isinstance(it, six.string_types) else it.source_name
        _to_agg = type(self._input.schema)(
            columns=self._input.schema[[get_name(field) for field in item
                                        if get_name(field) in self._to_agg]])

        return GroupBy(_input=self._input, _by=self._by,
                       _by_names=getattr(self, '_by_names', None), _to_agg=_to_agg)

    def __getattr__(self, attr):
            return object.__getattribute__(self, attr)
        except AttributeError as e:
            if attr.startswith('__'):
                raise e

            agg = object.__getattribute__(self, '_to_agg')
            if agg is not None and attr in agg:
                return self[attr]

            raise e

    def sort_values(self, by, ascending=True):
        if hasattr(self, '_having') and self._having is not None:
            raise ExpressionError('Cannot sort GroupBy with `having`')

        if not isinstance(by, list):
            by = [by, ]
        by = [self._defunc(it) for it in by]

        attr_values = dict((attr, object_getattr(self, attr, None))
                           for attr in utils.get_attrs(self))
        attr_values['_sorted_fields'] = by
        attr_values['_ascending'] = ascending
        attr_values.pop('_having', None)

        return SortedGroupBy(**attr_values)

    def sort(self, *args, **kwargs):
        return self.sort_values(*args, **kwargs)

    def mutate(self, *windows, **kw):
        if hasattr(self, '_having') and self._having is not None:
            raise ExpressionError('Cannot mutate GroupBy with `having`')

        if len(windows) == 1 and isinstance(windows[0], list):
            windows = windows[0]
            windows = list(windows)
        windows = [self._defunc(win) for win in windows]
        if kw:
                            for new_name, win in six.iteritems(kw)])

        from .window import Window

        if not windows:
            raise ValueError('Cannot mutate on grouped data')
        if not all(isinstance(win, Window) for win in windows):
            raise TypeError('Only window functions can be provided')

        win_field_names = filter(lambda it: it is not None,
                                 [win.source_name for win in windows])
        if not frozenset(win_field_names).issubset(self._to_agg.names):
            for agg_field_name in win_field_names:
                if agg_field_name not in self._to_agg:
                    raise ValueError('Unknown field to aggregate: %s' % repr_obj(agg_field_name))

        names = [by.name for by in self._by if isinstance(by, Column)] + \
                [win.name for win in windows]
        types = [by._data_type for by in self._by if isinstance(by, Column)] + \
                [win._data_type for win in windows]

        return MutateCollectionExpr(
            _schema=TableSchema.from_lists(names, types)

    def apply(self, func, names=None, types=None, resources=None, args=(), **kwargs):
        reshuffled = ReshuffledCollectionExpr(_input=self, _schema=self._input._schema)
        return reshuffled.apply(axis=1, func=func, names=names, types=types,
                                resources=resources, args=args, **kwargs)

[文档] class GroupBy(BaseGroupBy): __slots__ = '_having', def _init(self, *args, **kwargs): self._init_attr('_having', None) super(GroupBy, self)._init(*args, **kwargs) def _same_by(self, other): if other._input is not self._input: return False if len(self._by) != len(other._by): return False if any(x is not y for x, y in zip(self._by, other._by)): return False return True def _validate_agg(self, agg): from .reduction import GroupedSequenceReduction from .window import RankOp has_reduction = False for node in agg.traverse(top_down=True, unique=True, stop_cond=lambda x: x is self._input): if isinstance(node, GroupedSequenceReduction): has_reduction = True if not self._same_by(node._grouped): raise ExpressionError( 'Aggregation has not been applied to the right GroupBy, got: %s' % repr_obj(agg)) elif isinstance(node, Column): if node._input is not self._input: raise ExpressionError( 'Aggregation should be applied to the column of %s' % repr_obj(self._input)) elif isinstance(node, RankOp) and node._input is not self._input: raise ExpressionError( 'Aggregation should be applied to the column of %s' % repr_obj(self._input)) if not has_reduction: raise ExpressionError('No aggregation found in %s' % repr_obj(agg)) def _transform(self, reduction_expr): if isinstance(reduction_expr, Scalar): from .reduction import SequenceReduction from .window import RankOp dag = reduction_expr.to_dag(copy=False, validate=False) for node in dag.traverse( stop_cond=lambda x: isinstance(x, (Column, RankOp)) or x is self._input): if isinstance(node, SequenceReduction): to_sub = node.to_grouped_reduction(self) dag.substitute(node, to_sub) elif isinstance(node, Scalar) and not is_constant_scalar(node) \ and len(node.children()) > 0: to_sub = node.to_sequence() dag.substitute(node, to_sub) elif isinstance(node, Column) and node._input is not self._input: field = self._input._get_field(node) if field: dag.substitute(node, field) elif isinstance(node, RankOp) and node._input is not self._input: dag.substitute(node._input, self._input, parents=[node]) return dag.root return reduction_expr def __repr__(self): return object.__repr__(self) def __getitem__(self, item): item = self._defunc(item) if isinstance(item, (BooleanSequenceExpr, BooleanScalar)): having = item if isinstance(item, BooleanSequenceExpr) \ else self._transform(item) if self._having is not None: having = having & self._having return GroupBy(_input=self._input, _by=self._by, _to_agg=self._to_agg, _having=having) else: return super(GroupBy, self).__getitem__(item) def filter(self, *predicates): predicates = [self._defunc(it) for it in predicates] predicate = reduce(operator.and_, predicates) return self[predicate] def aggregate(self, *aggregations, **kw): sort_by_name = kw.pop('sort_by_name', True) if len(aggregations) == 1 and isinstance(aggregations[0], list): aggregations = aggregations[0] else: aggregations = list(aggregations) aggregations = [self._defunc(it) for it in aggregations] if kw: aggregations.extend([self._defunc(agg).rename(new_name) for new_name, agg in six.iteritems(kw)]) # keep sequence to ensure that test cases work well if sort_by_name: aggregations = sorted([self._transform(agg) for agg in aggregations], key=lambda it: it.name) else: aggregations = [self._transform(agg) for agg in aggregations] if not aggregations: raise ValueError('Cannot aggregate on grouped data') [self._validate_agg(agg) for agg in aggregations] names = [by.name for by in self._by if isinstance(by, (Scalar, SequenceExpr)) and by.name is not None] + \ [agg.name for agg in aggregations] types = [by.dtype for by in self._by if isinstance(by, (Scalar, SequenceExpr)) and by.name is not None] + \ [agg._data_type for agg in aggregations] return GroupByCollectionExpr( _input=self, _aggregations=aggregations, _schema=TableSchema.from_lists(names, types) ) def agg(self, *args, **kwargs): return self.aggregate(*args, **kwargs)
[文档] class SequenceGroupBy(Expr): __slots__ = '_name', '_data_type', '_source_data_type' _args = '_input', def _init(self, *args, **kwargs): self._init_attr('_data_type', None) self._init_attr('_source_data_type', None) super(SequenceGroupBy, self)._init(*args, **kwargs) def __new__(cls, *args, **kwargs): data_type = kwargs.get('_data_type') if data_type: cls_name = data_type.__class__.__name__ + SequenceGroupBy.__name__ clazz = globals()[cls_name] return super(SequenceGroupBy, clazz).__new__(clazz) else: return super(SequenceGroupBy, cls).__new__(cls) @property def name(self): return self._name @property def dtype(self): return self._data_type @property def input(self): return self._input def astype(self, data_type): data_type = types.validate_data_type(data_type) if data_type == self._data_type: return self attr_dict = dict((k, getattr(self, k, None)) for k in utils.get_attrs(self)) attr_dict['_data_type'] = data_type attr_dict['_source_data_type'] = self._source_data_type cls = globals().get(repr(data_type).capitalize() + SequenceGroupBy.__name__) new_sequence_groupby = cls(**attr_dict) return new_sequence_groupby def is_astyped(self): if self._source_data_type is None: return False return self._data_type != self._source_data_type def to_column(self): collection = self.input.input input = collection[self.name] if self.is_astyped(): input = input.astype(self._data_type) return input
class BooleanSequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(BooleanSequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.boolean class Int8SequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(Int8SequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.int8 class Int16SequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(Int16SequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.int16 class Int32SequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(Int32SequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.int32
[文档] class Int64SequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(Int64SequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.int64
class Float32SequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(Float32SequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.float32 class Float64SequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(Float64SequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.float64 class DecimalSequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(DecimalSequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.decimal class StringSequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(StringSequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.string class BinarySequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(BinarySequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.binary class DatetimeSequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(DatetimeSequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.datetime class DateSequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(DateSequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.datetime class TimestampSequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(TimestampSequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.datetime class UnknownSequenceGroupBy(SequenceGroupBy): def _init(self, *args, **kwargs): super(UnknownSequenceGroupBy, self)._init(*args, **kwargs) self._data_type = types.Unknown() class SortedGroupBy(BaseGroupBy, SortedExpr): __slots__ = '_sorted_fields', '_ascending' class GroupByCollectionExpr(CollectionExpr): _args = '_input', '_by', '_having', '_aggregations', '_fields' node_name = 'GroupBy' def _init(self, *args, **kwargs): self._init_attr('_fields', None) super(GroupByCollectionExpr, self)._init(*args, **kwargs) if isinstance(self._input, GroupBy): self._by = self._input._by self._having = self._input._having self._input = self._input._input def iter_args(self): arg_names = ['collection', 'bys', 'having', 'aggregations'] for it in zip(arg_names, self.args): yield it if self._fields is not None: yield ('selections', self._fields) def _name_to_exprs(self): if hasattr(self, '_fields') and self._fields is not None: exprs = self._fields else: exprs = self.args[1] + self.args[3] return dict((expr.name, expr) for expr in exprs if hasattr(expr, 'name')) @property def input(self): return self._input def accept(self, visitor): return visitor.visit_groupby(self) @property def fields(self): if self._fields is not None: return self._fields return self._by + self._aggregations class MutateCollectionExpr(CollectionExpr): _args = '_input', '_by', '_window_fields', '_fields' node_name = 'Mutate' def _init(self, *args, **kwargs): self._init_attr('_fields', None) super(MutateCollectionExpr, self)._init(*args, **kwargs) if isinstance(self._input, GroupBy): self._by = self._input._by self._input = self._input._input @property def _project_fields(self): return self._window_fields def iter_args(self): for it in zip(['collection', 'bys', 'mutates'], self.args): yield it @property def input(self): return self._input @property def fields(self): if self._fields is not None: return self._fields return self._by + self._window_fields def accept(self, visitor): return visitor.visit_mutate(self) def groupby(expr, by, *bys): """ Group collection by a series of sequences. :param expr: collection :param by: columns to group :param bys: columns to group :return: GroupBy instance :rtype: :class:`odps.df.expr.groupby.GroupBy` """ if not isinstance(by, list): by = [by, ] if len(bys) > 0: by = by + list(bys) return GroupBy(_input=expr, _by=by) CollectionExpr.groupby = groupby class ValueCounts(CollectionExpr): _args = '_input', '_by', '_sort', '_ascending', '_dropna' node_name = 'ValueCounts' def _init(self, *args, **kwargs): super(ValueCounts, self)._init(*args, **kwargs) if isinstance(self._input, SequenceExpr): self._by = self._input self._input = next(it for it in self._input.traverse(top_down=True) if isinstance(it, CollectionExpr)) if isinstance(self._sort, bool): self._sort = Scalar(_value=self._sort) if isinstance(self._ascending, bool): self._ascending = Scalar(_value=self._ascending) if isinstance(self._dropna, bool): self._dropna = Scalar(_value=self._dropna) def iter_args(self): for it in zip(['collection', 'by', 'sort', 'ascending', 'dropna'], self.args): yield it @property def input(self): return self._input def accept(self, visitor): return visitor.visit_value_counts(self) def value_counts(expr, sort=True, ascending=False, dropna=False): """ Return object containing counts of unique values. The resulting object will be in descending order so that the first element is the most frequently-occuring element. Exclude NA values by default :param expr: sequence :param sort: if sort :type sort: bool :param dropna: Don’t include counts of None, default False :return: collection with two columns :rtype: :class:`odps.df.expr.expressions.CollectionExpr` """ names = [expr.name, 'count'] typos = [expr.dtype, types.int64] return ValueCounts(_input=expr, _schema=TableSchema.from_lists(names, typos), _sort=sort, _ascending=ascending, _dropna=dropna) def topk(expr, k): return expr.value_counts().limit(k) SequenceExpr.value_counts = value_counts SequenceExpr.topk = topk