#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
from collections import OrderedDict
import requests
from .. import serializers, types
from ..compat import Enum, six
from ..config import options
from ..models import Projects, TableSchema
from .base import TUNNEL_VERSION, BaseTunnel
from .errors import TunnelError
from .io.reader import BufferedRecordReader, TunnelArrowReader, TunnelRecordReader
from .io.stream import CompressOption, get_decompress_stream
try:
import numpy as np
except ImportError:
np = None
logger = logging.getLogger(__name__)
[docs]
class InstanceDownloadSession(serializers.JSONSerializableModel):
"""
Tunnel session for downloading data from instance results. Instances of
this class should be created by :meth:`InstanceTunnel.create_download_session`.
You may get the id of the session for reuse by attribute ``id`` of
the session object.
"""
__slots__ = (
"_client",
"_instance",
"_limit_enabled",
"_compress_option",
"_sessional",
"_session_task_name",
"_session_subquery_id",
"_quota_name",
"_timeout",
"_tags",
)
class Status(Enum):
Unknown = "UNKNOWN"
Normal = "NORMAL"
Closes = "CLOSES"
Expired = "EXPIRED"
Failed = "FAILED"
Initiating = "INITIATING"
id = serializers.JSONNodeField("DownloadID")
status = serializers.JSONNodeField(
"Status", parse_callback=lambda s: InstanceDownloadSession.Status(s.upper())
)
count = serializers.JSONNodeField("RecordCount")
schema = serializers.JSONNodeReferenceField(TableSchema, "Schema")
quota_name = serializers.JSONNodeField("QuotaName")
support_read_by_raw_size = serializers.JSONNodeField(
"SupportReadByRawSize", default=None
)
def __init__(
self,
client,
instance,
download_id=None,
limit=None,
compress_option=None,
quota_name=None,
timeout=None,
tags=None,
**kw
):
super(InstanceDownloadSession, self).__init__()
self._client = client
self._instance = instance
self._limit_enabled = (
limit if limit is not None else kw.get("limit_enabled", False)
)
self._quota_name = quota_name
self._sessional = kw.pop("sessional", False)
self._session_task_name = kw.pop("session_task_name", "")
self._session_subquery_id = int(kw.pop("session_subquery_id", -1))
self._timeout = timeout
if self._sessional and (
(not self._session_task_name) or (self._session_subquery_id == -1)
):
raise TunnelError(
"Taskname('session_task_name') and Subquery ID ('session_subquery_id') "
"keyword argument must be provided for session instance tunnels."
)
self._tags = tags or options.tunnel.tags
if isinstance(self._tags, six.string_types):
self._tags = self._tags.split(",")
if download_id is None:
self._init()
else:
self.id = download_id
self.reload()
self._compress_option = compress_option
logger.info("Tunnel session created: %r", self)
if options.tunnel_session_create_callback:
options.tunnel_session_create_callback(self)
def __repr__(self):
repr_kw = OrderedDict(
[
("id", self.id),
("project_name", self._instance.project.name),
("instance_id", self._instance.id),
("subquery_id", self._session_subquery_id if self._sessional else None),
("limited", self._limit_enabled if self._limit_enabled else None),
]
)
repr_kw = OrderedDict([(k, v) for k, v in repr_kw.items() if v is not None])
return "<InstanceDownloadSession %s>" % " ".join(
"%s=%s" % (k, v) for k, v in repr_kw.items()
)
def _init(self):
params = {}
headers = {
"Content-Length": 0,
"x-odps-tunnel-version": TUNNEL_VERSION,
}
if self._tags:
headers["odps-tunnel-tags"] = ",".join(self._tags)
if self._quota_name is not None:
params["quotaName"] = self._quota_name
# Now we use DirectDownloadMode to fetch session results(any other method is removed)
# This mode, only one request needed. So we don't have to send request here ..
if not self._sessional:
if self._limit_enabled:
params["instance_tunnel_limit_enabled"] = ""
url = self._instance.resource()
try:
resp = self._client.post(
url,
{},
action="downloads",
params=params,
headers=headers,
timeout=self._timeout,
)
except requests.exceptions.ReadTimeout:
if callable(options.tunnel_session_create_timeout_callback):
options.tunnel_session_create_timeout_callback(*sys.exc_info())
raise
if self._client.is_ok(resp):
self.parse(resp, obj=self)
if self.schema is not None:
self.schema.build_snapshot()
else:
e = TunnelError.parse(resp)
raise e
def reload(self):
if not self._sessional:
params = {"downloadid": self.id}
if self._quota_name is not None:
params["quotaName"] = self._quota_name
headers = {
"Content-Length": 0,
"x-odps-tunnel-version": TUNNEL_VERSION,
}
if self._tags:
headers["odps-tunnel-tags"] = ",".join(self._tags)
if self._sessional:
params["cached"] = ""
params["taskname"] = self._session_task_name
url = self._instance.resource()
resp = self._client.get(url, params=params, headers=headers)
if self._client.is_ok(resp):
self.parse(resp, obj=self)
if self.schema is not None:
self.schema.build_snapshot()
else:
e = TunnelError.parse(resp)
raise e
else:
self.status = InstanceDownloadSession.Status.Normal
def _build_input_stream(
self, start, count, compress=False, columns=None, arrow=False, raw_size=None
):
compress_option = self._compress_option or CompressOption()
params = {}
headers = {"x-odps-tunnel-version": TUNNEL_VERSION}
if self._quota_name is not None:
params["quotaName"] = self._quota_name
if self._tags:
headers["odps-tunnel-tags"] = ",".join(self._tags)
if self._sessional:
params["cached"] = ""
params["taskname"] = self._session_task_name
params["queryid"] = str(self._session_subquery_id)
else:
params["downloadid"] = self.id
params["rowrange"] = "(%s,%s)" % (start, count)
headers["Content-Length"] = 0
if compress:
encoding = compress_option.algorithm.get_encoding()
if encoding is not None:
headers["Accept-Encoding"] = encoding
params["data"] = ""
if columns is not None and len(columns) > 0:
col_name = lambda col: col.name if isinstance(col, types.Column) else col
params["columns"] = ",".join(col_name(col) for col in columns)
if arrow:
params["arrow"] = ""
if raw_size:
params["raw_size"] = str(raw_size)
url = self._instance.resource()
resp = self._client.get(url, stream=True, params=params, headers=headers)
if not self._client.is_ok(resp):
e = TunnelError.parse(resp)
raise e
if self._sessional:
# in DirectDownloadMode, the schema is brought back in HEADER.
# handle this.
schema_json = resp.headers.get("odps-tunnel-schema")
self.schema = TableSchema()
self.schema = self.schema.deserial(schema_json)
content_encoding = resp.headers.get("Content-Encoding")
if content_encoding is not None:
compress_algo = CompressOption.CompressAlgorithm.from_encoding(
content_encoding
)
if compress_algo != compress_option.algorithm:
compress_option = self._compress_option = CompressOption(
compress_algo, -1, 0
)
compress = True
else:
compress = False
option = compress_option if compress else None
return get_decompress_stream(resp, option)
def _open_reader(
self,
start,
count,
compress=False,
columns=None,
arrow=False,
on_exception=None,
reader_cls=None,
**kw
):
stream_kw = dict(compress=compress, columns=columns, arrow=arrow)
initial_stream_cache = [None]
def stream_creator(cursor, cache=False, row_number=None, raw_size=None):
if cursor == 0 and initial_stream_cache[0] is not None:
initial_stream_cache[0], stream = None, initial_stream_cache[0]
return stream
attempt_count = count - cursor if count is not None else None
if attempt_count <= 0:
return None
if row_number is not None:
attempt_count = min(attempt_count, row_number)
stream = self._build_input_stream(
start + cursor, attempt_count, raw_size=raw_size, **stream_kw
)
if cache:
initial_stream_cache[0] = stream
return stream
# for MCQA we must obtain schema from the first stream, hence the first reader
# must be created beforehand and then cached for the reader class
stream_creator(0, True)
return reader_cls(
self.schema,
stream_creator,
columns=columns,
on_exception=on_exception,
**kw
)
[docs]
def open_record_reader(
self,
start,
count,
compress=False,
columns=None,
buffered=False,
buffer_size=None,
row_batch_size=None,
on_exception=None,
**_
):
"""
Open a reader to read data as records from the tunnel.
:param int start: start row index
:param int count: number of rows to read
:param bool compress: whether to compress data
:param columns: list of column names to read
:param bool buffered: whether to use buffered reader
:param int buffer_size: download buffer size in bytes. Num of rows read in every batch
will be limited by this parameter as well as `row_batch_size`.
:param bool row_batch_size: number of rows to read per batch. Num of rows read in every
batch will be limited by this parameter as well as `buffer_size`.
:param on_exception: custom error handling function accepting
an Exception instance as input. If return value is True,
error will be raised. Otherwise retry will continue.
:return: a record reader
:rtype: :class:`TunnelRecordReader`
"""
reader_cls = BufferedRecordReader if buffered else TunnelRecordReader
kw = {}
if buffer_size:
kw["buffer_size"] = buffer_size
if row_batch_size:
kw["row_batch_size"] = row_batch_size
if buffered:
kw["session_with_byte_size_limit"] = self.support_read_by_raw_size
return self._open_reader(
start,
count,
compress=compress,
columns=columns,
on_exception=on_exception,
reader_cls=reader_cls,
**kw
)
[docs]
def open_arrow_reader(
self, start, count, compress=False, columns=None, on_exception=None, **_
):
"""
Open a reader to read data as arrow format from the tunnel.
:param int start: start row index
:param int count: number of rows to read
:param bool compress: whether to compress data
:param columns: list of column names to read
:param on_exception: custom error handling function accepting
an Exception instance as input. If return value is True,
error will be raised. Otherwise retry will continue.
:return: an arrow reader
:rtype: :class:`TunnelArrowReader`
"""
return self._open_reader(
start,
count,
compress=compress,
columns=columns,
arrow=True,
on_exception=on_exception,
reader_cls=TunnelArrowReader,
)
[docs]
class InstanceTunnel(BaseTunnel):
"""
Instance tunnel API Entry.
:param odps: ODPS Entry object
:param str project: project name
:param str endpoint: tunnel endpoint
:param str quota_name: name of tunnel quota
"""
[docs]
def create_download_session(
self,
instance,
download_id=None,
limit=None,
compress_option=None,
compress_algo=None,
compress_level=None,
compress_strategy=None,
timeout=None,
tags=None,
**kw
):
"""
Create a download session for instance results.
:param instance: instance object to read
:type instance: str | :class:`odps.models.Instance`
:param str download_id: existing download id
:param int limit: record limit of the download session
:param compress_option: compress option
:type compress_option: :class:`odps.tunnel.CompressOption`
:param str compress_algo: compress algorithm
:param int compress_level: compress level
:param str schema: name of schema of the table
:param tags: tags of the upload session
:type tags: str | list
:return: :class:`InstanceDownloadSession`
"""
if not isinstance(instance, six.string_types):
instance = instance.id
instance = Projects(client=self.tunnel_rest)[self._project.name].instances[
instance
]
compress_option = compress_option
if compress_option is None and compress_algo is not None:
compress_option = CompressOption(
compress_algo=compress_algo,
level=compress_level,
strategy=compress_strategy,
)
if limit is None:
limit = kw.get("limit_enabled", False)
return InstanceDownloadSession(
self.tunnel_rest,
instance,
download_id=download_id,
limit=limit,
compress_option=compress_option,
quota_name=self._quota_name,
timeout=timeout,
tags=tags,
**kw
)