Source code for odps.tunnel.io.writer

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2026 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import struct

try:
    import pyarrow as pa
except (AttributeError, ImportError):
    pa = None
try:
    import pyarrow.compute as pac
except (AttributeError, ImportError):
    pac = None
try:
    import numpy as np
except ImportError:
    np = None
try:
    import pandas as pd
except (ImportError, ValueError):
    pd = None

from ... import compat, options, types, utils
from ...compat import Decimal, Enum, futures, six
from ...lib.monotonic import monotonic
from ..base import TunnelMetrics
from ..checksum import Checksum
from ..errors import TunnelError
from ..pb.encoder import Encoder
from ..pb.wire_format import (
    WIRETYPE_FIXED32,
    WIRETYPE_FIXED64,
    WIRETYPE_LENGTH_DELIMITED,
    WIRETYPE_VARINT,
)
from ..wireconstants import ProtoWireConstants
from .stream import RequestsIO, get_compress_stream
from .types import odps_schema_to_arrow_schema

try:
    if not options.force_py:
        from ..hasher_c import RecordHasher
        from .writer_c import BaseRecordWriter
    else:
        from ..hasher import RecordHasher

        BaseRecordWriter = None
except ImportError as e:
    if options.force_c:
        raise e

    from ..hasher import RecordHasher

    BaseRecordWriter = None

MICRO_SEC_PER_SEC = 1000000

varint_tag_types = types.integer_types + (
    types.boolean,
    types.datetime,
    types.date,
    types.interval_year_month,
)
length_delim_tag_types = (
    types.string,
    types.binary,
    types.timestamp,
    types.timestamp_ntz,
    types.interval_day_time,
    types.json,
)


