#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2026 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import io
import json
import struct
from concurrent.futures import ThreadPoolExecutor
from decimal import Decimal
try:
import pyarrow as pa
except (AttributeError, ImportError):
pa = None
try:
import pyarrow.compute as pac
except (AttributeError, ImportError):
pac = None
try:
import numpy as np
except ImportError:
np = None
try:
import pandas as pd
except (ImportError, ValueError):
pd = None
import time
from ... import options, types, utils
from ..base import TunnelMetrics
from ..checksum import Checksum
from ..errors import TunnelError
from ..pb.encoder import Encoder
from ..pb.wire_format import (
WIRETYPE_FIXED32,
WIRETYPE_FIXED64,
WIRETYPE_LENGTH_DELIMITED,
WIRETYPE_VARINT,
)
from ..wireconstants import ProtoWireConstants
from .stream import RequestsIO, get_compress_stream
from .types import _is_timestamp_struct_type, odps_schema_to_arrow_schema
try:
if not options.force_py:
from ..hasher_c import RecordHasher
from .writer_c import BaseRecordWriter
else:
from ..hasher import RecordHasher
BaseRecordWriter = None
except ImportError as e:
if options.force_c:
raise e
from ..hasher import RecordHasher
BaseRecordWriter = None
MICRO_SEC_PER_SEC = 1000000
varint_tag_types = types.integer_types + (
types.boolean,
types.datetime,
types.date,
types.interval_year_month,
)
length_delim_tag_types = (
types.string,
types.binary,
types.timestamp,
types.timestamp_ntz,
types.interval_day_time,
types.json,
)
if BaseRecordWriter is None:
class ProtobufWriter:
"""
ProtobufWriter is a stream-interface wrapper around encoder_c.Encoder(c)
and encoder.Encoder(py)
"""
DEFAULT_BUFFER_SIZE = 4096
def __init__(self, output, buffer_size=None):
self._encoder = Encoder()
self._output = output
self._buffer_size = buffer_size or self.DEFAULT_BUFFER_SIZE
self._n_total = 0
def _re_init(self, output):
self._encoder = Encoder()
self._output = output
self._n_total = 0
def _mode(self):
return "py"
def flush(self):
if len(self._encoder) > 0:
data = self._encoder.tostring()
self._output.write(data)
self._n_total += len(self._encoder)
self._encoder = Encoder()
def close(self):
self.flush_all()
def flush_all(self):
self.flush()
self._output.flush()
def _refresh_buffer(self):
"""Control the buffer size of _encoder. Flush if necessary"""
if len(self._encoder) > self._buffer_size:
self.flush()
@property
def n_bytes(self):
return self._n_total + len(self._encoder)
def __len__(self):
return self.n_bytes
def _write_tag(self, field_num, wire_type):
self._encoder.append_tag(field_num, wire_type)
self._refresh_buffer()
def _write_raw_long(self, val):
self._encoder.append_sint64(val)
self._refresh_buffer()
def _write_raw_int(self, val):
self._encoder.append_sint32(val)
self._refresh_buffer()
def _write_raw_uint(self, val):
self._encoder.append_uint32(val)
self._refresh_buffer()
def _write_raw_bool(self, val):
self._encoder.append_bool(val)
self._refresh_buffer()
def _write_raw_float(self, val):
self._encoder.append_float(val)
self._refresh_buffer()
def _write_raw_double(self, val):
self._encoder.append_double(val)
self._refresh_buffer()
def _write_raw_string(self, val):
self._encoder.append_string(val)
self._refresh_buffer()
class BaseRecordWriter(ProtobufWriter):
def __init__(self, schema, out, encoding="utf-8"):
self._encoding = encoding
self._schema = schema
self._columns = self._schema.columns
self._crc = Checksum()
self._crccrc = Checksum()
self._curr_cursor = 0
self._to_milliseconds = utils.MillisecondsConverter().to_milliseconds
self._to_milliseconds_utc = utils.MillisecondsConverter(
local_tz=False
).to_milliseconds
self._to_days = utils.to_days
self._enable_client_metrics = options.tunnel.enable_client_metrics
self._local_wall_time_ms = 0
super(BaseRecordWriter, self).__init__(out)
def write(self, record):
n_record_fields = len(record)
n_columns = len(self._columns)
if self._enable_client_metrics:
ts = time.monotonic()
if n_record_fields > n_columns:
raise IOError("record fields count is more than schema.")
for i in range(min(n_record_fields, n_columns)):
if self._schema.is_partition(self._columns[i]):
continue
val = record[i]
if val is None:
continue
pb_index = i + 1
self._crc.update_int(pb_index)
data_type = self._columns[i].type
if data_type in varint_tag_types:
self._write_tag(pb_index, WIRETYPE_VARINT)
elif data_type == types.float_:
self._write_tag(pb_index, WIRETYPE_FIXED32)
elif data_type == types.double:
self._write_tag(pb_index, WIRETYPE_FIXED64)
elif data_type in length_delim_tag_types:
self._write_tag(pb_index, WIRETYPE_LENGTH_DELIMITED)
elif isinstance(
data_type,
(
types.Char,
types.Varchar,
types.Decimal,
types.Array,
types.Map,
types.Struct,
types.Vector,
),
):
self._write_tag(pb_index, WIRETYPE_LENGTH_DELIMITED)
else:
raise IOError(f"Invalid data type: {data_type}")
self._write_field(val, data_type)
checksum = utils.long_to_int(self._crc.getvalue())
self._write_tag(ProtoWireConstants.TUNNEL_END_RECORD, WIRETYPE_VARINT)
self._write_raw_uint(utils.long_to_uint(checksum))
self._crc.reset()
self._crccrc.update_int(checksum)
self._curr_cursor += 1
if self._enable_client_metrics:
self._local_wall_time_ms += int(
MICRO_SEC_PER_SEC * (time.monotonic() - ts)
)
def _write_bool(self, data):
self._crc.update_bool(data)
self._write_raw_bool(data)
def _write_long(self, data):
self._crc.update_long(data)
self._write_raw_long(data)
def _write_float(self, data):
self._crc.update_float(data)
self._write_raw_float(data)
def _write_double(self, data):
self._crc.update_double(data)
self._write_raw_double(data)
def _write_string(self, data):
if isinstance(data, str):
data = data.encode(self._encoding)
self._crc.update(data)
self._write_raw_string(data)
def _write_timestamp(self, data, ntz=False):
to_mills = self._to_milliseconds_utc if ntz else self._to_milliseconds
t_val = int(to_mills(data.to_pydatetime(warn=False)) / 1000)
nano_val = data.microsecond * 1000 + data.nanosecond
self._crc.update_long(t_val)
self._write_raw_long(t_val)
self._crc.update_int(nano_val)
self._write_raw_int(nano_val)
def _write_interval_day_time(self, data):
t_val = data.days * 3600 * 24 + data.seconds
nano_val = data.microseconds * 1000 + data.nanoseconds
self._crc.update_long(t_val)
self._write_raw_long(t_val)
self._crc.update_int(nano_val)
self._write_raw_int(nano_val)
def _write_array(self, data, data_type):
for value in data:
if value is None:
self._write_raw_bool(True)
else:
self._write_raw_bool(False)
self._write_field(value, data_type)
def _write_vector(self, val, data_type):
if val is None:
return # Null vector
dim = len(val)
self._write_raw_uint(dim)
for elem in val:
self._write_field(elem, data_type.element_type)
def _write_struct(self, data, data_type):
if isinstance(data, dict):
vals = [None] * len(data)
for idx, key in enumerate(data_type.field_types.keys()):
vals[idx] = data[key]
data = tuple(vals)
for value, typ in zip(data, data_type.field_types.values()):
if value is None:
self._write_raw_bool(True)
else:
self._write_raw_bool(False)
self._write_field(value, typ)
def _write_field(self, val, data_type):
if data_type == types.boolean:
self._write_bool(val)
elif data_type == types.datetime:
val = self._to_milliseconds(val)
self._write_long(val)
elif data_type == types.date:
val = self._to_days(val)
self._write_long(val)
elif data_type == types.float_:
self._write_float(val)
elif data_type == types.double:
self._write_double(val)
elif data_type in types.integer_types:
self._write_long(val)
elif data_type == types.string:
self._write_string(val)
elif data_type == types.binary:
self._write_string(val)
elif data_type == types.timestamp or data_type == types.timestamp_ntz:
self._write_timestamp(val, ntz=data_type == types.timestamp_ntz)
elif data_type == types.interval_day_time:
self._write_interval_day_time(val)
elif data_type == types.interval_year_month:
self._write_long(val.total_months())
elif isinstance(data_type, (types.Char, types.Varchar)):
self._write_string(val)
elif isinstance(data_type, types.Decimal):
self._write_string(str(val))
elif isinstance(data_type, types.Json):
self._write_string(json.dumps(val))
elif isinstance(data_type, types.Array):
self._write_raw_uint(len(val))
self._write_array(val, data_type.value_type)
elif isinstance(data_type, types.Map):
self._write_raw_uint(len(val))
self._write_array(list(val.keys()), data_type.key_type)
self._write_raw_uint(len(val))
self._write_array(list(val.values()), data_type.value_type)
elif isinstance(data_type, types.Struct):
self._write_struct(val, data_type)
elif isinstance(data_type, types.Vector):
self._write_vector(val, data_type)
else:
raise IOError(f"Invalid data type: {data_type}")
@property
def count(self):
return self._curr_cursor
def _write_finish_tags(self):
self._write_tag(ProtoWireConstants.TUNNEL_META_COUNT, WIRETYPE_VARINT)
self._write_raw_long(self.count)
self._write_tag(ProtoWireConstants.TUNNEL_META_CHECKSUM, WIRETYPE_VARINT)
self._write_raw_uint(utils.long_to_uint(self._crccrc.getvalue()))
def close(self):
self._write_finish_tags()
super(BaseRecordWriter, self).close()
self._curr_cursor = 0
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# if an error occurs inside the with block, we do not commit
if exc_val is not None:
return
self.close()
[docs]
class RecordWriter(BaseRecordWriter):
"""
Writer object to write data to ODPS with records. Should be created
with :meth:`TableUploadSession.open_record_writer` with ``block_id`` specified.
:Example:
Here we show an example of writing data to ODPS with two records created in different ways.
.. code-block:: python
from odps.tunnel import TableTunnel
tunnel = TableTunnel(o)
upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test')
# creates a RecordWriter instance for block 0
with upload_session.open_record_writer(0) as writer:
record = upload_session.new_record()
record[0] = 'test1'
record[1] = 'id1'
writer.write(record)
record = upload_session.new_record(['test2', 'id2'])
writer.write(record)
# commit block 0
upload_session.commit([0])
:Note:
``RecordWriter`` holds long HTTP connection which might be closed at
server end when the duration is over 3 minutes. Please avoid opening
``RecordWriter`` for a long period. Details can be found :ref:`here <tunnel>`.
"""
def __init__(
self, schema, request_callback, compress_option=None, encoding="utf-8"
):
self._enable_client_metrics = options.tunnel.enable_client_metrics
self._server_metrics_string = None
if self._enable_client_metrics:
ts = time.monotonic()
self._req_io = RequestsIO(
request_callback,
chunk_size=options.chunk_size,
record_io_time=self._enable_client_metrics,
)
out = get_compress_stream(self._req_io, compress_option)
super(RecordWriter, self).__init__(schema, out, encoding=encoding)
self._req_io.start()
if self._enable_client_metrics:
self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (time.monotonic() - ts))
@property
def metrics(self):
if self._server_metrics_string is None:
return None
return TunnelMetrics.from_server_json(
type(self).__name__,
self._server_metrics_string,
self._local_wall_time_ms,
self._req_io.io_time_ms,
)
[docs]
def write(self, record):
"""
Write a record to the tunnel.
:param record: record to write
:type record: :class:`odps.models.Record`
"""
if self._req_io._async_err:
ex_type, ex_value, tb = self._req_io._async_err
raise ex_value.with_traceback(tb)
super(RecordWriter, self).write(record)
[docs]
def close(self):
"""
Close the writer and flush all data to server.
"""
if self._enable_client_metrics:
ts = time.monotonic()
super(RecordWriter, self).close()
resp = self._req_io.finish()
if self._enable_client_metrics:
self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (time.monotonic() - ts))
self._server_metrics_string = resp.headers.get("odps-tunnel-metrics")
def get_total_bytes(self):
return self.n_bytes
[docs]
class BufferedRecordWriter(BaseRecordWriter):
"""
Writer object to write data to ODPS with records. Should be created
with :meth:`TableUploadSession.open_record_writer` without ``block_id``.
Results should be submitted with :meth:`TableUploadSession.commit` with
returned value from :meth:`get_blocks_written`.
:Example:
Here we show an example of writing data to ODPS with two records created in different ways.
.. code-block:: python
from odps.tunnel import TableTunnel
tunnel = TableTunnel(o)
upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test')
# creates a BufferedRecordWriter instance
with upload_session.open_record_writer() as writer:
record = upload_session.new_record()
record[0] = 'test1'
record[1] = 'id1'
writer.write(record)
record = upload_session.new_record(['test2', 'id2'])
writer.write(record)
# commit blocks
upload_session.commit(writer.get_blocks_written())
"""
def __init__(
self,
schema,
request_callback,
compress_option=None,
encoding="utf-8",
buffer_size=None,
block_id=None,
block_id_gen=None,
):
self._request_callback = request_callback
self._block_id = block_id or 0
self._blocks_written = []
self._buffer = io.BytesIO()
self._n_bytes_written = 0
self._compress_option = compress_option
self._block_id_gen = block_id_gen
self._enable_client_metrics = options.tunnel.enable_client_metrics
self._server_metrics_string = None
self._network_wall_time_ms = 0
if not self._enable_client_metrics:
self._accumulated_metrics = None
else:
self._accumulated_metrics = TunnelMetrics(type(self).__name__)
ts = time.monotonic()
out = get_compress_stream(self._buffer, compress_option)
super(BufferedRecordWriter, self).__init__(schema, out, encoding=encoding)
# make sure block buffer size is applied correctly here
self._buffer_size = buffer_size or options.tunnel.block_buffer_size
if self._enable_client_metrics:
self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (time.monotonic() - ts))
@property
def cur_block_id(self):
return self._block_id
def _get_next_block_id(self):
if callable(self._block_id_gen):
return self._block_id_gen()
return self._block_id + 1
[docs]
def write(self, record):
"""
Write a record to the tunnel.
:param record: record to write
:type record: :class:`odps.models.Record`
"""
super(BufferedRecordWriter, self).write(record)
if 0 < self._buffer_size < self._n_raw_bytes:
self._flush()
[docs]
def close(self):
"""
Close the writer and flush all data to server.
"""
if self._n_raw_bytes > 0:
self._flush()
self.flush_all()
self._buffer.close()
def _collect_metrics(self):
if not self._enable_client_metrics:
return
if self._server_metrics_string is not None:
self._accumulated_metrics += TunnelMetrics.from_server_json(
type(self).__name__,
self._server_metrics_string,
self._local_wall_time_ms,
self._network_wall_time_ms,
)
self._server_metrics_string = None
self._local_wall_time_ms = 0
self._network_wall_time_ms = 0
def _reset_writer(self, write_response):
self._collect_metrics()
if self._enable_client_metrics:
ts = time.monotonic()
self._buffer = io.BytesIO()
out = get_compress_stream(self._buffer, self._compress_option)
self._re_init(out)
self._curr_cursor = 0
self._crccrc.reset()
self._crc.reset()
if self._enable_client_metrics:
self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (time.monotonic() - ts))
def _send_buffer(self):
if self._enable_client_metrics:
ts = time.monotonic()
resp = self._request_callback(self._block_id, self._buffer.getvalue())
if self._enable_client_metrics:
self._network_wall_time_ms += int(
MICRO_SEC_PER_SEC * (time.monotonic() - ts)
)
return resp
def _flush(self):
if self._enable_client_metrics:
ts = time.monotonic()
self._write_finish_tags()
self._n_bytes_written += self._n_raw_bytes
self.flush_all()
resp = self._send_buffer()
self._server_metrics_string = resp.headers.get("odps-tunnel-metrics")
self._blocks_written.append(self._block_id)
self._block_id = self._get_next_block_id()
if self._enable_client_metrics:
self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (time.monotonic() - ts))
self._reset_writer(resp)
@property
def metrics(self):
return self._accumulated_metrics
@property
def _n_raw_bytes(self):
return super(BufferedRecordWriter, self).n_bytes
@property
def n_bytes(self):
return self._n_bytes_written + self._n_raw_bytes
def get_total_bytes(self):
return self.n_bytes
[docs]
def get_blocks_written(self):
"""
Get block ids created during writing. Should be provided as the argument to
:meth:`TableUploadSession.commit`.
"""
return self._blocks_written
# make sure original typo class also referable
BufferredRecordWriter = BufferedRecordWriter
class StreamRecordWriter(BufferedRecordWriter):
def __init__(
self,
schema,
request_callback,
session,
slot,
compress_option=None,
encoding="utf-8",
buffer_size=None,
):
self.session = session
self.slot = slot
self._record_count = 0
super(StreamRecordWriter, self).__init__(
schema,
request_callback,
compress_option=compress_option,
encoding=encoding,
buffer_size=buffer_size,
)
def write(self, record):
super(StreamRecordWriter, self).write(record)
self._record_count += 1
def _reset_writer(self, write_response):
self._record_count = 0
slot_server = write_response.headers["odps-tunnel-routed-server"]
slot_num = int(write_response.headers["odps-tunnel-slot-num"])
if self._enable_client_metrics:
ts = time.monotonic()
self.session.reload_slots(self.slot, slot_server, slot_num)
if self._enable_client_metrics:
time_cost = int(MICRO_SEC_PER_SEC * (time.monotonic() - ts))
self._local_wall_time_ms += time_cost
self._network_wall_time_ms += time_cost
super(StreamRecordWriter, self)._reset_writer(write_response)
def _send_buffer(self):
if self._enable_client_metrics:
ts = time.monotonic()
def gen(): # synchronize chunk upload
data = self._buffer.getvalue()
chunk_size = options.chunk_size
while data:
to_send = data[:chunk_size]
data = data[chunk_size:]
yield to_send
try:
return self._request_callback(gen())
finally:
if self._enable_client_metrics:
self._network_wall_time_ms += int(
MICRO_SEC_PER_SEC * (time.monotonic() - ts)
)
class BaseArrowWriter:
def __init__(self, schema, out=None, chunk_size=None):
if pa is None:
raise ValueError("To use arrow writer you need to install pyarrow")
self._schema = schema
self._arrow_schema = odps_schema_to_arrow_schema(schema)
self._chunk_size = chunk_size or options.chunk_size
self._crc = Checksum()
self._crccrc = Checksum()
self._cur_chunk_size = 0
self._output = out
self._chunk_size_written = False
self._pd_mappers = self._build_pd_mappers()
def _re_init(self, output):
self._output = output
self._chunk_size_written = False
self._cur_chunk_size = 0
def _write_chunk_size(self):
self._write_uint32(self._chunk_size)
def _write_uint32(self, val):
data = struct.pack("!I", utils.long_to_uint(val))
self._output.write(data)
def _write_chunk(self, buf):
if not self._chunk_size_written:
self._write_chunk_size()
self._chunk_size_written = True
self._output.write(buf)
self._crc.update(buf)
self._crccrc.update(buf)
self._cur_chunk_size += len(buf)
if self._cur_chunk_size >= self._chunk_size:
checksum = self._crc.getvalue()
self._write_uint32(checksum)
self._crc.reset()
self._cur_chunk_size = 0
@classmethod
def _localize_timezone(cls, col, tz=None):
from ...lib import tzlocal
if tz is None:
if options.local_timezone is True or options.local_timezone is None:
tz = str(tzlocal.get_localzone())
elif options.local_timezone is False:
tz = "UTC"
else:
tz = str(options.local_timezone)
if col.type.tz is not None:
return col
if hasattr(pac, "assume_timezone") and isinstance(tz, str):
# pyarrow.compute.assume_timezone only accepts
# string-represented zones
col = pac.assume_timezone(col, timezone=tz)
return col
else:
pd_col = col.to_pandas().dt.tz_localize(tz)
return pa.Array.from_pandas(pd_col)
@classmethod
def _str_to_decimal_array(cls, col, dec_type):
dec_col = col.to_pandas().map(Decimal)
return pa.Array.from_pandas(dec_col, type=dec_type)
@staticmethod
def _convert_struct_to_timestamp(col):
"""Convert a struct-based timestamp column (sec + nano) to pa.timestamp("ns")."""
sec = col.field("sec")
nano = col.field("nano")
ns = pac.add(pac.multiply(sec, 1_000_000_000), nano)
return ns.cast(pa.timestamp("ns"))
def _build_pd_mappers(self):
pa_dec_types = (pa.Decimal128Type,)
if hasattr(pa, "Decimal256Type"):
pa_dec_types += (pa.Decimal256Type,)
def _need_cast(arrow_type):
if isinstance(arrow_type, (pa.MapType, pa.StructType) + pa_dec_types):
return True
elif isinstance(arrow_type, pa.ListType):
return _need_cast(arrow_type.value_type)
else:
return False
def _build_mapper(cur_type):
if isinstance(cur_type, pa.MapType):
key_mapper = _build_mapper(cur_type.key_type)
value_mapper = _build_mapper(cur_type.item_type)
def mapper(data):
if isinstance(data, dict):
return [
(key_mapper(k), value_mapper(v)) for k, v in data.items()
]
else:
return data
elif isinstance(cur_type, pa.ListType):
item_mapper = _build_mapper(cur_type.value_type)
def mapper(data):
if data is None:
return data
return [item_mapper(element) for element in data]
elif isinstance(cur_type, pa.StructType):
val_mappers = dict()
for fid in range(cur_type.num_fields):
field = cur_type[fid]
val_mappers[field.name.lower()] = _build_mapper(field.type)
def mapper(data):
if isinstance(data, (list, tuple)):
field_names = getattr(data, "_fields", None) or [
cur_type[fid].name for fid in range(cur_type.num_fields)
]
data = dict(zip(data, field_names))
if isinstance(data, dict):
fields = dict()
for key, val in data.items():
fields[key] = val_mappers[key.lower()](val)
data = fields
return data
elif isinstance(cur_type, pa_dec_types):
def mapper(data):
if data is None:
return None
return Decimal(data)
else:
mapper = lambda x: x
return mapper
mappers = dict()
for name, typ in zip(self._arrow_schema.names, self._arrow_schema.types):
if _need_cast(typ):
mappers[name.lower()] = _build_mapper(typ)
return mappers
def _convert_df_types(self, df):
lower_to_df_name = {utils.to_lower_str(s): s for s in df.columns}
new_cols = []
for name, typ in zip(self._arrow_schema.names, self._arrow_schema.types):
df_name = lower_to_df_name[name.lower()]
if df_name not in df.columns:
new_cols.append(pa.array([None] * len(df), type=typ))
elif name.lower() not in self._pd_mappers:
data = pa.Array.from_pandas(df[df_name])
new_cols.append(data)
else:
new_cols.append(
pa.Array.from_pandas(
df[df_name].map(self._pd_mappers[name.lower()]), type=typ
)
)
return pa.Table.from_arrays(new_cols, names=self._arrow_schema.names)
def write(self, data):
"""
Write an Arrow RecordBatch, an Arrow Table or a pandas DataFrame.
"""
if isinstance(data, pd.DataFrame):
arrow_data = self._convert_df_types(data)
elif isinstance(data, (pa.Table, pa.RecordBatch)):
arrow_data = data
else:
raise TypeError("Cannot support writing data type %s", type(data))
arrow_decimal_types = (pa.Decimal128Type,)
if hasattr(pa, "Decimal256Type"):
arrow_decimal_types += (pa.Decimal256Type,)
assert isinstance(arrow_data, (pa.RecordBatch, pa.Table))
if arrow_data.schema != self._arrow_schema or any(
isinstance(tp, pa.TimestampType) or _is_timestamp_struct_type(tp)
for tp in arrow_data.schema.types
):
lower_names = [n.lower() for n in arrow_data.schema.names]
type_dict = dict(zip(lower_names, arrow_data.schema.types))
column_dict = dict(zip(lower_names, arrow_data.columns))
arrays = []
for name, tp in zip(self._arrow_schema.names, self._arrow_schema.types):
lower_name = name.lower()
if lower_name not in column_dict:
raise ValueError(
f"Input record batch does not contain column {name}"
)
if isinstance(tp, pa.TimestampType):
col = column_dict[lower_name]
if _is_timestamp_struct_type(col.type):
col = self._convert_struct_to_timestamp(col)
elif not isinstance(col.type, pa.TimestampType):
col = col.cast(pa.timestamp(tp.unit))
if self._schema[lower_name].type == types.timestamp_ntz:
col = self._localize_timezone(col, "UTC")
else:
col = self._localize_timezone(col)
column_dict[lower_name] = col.cast(
pa.timestamp(tp.unit, col.type.tz)
)
elif (
isinstance(tp, arrow_decimal_types)
and isinstance(column_dict[lower_name], (pa.Array, pa.ChunkedArray))
and column_dict[lower_name].type in (pa.binary(), pa.string())
):
column_dict[lower_name] = self._str_to_decimal_array(
column_dict[lower_name], tp
)
if tp == type_dict[lower_name]:
arrays.append(column_dict[lower_name])
else:
try:
arrays.append(column_dict[lower_name].cast(tp, safe=False))
except (pa.ArrowInvalid, pa.ArrowNotImplementedError):
raise ValueError(f"Failed to cast column {name} to type {tp}")
pa_type = type(arrow_data)
arrow_data = pa_type.from_arrays(arrays, names=self._arrow_schema.names)
if isinstance(arrow_data, pa.RecordBatch):
batches = [arrow_data]
else: # pa.Table
batches = arrow_data.to_batches()
for batch in batches:
data = batch.serialize().to_pybytes()
written_bytes = 0
while written_bytes < len(data):
length = min(
self._chunk_size - self._cur_chunk_size, len(data) - written_bytes
)
chunk_data = data[written_bytes : written_bytes + length]
self._write_chunk(chunk_data)
written_bytes += length
def _write_finish_tags(self):
checksum = self._crccrc.getvalue()
self._write_uint32(checksum)
self._crccrc.reset()
def flush(self):
self._output.flush()
def _finish(self):
self._write_finish_tags()
self._output.flush()
def close(self):
"""
Closes the writer and flush all data to server.
"""
self._finish()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# if an error occurs inside the with block, we do not commit
if exc_val is not None:
return
self.close()
[docs]
class ArrowWriter(BaseArrowWriter):
"""
Writer object to write data to ODPS using Arrow format. Should be created
with :meth:`TableUploadSession.open_arrow_writer` with ``block_id`` specified.
:Example:
Here we show an example of writing a pandas DataFrame to ODPS.
.. code-block:: python
import pandas as pd
from odps.tunnel import TableTunnel
tunnel = TableTunnel(o)
upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test')
# creates an ArrowWriter instance for block 0
with upload_session.open_arrow_writer(0) as writer:
df = pd.DataFrame({'col1': ['test1', 'test2'], 'col2': ['id1', 'id2']})
writer.write(df)
# commit block 0
upload_session.commit([0])
:Note:
``ArrowWriter`` holds long HTTP connection which might be closed at
server end when the duration is over 3 minutes. Please avoid opening
``ArrowWriter`` for a long period. Details can be found :ref:`here <tunnel>`.
"""
def __init__(self, schema, request_callback, compress_option=None, chunk_size=None):
self._req_io = RequestsIO(request_callback, chunk_size=chunk_size)
out = get_compress_stream(self._req_io, compress_option)
super(ArrowWriter, self).__init__(schema, out, chunk_size)
self._req_io.start()
def _finish(self):
super(ArrowWriter, self)._finish()
self._req_io.finish()
[docs]
class BufferedArrowWriter(BaseArrowWriter):
"""
Writer object to write data to ODPS using Arrow format. Should be created
with :meth:`TableUploadSession.open_arrow_writer` without ``block_id``.
Results should be submitted with :meth:`TableUploadSession.commit` with
returned value from :meth:`get_blocks_written`.
:Example:
Here we show an example of writing a pandas DataFrame to ODPS.
.. code-block:: python
import pandas as pd
from odps.tunnel import TableTunnel
tunnel = TableTunnel(o)
upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test')
# creates a BufferedArrowWriter instance
with upload_session.open_arrow_writer() as writer:
df = pd.DataFrame({'col1': ['test1', 'test2'], 'col2': ['id1', 'id2']})
writer.write(df)
# commit blocks
upload_session.commit(writer.get_blocks_written())
"""
def __init__(
self,
schema,
request_callback,
compress_option=None,
buffer_size=None,
chunk_size=None,
block_id=None,
block_id_gen=None,
):
self._buffer_size = buffer_size or options.tunnel.block_buffer_size
self._request_callback = request_callback
self._block_id = block_id or 0
self._blocks_written = []
self._buffer = io.BytesIO()
self._compress_option = compress_option
self._n_bytes_written = 0
self._block_id_gen = block_id_gen
out = get_compress_stream(self._buffer, compress_option)
super(BufferedArrowWriter, self).__init__(schema, out, chunk_size=chunk_size)
@property
def cur_block_id(self):
return self._block_id
def _get_next_block_id(self):
if callable(self._block_id_gen):
return self._block_id_gen()
return self._block_id + 1
[docs]
def write(self, data):
super(BufferedArrowWriter, self).write(data)
if 0 < self._buffer_size < self._n_raw_bytes:
self._flush()
[docs]
def close(self):
if self._n_raw_bytes > 0:
self._flush()
self._finish()
self._buffer.close()
def _reset_writer(self):
self._buffer = io.BytesIO()
out = get_compress_stream(self._buffer, self._compress_option)
self._re_init(out)
self._crccrc.reset()
self._crc.reset()
def _send_buffer(self):
return self._request_callback(self._block_id, self._buffer.getvalue())
def _flush(self):
self._write_finish_tags()
self._n_bytes_written += self._n_raw_bytes
self._send_buffer()
self._blocks_written.append(self._block_id)
self._block_id = self._get_next_block_id()
self._reset_writer()
@property
def _n_raw_bytes(self):
return self._buffer.tell()
@property
def n_bytes(self):
return self._n_bytes_written + self._n_raw_bytes
def get_total_bytes(self):
return self.n_bytes
[docs]
def get_blocks_written(self):
"""
Get block ids created during writing. Should be provided as the argument to
:meth:`TableUploadSession.commit`.
"""
return self._blocks_written
[docs]
class Upsert:
"""
Object to insert or update data into an ODPS upsert table with records. Should be
created with :meth:`TableUpsertSession.open_upsert_stream`.
:Example:
Here we show an example of inserting, updating and deleting data to an upsert table.
.. code-block:: python
from odps.tunnel import TableTunnel
tunnel = TableTunnel(o)
upsert_session = tunnel.create_upsert_session('my_table', partition_spec='pt=test')
# creates a BufferedRecordWriter instance
stream = upsert_session.open_upsert_stream(compress=True)
rec = upsert_session.new_record(["0", "v1"])
stream.upsert(rec)
rec = upsert_session.new_record(["0", "v2"])
stream.upsert(rec)
rec = upsert_session.new_record(["1", "v1"])
stream.upsert(rec)
rec = upsert_session.new_record(["2", "v1"])
stream.upsert(rec)
stream.delete(rec)
stream.flush()
stream.close()
upsert_session.commit()
"""
DEFAULT_MAX_BUFFER_SIZE = 64 * 1024**2
DEFAULT_SLOT_BUFFER_SIZE = 1024**2
class Operation(enum.Enum):
UPSERT = "UPSERT"
DELETE = "DELETE"
class Status(enum.Enum):
NORMAL = "NORMAL"
ERROR = "ERROR"
CLOSED = "CLOSED"
def __init__(
self,
schema,
request_callback,
session,
compress_option=None,
encoding="utf-8",
max_buffer_size=None,
slot_buffer_size=None,
):
self._schema = schema
self._request_callback = request_callback
self._session = session
self._compress_option = compress_option
self._max_buffer_size = max_buffer_size or self.DEFAULT_MAX_BUFFER_SIZE
self._slot_buffer_size = slot_buffer_size or self.DEFAULT_SLOT_BUFFER_SIZE
self._total_n_bytes = 0
self._status = Upsert.Status.NORMAL
self._schema = session.schema
self._encoding = encoding
self._hash_keys = self._session.hash_keys
self._hasher = RecordHasher(schema, self._session.hasher, self._hash_keys)
self._buckets = self._session.buckets.copy()
self._bucket_buffers = {}
self._bucket_writers = {}
for slot in session.buckets.keys():
self._build_bucket_writer(slot)
@property
def status(self):
return self._status
@property
def n_bytes(self):
return self._total_n_bytes
[docs]
def upsert(self, record):
"""
Insert or update a record.
:param record: record to write
:type record: :class:`odps.models.Record`
"""
return self._write(record, Upsert.Operation.UPSERT)
[docs]
def delete(self, record):
"""
Delete a record.
:param record: record to write
:type record: :class:`odps.models.Record`
"""
return self._write(record, Upsert.Operation.DELETE)
[docs]
def flush(self, flush_all=True):
"""
Flush all data in buffer to server.
"""
if len(self._session.buckets) != len(self._bucket_writers):
raise TunnelError("session slot map is changed")
else:
self._buckets = self._session.buckets.copy()
bucket_written = dict()
bucket_to_count = dict()
def write_bucket(bucket_id):
slot = self._buckets[bucket_id]
sio = self._bucket_buffers[bucket_id]
rec_count = bucket_to_count[bucket_id]
self._request_callback(bucket_id, slot, rec_count, sio.getvalue())
self._build_bucket_writer(bucket_id)
bucket_written[bucket_id] = True
retry = 0
while True:
futs = []
pool = ThreadPoolExecutor(len(self._bucket_writers))
try:
self._check_status()
for bucket, writer in self._bucket_writers.items():
if writer.n_bytes == 0 or bucket_written.get(bucket):
continue
if not flush_all and writer.n_bytes <= self._slot_buffer_size:
continue
bucket_to_count[bucket] = writer.count
writer.close()
futs.append(pool.submit(write_bucket, bucket))
for fut in futs:
fut.result()
break
except KeyboardInterrupt:
raise TunnelError("flush interrupted")
except:
retry += 1
if retry == 3:
raise
finally:
pool.shutdown()
[docs]
def close(self):
"""
Close the stream and write all data to server.
"""
if self.status == Upsert.Status.NORMAL:
self.flush()
self._status = Upsert.Status.CLOSED
def _build_bucket_writer(self, slot):
self._bucket_buffers[slot] = io.BytesIO()
self._bucket_writers[slot] = BaseRecordWriter(
self._schema,
get_compress_stream(self._bucket_buffers[slot], self._compress_option),
encoding=self._encoding,
)
def _check_status(self):
if self._status == Upsert.Status.CLOSED:
raise TunnelError("Stream is closed")
elif self._status == Upsert.Status.ERROR:
raise TunnelError("Stream has error")
def _write(self, record, op, valid_columns=None):
self._check_status()
bucket = self._hasher.hash_record(record) % len(self._bucket_writers)
if bucket not in self._bucket_writers:
raise TunnelError(
"Tunnel internal error! Do not have bucket for hash key " + bucket
)
record[self._session.UPSERT_OPERATION_KEY] = ord(
b"U" if op == Upsert.Operation.UPSERT else b"D"
)
if valid_columns is None:
record[self._session.UPSERT_VALUE_COLS_KEY] = []
else:
valid_cols_set = set(valid_columns)
col_idxes = [
idx for idx, col in self._schema.columns if col in valid_cols_set
]
record[self._session.UPSERT_VALUE_COLS_KEY] = col_idxes
writer = self._bucket_writers[bucket]
prev_written_size = writer.n_bytes
writer.write(record)
written_size = writer.n_bytes
self._total_n_bytes += written_size - prev_written_size
if writer.n_bytes > self._slot_buffer_size:
self.flush(False)
elif self._total_n_bytes > self._max_buffer_size:
self.flush(True)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# if an error occurs inside the with block, we do not commit
if exc_val is not None:
return
self.close()