Source code for odps.tunnel.io.writer

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2026 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import io
import json
import struct
from concurrent.futures import ThreadPoolExecutor
from decimal import Decimal

try:
    import pyarrow as pa
except (AttributeError, ImportError):
    pa = None
try:
    import pyarrow.compute as pac
except (AttributeError, ImportError):
    pac = None
try:
    import numpy as np
except ImportError:
    np = None
try:
    import pandas as pd
except (ImportError, ValueError):
    pd = None

import time

from ... import options, types, utils
from ..base import TunnelMetrics
from ..checksum import Checksum
from ..errors import TunnelError
from ..pb.encoder import Encoder
from ..pb.wire_format import (
    WIRETYPE_FIXED32,
    WIRETYPE_FIXED64,
    WIRETYPE_LENGTH_DELIMITED,
    WIRETYPE_VARINT,
)
from ..wireconstants import ProtoWireConstants
from .stream import RequestsIO, get_compress_stream
from .types import _is_timestamp_struct_type, odps_schema_to_arrow_schema

try:
    if not options.force_py:
        from ..hasher_c import RecordHasher
        from .writer_c import BaseRecordWriter
    else:
        from ..hasher import RecordHasher

        BaseRecordWriter = None
except ImportError as e:
    if options.force_c:
        raise e

    from ..hasher import RecordHasher

    BaseRecordWriter = None

MICRO_SEC_PER_SEC = 1000000

varint_tag_types = types.integer_types + (
    types.boolean,
    types.datetime,
    types.date,
    types.interval_year_month,
)
length_delim_tag_types = (
    types.string,
    types.binary,
    types.timestamp,
    types.timestamp_ntz,
    types.interval_day_time,
    types.json,
)