if BaseRecordWriter is None:

    class ProtobufWriter(object):
        """
        ProtobufWriter is a stream-interface wrapper around encoder_c.Encoder(c)
        and encoder.Encoder(py)
        """

        DEFAULT_BUFFER_SIZE = 4096

        def __init__(self, output, buffer_size=None):
            self._encoder = Encoder()
            self._output = output
            self._buffer_size = buffer_size or self.DEFAULT_BUFFER_SIZE
            self._n_total = 0

        def _re_init(self, output):
            self._encoder = Encoder()
            self._output = output
            self._n_total = 0

        def _mode(self):
            return "py"

        def flush(self):
            if len(self._encoder) > 0:
                data = self._encoder.tostring()
                self._output.write(data)
                self._n_total += len(self._encoder)
                self._encoder = Encoder()

        def close(self):
            self.flush_all()

        def flush_all(self):
            self.flush()
            self._output.flush()

        def _refresh_buffer(self):
            """Control the buffer size of _encoder. Flush if necessary"""
            if len(self._encoder) > self._buffer_size:
                self.flush()

        @property
        def n_bytes(self):
            return self._n_total + len(self._encoder)

        def __len__(self):
            return self.n_bytes

        def _write_tag(self, field_num, wire_type):
            self._encoder.append_tag(field_num, wire_type)
            self._refresh_buffer()

        def _write_raw_long(self, val):
            self._encoder.append_sint64(val)
            self._refresh_buffer()

        def _write_raw_int(self, val):
            self._encoder.append_sint32(val)
            self._refresh_buffer()

        def _write_raw_uint(self, val):
            self._encoder.append_uint32(val)
            self._refresh_buffer()

        def _write_raw_bool(self, val):
            self._encoder.append_bool(val)
            self._refresh_buffer()

        def _write_raw_float(self, val):
            self._encoder.append_float(val)
            self._refresh_buffer()

        def _write_raw_double(self, val):
            self._encoder.append_double(val)
            self._refresh_buffer()

        def _write_raw_string(self, val):
            self._encoder.append_string(val)
            self._refresh_buffer()

    class BaseRecordWriter(ProtobufWriter):
        def __init__(self, schema, out, encoding="utf-8"):
            self._encoding = encoding
            self._schema = schema
            self._columns = self._schema.columns
            self._crc = Checksum()
            self._crccrc = Checksum()
            self._curr_cursor = 0
            self._to_milliseconds = utils.MillisecondsConverter().to_milliseconds
            self._to_milliseconds_utc = utils.MillisecondsConverter(
                local_tz=False
            ).to_milliseconds
            self._to_days = utils.to_days

            self._enable_client_metrics = options.tunnel.enable_client_metrics
            self._local_wall_time_ms = 0

            super(BaseRecordWriter, self).__init__(out)

        def write(self, record):
            n_record_fields = len(record)
            n_columns = len(self._columns)

            if self._enable_client_metrics:
                ts = monotonic()

            if n_record_fields > n_columns:
                raise IOError("record fields count is more than schema.")

            for i in range(min(n_record_fields, n_columns)):
                if self._schema.is_partition(self._columns[i]):
                    continue

                val = record[i]
                if val is None:
                    continue

                pb_index = i + 1
                self._crc.update_int(pb_index)

                data_type = self._columns[i].type
                if data_type in varint_tag_types:
                    self._write_tag(pb_index, WIRETYPE_VARINT)
                elif data_type == types.float_:
                    self._write_tag(pb_index, WIRETYPE_FIXED32)
                elif data_type == types.double:
                    self._write_tag(pb_index, WIRETYPE_FIXED64)
                elif data_type in length_delim_tag_types:
                    self._write_tag(pb_index, WIRETYPE_LENGTH_DELIMITED)
                elif isinstance(
                    data_type,
                    (
                        types.Char,
                        types.Varchar,
                        types.Decimal,
                        types.Array,
                        types.Map,
                        types.Struct,
                    ),
                ):
                    self._write_tag(pb_index, WIRETYPE_LENGTH_DELIMITED)
                else:
                    raise IOError("Invalid data type: %s" % data_type)
                self._write_field(val, data_type)

            checksum = utils.long_to_int(self._crc.getvalue())
            self._write_tag(ProtoWireConstants.TUNNEL_END_RECORD, WIRETYPE_VARINT)
            self._write_raw_uint(utils.long_to_uint(checksum))
            self._crc.reset()
            self._crccrc.update_int(checksum)
            self._curr_cursor += 1

            if self._enable_client_metrics:
                self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts))

        def _write_bool(self, data):
            self._crc.update_bool(data)
            self._write_raw_bool(data)

        def _write_long(self, data):
            self._crc.update_long(data)
            self._write_raw_long(data)

        def _write_float(self, data):
            self._crc.update_float(data)
            self._write_raw_float(data)

        def _write_double(self, data):
            self._crc.update_double(data)
            self._write_raw_double(data)

        def _write_string(self, data):
            if isinstance(data, six.text_type):
                data = data.encode(self._encoding)
            self._crc.update(data)
            self._write_raw_string(data)

        def _write_timestamp(self, data, ntz=False):
            to_mills = self._to_milliseconds_utc if ntz else self._to_milliseconds
            t_val = int(to_mills(data.to_pydatetime(warn=False)) / 1000)
            nano_val = data.microsecond * 1000 + data.nanosecond
            self._crc.update_long(t_val)
            self._write_raw_long(t_val)
            self._crc.update_int(nano_val)
            self._write_raw_int(nano_val)

        def _write_interval_day_time(self, data):
            t_val = data.days * 3600 * 24 + data.seconds
            nano_val = data.microseconds * 1000 + data.nanoseconds
            self._crc.update_long(t_val)
            self._write_raw_long(t_val)
            self._crc.update_int(nano_val)
            self._write_raw_int(nano_val)

        def _write_array(self, data, data_type):
            for value in data:
                if value is None:
                    self._write_raw_bool(True)
                else:
                    self._write_raw_bool(False)
                    self._write_field(value, data_type)

        def _write_struct(self, data, data_type):
            if isinstance(data, dict):
                vals = [None] * len(data)
                for idx, key in enumerate(data_type.field_types.keys()):
                    vals[idx] = data[key]
                data = tuple(vals)
            for value, typ in zip(data, data_type.field_types.values()):
                if value is None:
                    self._write_raw_bool(True)
                else:
                    self._write_raw_bool(False)
                    self._write_field(value, typ)

        def _write_field(self, val, data_type):
            if data_type == types.boolean:
                self._write_bool(val)
            elif data_type == types.datetime:
                val = self._to_milliseconds(val)
                self._write_long(val)
            elif data_type == types.date:
                val = self._to_days(val)
                self._write_long(val)
            elif data_type == types.float_:
                self._write_float(val)
            elif data_type == types.double:
                self._write_double(val)
            elif data_type in types.integer_types:
                self._write_long(val)
            elif data_type == types.string:
                self._write_string(val)
            elif data_type == types.binary:
                self._write_string(val)
            elif data_type == types.timestamp or data_type == types.timestamp_ntz:
                self._write_timestamp(val, ntz=data_type == types.timestamp_ntz)
            elif data_type == types.interval_day_time:
                self._write_interval_day_time(val)
            elif data_type == types.interval_year_month:
                self._write_long(val.total_months())
            elif isinstance(data_type, (types.Char, types.Varchar)):
                self._write_string(val)
            elif isinstance(data_type, types.Decimal):
                self._write_string(str(val))
            elif isinstance(data_type, types.Json):
                self._write_string(json.dumps(val))
            elif isinstance(data_type, types.Array):
                self._write_raw_uint(len(val))
                self._write_array(val, data_type.value_type)
            elif isinstance(data_type, types.Map):
                self._write_raw_uint(len(val))
                self._write_array(compat.lkeys(val), data_type.key_type)
                self._write_raw_uint(len(val))
                self._write_array(compat.lvalues(val), data_type.value_type)
            elif isinstance(data_type, types.Struct):
                self._write_struct(val, data_type)
            else:
                raise IOError("Invalid data type: %s" % data_type)

        @property
        def count(self):
            return self._curr_cursor

        def _write_finish_tags(self):
            self._write_tag(ProtoWireConstants.TUNNEL_META_COUNT, WIRETYPE_VARINT)
            self._write_raw_long(self.count)
            self._write_tag(ProtoWireConstants.TUNNEL_META_CHECKSUM, WIRETYPE_VARINT)
            self._write_raw_uint(utils.long_to_uint(self._crccrc.getvalue()))

        def close(self):
            self._write_finish_tags()
            super(BaseRecordWriter, self).close()
            self._curr_cursor = 0

        def __enter__(self):
            return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            # if an error occurs inside the with block, we do not commit
            if exc_val is not None:
                return
            self.close()


