diff --git a/convergence-arrow/src/table.rs b/convergence-arrow/src/table.rs index 992d373..cd89418 100644 --- a/convergence-arrow/src/table.rs +++ b/convergence-arrow/src/table.rs @@ -1,12 +1,15 @@ //! Utilities for converting between Arrow and Postgres formats. +use std::str::FromStr; + use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlState}; use convergence::protocol_ext::DataRowBatch; +use datafusion::arrow::array::timezone::Tz; use datafusion::arrow::array::{ - BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array, - Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, - StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array + BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, StringArray, StringViewArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, }; use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit}; use datafusion::arrow::record_batch::RecordBatch; @@ -61,23 +64,52 @@ pub fn record_batch_to_rows(arrow_batch: &RecordBatch, pg_batch: &mut DataRowBat ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported date type") })?) } - DataType::Timestamp(unit, None) => row.write_timestamp( - match unit { - TimeUnit::Second => array_val!(TimestampSecondArray, col, row_idx, value_as_datetime), - TimeUnit::Millisecond => { - array_val!(TimestampMillisecondArray, col, row_idx, value_as_datetime) - } - TimeUnit::Microsecond => { - array_val!(TimestampMicrosecondArray, col, row_idx, value_as_datetime) + DataType::Timestamp(unit, tz) => { + match tz { + Some(tz) => { + let tz = Tz::from_str(tz.as_ref()).map_err(|_| { + ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported timezone") + })?; + let dt = match unit { + TimeUnit::Second => array_cast!(TimestampSecondArray, col) + .value_as_datetime_with_tz(row_idx, tz) + .map(|d| d.fixed_offset()), + TimeUnit::Millisecond => array_cast!(TimestampMillisecondArray, col) + .value_as_datetime_with_tz(row_idx, tz) + .map(|d| d.fixed_offset()), + TimeUnit::Microsecond => array_cast!(TimestampMicrosecondArray, col) + .value_as_datetime_with_tz(row_idx, tz) + .map(|d| d.fixed_offset()), + TimeUnit::Nanosecond => array_cast!(TimestampNanosecondArray, col) + .value_as_datetime_with_tz(row_idx, tz) + .map(|d| d.fixed_offset()), + } + .ok_or_else(|| { + ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported timestamp type") + })?; + row.write_timestamp_with_tz(dt) } - TimeUnit::Nanosecond => { - array_val!(TimestampNanosecondArray, col, row_idx, value_as_datetime) - } - } - .ok_or_else(|| { - ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported timestamp type") - })?, - ), + None => row.write_timestamp( + match unit { + TimeUnit::Second => { + array_val!(TimestampSecondArray, col, row_idx, value_as_datetime) + } + TimeUnit::Millisecond => { + array_val!(TimestampMillisecondArray, col, row_idx, value_as_datetime) + } + TimeUnit::Microsecond => { + array_val!(TimestampMicrosecondArray, col, row_idx, value_as_datetime) + } + TimeUnit::Nanosecond => { + array_val!(TimestampNanosecondArray, col, row_idx, value_as_datetime) + } + } + .ok_or_else(|| { + ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported timestamp type") + })?, + ), + }; + } other => { return Err(ErrorResponse::error( SqlState::FeatureNotSupported, @@ -108,7 +140,10 @@ pub fn data_type_to_oid(ty: &DataType) -> Result { DataType::Decimal128(_, _) => DataTypeOid::Numeric, DataType::Utf8 | DataType::Utf8View => DataTypeOid::Text, DataType::Date32 | DataType::Date64 => DataTypeOid::Date, - DataType::Timestamp(_, None) => DataTypeOid::Timestamp, + DataType::Timestamp(_, tz) => match tz { + Some(_) => DataTypeOid::Timestamptz, + None => DataTypeOid::Timestamp, + }, other => { return Err(ErrorResponse::error( SqlState::FeatureNotSupported, diff --git a/convergence-arrow/tests/test_arrow.rs b/convergence-arrow/tests/test_arrow.rs index aecf492..0862a27 100644 --- a/convergence-arrow/tests/test_arrow.rs +++ b/convergence-arrow/tests/test_arrow.rs @@ -1,16 +1,19 @@ use async_trait::async_trait; -use chrono::{NaiveDate, NaiveDateTime}; +use chrono::{DateTime, NaiveDate, NaiveDateTime}; use convergence::engine::{Engine, Portal}; use convergence::protocol::{ErrorResponse, FieldDescription}; use convergence::protocol_ext::DataRowBatch; use convergence::server::{self, BindOptions}; use convergence::sqlparser::ast::Statement; use convergence_arrow::table::{record_batch_to_rows, schema_to_field_desc}; -use datafusion::arrow::array::{ArrayRef, Date32Array, Decimal128Array, Float32Array, Int32Array, StringArray, StringViewArray, TimestampSecondArray}; +use datafusion::arrow::array::{ + ArrayRef, Date32Array, Decimal128Array, Float32Array, Int32Array, StringArray, StringViewArray, + TimestampSecondArray, +}; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::arrow::record_batch::RecordBatch; -use std::sync::Arc; use rust_decimal::Decimal; +use std::sync::Arc; use tokio_postgres::{connect, NoTls}; struct ArrowPortal { @@ -32,10 +35,17 @@ impl ArrowEngine { fn new() -> Self { let int_col = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; let float_col = Arc::new(Float32Array::from(vec![1.5, 2.5, 3.5])) as ArrayRef; - let decimal_col = Arc::new(Decimal128Array::from(vec![11, 22, 33]).with_precision_and_scale(2, 0).unwrap()) as ArrayRef; + let decimal_col = Arc::new( + Decimal128Array::from(vec![11, 22, 33]) + .with_precision_and_scale(2, 0) + .unwrap(), + ) as ArrayRef; let string_col = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; let string_view_col = Arc::new(StringViewArray::from(vec!["aa", "bb", "cc"])) as ArrayRef; let ts_col = Arc::new(TimestampSecondArray::from(vec![1577836800, 1580515200, 1583020800])) as ArrayRef; + let ts_tz_col = + Arc::new(TimestampSecondArray::from(vec![1577854800, 1580533200, 1583038800]).with_timezone("+05:00")) + as ArrayRef; let date_col = Arc::new(Date32Array::from(vec![0, 1, 2])) as ArrayRef; let schema = Schema::new(vec![ @@ -45,12 +55,29 @@ impl ArrowEngine { Field::new("string_col", DataType::Utf8, true), Field::new("string_view_col", DataType::Utf8View, true), Field::new("ts_col", DataType::Timestamp(TimeUnit::Second, None), true), + Field::new( + "ts_tz_col", + DataType::Timestamp(TimeUnit::Second, Some("+05:00".into())), + true, + ), Field::new("date_col", DataType::Date32, true), ]); Self { - batch: RecordBatch::try_new(Arc::new(schema), vec![int_col, float_col, decimal_col, string_col, string_view_col, ts_col, date_col]) - .expect("failed to create batch"), + batch: RecordBatch::try_new( + Arc::new(schema), + vec![ + int_col, + float_col, + decimal_col, + string_col, + string_view_col, + ts_col, + ts_tz_col, + date_col, + ], + ) + .expect("failed to create batch"), } } } @@ -94,8 +121,16 @@ async fn basic_data_types() { let rows = client.query("select 1", &[]).await.unwrap(); let get_row = |idx: usize| { let row = &rows[idx]; - let cols: (i32, f32, Decimal, &str, &str, NaiveDateTime, NaiveDate) = - (row.get(0), row.get(1), row.get(2), row.get(3), row.get(4), row.get(5), row.get(6)); + let cols: (i32, f32, Decimal, &str, &str, NaiveDateTime, DateTime<_>, NaiveDate) = ( + row.get(0), + row.get(1), + row.get(2), + row.get(3), + row.get(4), + row.get(5), + row.get(6), + row.get(7), + ); cols }; @@ -111,6 +146,7 @@ async fn basic_data_types() { .unwrap() .and_hms_opt(0, 0, 0) .unwrap(), + DateTime::from_timestamp_millis(1577854800000).unwrap(), NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(), ) ); @@ -126,6 +162,7 @@ async fn basic_data_types() { .unwrap() .and_hms_opt(0, 0, 0) .unwrap(), + DateTime::from_timestamp_millis(1580533200000).unwrap(), NaiveDate::from_ymd_opt(1970, 1, 2).unwrap() ) ); @@ -141,6 +178,7 @@ async fn basic_data_types() { .unwrap() .and_hms_opt(0, 0, 0) .unwrap(), + DateTime::from_timestamp_millis(1583038800000).unwrap(), NaiveDate::from_ymd_opt(1970, 1, 3).unwrap() ) ); diff --git a/convergence/src/protocol.rs b/convergence/src/protocol.rs index bfae419..5b6b517 100644 --- a/convergence/src/protocol.rs +++ b/convergence/src/protocol.rs @@ -79,6 +79,7 @@ data_types! { Date = 1082, 4 Timestamp = 1114, 8 + Timestamptz = 1184, 8 Text = 25, -1 } diff --git a/convergence/src/protocol_ext.rs b/convergence/src/protocol_ext.rs index f2db117..c47037b 100644 --- a/convergence/src/protocol_ext.rs +++ b/convergence/src/protocol_ext.rs @@ -2,7 +2,7 @@ use crate::protocol::{ConnectionCodec, FormatCode, ProtocolError, RowDescription}; use bytes::{BufMut, BytesMut}; -use chrono::{NaiveDate, NaiveDateTime}; +use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime}; use rust_decimal::Decimal; use tokio_postgres::types::{ToSql, Type}; use tokio_util::codec::Encoder; @@ -133,17 +133,29 @@ impl<'a> DataRowWriter<'a> { } } + /// Writes a timestamp with timezone value for the next column. + pub fn write_timestamp_with_tz(&mut self, val: DateTime) { + match self.parent.format_code { + FormatCode::Binary => { + let ts_tz_type = Type::from_oid(1184).expect("failed to create timestamptz type"); + let mut buf = BytesMut::new(); + val.to_sql(&ts_tz_type, &mut buf).expect("failed to write timestamptz"); + self.write_value(&buf.freeze()) + } + FormatCode::Text => self.write_string(&val.to_string()), + } + } + /// Writes a numeric value for the next column. pub fn write_numeric_16(&mut self, val: i128, _p: &u8, s: &i8) { let decimal = Decimal::from_i128_with_scale(val, *s as u32); match self.parent.format_code { - FormatCode::Text => { - self.write_string(&decimal.to_string()) - } + FormatCode::Text => self.write_string(&decimal.to_string()), FormatCode::Binary => { let numeric_type = Type::from_oid(1700).expect("failed to create numeric type"); let mut buf = BytesMut::new(); - decimal.to_sql(&numeric_type, &mut buf) + decimal + .to_sql(&numeric_type, &mut buf) .expect("failed to write numeric"); self.write_value(&buf.freeze())