if BaseRecordWriter is None:

    class ProtobufWriter:
        """
        ProtobufWriter is a stream-interface wrapper around encoder_c.Encoder(c)
        and encoder.Encoder(py)
        """

        DEFAULT_BUFFER_SIZE = 4096

        def __init__(self, output, buffer_size=None):
            self._encoder = Encoder()
            self._output = output
            self._buffer_size = buffer_size or self.DEFAULT_BUFFER_SIZE
            self._n_total = 0

        def _re_init(self, output):
            self._encoder = Encoder()
            self._output = output
            self._n_total = 0

        def _mode(self):
            return "py"

        def flush(self):
            if len(self._encoder) > 0:
                data = self._encoder.tostring()
                self._output.write(data)
                self._n_total += len(self._encoder)
                self._encoder = Encoder()

        def close(self):
            self.flush_all()

        def flush_all(self):
            self.flush()
            self._output.flush()

        def _refresh_buffer(self):
            """Control the buffer size of _encoder. Flush if necessary"""
            if len(self._encoder) > self._buffer_size:
                self.flush()

        @property
        def n_bytes(self):
            return self._n_total + len(self._encoder)

        def __len__(self):
            return self.n_bytes

        def _write_tag(self, field_num, wire_type):
            self._encoder.append_tag(field_num, wire_type)
            self._refresh_buffer()

        def _write_raw_long(self, val):
            self._encoder.append_sint64(val)
            self._refresh_buffer()

        def _write_raw_int(self, val):
            self._encoder.append_sint32(val)
            self._refresh_buffer()

        def _write_raw_uint(self, val):
            self._encoder.append_uint32(val)
            self._refresh_buffer()

        def _write_raw_bool(self, val):
            self._encoder.append_bool(val)
            self._refresh_buffer()

        def _write_raw_float(self, val):
            self._encoder.append_float(val)
            self._refresh_buffer()

        def _write_raw_double(self, val):
            self._encoder.append_double(val)
            self._refresh_buffer()

        def _write_raw_string(self, val):
            self._encoder.append_string(val)
            self._refresh_buffer()

    class BaseRecordWriter(ProtobufWriter):
        def __init__(self, schema, out, encoding="utf-8"):
            self._encoding = encoding
            self._schema = schema
            self._columns = self._schema.columns
            self._crc = Checksum()
            self._crccrc = Checksum()
            self._curr_cursor = 0
            self._to_milliseconds = utils.MillisecondsConverter().to_milliseconds
            self._to_milliseconds_utc = utils.MillisecondsConverter(
                local_tz=False
            ).to_milliseconds
            self._to_days = utils.to_days

            self._enable_client_metrics = options.tunnel.enable_client_metrics
            self._local_wall_time_ms = 0

            super(BaseRecordWriter, self).__init__(out)

        def write(self, record):
            n_record_fields = len(record)
            n_columns = len(self._columns)

            if self._enable_client_metrics:
                ts = time.monotonic()

            if n_record_fields > n_columns:
                raise IOError("record fields count is more than schema.")

            for i in range(min(n_record_fields, n_columns)):
                if self._schema.is_partition(self._columns[i]):
                    continue

                val = record[i]
                if val is None:
                    continue

                pb_index = i + 1
                self._crc.update_int(pb_index)

                data_type = self._columns[i].type
                if data_type in varint_tag_types:
                    self._write_tag(pb_index, WIRETYPE_VARINT)
                elif data_type == types.float_:
                    self._write_tag(pb_index, WIRETYPE_FIXED32)
                elif data_type == types.double:
                    self._write_tag(pb_index, WIRETYPE_FIXED64)
                elif data_type in length_delim_tag_types:
                    self._write_tag(pb_index, WIRETYPE_LENGTH_DELIMITED)
                elif isinstance(
                    data_type,
                    (
                        types.Char,
                        types.Varchar,
                        types.Decimal,
                        types.Array,
                        types.Map,
                        types.Struct,
                        types.Vector,
                    ),
                ):
                    self._write_tag(pb_index, WIRETYPE_LENGTH_DELIMITED)
                else:
                    raise IOError(f"Invalid data type: {data_type}")
                self._write_field(val, data_type)

            checksum = utils.long_to_int(self._crc.getvalue())
            self._write_tag(ProtoWireConstants.TUNNEL_END_RECORD, WIRETYPE_VARINT)
            self._write_raw_uint(utils.long_to_uint(checksum))
            self._crc.reset()
            self._crccrc.update_int(checksum)
            self._curr_cursor += 1

            if self._enable_client_metrics:
                self._local_wall_time_ms += int(
                    MICRO_SEC_PER_SEC * (time.monotonic() - ts)
                )

        def _write_bool(self, data):
            self._crc.update_bool(data)
            self._write_raw_bool(data)

        def _write_long(self, data):
            self._crc.update_long(data)
            self._write_raw_long(data)

        def _write_float(self, data):
            self._crc.update_float(data)
            self._write_raw_float(data)

        def _write_double(self, data):
            self._crc.update_double(data)
            self._write_raw_double(data)

        def _write_string(self, data):
            if isinstance(data, str):
                data = data.encode(self._encoding)
            self._crc.update(data)
            self._write_raw_string(data)

        def _write_timestamp(self, data, ntz=False):
            to_mills = self._to_milliseconds_utc if ntz else self._to_milliseconds
            t_val = int(to_mills(data.to_pydatetime(warn=False)) / 1000)
            nano_val = data.microsecond * 1000 + data.nanosecond
            self._crc.update_long(t_val)
            self._write_raw_long(t_val)
            self._crc.update_int(nano_val)
            self._write_raw_int(nano_val)

        def _write_interval_day_time(self, data):
            t_val = data.days * 3600 * 24 + data.seconds
            nano_val = data.microseconds * 1000 + data.nanoseconds
            self._crc.update_long(t_val)
            self._write_raw_long(t_val)
            self._crc.update_int(nano_val)
            self._write_raw_int(nano_val)

        def _write_array(self, data, data_type):
            for value in data:
                if value is None:
                    self._write_raw_bool(True)
                else:
                    self._write_raw_bool(False)
                    self._write_field(value, data_type)

        def _write_vector(self, val, data_type):
            if val is None:
                return  # Null vector
            dim = len(val)
            self._write_raw_uint(dim)
            for elem in val:
                self._write_field(elem, data_type.element_type)

        def _write_struct(self, data, data_type):
            if isinstance(data, dict):
                vals = [None] * len(data)
                for idx, key in enumerate(data_type.field_types.keys()):
                    vals[idx] = data[key]
                data = tuple(vals)
            for value, typ in zip(data, data_type.field_types.values()):
                if value is None:
                    self._write_raw_bool(True)
                else:
                    self._write_raw_bool(False)
                    self._write_field(value, typ)

        def _write_field(self, val, data_type):
            if data_type == types.boolean:
                self._write_bool(val)
            elif data_type == types.datetime:
                val = self._to_milliseconds(val)
                self._write_long(val)
            elif data_type == types.date:
                val = self._to_days(val)
                self._write_long(val)
            elif data_type == types.float_:
                self._write_float(val)
            elif data_type == types.double:
                self._write_double(val)
            elif data_type in types.integer_types:
                self._write_long(val)
            elif data_type == types.string:
                self._write_string(val)
            elif data_type == types.binary:
                self._write_string(val)
            elif data_type == types.timestamp or data_type == types.timestamp_ntz:
                self._write_timestamp(val, ntz=data_type == types.timestamp_ntz)
            elif data_type == types.interval_day_time:
                self._write_interval_day_time(val)
            elif data_type == types.interval_year_month:
                self._write_long(val.total_months())
            elif isinstance(data_type, (types.Char, types.Varchar)):
                self._write_string(val)
            elif isinstance(data_type, types.Decimal):
                self._write_string(str(val))
            elif isinstance(data_type, types.Json):
                self._write_string(json.dumps(val))
            elif isinstance(data_type, types.Array):
                self._write_raw_uint(len(val))
                self._write_array(val, data_type.value_type)
            elif isinstance(data_type, types.Map):
                self._write_raw_uint(len(val))
                self._write_array(list(val.keys()), data_type.key_type)
                self._write_raw_uint(len(val))
                self._write_array(list(val.values()), data_type.value_type)
            elif isinstance(data_type, types.Struct):
                self._write_struct(val, data_type)
            elif isinstance(data_type, types.Vector):
                self._write_vector(val, data_type)
            else:
                raise IOError(f"Invalid data type: {data_type}")

        @property
        def count(self):
            return self._curr_cursor

        def _write_finish_tags(self):
            self._write_tag(ProtoWireConstants.TUNNEL_META_COUNT, WIRETYPE_VARINT)
            self._write_raw_long(self.count)
            self._write_tag(ProtoWireConstants.TUNNEL_META_CHECKSUM, WIRETYPE_VARINT)
            self._write_raw_uint(utils.long_to_uint(self._crccrc.getvalue()))

        def close(self):
            self._write_finish_tags()
            super(BaseRecordWriter, self).close()
            self._curr_cursor = 0

        def __enter__(self):
            return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            # if an error occurs inside the with block, we do not commit
            if exc_val is not None:
                return
            self.close()