[docs] class RecordWriter(BaseRecordWriter): """ Writer object to write data to ODPS with records. Should be created with :meth:`TableUploadSession.open_record_writer` with ``block_id`` specified. :Example: Here we show an example of writing data to ODPS with two records created in different ways. .. code-block:: python from odps.tunnel import TableTunnel tunnel = TableTunnel(o) upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test') # creates a RecordWriter instance for block 0 with upload_session.open_record_writer(0) as writer: record = upload_session.new_record() record[0] = 'test1' record[1] = 'id1' writer.write(record) record = upload_session.new_record(['test2', 'id2']) writer.write(record) # commit block 0 upload_session.commit([0]) :Note: ``RecordWriter`` holds long HTTP connection which might be closed at server end when the duration is over 3 minutes. Please avoid opening ``RecordWriter`` for a long period. Details can be found :ref:`here <tunnel>`. """ def __init__( self, schema, request_callback, compress_option=None, encoding="utf-8" ): self._enable_client_metrics = options.tunnel.enable_client_metrics self._server_metrics_string = None if self._enable_client_metrics: ts = monotonic() self._req_io = RequestsIO( request_callback, chunk_size=options.chunk_size, record_io_time=self._enable_client_metrics, ) out = get_compress_stream(self._req_io, compress_option) super(RecordWriter, self).__init__(schema, out, encoding=encoding) self._req_io.start() if self._enable_client_metrics: self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts)) @property def metrics(self): if self._server_metrics_string is None: return None return TunnelMetrics.from_server_json( type(self).__name__, self._server_metrics_string, self._local_wall_time_ms, self._req_io.io_time_ms, )
[docs] def write(self, record): """ Write a record to the tunnel. :param record: record to write :type record: :class:`odps.models.Record` """ if self._req_io._async_err: ex_type, ex_value, tb = self._req_io._async_err six.reraise(ex_type, ex_value, tb) super(RecordWriter, self).write(record)
[docs] def close(self): """ Close the writer and flush all data to server. """ if self._enable_client_metrics: ts = monotonic() super(RecordWriter, self).close() resp = self._req_io.finish() if self._enable_client_metrics: self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts)) self._server_metrics_string = resp.headers.get("odps-tunnel-metrics")
def get_total_bytes(self): return self.n_bytes
[docs] class BufferedRecordWriter(BaseRecordWriter): """ Writer object to write data to ODPS with records. Should be created with :meth:`TableUploadSession.open_record_writer` without ``block_id``. Results should be submitted with :meth:`TableUploadSession.commit` with returned value from :meth:`get_blocks_written`. :Example: Here we show an example of writing data to ODPS with two records created in different ways. .. code-block:: python from odps.tunnel import TableTunnel tunnel = TableTunnel(o) upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test') # creates a BufferedRecordWriter instance with upload_session.open_record_writer() as writer: record = upload_session.new_record() record[0] = 'test1' record[1] = 'id1' writer.write(record) record = upload_session.new_record(['test2', 'id2']) writer.write(record) # commit blocks upload_session.commit(writer.get_blocks_written()) """ def __init__( self, schema, request_callback, compress_option=None, encoding="utf-8", buffer_size=None, block_id=None, block_id_gen=None, ): self._request_callback = request_callback self._block_id = block_id or 0 self._blocks_written = [] self._buffer = compat.BytesIO() self._n_bytes_written = 0 self._compress_option = compress_option self._block_id_gen = block_id_gen self._enable_client_metrics = options.tunnel.enable_client_metrics self._server_metrics_string = None self._network_wall_time_ms = 0 if not self._enable_client_metrics: self._accumulated_metrics = None else: self._accumulated_metrics = TunnelMetrics(type(self).__name__) ts = monotonic() out = get_compress_stream(self._buffer, compress_option) super(BufferedRecordWriter, self).__init__(schema, out, encoding=encoding) # make sure block buffer size is applied correctly here self._buffer_size = buffer_size or options.tunnel.block_buffer_size if self._enable_client_metrics: self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts)) @property def cur_block_id(self): return self._block_id def _get_next_block_id(self): if callable(self._block_id_gen): return self._block_id_gen() return self._block_id + 1
[docs] def write(self, record): """ Write a record to the tunnel. :param record: record to write :type record: :class:`odps.models.Record` """ super(BufferedRecordWriter, self).write(record) if 0 < self._buffer_size < self._n_raw_bytes: self._flush()
[docs] def close(self): """ Close the writer and flush all data to server. """ if self._n_raw_bytes > 0: self._flush() self.flush_all() self._buffer.close()
def _collect_metrics(self): if not self._enable_client_metrics: return if self._server_metrics_string is not None: self._accumulated_metrics += TunnelMetrics.from_server_json( type(self).__name__, self._server_metrics_string, self._local_wall_time_ms, self._network_wall_time_ms, ) self._server_metrics_string = None self._local_wall_time_ms = 0 self._network_wall_time_ms = 0 def _reset_writer(self, write_response): self._collect_metrics() if self._enable_client_metrics: ts = monotonic() self._buffer = compat.BytesIO() out = get_compress_stream(self._buffer, self._compress_option) self._re_init(out) self._curr_cursor = 0 self._crccrc.reset() self._crc.reset() if self._enable_client_metrics: self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts)) def _send_buffer(self): if self._enable_client_metrics: ts = monotonic() resp = self._request_callback(self._block_id, self._buffer.getvalue()) if self._enable_client_metrics: self._network_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts)) return resp def _flush(self): if self._enable_client_metrics: ts = monotonic() self._write_finish_tags() self._n_bytes_written += self._n_raw_bytes self.flush_all() resp = self._send_buffer() self._server_metrics_string = resp.headers.get("odps-tunnel-metrics") self._blocks_written.append(self._block_id) self._block_id = self._get_next_block_id() if self._enable_client_metrics: self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts)) self._reset_writer(resp) @property def metrics(self): return self._accumulated_metrics @property def _n_raw_bytes(self): return super(BufferedRecordWriter, self).n_bytes @property def n_bytes(self): return self._n_bytes_written + self._n_raw_bytes def get_total_bytes(self): return self.n_bytes
[docs] def get_blocks_written(self): """ Get block ids created during writing. Should be provided as the argument to :meth:`TableUploadSession.commit`. """ return self._blocks_written
# make sure original typo class also referable BufferredRecordWriter = BufferedRecordWriter class StreamRecordWriter(BufferedRecordWriter): def __init__( self, schema, request_callback, session, slot, compress_option=None, encoding="utf-8", buffer_size=None, ): self.session = session self.slot = slot self._record_count = 0 super(StreamRecordWriter, self).__init__( schema, request_callback, compress_option=compress_option, encoding=encoding, buffer_size=buffer_size, ) def write(self, record): super(StreamRecordWriter, self).write(record) self._record_count += 1 def _reset_writer(self, write_response): self._record_count = 0 slot_server = write_response.headers["odps-tunnel-routed-server"] slot_num = int(write_response.headers["odps-tunnel-slot-num"]) if self._enable_client_metrics: ts = monotonic() self.session.reload_slots(self.slot, slot_server, slot_num) if self._enable_client_metrics: time_cost = int(MICRO_SEC_PER_SEC * (monotonic() - ts)) self._local_wall_time_ms += time_cost self._network_wall_time_ms += time_cost super(StreamRecordWriter, self)._reset_writer(write_response) def _send_buffer(self): if self._enable_client_metrics: ts = monotonic() def gen(): # synchronize chunk upload data = self._buffer.getvalue() chunk_size = options.chunk_size while data: to_send = data[:chunk_size] data = data[chunk_size:] yield to_send try: return self._request_callback(gen()) finally: if self._enable_client_metrics: self._network_wall_time_ms += int( MICRO_SEC_PER_SEC * (monotonic() - ts) ) class BaseArrowWriter(object): def __init__(self, schema, out=None, chunk_size=None): if pa is None: raise ValueError("To use arrow writer you need to install pyarrow") self._schema = schema self._arrow_schema = odps_schema_to_arrow_schema(schema) self._chunk_size = chunk_size or options.chunk_size self._crc = Checksum() self._crccrc = Checksum() self._cur_chunk_size = 0 self._output = out self._chunk_size_written = False self._pd_mappers = self._build_pd_mappers() def _re_init(self, output): self._output = output self._chunk_size_written = False self._cur_chunk_size = 0 def _write_chunk_size(self): self._write_uint32(self._chunk_size) def _write_uint32(self, val): data = struct.pack("!I", utils.long_to_uint(val)) self._output.write(data) def _write_chunk(self, buf): if not self._chunk_size_written: self._write_chunk_size() self._chunk_size_written = True self._output.write(buf) self._crc.update(buf) self._crccrc.update(buf) self._cur_chunk_size += len(buf) if self._cur_chunk_size >= self._chunk_size: checksum = self._crc.getvalue() self._write_uint32(checksum) self._crc.reset() self._cur_chunk_size = 0 @classmethod def _localize_timezone(cls, col, tz=None): from ...lib import tzlocal if tz is None: if options.local_timezone is True or options.local_timezone is None: tz = str(tzlocal.get_localzone()) elif options.local_timezone is False: tz = "UTC" else: tz = str(options.local_timezone) if col.type.tz is not None: return col if hasattr(pac, "assume_timezone") and isinstance(tz, str): # pyarrow.compute.assume_timezone only accepts # string-represented zones col = pac.assume_timezone(col, timezone=tz) return col else: pd_col = col.to_pandas().dt.tz_localize(tz) return pa.Array.from_pandas(pd_col) @classmethod def _str_to_decimal_array(cls, col, dec_type): dec_col = col.to_pandas().map(Decimal) return pa.Array.from_pandas(dec_col, type=dec_type) def _build_pd_mappers(self): pa_dec_types = (pa.Decimal128Type,) if hasattr(pa, "Decimal256Type"): pa_dec_types += (pa.Decimal256Type,) def _need_cast(arrow_type): if isinstance(arrow_type, (pa.MapType, pa.StructType) + pa_dec_types): return True elif isinstance(arrow_type, pa.ListType): return _need_cast(arrow_type.value_type) else: return False def _build_mapper(cur_type): if isinstance(cur_type, pa.MapType): key_mapper = _build_mapper(cur_type.key_type) value_mapper = _build_mapper(cur_type.item_type) def mapper(data): if isinstance(data, dict): return [ (key_mapper(k), value_mapper(v)) for k, v in data.items() ] else: return data elif isinstance(cur_type, pa.ListType): item_mapper = _build_mapper(cur_type.value_type) def mapper(data): if data is None: return data return [item_mapper(element) for element in data] elif isinstance(cur_type, pa.StructType): val_mappers = dict() for fid in range(cur_type.num_fields): field = cur_type[fid] val_mappers[field.name.lower()] = _build_mapper(field.type) def mapper(data): if isinstance(data, (list, tuple)): field_names = getattr(data, "_fields", None) or [ cur_type[fid].name for fid in range(cur_type.num_fields) ] data = dict(zip(data, field_names)) if isinstance(data, dict): fields = dict() for key, val in data.items(): fields[key] = val_mappers[key.lower()](val) data = fields return data elif isinstance(cur_type, pa_dec_types): def mapper(data): if data is None: return None return Decimal(data) else: mapper = lambda x: x return mapper mappers = dict() for name, typ in zip(self._arrow_schema.names, self._arrow_schema.types): if _need_cast(typ): mappers[name.lower()] = _build_mapper(typ) return mappers def _convert_df_types(self, df): lower_to_df_name = {utils.to_lower_str(s): s for s in df.columns} new_cols = [] for name, typ in zip(self._arrow_schema.names, self._arrow_schema.types): df_name = lower_to_df_name[name.lower()] if df_name not in df.columns: new_cols.append(pa.array([None] * len(df), type=typ)) elif name.lower() not in self._pd_mappers: data = pa.Array.from_pandas(df[df_name]) new_cols.append(data) else: new_cols.append( pa.Array.from_pandas( df[df_name].map(self._pd_mappers[name.lower()]), type=typ ) ) return pa.Table.from_arrays(new_cols, names=self._arrow_schema.names) def write(self, data): """ Write an Arrow RecordBatch, an Arrow Table or a pandas DataFrame. """ if isinstance(data, pd.DataFrame): arrow_data = self._convert_df_types(data) elif isinstance(data, (pa.Table, pa.RecordBatch)): arrow_data = data else: raise TypeError("Cannot support writing data type %s", type(data)) arrow_decimal_types = (pa.Decimal128Type,) if hasattr(pa, "Decimal256Type"): arrow_decimal_types += (pa.Decimal256Type,) assert isinstance(arrow_data, (pa.RecordBatch, pa.Table)) if arrow_data.schema != self._arrow_schema or any( isinstance(tp, pa.TimestampType) for tp in arrow_data.schema.types ): lower_names = [n.lower() for n in arrow_data.schema.names] type_dict = dict(zip(lower_names, arrow_data.schema.types)) column_dict = dict(zip(lower_names, arrow_data.columns)) arrays = [] for name, tp in zip(self._arrow_schema.names, self._arrow_schema.types): lower_name = name.lower() if lower_name not in column_dict: raise ValueError( "Input record batch does not contain column %s" % name ) if isinstance(tp, pa.TimestampType): col = column_dict[lower_name] if not isinstance(col.type, pa.TimestampType): col = col.cast(pa.timestamp(tp.unit)) if self._schema[lower_name].type == types.timestamp_ntz: col = self._localize_timezone(col, "UTC") else: col = self._localize_timezone(col) column_dict[lower_name] = col.cast( pa.timestamp(tp.unit, col.type.tz) ) elif ( isinstance(tp, arrow_decimal_types) and isinstance(column_dict[lower_name], (pa.Array, pa.ChunkedArray)) and column_dict[lower_name].type in (pa.binary(), pa.string()) ): column_dict[lower_name] = self._str_to_decimal_array( column_dict[lower_name], tp ) if tp == type_dict[lower_name]: arrays.append(column_dict[lower_name]) else: try: arrays.append(column_dict[lower_name].cast(tp, safe=False)) except (pa.ArrowInvalid, pa.ArrowNotImplementedError): raise ValueError( "Failed to cast column %s to type %s" % (name, tp) ) pa_type = type(arrow_data) arrow_data = pa_type.from_arrays(arrays, names=self._arrow_schema.names) if isinstance(arrow_data, pa.RecordBatch): batches = [arrow_data] else: # pa.Table batches = arrow_data.to_batches() for batch in batches: data = batch.serialize().to_pybytes() written_bytes = 0 while written_bytes < len(data): length = min( self._chunk_size - self._cur_chunk_size, len(data) - written_bytes ) chunk_data = data[written_bytes : written_bytes + length] self._write_chunk(chunk_data) written_bytes += length def _write_finish_tags(self): checksum = self._crccrc.getvalue() self._write_uint32(checksum) self._crccrc.reset() def flush(self): self._output.flush() def _finish(self): self._write_finish_tags() self._output.flush() def close(self): """ Closes the writer and flush all data to server. """ self._finish() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): # if an error occurs inside the with block, we do not commit if exc_val is not None: return self.close()
[docs] class ArrowWriter(BaseArrowWriter): """ Writer object to write data to ODPS using Arrow format. Should be created with :meth:`TableUploadSession.open_arrow_writer` with ``block_id`` specified. :Example: Here we show an example of writing a pandas DataFrame to ODPS. .. code-block:: python import pandas as pd from odps.tunnel import TableTunnel tunnel = TableTunnel(o) upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test') # creates an ArrowWriter instance for block 0 with upload_session.open_arrow_writer(0) as writer: df = pd.DataFrame({'col1': ['test1', 'test2'], 'col2': ['id1', 'id2']}) writer.write(df) # commit block 0 upload_session.commit([0]) :Note: ``ArrowWriter`` holds long HTTP connection which might be closed at server end when the duration is over 3 minutes. Please avoid opening ``ArrowWriter`` for a long period. Details can be found :ref:`here <tunnel>`. """ def __init__(self, schema, request_callback, compress_option=None, chunk_size=None): self._req_io = RequestsIO(request_callback, chunk_size=chunk_size) out = get_compress_stream(self._req_io, compress_option) super(ArrowWriter, self).__init__(schema, out, chunk_size) self._req_io.start() def _finish(self): super(ArrowWriter, self)._finish() self._req_io.finish()
[docs] class BufferedArrowWriter(BaseArrowWriter): """ Writer object to write data to ODPS using Arrow format. Should be created with :meth:`TableUploadSession.open_arrow_writer` without ``block_id``. Results should be submitted with :meth:`TableUploadSession.commit` with returned value from :meth:`get_blocks_written`. :Example: Here we show an example of writing a pandas DataFrame to ODPS. .. code-block:: python import pandas as pd from odps.tunnel import TableTunnel tunnel = TableTunnel(o) upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test') # creates a BufferedArrowWriter instance with upload_session.open_arrow_writer() as writer: df = pd.DataFrame({'col1': ['test1', 'test2'], 'col2': ['id1', 'id2']}) writer.write(df) # commit blocks upload_session.commit(writer.get_blocks_written()) """ def __init__( self, schema, request_callback, compress_option=None, buffer_size=None, chunk_size=None, block_id=None, block_id_gen=None, ): self._buffer_size = buffer_size or options.tunnel.block_buffer_size self._request_callback = request_callback self._block_id = block_id or 0 self._blocks_written = [] self._buffer = compat.BytesIO() self._compress_option = compress_option self._n_bytes_written = 0 self._block_id_gen = block_id_gen out = get_compress_stream(self._buffer, compress_option) super(BufferedArrowWriter, self).__init__(schema, out, chunk_size=chunk_size) @property def cur_block_id(self): return self._block_id def _get_next_block_id(self): if callable(self._block_id_gen): return self._block_id_gen() return self._block_id + 1
[docs] def write(self, data): super(BufferedArrowWriter, self).write(data) if 0 < self._buffer_size < self._n_raw_bytes: self._flush()
[docs] def close(self): if self._n_raw_bytes > 0: self._flush() self._finish() self._buffer.close()
def _reset_writer(self): self._buffer = compat.BytesIO() out = get_compress_stream(self._buffer, self._compress_option) self._re_init(out) self._crccrc.reset() self._crc.reset() def _send_buffer(self): return self._request_callback(self._block_id, self._buffer.getvalue()) def _flush(self): self._write_finish_tags() self._n_bytes_written += self._n_raw_bytes self._send_buffer() self._blocks_written.append(self._block_id) self._block_id = self._get_next_block_id() self._reset_writer() @property def _n_raw_bytes(self): return self._buffer.tell() @property def n_bytes(self): return self._n_bytes_written + self._n_raw_bytes def get_total_bytes(self): return self.n_bytes
[docs] def get_blocks_written(self): """ Get block ids created during writing. Should be provided as the argument to :meth:`TableUploadSession.commit`. """ return self._blocks_written
[docs] class Upsert(object): """ Object to insert or update data into an ODPS upsert table with records. Should be created with :meth:`TableUpsertSession.open_upsert_stream`. :Example: Here we show an example of inserting, updating and deleting data to an upsert table. .. code-block:: python from odps.tunnel import TableTunnel tunnel = TableTunnel(o) upsert_session = tunnel.create_upsert_session('my_table', partition_spec='pt=test') # creates a BufferedRecordWriter instance stream = upsert_session.open_upsert_stream(compress=True) rec = upsert_session.new_record(["0", "v1"]) stream.upsert(rec) rec = upsert_session.new_record(["0", "v2"]) stream.upsert(rec) rec = upsert_session.new_record(["1", "v1"]) stream.upsert(rec) rec = upsert_session.new_record(["2", "v1"]) stream.upsert(rec) stream.delete(rec) stream.flush() stream.close() upsert_session.commit() """ DEFAULT_MAX_BUFFER_SIZE = 64 * 1024**2 DEFAULT_SLOT_BUFFER_SIZE = 1024**2 class Operation(Enum): UPSERT = "UPSERT" DELETE = "DELETE" class Status(Enum): NORMAL = "NORMAL" ERROR = "ERROR" CLOSED = "CLOSED" def __init__( self, schema, request_callback, session, compress_option=None, encoding="utf-8", max_buffer_size=None, slot_buffer_size=None, ): self._schema = schema self._request_callback = request_callback self._session = session self._compress_option = compress_option self._max_buffer_size = max_buffer_size or self.DEFAULT_MAX_BUFFER_SIZE self._slot_buffer_size = slot_buffer_size or self.DEFAULT_SLOT_BUFFER_SIZE self._total_n_bytes = 0 self._status = Upsert.Status.NORMAL self._schema = session.schema self._encoding = encoding self._hash_keys = self._session.hash_keys self._hasher = RecordHasher(schema, self._session.hasher, self._hash_keys) self._buckets = self._session.buckets.copy() self._bucket_buffers = {} self._bucket_writers = {} for slot in session.buckets.keys(): self._build_bucket_writer(slot) @property def status(self): return self._status @property def n_bytes(self): return self._total_n_bytes
[docs] def upsert(self, record): """ Insert or update a record. :param record: record to write :type record: :class:`odps.models.Record` """ return self._write(record, Upsert.Operation.UPSERT)
[docs] def delete(self, record): """ Delete a record. :param record: record to write :type record: :class:`odps.models.Record` """ return self._write(record, Upsert.Operation.DELETE)
[docs] def flush(self, flush_all=True): """ Flush all data in buffer to server. """ if len(self._session.buckets) != len(self._bucket_writers): raise TunnelError("session slot map is changed") else: self._buckets = self._session.buckets.copy() bucket_written = dict() bucket_to_count = dict() def write_bucket(bucket_id): slot = self._buckets[bucket_id] sio = self._bucket_buffers[bucket_id] rec_count = bucket_to_count[bucket_id] self._request_callback(bucket_id, slot, rec_count, sio.getvalue()) self._build_bucket_writer(bucket_id) bucket_written[bucket_id] = True retry = 0 while True: futs = [] pool = futures.ThreadPoolExecutor(len(self._bucket_writers)) try: self._check_status() for bucket, writer in self._bucket_writers.items(): if writer.n_bytes == 0 or bucket_written.get(bucket): continue if not flush_all and writer.n_bytes <= self._slot_buffer_size: continue bucket_to_count[bucket] = writer.count writer.close() futs.append(pool.submit(write_bucket, bucket)) for fut in futs: fut.result() break except KeyboardInterrupt: raise TunnelError("flush interrupted") except: retry += 1 if retry == 3: raise finally: pool.shutdown()
[docs] def close(self): """ Close the stream and write all data to server. """ if self.status == Upsert.Status.NORMAL: self.flush() self._status = Upsert.Status.CLOSED
def _build_bucket_writer(self, slot): self._bucket_buffers[slot] = compat.BytesIO() self._bucket_writers[slot] = BaseRecordWriter( self._schema, get_compress_stream(self._bucket_buffers[slot], self._compress_option), encoding=self._encoding, ) def _check_status(self): if self._status == Upsert.Status.CLOSED: raise TunnelError("Stream is closed!") elif self._status == Upsert.Status.ERROR: raise TunnelError("Stream has error!") def _write(self, record, op, valid_columns=None): self._check_status() bucket = self._hasher.hash_record(record) % len(self._bucket_writers) if bucket not in self._bucket_writers: raise TunnelError( "Tunnel internal error! Do not have bucket for hash key " + bucket ) record[self._session.UPSERT_OPERATION_KEY] = ord( b"U" if op == Upsert.Operation.UPSERT else b"D" ) if valid_columns is None: record[self._session.UPSERT_VALUE_COLS_KEY] = [] else: valid_cols_set = set(valid_columns) col_idxes = [ idx for idx, col in self._schema.columns if col in valid_cols_set ] record[self._session.UPSERT_VALUE_COLS_KEY] = col_idxes writer = self._bucket_writers[bucket] prev_written_size = writer.n_bytes writer.write(record) written_size = writer.n_bytes self._total_n_bytes += written_size - prev_written_size if writer.n_bytes > self._slot_buffer_size: self.flush(False) elif self._total_n_bytes > self._max_buffer_size: self.flush(True) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): # if an error occurs inside the with block, we do not commit if exc_val is not None: return self.close()