[docs] class RecordWriter(BaseRecordWriter): """ Writer object to write data to ODPS with records. Should be created with :meth:`TableUploadSession.open_record_writer` with ``block_id`` specified. :Example: Here we show an example of writing data to ODPS with two records created in different ways. .. code-block:: python from odps.tunnel import TableTunnel tunnel = TableTunnel(o) upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test') # creates a RecordWriter instance for block 0 with upload_session.open_record_writer(0) as writer: record = upload_session.new_record() record[0] = 'test1' record[1] = 'id1' writer.write(record) record = upload_session.new_record(['test2', 'id2']) writer.write(record) # commit block 0 upload_session.commit([0]) :Note: ``RecordWriter`` holds long HTTP connection which might be closed at server end when the duration is over 3 minutes. Please avoid opening ``RecordWriter`` for a long period. Details can be found :ref:`here <tunnel>`. """ def __init__( self, schema, request_callback, compress_option=None, encoding="utf-8" ): self._enable_client_metrics = options.tunnel.enable_client_metrics self._server_metrics_string = None if self._enable_client_metrics: ts = time.monotonic() self._req_io = RequestsIO( request_callback, chunk_size=options.chunk_size, record_io_time=self._enable_client_metrics, ) out = get_compress_stream(self._req_io, compress_option) super(RecordWriter, self).__init__(schema, out, encoding=encoding) self._req_io.start() if self._enable_client_metrics: self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (time.monotonic() - ts)) @property def metrics(self): if self._server_metrics_string is None: return None return TunnelMetrics.from_server_json( type(self).__name__, self._server_metrics_string, self._local_wall_time_ms, self._req_io.io_time_ms, )
[docs] def write(self, record): """ Write a record to the tunnel. :param record: record to write :type record: :class:`odps.models.Record` """ if self._req_io._async_err: ex_type, ex_value, tb = self._req_io._async_err raise ex_value.with_traceback(tb) super(RecordWriter, self).write(record)
[docs] def close(self): """ Close the writer and flush all data to server. """ if self._enable_client_metrics: ts = time.monotonic() super(RecordWriter, self).close() resp = self._req_io.finish() if self._enable_client_metrics: self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (time.monotonic() - ts)) self._server_metrics_string = resp.headers.get("odps-tunnel-metrics")
def get_total_bytes(self): return self.n_bytes
[docs] class BufferedRecordWriter(BaseRecordWriter): """ Writer object to write data to ODPS with records. Should be created with :meth:`TableUploadSession.open_record_writer` without ``block_id``. Results should be submitted with :meth:`TableUploadSession.commit` with returned value from :meth:`get_blocks_written`. :Example: Here we show an example of writing data to ODPS with two records created in different ways. .. code-block:: python from odps.tunnel import TableTunnel tunnel = TableTunnel(o) upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test') # creates a BufferedRecordWriter instance with upload_session.open_record_writer() as writer: record = upload_session.new_record() record[0] = 'test1' record[1] = 'id1' writer.write(record) record = upload_session.new_record(['test2', 'id2']) writer.write(record) # commit blocks upload_session.commit(writer.get_blocks_written()) """ def __init__( self, schema, request_callback, compress_option=None, encoding="utf-8", buffer_size=None, block_id=None, block_id_gen=None, ): self._request_callback = request_callback self._block_id = block_id or 0 self._blocks_written = [] self._buffer = io.BytesIO() self._n_bytes_written = 0 self._compress_option = compress_option self._block_id_gen = block_id_gen self._enable_client_metrics = options.tunnel.enable_client_metrics self._server_metrics_string = None self._network_wall_time_ms = 0 if not self._enable_client_metrics: self._accumulated_metrics = None else: self._accumulated_metrics = TunnelMetrics(type(self).__name__) ts = time.monotonic() out = get_compress_stream(self._buffer, compress_option) super(BufferedRecordWriter, self).__init__(schema, out, encoding=encoding) # make sure block buffer size is applied correctly here self._buffer_size = buffer_size or options.tunnel.block_buffer_size if self._enable_client_metrics: self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (time.monotonic() - ts)) @property def cur_block_id(self): return self._block_id def _get_next_block_id(self): if callable(self._block_id_gen): return self._block_id_gen() return self._block_id + 1
[docs] def write(self, record): """ Write a record to the tunnel. :param record: record to write :type record: :class:`odps.models.Record` """ super(BufferedRecordWriter, self).write(record) if 0 < self._buffer_size < self._n_raw_bytes: self._flush()
[docs] def close(self): """ Close the writer and flush all data to server. """ if self._n_raw_bytes > 0: self._flush() self.flush_all() self._buffer.close()
def _collect_metrics(self): if not self._enable_client_metrics: return if self._server_metrics_string is not None: self._accumulated_metrics += TunnelMetrics.from_server_json( type(self).__name__, self._server_metrics_string, self._local_wall_time_ms, self._network_wall_time_ms, ) self._server_metrics_string = None self._local_wall_time_ms = 0 self._network_wall_time_ms = 0 def _reset_writer(self, write_response): self._collect_metrics() if self._enable_client_metrics: ts = time.monotonic() self._buffer = io.BytesIO() out = get_compress_stream(self._buffer, self._compress_option) self._re_init(out) self._curr_cursor = 0 self._crccrc.reset() self._crc.reset() if self._enable_client_metrics: self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (time.monotonic() - ts)) def _send_buffer(self): if self._enable_client_metrics: ts = time.monotonic() resp = self._request_callback(self._block_id, self._buffer.getvalue()) if self._enable_client_metrics: self._network_wall_time_ms += int( MICRO_SEC_PER_SEC * (time.monotonic() - ts) ) return resp def _flush(self): if self._enable_client_metrics: ts = time.monotonic() self._write_finish_tags() self._n_bytes_written += self._n_raw_bytes self.flush_all() resp = self._send_buffer() self._server_metrics_string = resp.headers.get("odps-tunnel-metrics") self._blocks_written.append(self._block_id) self._block_id = self._get_next_block_id() if self._enable_client_metrics: self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (time.monotonic() - ts)) self._reset_writer(resp) @property def metrics(self): return self._accumulated_metrics @property def _n_raw_bytes(self): return super(BufferedRecordWriter, self).n_bytes @property def n_bytes(self): return self._n_bytes_written + self._n_raw_bytes def get_total_bytes(self): return self.n_bytes
[docs] def get_blocks_written(self): """ Get block ids created during writing. Should be provided as the argument to :meth:`TableUploadSession.commit`. """ return self._blocks_written
# make sure original typo class also referable BufferredRecordWriter = BufferedRecordWriter class StreamRecordWriter(BufferedRecordWriter): def __init__( self, schema, request_callback, session, slot, compress_option=None, encoding="utf-8", buffer_size=None, ): self.session = session self.slot = slot self._record_count = 0 super(StreamRecordWriter, self).__init__( schema, request_callback, compress_option=compress_option, encoding=encoding, buffer_size=buffer_size, ) def write(self, record): super(StreamRecordWriter, self).write(record) self._record_count += 1 def _reset_writer(self, write_response): self._record_count = 0 slot_server = write_response.headers["odps-tunnel-routed-server"] slot_num = int(write_response.headers["odps-tunnel-slot-num"]) if self._enable_client_metrics: ts = time.monotonic() self.session.reload_slots(self.slot, slot_server, slot_num) if self._enable_client_metrics: time_cost = int(MICRO_SEC_PER_SEC * (time.monotonic() - ts)) self._local_wall_time_ms += time_cost self._network_wall_time_ms += time_cost super(StreamRecordWriter, self)._reset_writer(write_response) def _send_buffer(self): if self._enable_client_metrics: ts = time.monotonic() def gen(): # synchronize chunk upload data = self._buffer.getvalue() chunk_size = options.chunk_size while data: to_send = data[:chunk_size] data = data[chunk_size:] yield to_send try: return self._request_callback(gen()) finally: if self._enable_client_metrics: self._network_wall_time_ms += int( MICRO_SEC_PER_SEC * (time.monotonic() - ts) ) class BaseArrowWriter: def __init__(self, schema, out=None, chunk_size=None): if pa is None: raise ValueError("To use arrow writer you need to install pyarrow") self._schema = schema self._arrow_schema = odps_schema_to_arrow_schema(schema) self._chunk_size = chunk_size or options.chunk_size self._crc = Checksum() self._crccrc = Checksum() self._cur_chunk_size = 0 self._output = out self._chunk_size_written = False self._pd_mappers = self._build_pd_mappers() def _re_init(self, output): self._output = output self._chunk_size_written = False self._cur_chunk_size = 0 def _write_chunk_size(self): self._write_uint32(self._chunk_size) def _write_uint32(self, val): data = struct.pack("!I", utils.long_to_uint(val)) self._output.write(data) def _write_chunk(self, buf): if not self._chunk_size_written: self._write_chunk_size() self._chunk_size_written = True self._output.write(buf) self._crc.update(buf) self._crccrc.update(buf) self._cur_chunk_size += len(buf) if self._cur_chunk_size >= self._chunk_size: checksum = self._crc.getvalue() self._write_uint32(checksum) self._crc.reset() self._cur_chunk_size = 0 @classmethod def _localize_timezone(cls, col, tz=None): from ...lib import tzlocal if tz is None: if options.local_timezone is True or options.local_timezone is None: tz = str(tzlocal.get_localzone()) elif options.local_timezone is False: tz = "UTC" else: tz = str(options.local_timezone) if col.type.tz is not None: return col if hasattr(pac, "assume_timezone") and isinstance(tz, str): # pyarrow.compute.assume_timezone only accepts # string-represented zones col = pac.assume_timezone(col, timezone=tz) return col else: pd_col = col.to_pandas().dt.tz_localize(tz) return pa.Array.from_pandas(pd_col) @classmethod def _str_to_decimal_array(cls, col, dec_type): dec_col = col.to_pandas().map(Decimal) return pa.Array.from_pandas(dec_col, type=dec_type) @staticmethod def _convert_struct_to_timestamp(col): """Convert a struct-based timestamp column (sec + nano) to pa.timestamp("ns").""" sec = col.field("sec") nano = col.field("nano") ns = pac.add(pac.multiply(sec, 1_000_000_000), nano) return ns.cast(pa.timestamp("ns")) def _build_pd_mappers(self): pa_dec_types = (pa.Decimal128Type,) if hasattr(pa, "Decimal256Type"): pa_dec_types += (pa.Decimal256Type,) def _need_cast(arrow_type): if isinstance(arrow_type, (pa.MapType, pa.StructType) + pa_dec_types): return True elif isinstance(arrow_type, pa.ListType): return _need_cast(arrow_type.value_type) else: return False def _build_mapper(cur_type): if isinstance(cur_type, pa.MapType): key_mapper = _build_mapper(cur_type.key_type) value_mapper = _build_mapper(cur_type.item_type) def mapper(data): if isinstance(data, dict): return [ (key_mapper(k), value_mapper(v)) for k, v in data.items() ] else: return data elif isinstance(cur_type, pa.ListType): item_mapper = _build_mapper(cur_type.value_type) def mapper(data): if data is None: return data return [item_mapper(element) for element in data] elif isinstance(cur_type, pa.StructType): val_mappers = dict() for fid in range(cur_type.num_fields): field = cur_type[fid] val_mappers[field.name.lower()] = _build_mapper(field.type) def mapper(data): if isinstance(data, (list, tuple)): field_names = getattr(data, "_fields", None) or [ cur_type[fid].name for fid in range(cur_type.num_fields) ] data = dict(zip(data, field_names)) if isinstance(data, dict): fields = dict() for key, val in data.items(): fields[key] = val_mappers[key.lower()](val) data = fields return data elif isinstance(cur_type, pa_dec_types): def mapper(data): if data is None: return None return Decimal(data) else: mapper = lambda x: x return mapper mappers = dict() for name, typ in zip(self._arrow_schema.names, self._arrow_schema.types): if _need_cast(typ): mappers[name.lower()] = _build_mapper(typ) return mappers def _convert_df_types(self, df): lower_to_df_name = {utils.to_lower_str(s): s for s in df.columns} new_cols = [] for name, typ in zip(self._arrow_schema.names, self._arrow_schema.types): df_name = lower_to_df_name[name.lower()] if df_name not in df.columns: new_cols.append(pa.array([None] * len(df), type=typ)) elif name.lower() not in self._pd_mappers: data = pa.Array.from_pandas(df[df_name]) new_cols.append(data) else: new_cols.append( pa.Array.from_pandas( df[df_name].map(self._pd_mappers[name.lower()]), type=typ ) ) return pa.Table.from_arrays(new_cols, names=self._arrow_schema.names) def write(self, data): """ Write an Arrow RecordBatch, an Arrow Table or a pandas DataFrame. """ if isinstance(data, pd.DataFrame): arrow_data = self._convert_df_types(data) elif isinstance(data, (pa.Table, pa.RecordBatch)): arrow_data = data else: raise TypeError("Cannot support writing data type %s", type(data)) arrow_decimal_types = (pa.Decimal128Type,) if hasattr(pa, "Decimal256Type"): arrow_decimal_types += (pa.Decimal256Type,) assert isinstance(arrow_data, (pa.RecordBatch, pa.Table)) if arrow_data.schema != self._arrow_schema or any( isinstance(tp, pa.TimestampType) or _is_timestamp_struct_type(tp) for tp in arrow_data.schema.types ): lower_names = [n.lower() for n in arrow_data.schema.names] type_dict = dict(zip(lower_names, arrow_data.schema.types)) column_dict = dict(zip(lower_names, arrow_data.columns)) arrays = [] for name, tp in zip(self._arrow_schema.names, self._arrow_schema.types): lower_name = name.lower() if lower_name not in column_dict: raise ValueError( f"Input record batch does not contain column {name}" ) if isinstance(tp, pa.TimestampType): col = column_dict[lower_name] if _is_timestamp_struct_type(col.type): col = self._convert_struct_to_timestamp(col) elif not isinstance(col.type, pa.TimestampType): col = col.cast(pa.timestamp(tp.unit)) if self._schema[lower_name].type == types.timestamp_ntz: col = self._localize_timezone(col, "UTC") else: col = self._localize_timezone(col) column_dict[lower_name] = col.cast( pa.timestamp(tp.unit, col.type.tz) ) elif ( isinstance(tp, arrow_decimal_types) and isinstance(column_dict[lower_name], (pa.Array, pa.ChunkedArray)) and column_dict[lower_name].type in (pa.binary(), pa.string()) ): column_dict[lower_name] = self._str_to_decimal_array( column_dict[lower_name], tp ) if tp == type_dict[lower_name]: arrays.append(column_dict[lower_name]) else: try: arrays.append(column_dict[lower_name].cast(tp, safe=False)) except (pa.ArrowInvalid, pa.ArrowNotImplementedError): raise ValueError(f"Failed to cast column {name} to type {tp}") pa_type = type(arrow_data) arrow_data = pa_type.from_arrays(arrays, names=self._arrow_schema.names) if isinstance(arrow_data, pa.RecordBatch): batches = [arrow_data] else: # pa.Table batches = arrow_data.to_batches() for batch in batches: data = batch.serialize().to_pybytes() written_bytes = 0 while written_bytes < len(data): length = min( self._chunk_size - self._cur_chunk_size, len(data) - written_bytes ) chunk_data = data[written_bytes : written_bytes + length] self._write_chunk(chunk_data) written_bytes += length def _write_finish_tags(self): checksum = self._crccrc.getvalue() self._write_uint32(checksum) self._crccrc.reset() def flush(self): self._output.flush() def _finish(self): self._write_finish_tags() self._output.flush() def close(self): """ Closes the writer and flush all data to server. """ self._finish() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): # if an error occurs inside the with block, we do not commit if exc_val is not None: return self.close()
[docs] class ArrowWriter(BaseArrowWriter): """ Writer object to write data to ODPS using Arrow format. Should be created with :meth:`TableUploadSession.open_arrow_writer` with ``block_id`` specified. :Example: Here we show an example of writing a pandas DataFrame to ODPS. .. code-block:: python import pandas as pd from odps.tunnel import TableTunnel tunnel = TableTunnel(o) upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test') # creates an ArrowWriter instance for block 0 with upload_session.open_arrow_writer(0) as writer: df = pd.DataFrame({'col1': ['test1', 'test2'], 'col2': ['id1', 'id2']}) writer.write(df) # commit block 0 upload_session.commit([0]) :Note: ``ArrowWriter`` holds long HTTP connection which might be closed at server end when the duration is over 3 minutes. Please avoid opening ``ArrowWriter`` for a long period. Details can be found :ref:`here <tunnel>`. """ def __init__(self, schema, request_callback, compress_option=None, chunk_size=None): self._req_io = RequestsIO(request_callback, chunk_size=chunk_size) out = get_compress_stream(self._req_io, compress_option) super(ArrowWriter, self).__init__(schema, out, chunk_size) self._req_io.start() def _finish(self): super(ArrowWriter, self)._finish() self._req_io.finish()
[docs] class BufferedArrowWriter(BaseArrowWriter): """ Writer object to write data to ODPS using Arrow format. Should be created with :meth:`TableUploadSession.open_arrow_writer` without ``block_id``. Results should be submitted with :meth:`TableUploadSession.commit` with returned value from :meth:`get_blocks_written`. :Example: Here we show an example of writing a pandas DataFrame to ODPS. .. code-block:: python import pandas as pd from odps.tunnel import TableTunnel tunnel = TableTunnel(o) upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test') # creates a BufferedArrowWriter instance with upload_session.open_arrow_writer() as writer: df = pd.DataFrame({'col1': ['test1', 'test2'], 'col2': ['id1', 'id2']}) writer.write(df) # commit blocks upload_session.commit(writer.get_blocks_written()) """ def __init__( self, schema, request_callback, compress_option=None, buffer_size=None, chunk_size=None, block_id=None, block_id_gen=None, ): self._buffer_size = buffer_size or options.tunnel.block_buffer_size self._request_callback = request_callback self._block_id = block_id or 0 self._blocks_written = [] self._buffer = io.BytesIO() self._compress_option = compress_option self._n_bytes_written = 0 self._block_id_gen = block_id_gen out = get_compress_stream(self._buffer, compress_option) super(BufferedArrowWriter, self).__init__(schema, out, chunk_size=chunk_size) @property def cur_block_id(self): return self._block_id def _get_next_block_id(self): if callable(self._block_id_gen): return self._block_id_gen() return self._block_id + 1
[docs] def write(self, data): super(BufferedArrowWriter, self).write(data) if 0 < self._buffer_size < self._n_raw_bytes: self._flush()
[docs] def close(self): if self._n_raw_bytes > 0: self._flush() self._finish() self._buffer.close()
def _reset_writer(self): self._buffer = io.BytesIO() out = get_compress_stream(self._buffer, self._compress_option) self._re_init(out) self._crccrc.reset() self._crc.reset() def _send_buffer(self): return self._request_callback(self._block_id, self._buffer.getvalue()) def _flush(self): self._write_finish_tags() self._n_bytes_written += self._n_raw_bytes self._send_buffer() self._blocks_written.append(self._block_id) self._block_id = self._get_next_block_id() self._reset_writer() @property def _n_raw_bytes(self): return self._buffer.tell() @property def n_bytes(self): return self._n_bytes_written + self._n_raw_bytes def get_total_bytes(self): return self.n_bytes
[docs] def get_blocks_written(self): """ Get block ids created during writing. Should be provided as the argument to :meth:`TableUploadSession.commit`. """ return self._blocks_written
[docs] class Upsert: """ Object to insert or update data into an ODPS upsert table with records. Should be created with :meth:`TableUpsertSession.open_upsert_stream`. :Example: Here we show an example of inserting, updating and deleting data to an upsert table. .. code-block:: python from odps.tunnel import TableTunnel tunnel = TableTunnel(o) upsert_session = tunnel.create_upsert_session('my_table', partition_spec='pt=test') # creates a BufferedRecordWriter instance stream = upsert_session.open_upsert_stream(compress=True) rec = upsert_session.new_record(["0", "v1"]) stream.upsert(rec) rec = upsert_session.new_record(["0", "v2"]) stream.upsert(rec) rec = upsert_session.new_record(["1", "v1"]) stream.upsert(rec) rec = upsert_session.new_record(["2", "v1"]) stream.upsert(rec) stream.delete(rec) stream.flush() stream.close() upsert_session.commit() """ DEFAULT_MAX_BUFFER_SIZE = 64 * 1024**2 DEFAULT_SLOT_BUFFER_SIZE = 1024**2 class Operation(enum.Enum): UPSERT = "UPSERT" DELETE = "DELETE" class Status(enum.Enum): NORMAL = "NORMAL" ERROR = "ERROR" CLOSED = "CLOSED" def __init__( self, schema, request_callback, session, compress_option=None, encoding="utf-8", max_buffer_size=None, slot_buffer_size=None, ): self._schema = schema self._request_callback = request_callback self._session = session self._compress_option = compress_option self._max_buffer_size = max_buffer_size or self.DEFAULT_MAX_BUFFER_SIZE self._slot_buffer_size = slot_buffer_size or self.DEFAULT_SLOT_BUFFER_SIZE self._total_n_bytes = 0 self._status = Upsert.Status.NORMAL self._schema = session.schema self._encoding = encoding self._hash_keys = self._session.hash_keys self._hasher = RecordHasher(schema, self._session.hasher, self._hash_keys) self._buckets = self._session.buckets.copy() self._bucket_buffers = {} self._bucket_writers = {} for slot in session.buckets.keys(): self._build_bucket_writer(slot) @property def status(self): return self._status @property def n_bytes(self): return self._total_n_bytes
[docs] def upsert(self, record): """ Insert or update a record. :param record: record to write :type record: :class:`odps.models.Record` """ return self._write(record, Upsert.Operation.UPSERT)
[docs] def delete(self, record): """ Delete a record. :param record: record to write :type record: :class:`odps.models.Record` """ return self._write(record, Upsert.Operation.DELETE)
[docs] def flush(self, flush_all=True): """ Flush all data in buffer to server. """ if len(self._session.buckets) != len(self._bucket_writers): raise TunnelError("session slot map is changed") else: self._buckets = self._session.buckets.copy() bucket_written = dict() bucket_to_count = dict() def write_bucket(bucket_id): slot = self._buckets[bucket_id] sio = self._bucket_buffers[bucket_id] rec_count = bucket_to_count[bucket_id] self._request_callback(bucket_id, slot, rec_count, sio.getvalue()) self._build_bucket_writer(bucket_id) bucket_written[bucket_id] = True retry = 0 while True: futs = [] pool = ThreadPoolExecutor(len(self._bucket_writers)) try: self._check_status() for bucket, writer in self._bucket_writers.items(): if writer.n_bytes == 0 or bucket_written.get(bucket): continue if not flush_all and writer.n_bytes <= self._slot_buffer_size: continue bucket_to_count[bucket] = writer.count writer.close() futs.append(pool.submit(write_bucket, bucket)) for fut in futs: fut.result() break except KeyboardInterrupt: raise TunnelError("flush interrupted") except: retry += 1 if retry == 3: raise finally: pool.shutdown()
[docs] def close(self): """ Close the stream and write all data to server. """ if self.status == Upsert.Status.NORMAL: self.flush() self._status = Upsert.Status.CLOSED
def _build_bucket_writer(self, slot): self._bucket_buffers[slot] = io.BytesIO() self._bucket_writers[slot] = BaseRecordWriter( self._schema, get_compress_stream(self._bucket_buffers[slot], self._compress_option), encoding=self._encoding, ) def _check_status(self): if self._status == Upsert.Status.CLOSED: raise TunnelError("Stream is closed") elif self._status == Upsert.Status.ERROR: raise TunnelError("Stream has error") def _write(self, record, op, valid_columns=None): self._check_status() bucket = self._hasher.hash_record(record) % len(self._bucket_writers) if bucket not in self._bucket_writers: raise TunnelError( "Tunnel internal error! Do not have bucket for hash key " + bucket ) record[self._session.UPSERT_OPERATION_KEY] = ord( b"U" if op == Upsert.Operation.UPSERT else b"D" ) if valid_columns is None: record[self._session.UPSERT_VALUE_COLS_KEY] = [] else: valid_cols_set = set(valid_columns) col_idxes = [ idx for idx, col in self._schema.columns if col in valid_cols_set ] record[self._session.UPSERT_VALUE_COLS_KEY] = col_idxes writer = self._bucket_writers[bucket] prev_written_size = writer.n_bytes writer.write(record) written_size = writer.n_bytes self._total_n_bytes += written_size - prev_written_size if writer.n_bytes > self._slot_buffer_size: self.flush(False) elif self._total_n_bytes > self._max_buffer_size: self.flush(True) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): # if an error occurs inside the with block, we do not commit if exc_val is not None: return self.close()