Dataframe#

Dataframe methods#

class pystarburst.dataframe.DataFrame(session: Session | None = None, plan: TrinoPlan | None = None, is_cached: bool = False)#

Bases: object

Represents a lazily-evaluated relational dataset that contains a collection of Row objects with columns defined by a schema (column name and type).

A DataFrame is considered lazy because it encapsulates the computation or query required to produce a relational dataset. The computation is not performed until you call a method that performs an action (e.g. collect()).

Creating a DataFrame

You can create a DataFrame in a number of different ways, as shown in the examples below.

Creating tables and data to run the sample code:
>>> session.sql("create table prices(product_id varchar, amount decimal(10, 2))").collect()
[]
>>> session.sql("insert into prices values ('id1', 10.0), ('id2', 20.0)").collect()
[]
>>> session.sql("create table product_catalog(id varchar, name varchar)").collect()
[]
>>> session.sql("insert into prices values ('id1', 'Product A'), ('id2', 'Product B')").collect()
[]
Example 1

Creating a DataFrame by reading a table in Trino:

>>> df_prices = session.table("prices")
>>> df_catalog = session.table("product_catalog")
Example 2

Creating a DataFrame by specifying a sequence or a range:

>>> session.create_dataframe([(1, "one"), (2, "two")], schema=["col_a", "col_b"]).show()
---------------------
|"COL_A"  |"COL_B"  |
---------------------
|1        |one      |
|2        |two      |
---------------------

>>> session.range(1, 10, 2).to_df("col1").show()
----------
|"COL1"  |
----------
|1       |
|3       |
|5       |
|7       |
|9       |
----------
Example 3

Create a new DataFrame by applying transformations to other existing DataFrames:

>>> df_merged_data = df_catalog.join(df_prices, df_catalog["id"] == df_prices["product_id"])

Performing operations on a DataFrame

Broadly, the operations on DataFrame can be divided into two types:

  • Transformations produce a new DataFrame from one or more existing DataFrames. Note that transformations are lazy and don’t cause the DataFrame to be evaluated. If the API does not provide a method to express the SQL that you want to use, you can use functions.sqlExpr() as a workaround.

  • Actions cause the DataFrame to be evaluated. When you call a method that performs an action, PyStarburst sends the SQL query for the DataFrame to the server for evaluation.

Transforming a DataFrame

The following examples demonstrate how you can transform a DataFrame.

Example 4

Using the select() method to select the columns that should be in the DataFrame (similar to adding a SELECT clause):

>>> # Return a new DataFrame containing the product_id and amount columns of the prices table.
>>> # This is equivalent to: SELECT PRODUCT_ID, AMOUNT FROM PRICES;
>>> df_price_ids_and_amounts = df_prices.select(col("product_id"), col("amount"))
Example 5

Using the Column.as_() method to rename a column in a DataFrame (similar to using SELECT col AS alias):

>>> # Return a new DataFrame containing the product_id column of the prices table as a column named
>>> # item_id. This is equivalent to: SELECT PRODUCT_ID AS ITEM_ID FROM PRICES;
>>> df_price_item_ids = df_prices.select(col("product_id").as_("item_id"))
Example 6

Using the filter() method to filter data (similar to adding a WHERE clause):

>>> # Return a new DataFrame containing the row from the prices table with the ID 1.
>>> # This is equivalent to:
>>> # SELECT * FROM PRICES WHERE PRODUCT_ID = 1;
>>> df_price1 = df_prices.filter((col("product_id") == 1))
Example 7

Using the sort() method to specify the sort order of the data (similar to adding an ORDER BY clause):

>>> # Return a new DataFrame for the prices table with the rows sorted by product_id.
>>> # This is equivalent to: SELECT * FROM PRICES ORDER BY PRODUCT_ID;
>>> df_sorted_prices = df_prices.sort(col("product_id"))
Example 8

Using agg() method to aggregate results.

>>> import pystarburst.functions as f
>>> df_prices.agg(("amount", "sum")).collect()
[Row(SUM(AMOUNT)=Decimal('30.00'))]
>>> df_prices.agg(f.sum("amount")).collect()
[Row(SUM(AMOUNT)=Decimal('30.00'))]
>>> # rename the aggregation column name
>>> df_prices.agg(f.sum("amount").alias("total_amount"), f.max("amount").alias("max_amount")).collect()
[Row(TOTAL_AMOUNT=Decimal('30.00'), MAX_AMOUNT=Decimal('20.00'))]
Example 9

Using the group_by() method to return a RelationalGroupedDataFrame that you can use to group and aggregate results (similar to adding a GROUP BY clause).

RelationalGroupedDataFrame provides methods for aggregating results, including:

  • RelationalGroupedDataFrame.avg() (equivalent to AVG(column))

  • RelationalGroupedDataFrame.count() (equivalent to COUNT())

  • RelationalGroupedDataFrame.max() (equivalent to MAX(column))

  • RelationalGroupedDataFrame.min() (equivalent to MIN(column))

  • RelationalGroupedDataFrame.sum() (equivalent to SUM(column))

>>> # Return a new DataFrame for the prices table that computes the sum of the prices by
>>> # category. This is equivalent to:
>>> #  SELECT CATEGORY, SUM(AMOUNT) FROM PRICES GROUP BY CATEGORY
>>> df_total_price_per_category = df_prices.group_by(col("product_id")).sum(col("amount"))
>>> # Have multiple aggregation values with the group by
>>> import pystarburst.functions as f
>>> df_summary = df_prices.group_by(col("product_id")).agg(f.sum(col("amount")).alias("total_amount"), f.avg("amount"))
>>> df_summary.show()
-------------------------------------------------
|"PRODUCT_ID"  |"TOTAL_AMOUNT"  |"AVG(AMOUNT)"  |
-------------------------------------------------
|id1           |10.00           |10.00000000    |
|id2           |20.00           |20.00000000    |
-------------------------------------------------
Example 10

Using windowing functions. Refer to Window for more details.

>>> from pystarburst import Window
>>> from pystarburst.functions import row_number
>>> df_prices.with_column("price_rank",  row_number().over(Window.order_by(col("amount").desc()))).show()
------------------------------------------
|"PRODUCT_ID"  |"AMOUNT"  |"PRICE_RANK"  |
------------------------------------------
|id2           |20.00     |1             |
|id1           |10.00     |2             |
------------------------------------------
Example 11

Handling missing values. Refer to DataFrameNaFunctions for more details.

>>> df = session.create_dataframe([[1, None, 3], [4, 5, None]], schema=["a", "b", "c"])
>>> df.na.fill({"b": 2, "c": 6}).show()
-------------------
|"A"  |"B"  |"C"  |
-------------------
|1    |2    |3    |
|4    |5    |6    |
-------------------

Performing an action on a DataFrame

The following examples demonstrate how you can perform an action on a DataFrame.

Example 12

Performing a query and returning an array of Rows:

>>> df_prices.collect()
[Row(PRODUCT_ID='id1', AMOUNT=Decimal('10.00')), Row(PRODUCT_ID='id2', AMOUNT=Decimal('20.00'))]
Example 13

Performing a query and print the results:

>>> df_prices.show()
---------------------------
|"PRODUCT_ID"  |"AMOUNT"  |
---------------------------
|id1           |10.00     |
|id2           |20.00     |
---------------------------
Example 14

Calculating statistics values. Refer to DataFrameStatFunctions for more details.

>>> df = session.create_dataframe([[1, 2], [3, 4], [5, -1]], schema=["a", "b"])
>>> df.stat.corr("a", "b")
-0.5960395606792697
agg(*exprs: Column | Tuple[ColumnOrName, str] | Dict[str, str]) DataFrame#

Aggregate the data in the DataFrame. Use this method if you don’t need to group the data (group_by()).

Parameters:

exprs

A variable length arguments list where every element is

  • a Column object

  • a tuple where the first element is a column object or a column name and the second element is the name of the aggregate function

  • a list of the above

  • a dict maps column names to aggregate function names.

Examples

>>> from pystarburst.functions import col, stddev, stddev_pop

>>> df = session.create_dataframe([[1, 2], [3, 4], [1, 4]], schema=["A", "B"])
>>> df.agg(stddev(col("a"))).show()
----------------------
|"STDDEV(A)"         |
----------------------
|1.1547003940416753  |
----------------------

>>> df.agg(stddev(col("a")), stddev_pop(col("a"))).show()
-------------------------------------------
|"STDDEV(A)"         |"STDDEV_POP(A)"     |
-------------------------------------------
|1.1547003940416753  |0.9428091005076267  |
-------------------------------------------

>>> df.agg(("a", "min"), ("b", "max")).show()
-----------------------
|"MIN(A)"  |"MAX(B)"  |
-----------------------
|1         |4         |
-----------------------

>>> df.agg({"a": "count", "b": "sum"}).show()
-------------------------
|"COUNT(A)"  |"SUM(B)"  |
-------------------------
|3           |10        |
-------------------------

Note

The name of the aggregate function to compute must be a valid Trino aggregate function.

See also

alias(alias: str) DataFrame#

Returns a new DataFrame with an alias set.

Parameters:

alias (str) – an alias name to be set for the DataFrame.

Examples

>>> from pystarburst.functions import *
>>> df_as1 = df.alias("df_as1")
>>> df_as2 = df.alias("df_as2")
>>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner')
>>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age")                 .sort(desc("df_as1.name")).collect()
[Row(name='Bob', name='Bob', age=5), Row(name='Alice', name='Alice', age=2)]
approxQuantile(col: ColumnOrName | Iterable[ColumnOrName], percentile: Iterable[float], *, statement_properties: Dict[str, str] | None = None) List[float] | List[List[float]]#

For a specified numeric column and a list of desired quantiles, returns an approximate value for the column at each of the desired quantiles. This function uses the t-Digest algorithm.

approxQuantile() is an alias of approx_quantile().

Parameters:
  • col – The name of the numeric column.

  • percentile – A list of float values greater than or equal to 0.0 and less than 1.0.

Returns:

A list of approximate percentile values if col is a single column name, or a matrix with the dimensions (len(col) * len(percentile) containing the approximate percentile values if col is a list of column names.

Examples

>>> df = session.create_dataframe([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], schema=["a"])
>>> df.stat.approx_quantile("a", [0, 0.1, 0.4, 0.6, 1])
[-0.5, 0.5, 3.5, 5.5, 9.5]

>>> df2 = session.create_dataframe([[0.1, 0.5], [0.2, 0.6], [0.3, 0.7]], schema=["a", "b"])
>>> df2.stat.approx_quantile(["a", "b"], [0, 0.1, 0.6])
[[0.05, 0.15000000000000002, 0.25], [0.45, 0.55, 0.6499999999999999]]
approx_quantile(col: ColumnOrName | Iterable[ColumnOrName], percentile: Iterable[float], *, statement_properties: Dict[str, str] | None = None) List[float] | List[List[float]]#

For a specified numeric column and a list of desired quantiles, returns an approximate value for the column at each of the desired quantiles. This function uses the t-Digest algorithm.

approxQuantile() is an alias of approx_quantile().

Parameters:
  • col – The name of the numeric column.

  • percentile – A list of float values greater than or equal to 0.0 and less than 1.0.

Returns:

A list of approximate percentile values if col is a single column name, or a matrix with the dimensions (len(col) * len(percentile) containing the approximate percentile values if col is a list of column names.

Examples

>>> df = session.create_dataframe([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], schema=["a"])
>>> df.stat.approx_quantile("a", [0, 0.1, 0.4, 0.6, 1])
[-0.5, 0.5, 3.5, 5.5, 9.5]

>>> df2 = session.create_dataframe([[0.1, 0.5], [0.2, 0.6], [0.3, 0.7]], schema=["a", "b"])
>>> df2.stat.approx_quantile(["a", "b"], [0, 0.1, 0.6])
[[0.05, 0.15000000000000002, 0.25], [0.45, 0.55, 0.6499999999999999]]
cache_result(*, statement_properties: Dict[str, str] | None = None) Table#

Caches the content of this DataFrame to create a new cached Table DataFrame.

All subsequent operations on the returned cached DataFrame are performed on the cached data and have no effect on the original DataFrame.

You can use Table.drop_table() or the with statement to clean up the cached result when it’s not needed. Refer to the example code below.

Note

An error will be thrown if a cached result is cleaned up and it’s used again, or any other DataFrames derived from the cached result are used again.

Examples

>>> create_result = session.sql("create temp table RESULT (NUM int)").collect()
>>> insert_result = session.sql("insert into RESULT values(1),(2)").collect()
>>> df = session.table("RESULT")
>>> df.collect()
[Row(NUM=1), Row(NUM=2)]
>>> # Run cache_result and then insert into the original table to see
>>> # that the cached result is not affected
>>> df1 = df.cache_result()
>>> insert_again_result = session.sql("insert into RESULT values (3)").collect()
>>> df1.collect()
[Row(NUM=1), Row(NUM=2)]
>>> df.collect()
[Row(NUM=1), Row(NUM=2), Row(NUM=3)]
>>> # You can run cache_result on a result that has already been cached
>>> df2 = df1.cache_result()
>>> df2.collect()
[Row(NUM=1), Row(NUM=2)]
>>> df3 = df.cache_result()
>>> # Drop RESULT and see that the cached results still exist
>>> drop_table_result = session.sql(f"drop table RESULT").collect()
>>> df1.collect()
[Row(NUM=1), Row(NUM=2)]
>>> df2.collect()
[Row(NUM=1), Row(NUM=2)]
>>> df3.collect()
[Row(NUM=1), Row(NUM=2), Row(NUM=3)]
>>> # Clean up the cached result
>>> df3.drop_table()
>>> # use context manager to clean up the cached result after it's use.
>>> with df2.cache_result() as df4:
...     df4.collect()
[Row(NUM=1), Row(NUM=2)]
Returns:

A Table object that holds the cached result in a temporary table. All operations on this new DataFrame have no effect on the original.

col(col_name: str) Column#

Returns a reference to a column in the DataFrame.

colRegex(regex: str) DataFrame#

Selects column based on the column name specified as a regex and returns it.

Parameters:

regex – regex format

col_regex(regex: str) DataFrame#

Selects column based on the column name specified as a regex and returns it.

Parameters:

regex – regex format

collect(*, statement_properties: Dict[str, str] | None = None) List[Row]#

Executes the query representing this DataFrame and returns the result as a list of Row objects.

property columns: List[str]#

Returns all column names as a list.

The returned column names are consistent with the Trino identifier syntax.

Column name used to create a table

Column name returned in str

a

‘a’

A

‘a’

“a”

‘a’

“a b”

‘“a b”’

“a””b”

‘“a””b”’

corr(col1: ColumnOrName, col2: ColumnOrName, *, statement_properties: Dict[str, str] | None = None) float | None#

Calculates the correlation coefficient for non-null pairs in two numeric columns.

Parameters:
  • col1 – The name of the first numeric column to use.

  • col2 – The name of the second numeric column to use.

Returns:

The correlation of the two numeric columns. If there is not enough data to generate the correlation, the method returns None.

Examples

>>> df = session.create_dataframe([[0.1, 0.5], [0.2, 0.6], [0.3, 0.7]], schema=["a", "b"])
>>> df.stat.corr("a", "b")
0.9999999999999991
count(*, statement_properties: Dict[str, str] | None = None) int#

Executes the query representing this DataFrame and returns the number of rows in the result (similar to the COUNT function in SQL).

cov(col1: ColumnOrName, col2: ColumnOrName, *, statement_properties: Dict[str, str] | None = None) float | None#

Calculates the sample covariance for non-null pairs in two numeric columns.

Parameters:
  • col1 – The name of the first numeric column to use.

  • col2 – The name of the second numeric column to use.

Returns:

The sample covariance of the two numeric columns. If there is not enough data to generate the covariance, the method returns None.

Examples

>>> df = session.create_dataframe([[0.1, 0.5], [0.2, 0.6], [0.3, 0.7]], schema=["a", "b"])
>>> df.stat.cov("a", "b")
0.010000000000000037
createOrReplaceView(name: str | Iterable[str]) List[Row]#

Creates a view that captures the computation expressed by this DataFrame.

For name, you can include the database and schema name (i.e. specify a fully-qualified name). If no database name or schema name are specified, the view will be created in the current database or schema.

name must be a valid Trino identifier.

Parameters:

name – The name of the view to create or replace. Can be a list of strings that specifies the database name, schema name, and view name.

create_or_replace_view(name: str | Iterable[str]) List[Row]#

Creates a view that captures the computation expressed by this DataFrame.

For name, you can include the database and schema name (i.e. specify a fully-qualified name). If no database name or schema name are specified, the view will be created in the current database or schema.

name must be a valid Trino identifier.

Parameters:

name – The name of the view to create or replace. Can be a list of strings that specifies the database name, schema name, and view name.

crossJoin(right: DataFrame, *, lsuffix: str = '', rsuffix: str = '') DataFrame#

Performs a cross join, which returns the Cartesian product of the current DataFrame and another DataFrame (right).

If the current and right DataFrames have columns with the same name, and you need to refer to one of these columns in the returned DataFrame, use the col() function on the current or right DataFrame to disambiguate references to these columns.

crossJoin() is an alias of cross_join().

Parameters:
  • right – the right DataFrame to join.

  • lsuffix – Suffix to add to the overlapping columns of the left DataFrame.

  • rsuffix – Suffix to add to the overlapping columns of the right DataFrame.

Note

If both lsuffix and rsuffix are empty, the overlapping columns will have random column names in the result DataFrame. If either one is not empty, the overlapping columns won’t have random names.

Examples

>>> df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[5, 6], [7, 8]], schema=["c", "d"])
>>> df1.cross_join(df2).sort("a", "b", "c", "d").show()
-------------------------
|"A"  |"B"  |"C"  |"D"  |
-------------------------
|1    |2    |5    |6    |
|1    |2    |7    |8    |
|3    |4    |5    |6    |
|3    |4    |7    |8    |
-------------------------

>>> df3 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df4 = session.create_dataframe([[5, 6], [7, 8]], schema=["a", "b"])
>>> df3.cross_join(df4, lsuffix="_l", rsuffix="_r").sort("a_l", "b_l", "a_r", "b_r").show()
---------------------------------
|"A_L"  |"B_L"  |"A_R"  |"B_R"  |
---------------------------------
|1      |2      |5      |6      |
|1      |2      |7      |8      |
|3      |4      |5      |6      |
|3      |4      |7      |8      |
---------------------------------
cross_join(right: DataFrame, *, lsuffix: str = '', rsuffix: str = '') DataFrame#

Performs a cross join, which returns the Cartesian product of the current DataFrame and another DataFrame (right).

If the current and right DataFrames have columns with the same name, and you need to refer to one of these columns in the returned DataFrame, use the col() function on the current or right DataFrame to disambiguate references to these columns.

crossJoin() is an alias of cross_join().

Parameters:
  • right – the right DataFrame to join.

  • lsuffix – Suffix to add to the overlapping columns of the left DataFrame.

  • rsuffix – Suffix to add to the overlapping columns of the right DataFrame.

Note

If both lsuffix and rsuffix are empty, the overlapping columns will have random column names in the result DataFrame. If either one is not empty, the overlapping columns won’t have random names.

Examples

>>> df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[5, 6], [7, 8]], schema=["c", "d"])
>>> df1.cross_join(df2).sort("a", "b", "c", "d").show()
-------------------------
|"A"  |"B"  |"C"  |"D"  |
-------------------------
|1    |2    |5    |6    |
|1    |2    |7    |8    |
|3    |4    |5    |6    |
|3    |4    |7    |8    |
-------------------------

>>> df3 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df4 = session.create_dataframe([[5, 6], [7, 8]], schema=["a", "b"])
>>> df3.cross_join(df4, lsuffix="_l", rsuffix="_r").sort("a_l", "b_l", "a_r", "b_r").show()
---------------------------------
|"A_L"  |"B_L"  |"A_R"  |"B_R"  |
---------------------------------
|1      |2      |5      |6      |
|1      |2      |7      |8      |
|3      |4      |5      |6      |
|3      |4      |7      |8      |
---------------------------------
cube(*cols: ColumnOrName | Iterable[ColumnOrName]) RelationalGroupedDataFrame#

Performs a SQL GROUP BY CUBE. on the DataFrame.

Parameters:

cols – The columns to group by cube.

describe(*cols: str | List[str]) DataFrame#

Computes basic statistics for numeric columns, which includes count, mean, stddev, min, and max. If no columns are provided, this function computes statistics for all numerical or string columns. Non-numeric and non-string columns will be ignored when calling this method.

Parameters:

cols – The names of columns whose basic statistics are computed.

Examples

>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> desc_result = df.describe().sort("SUMMARY").show()
-------------------------------------------------------
|"SUMMARY"  |"A"                 |"B"                 |
-------------------------------------------------------
|count      |2.0                 |2.0                 |
|max        |3.0                 |4.0                 |
|mean       |2.0                 |3.0                 |
|min        |1.0                 |2.0                 |
|stddev     |1.4142135623730951  |1.4142135623730951  |
-------------------------------------------------------
distinct() DataFrame#

Returns a new DataFrame that contains only the rows with distinct values from the current DataFrame.

This is equivalent to performing a SELECT DISTINCT in SQL.

drop(*cols: ColumnOrName | Iterable[ColumnOrName]) DataFrame#

Returns a new DataFrame that excludes the columns with the specified names from the output.

This is functionally equivalent to calling select() and passing in all columns except the ones to exclude. This is a no-op if schema does not contain the given column name(s).

Parameters:

*cols – the columns to exclude, as str, Column or a list of those.

Raises:

PyStarburstClientException – if the resulting DataFrame contains no output columns.

Examples

>>> df = session.create_dataframe([[1, 2, 3]], schema=["a", "b", "c"])
>>> df.drop("a", "b").show()
-------
|"C"  |
-------
|3    |
-------
dropDuplicates(*subset: str | Iterable[str]) DataFrame#

Creates a new DataFrame by removing duplicated rows on given subset of columns.

If no subset of columns is specified, this function is the same as the distinct() function. The result is non-deterministic when removing duplicated rows from the subset of columns but not all columns.

For example, if we have a DataFrame df, which has columns (“a”, “b”, “c”) and contains three rows (1, 1, 1), (1, 1, 2), (1, 2, 3), the result of df.dropDuplicates("a", "b") can be either (1, 1, 1), (1, 2, 3) or (1, 1, 2), (1, 2, 3)

Parameters:

subset – The column names on which duplicates are dropped.

dropDuplicates() is an alias of drop_duplicates().

drop_duplicates(*subset: str | Iterable[str]) DataFrame#

Creates a new DataFrame by removing duplicated rows on given subset of columns.

If no subset of columns is specified, this function is the same as the distinct() function. The result is non-deterministic when removing duplicated rows from the subset of columns but not all columns.

For example, if we have a DataFrame df, which has columns (“a”, “b”, “c”) and contains three rows (1, 1, 1), (1, 1, 2), (1, 2, 3), the result of df.dropDuplicates("a", "b") can be either (1, 1, 1), (1, 2, 3) or (1, 1, 2), (1, 2, 3)

Parameters:

subset – The column names on which duplicates are dropped.

dropDuplicates() is an alias of drop_duplicates().

dropna(how: str = 'any', thresh: int | None = None, subset: Iterable[str] | None = None) DataFrame#

Returns a new DataFrame that excludes all rows containing fewer than a specified number of non-null and non-NaN values in the specified columns.

Parameters:
  • how – An str with value either ‘any’ or ‘all’. If ‘any’, drop a row if it contains any nulls. If ‘all’, drop a row only if all its values are null. The default value is ‘any’. If thresh is provided, how will be ignored.

  • thresh

    The minimum number of non-null and non-NaN values that should be in the specified columns in order for the row to be included. It overwrites how. In each case:

    • If thresh is not provided or None, the length of subset will be used when how is ‘any’ and 1 will be used when how is ‘all’.

    • If thresh is greater than the number of the specified columns, the method returns an empty DataFrame.

    • If thresh is less than 1, the method returns the original DataFrame.

  • subset

    A list of the names of columns to check for null and NaN values. In each case:

    • If subset is not provided or None, all columns will be included.

    • If subset is empty, the method returns the original DataFrame.

Examples

>>> df = session.create_dataframe([[1.0, 1], [float('nan'), 2], [None, 3], [4.0, None], [float('nan'), None]]).to_df("a", "b")
>>> # drop a row if it contains any nulls, with checking all columns
>>> df.na.drop().show()
-------------
|"A"  |"B"  |
-------------
|1.0  |1    |
-------------

>>> # drop a row only if all its values are null, with checking all columns
>>> df.na.drop(how='all').show()
---------------
|"A"   |"B"   |
---------------
|1.0   |1     |
|nan   |2     |
|NULL  |3     |
|4.0   |NULL  |
---------------

>>> # drop a row if it contains at least one non-null and non-NaN values, with checking all columns
>>> df.na.drop(thresh=1).show()
---------------
|"A"   |"B"   |
---------------
|1.0   |1     |
|nan   |2     |
|NULL  |3     |
|4.0   |NULL  |
---------------

>>> # drop a row if it contains any nulls, with checking column "a"
>>> df.na.drop(subset=["a"]).show()
--------------
|"A"  |"B"   |
--------------
|1.0  |1     |
|4.0  |NULL  |
--------------
exceptAll(other: DataFrame) DataFrame#

Returns a new DataFrame that contains all the rows from the current DataFrame except for the rows that also appear in the other DataFrame. Duplicate rows are eliminated.

exceptAll(), minus() and subtract() are aliases of except_().

Parameters:

other – The DataFrame that contains the rows to exclude.

Examples

>>> df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[1, 2], [5, 6]], schema=["c", "d"])
>>> df1.subtract(df2).show()
-------------
|"A"  |"B"  |
-------------
|3    |4    |
-------------
except_(other: DataFrame) DataFrame#

Returns a new DataFrame that contains all the rows from the current DataFrame except for the rows that also appear in the other DataFrame. Duplicate rows are eliminated.

exceptAll(), minus() and subtract() are aliases of except_().

Parameters:

other – The DataFrame that contains the rows to exclude.

Examples

>>> df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[1, 2], [5, 6]], schema=["c", "d"])
>>> df1.subtract(df2).show()
-------------
|"A"  |"B"  |
-------------
|3    |4    |
-------------
explain() None#

Prints the list of queries that will be executed to evaluate this DataFrame. Prints the query execution plan if only one SELECT/DML/DDL statement will be executed.

For more information about the query execution plan, see the EXPLAIN ANALYZE command.

explode(explode_col: ColumnOrName) DataFrame#

Adds new column(s) to DataFrame with expanded ARRAY or MAP, creating a new row for each element in the given array or map. Uses the default column name col for elements in the array and key and value for elements in the map.

Parameters:

explode_col – target column to work on.

Examples

>>> df = session.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"}),Row(a=2, intlist=[4,5,6], mapfield={"a": "b", "c": "d"})])
------------------------------------------
|"a"  |"intlist"  |"mapfield"            |
------------------------------------------
|1    |[1, 2, 3]  |{'a': 'b'}            |
|2    |[4, 5, 6]  |{'a': 'b', 'c': 'd'}  |
------------------------------------------

>>> df.explode(df.intlist)
--------------------------------------------------
|"a"  |"intlist"  |"mapfield"            |"col"  |
--------------------------------------------------
|1    |[1, 2, 3]  |{'a': 'b'}            |1      |
|1    |[1, 2, 3]  |{'a': 'b'}            |2      |
|1    |[1, 2, 3]  |{'a': 'b'}            |3      |
|2    |[4, 5, 6]  |{'a': 'b', 'c': 'd'}  |4      |
|2    |[4, 5, 6]  |{'a': 'b', 'c': 'd'}  |5      |
|2    |[4, 5, 6]  |{'a': 'b', 'c': 'd'}  |6      |
--------------------------------------------------

>>> df.explode(df.mapfield)
------------------------------------------------------------
|"a"  |"intlist"  |"mapfield"            |"key"  |"value"  |
------------------------------------------------------------
|1    |[1, 2, 3]  |{'a': 'b'}            |a      |b        |
|2    |[4, 5, 6]  |{'a': 'b', 'c': 'd'}  |a      |b        |
|2    |[4, 5, 6]  |{'a': 'b', 'c': 'd'}  |c      |d        |
------------------------------------------------------------
explode_outer(explode_col: ColumnOrName) DataFrame#

Adds new column(s) to DataFrame with expanded ARRAY or MAP, creating a new row for each element in the given array or map. Unlike explode, if the array/map is null or empty then null is produced. Uses the default column name col for elements in the array and key and value for elements in the map.

Parameters:

explode_col – target column to work on.

Examples

>>> df = session.createDataFrame(
>>>     [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],
>>>     ["id", "an_array", "a_map"])
--------------------------------------
|"id"  |"an_array"      |"a_map"     |
--------------------------------------
|1     |['foo', 'bar']  |{'x': 1.0}  |
|2     |[]              |{}          |
|3     |NULL            |NULL        |
--------------------------------------

>>> df.explode_outer(df.an_array).show()
----------------------------------------------
|"id"  |"an_array"      |"a_map"     |"col"  |
----------------------------------------------
|1     |['foo', 'bar']  |{'x': 1.0}  |foo    |
|1     |['foo', 'bar']  |{'x': 1.0}  |bar    |
|2     |[]              |{}          |NULL   |
|3     |NULL            |NULL        |NULL   |
----------------------------------------------

>>> df.explode_outer(df.a_map).show()
--------------------------------------------------------
|"id"  |"an_array"      |"a_map"     |"key"  |"value"  |
--------------------------------------------------------
|1     |['foo', 'bar']  |{'x': 1.0}  |x      |1.0      |
|2     |[]              |{}          |NULL   |NULL     |
|3     |NULL            |NULL        |NULL   |NULL     |
--------------------------------------------------------
fillna(value: LiteralType | Dict[str, LiteralType], subset: Iterable[str] | None = None) DataFrame#

Returns a new DataFrame that replaces all null and NaN values in the specified columns with the values provided.

Parameters:
  • value – A scalar value or a dict that associates the names of columns with the values that should be used to replace null and NaN values in those columns. If value is a dict, subset is ignored. If value is an empty dict, the method returns the original DataFrame.

  • subset

    A list of the names of columns to check for null and NaN values. In each case:

    • If subset is not provided or None, all columns will be included.

    • If subset is empty, the method returns the original DataFrame.

Examples

>>> df = session.create_dataframe([[1.0, 1], [float('nan'), 2], [None, 3], [4.0, None], [float('nan'), None]]).to_df("a", "b")
>>> # fill null and NaN values in all columns
>>> df.na.fill(3.14).show()
---------------
|"A"   |"B"   |
---------------
|1.0   |1     |
|3.14  |2     |
|3.14  |3     |
|4.0   |NULL  |
|3.14  |NULL  |
---------------

>>> # fill null and NaN values in column "a"
>>> df.na.fill({"a": 3.14}).show()
---------------
|"A"   |"B"   |
---------------
|1.0   |1     |
|3.14  |2     |
|3.14  |3     |
|4.0   |NULL  |
|3.14  |NULL  |
---------------

>>> # fill null and NaN values in column "a" and "b"
>>> df.na.fill({"a": 3.14, "b": 15}).show()
--------------
|"A"   |"B"  |
--------------
|1.0   |1    |
|3.14  |2    |
|3.14  |3    |
|4.0   |15   |
|3.14  |15   |
--------------

Note

If the type of a given value in value doesn’t match the column data type (e.g. a float for StringType column), this replacement will be skipped in this column. Especially,

  • int can be filled in a column with FloatType or DoubleType, but float cannot filled in a column with IntegerType or LongType.

filter(expr: ColumnOrSqlExpr) DataFrame#

Filters rows based on the specified conditional expression (similar to WHERE in SQL).

Parameters:

expr – a Column expression or SQL text.

where() is an alias of filter().

Examples

>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["A", "B"])
>>> df_filtered = df.filter((col("A") > 1) & (col("B") < 100))  # Must use parenthesis before and after operator &.
>>> # The following two result in the same SQL query:
>>> df.filter(col("a") > 1).collect()
[Row(A=3, B=4)]
>>> df.filter("a > 1").collect()  # use SQL expression
[Row(A=3, B=4)]
first(n: int | None = None, *, statement_properties: Dict[str, str] | None = None) Row | None | List[Row]#

Executes the query representing this DataFrame and returns the first n rows of the results.

Parameters:

n – The number of rows to return.

Returns:

A list of the first n Row objects if n is not None. If n is negative or larger than the number of rows in the result, returns all rows in the results. n is None, it returns the first Row of results, or None if it does not exist.

groupBy(*cols: ColumnOrName | Iterable[ColumnOrName]) RelationalGroupedDataFrame#

Groups rows by the columns specified by expressions (similar to GROUP BY in SQL).

This method returns a RelationalGroupedDataFrame that you can use to perform aggregations on each group of data.

groupBy() is an alias of group_by().

Parameters:

*cols – The columns to group by.

Valid inputs are:

  • Empty input

  • One or multiple Column object(s) or column name(s) (str)

  • A list of Column objects or column names (str)

Examples

>>> from pystarburst.functions import col, lit, sum as sum_, max as max_
>>> df = session.create_dataframe([(1, 1),(1, 2),(2, 1),(2, 2),(3, 1),(3, 2)], schema=["a", "b"])
>>> df.group_by().agg(sum_("b")).collect()
[Row(SUM(B)=9)]

>>> df.group_by("a").agg(sum_("b")).collect()
[Row(A=1, SUM(B)=3), Row(A=2, SUM(B)=3), Row(A=3, SUM(B)=3)]

>>> df.group_by(["a", lit("pystarburst")]).agg(sum_("b")).collect()
[Row(A=1, LITERAL()='pystarburst', SUM(B)=3), Row(A=2, LITERAL()='snow', SUM(B)=3), Row(A=3, LITERAL()='snow', SUM(B)=3)]

>>> df.group_by("a").agg((col("*"), "count"), max_("b")).collect()
[Row(A=1, COUNT(LITERAL())=2, MAX(B)=2), Row(A=2, COUNT(LITERAL())=2, MAX(B)=2), Row(A=3, COUNT(LITERAL())=2, MAX(B)=2)]

>>> df.group_by("a").function("avg")("b").collect()
[Row(A=1, AVG(B)=Decimal('1.500000')), Row(A=2, AVG(B)=Decimal('1.500000')), Row(A=3, AVG(B)=Decimal('1.500000'))]
groupByGroupingSets(*grouping_sets: GroupingSets | Iterable[GroupingSets]) RelationalGroupedDataFrame#

Performs a SQL GROUP BY GROUPING SETS. on the DataFrame.

GROUP BY GROUPING SETS is an extension of the GROUP BY clause that allows computing multiple GROUP BY clauses in a single statement. The group set is a set of dimension columns.

GROUP BY GROUPING SETS is equivalent to the UNION of two or more GROUP BY operations in the same result set.

groupByGroupingSets() is an alias of group_by_grouping_sets().

Parameters:

grouping_sets – The list of GroupingSets to group by.

Examples

>>> from pystarburst import GroupingSets
>>> df = session.create_dataframe([[1, 2, 10], [3, 4, 20], [1, 4, 30]], schema=["A", "B", "C"])
>>> df.group_by_grouping_sets(GroupingSets([col("a")])).count().collect()
[Row(A=1, COUNT=2), Row(A=3, COUNT=1)]

>>> df.group_by_grouping_sets(GroupingSets(col("a"))).count().collect()
[Row(A=1, COUNT=2), Row(A=3, COUNT=1)]

>>> df.group_by_grouping_sets(GroupingSets([col("a")], [col("b")])).count().collect()
[Row(A=1, B=None, COUNT=2), Row(A=3, B=None, COUNT=1), Row(A=None, B=2, COUNT=1), Row(A=None, B=4, COUNT=2)]

>>> df.group_by_grouping_sets(GroupingSets([col("a"), col("b")], [col("c")])).count().collect()
[Row(A=None, B=None, C=10, COUNT=1), Row(A=None, B=None, C=20, COUNT=1), Row(A=None, B=None, C=30, COUNT=1), Row(A=1, B=2, C=None, COUNT=1), Row(A=3, B=4, C=None, COUNT=1), Row(A=1, B=4, C=None, COUNT=1)]
group_by(*cols: ColumnOrName | Iterable[ColumnOrName]) RelationalGroupedDataFrame#

Groups rows by the columns specified by expressions (similar to GROUP BY in SQL).

This method returns a RelationalGroupedDataFrame that you can use to perform aggregations on each group of data.

groupBy() is an alias of group_by().

Parameters:

*cols – The columns to group by.

Valid inputs are:

  • Empty input

  • One or multiple Column object(s) or column name(s) (str)

  • A list of Column objects or column names (str)

Examples

>>> from pystarburst.functions import col, lit, sum as sum_, max as max_
>>> df = session.create_dataframe([(1, 1),(1, 2),(2, 1),(2, 2),(3, 1),(3, 2)], schema=["a", "b"])
>>> df.group_by().agg(sum_("b")).collect()
[Row(SUM(B)=9)]

>>> df.group_by("a").agg(sum_("b")).collect()
[Row(A=1, SUM(B)=3), Row(A=2, SUM(B)=3), Row(A=3, SUM(B)=3)]

>>> df.group_by(["a", lit("pystarburst")]).agg(sum_("b")).collect()
[Row(A=1, LITERAL()='pystarburst', SUM(B)=3), Row(A=2, LITERAL()='snow', SUM(B)=3), Row(A=3, LITERAL()='snow', SUM(B)=3)]

>>> df.group_by("a").agg((col("*"), "count"), max_("b")).collect()
[Row(A=1, COUNT(LITERAL())=2, MAX(B)=2), Row(A=2, COUNT(LITERAL())=2, MAX(B)=2), Row(A=3, COUNT(LITERAL())=2, MAX(B)=2)]

>>> df.group_by("a").function("avg")("b").collect()
[Row(A=1, AVG(B)=Decimal('1.500000')), Row(A=2, AVG(B)=Decimal('1.500000')), Row(A=3, AVG(B)=Decimal('1.500000'))]
group_by_grouping_sets(*grouping_sets: GroupingSets | Iterable[GroupingSets]) RelationalGroupedDataFrame#

Performs a SQL GROUP BY GROUPING SETS. on the DataFrame.

GROUP BY GROUPING SETS is an extension of the GROUP BY clause that allows computing multiple GROUP BY clauses in a single statement. The group set is a set of dimension columns.

GROUP BY GROUPING SETS is equivalent to the UNION of two or more GROUP BY operations in the same result set.

groupByGroupingSets() is an alias of group_by_grouping_sets().

Parameters:

grouping_sets – The list of GroupingSets to group by.

Examples

>>> from pystarburst import GroupingSets
>>> df = session.create_dataframe([[1, 2, 10], [3, 4, 20], [1, 4, 30]], schema=["A", "B", "C"])
>>> df.group_by_grouping_sets(GroupingSets([col("a")])).count().collect()
[Row(A=1, COUNT=2), Row(A=3, COUNT=1)]

>>> df.group_by_grouping_sets(GroupingSets(col("a"))).count().collect()
[Row(A=1, COUNT=2), Row(A=3, COUNT=1)]

>>> df.group_by_grouping_sets(GroupingSets([col("a")], [col("b")])).count().collect()
[Row(A=1, B=None, COUNT=2), Row(A=3, B=None, COUNT=1), Row(A=None, B=2, COUNT=1), Row(A=None, B=4, COUNT=2)]

>>> df.group_by_grouping_sets(GroupingSets([col("a"), col("b")], [col("c")])).count().collect()
[Row(A=None, B=None, C=10, COUNT=1), Row(A=None, B=None, C=20, COUNT=1), Row(A=None, B=None, C=30, COUNT=1), Row(A=1, B=2, C=None, COUNT=1), Row(A=3, B=4, C=None, COUNT=1), Row(A=1, B=4, C=None, COUNT=1)]
head(n: int | None = None) Row | None | List[Row]#

Returns the first n rows.

Parameters:

n (int, optional) – default None. Number of rows to return.

Returns:

  • If n is number, return a list of n Row.

  • If n is None, return a single Row.

Examples

>>> df.head()
Row(age=2, name='Alice')
>>> df.head(1)
[Row(age=2, name='Alice')]
inline(explode_col: ColumnOrName) DataFrame#

Explodes an array of structs into a table.

Parameters:

explode_col – input column of values to explode.

Examples: # TODO: add example after adding support for creating structs with struct function

inline_outer(explode_col: ColumnOrName) DataFrame#

Explodes an array of structs into a table. Unlike inline, if the array is null or empty then null is produced for each nested column.

Parameters:

explode_col – input column of values to explode.

Examples: # TODO: add example after adding support for creating structs with struct function

intersect(other: DataFrame) DataFrame#

Returns a new DataFrame that contains the intersection of rows from the current DataFrame and another DataFrame (other). Duplicate rows are eliminated.

Parameters:

other – the other DataFrame that contains the rows to use for the intersection.

Examples

>>> df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[1, 2], [5, 6]], schema=["c", "d"])
>>> df1.intersect(df2).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
-------------
intersectAll(other: DataFrame) DataFrame#

Returns a new DataFrame that contains the intersection of rows from the current DataFrame and another DataFrame (other). Duplicate rows are persisted.

intersectAll() is an alias of intersect_all().

Parameters:

other – the other DataFrame that contains the rows to use for the intersection.

Examples

>>> df1 = session.create_dataframe([("id1", 1), ("id1", 1), ("id", 1), ("id1", 3)]).to_df("id", "value")
>>> df2 = session.create_dataframe([("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)]).to_df("id", "value")
>>> df1.intersect_all(df2).show()
------------------
|"id"  |"value"  |
------------------
|id1    |1       |
|id1    |1       |
|id     |1       |
------------------
intersect_all(other: DataFrame) DataFrame#

Returns a new DataFrame that contains the intersection of rows from the current DataFrame and another DataFrame (other). Duplicate rows are persisted.

intersectAll() is an alias of intersect_all().

Parameters:

other – the other DataFrame that contains the rows to use for the intersection.

Examples

>>> df1 = session.create_dataframe([("id1", 1), ("id1", 1), ("id", 1), ("id1", 3)]).to_df("id", "value")
>>> df2 = session.create_dataframe([("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)]).to_df("id", "value")
>>> df1.intersect_all(df2).show()
------------------
|"id"  |"value"  |
------------------
|id1    |1       |
|id1    |1       |
|id     |1       |
------------------
isEmpty() bool#

Checks if the DataFrame is empty and returns a boolean value.

isEmpty() is an alias of is_empty().

Examples

>>> from pystarburst.types import *
>>> df_empty = session.createDataFrame([], schema=StructType([StructField('a', StringType(), True)]))
>>> df_empty.isEmpty()
True

>>> df_non_empty = session.createDataFrame(["a"], schema=["a"])
>>> df_non_empty.isEmpty()
False

>>> df_nulls = session.createDataFrame([(None, None)], schema=StructType([StructField("a", StringType(), True), StructField("b", IntegerType(), True)]))
>>> df_nulls.isEmpty()
False

>>> df_no_rows = session.createDataFrame([], schema=StructType([StructField('id', IntegerType(), True), StructField('value', StringType(), True)]))
>>> df_no_rows.isEmpty()
True
is_cached: bool#

Whether the dataframe is cached.

is_empty() bool#

Checks if the DataFrame is empty and returns a boolean value.

isEmpty() is an alias of is_empty().

Examples

>>> from pystarburst.types import *
>>> df_empty = session.createDataFrame([], schema=StructType([StructField('a', StringType(), True)]))
>>> df_empty.isEmpty()
True

>>> df_non_empty = session.createDataFrame(["a"], schema=["a"])
>>> df_non_empty.isEmpty()
False

>>> df_nulls = session.createDataFrame([(None, None)], schema=StructType([StructField("a", StringType(), True), StructField("b", IntegerType(), True)]))
>>> df_nulls.isEmpty()
False

>>> df_no_rows = session.createDataFrame([], schema=StructType([StructField('id', IntegerType(), True), StructField('value', StringType(), True)]))
>>> df_no_rows.isEmpty()
True
join(right: DataFrame, on: ColumnOrName | Iterable[ColumnOrName] | None = None, how: str | None = None, *, lsuffix: str = '', rsuffix: str = '', **kwargs) DataFrame#

Performs a join of the specified type (how) with the current DataFrame and another DataFrame (right) on a list of columns (on).

Parameters:
  • right – The other DataFrame to join.

  • on – A column name or a Column object or a list of them to be used for the join. When a list of column names are specified, this method assumes the named columns are present in both dataframes. You can use keyword using_columns to specify this condition. Note that to avoid breaking changes, when using_columns` is specified, it overrides on.

  • how

    We support the following join types:

    • Inner join: “inner” (the default value)

    • Left outer join: “left”, “leftouter”

    • Right outer join: “right”, “rightouter”

    • Full outer join: “full”, “outer”, “fullouter”

    • Left semi join: “semi”, “leftsemi”

    • Left anti join: “anti”, “leftanti”

    • Cross join: “cross”

    You can also use join_type keyword to specify this condition. Note that to avoid breaking changes, currently when join_type is specified, it overrides how.

  • lsuffix – Suffix to add to the overlapping columns of the left DataFrame.

  • rsuffix – Suffix to add to the overlapping columns of the right DataFrame.

Note

When both lsuffix and rsuffix are empty, the overlapping columns will have random column names in the resulting DataFrame. You can reference to these randomly named columns using Column.alias() (See the first usage in Examples).

Examples

>>> from pystarburst.functions import col
>>> df1 = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[1, 7], [3, 8]], schema=["a", "c"])
>>> df1.join(df2, df1.a == df2.a).select(df1.a.alias("a_1"), df2.a.alias("a_2"), df1.b, df2.c).show()
-----------------------------
|"A_1"  |"A_2"  |"B"  |"C"  |
-----------------------------
|1      |1      |2    |7    |
|3      |3      |4    |8    |
-----------------------------

>>> # refer a single column "a"
>>> df1.join(df2, "a").select(df1.a.alias("a"), df1.b, df2.c).show()
-------------------
|"A"  |"B"  |"C"  |
-------------------
|1    |2    |7    |
|3    |4    |8    |
-------------------

>>> # rename the ambiguous columns
>>> df3 = df1.to_df("df1_a", "b")
>>> df4 = df2.to_df("df2_a", "c")
>>> df3.join(df4, col("df1_a") == col("df2_a")).select(col("df1_a").alias("a"), "b", "c").show()
-------------------
|"A"  |"B"  |"C"  |
-------------------
|1    |2    |7    |
|3    |4    |8    |
-------------------
>>> # join multiple columns
>>> mdf1 = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"])
>>> mdf2 = session.create_dataframe([[1, 2], [3, 4], [7, 6]], schema=["a", "b"])
>>> mdf1.join(mdf2, ["a", "b"]).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|3    |4    |
-------------

>>> mdf1.join(mdf2, (mdf1["a"] < mdf2["a"]) & (mdf1["b"] == mdf2["b"])).select(mdf1["a"].as_("new_a"), mdf1["b"].as_("new_b")).show()
---------------------
|"NEW_A"  |"NEW_B"  |
---------------------
|5        |6        |
---------------------

>>> # use lsuffix and rsuffix to resolve duplicating column names
>>> mdf1.join(mdf2, (mdf1["a"] < mdf2["a"]) & (mdf1["b"] == mdf2["b"]), lsuffix="_left", rsuffix="_right").show()
-----------------------------------------------
|"A_LEFT"  |"B_LEFT"  |"A_RIGHT"  |"B_RIGHT"  |
-----------------------------------------------
|5         |6         |7          |6          |
-----------------------------------------------

>>> mdf1.join(mdf2, (mdf1["a"] < mdf2["a"]) & (mdf1["b"] == mdf2["b"]), rsuffix="_right").show()
-------------------------------------
|"A"  |"B"  |"A_RIGHT"  |"B_RIGHT"  |
-------------------------------------
|5    |6    |7          |6          |
-------------------------------------

Note

When performing chained operations, this method will not work if there are ambiguous column names. For example,

>>> df1.filter(df1.a == 1).join(df2, df1.a == df2.a).select(df1.a.alias("a"), df1.b, df2.c) 

will not work because df1.filter(df1.a == 1) has produced a new dataframe and you cannot refer to df1.a anymore. Instead, you can do either

>>> df1.join(df2, (df1.a == 1) & (df1.a == df2.a)).select(df1.a.alias("a"), df1.b, df2.c).show()
-------------------
|"A"  |"B"  |"C"  |
-------------------
|1    |2    |7    |
-------------------

or

>>> df3 = df1.filter(df1.a == 1)
>>> df3.join(df2, df3.a == df2.a).select(df3.a.alias("a"), df3.b, df2.c).show()
-------------------
|"A"  |"B"  |"C"  |
-------------------
|1    |2    |7    |
-------------------
limit(n: int, offset: int = 0) DataFrame#

Returns a new DataFrame that contains at most n rows from the current DataFrame, skipping offset rows from the beginning (similar to LIMIT and OFFSET in SQL).

Note that this is a transformation method and not an action method.

Parameters:
  • n – Number of rows to return.

  • offset – Number of rows to skip before the start of the result set. The default value is 0.

Examples

>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df.limit(1).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
-------------

>>> df.limit(1, offset=1).show()
-------------
|"A"  |"B"  |
-------------
|3    |4    |
-------------
melt(ids_column_list: ColumnOrName | Iterable[ColumnOrName], unpivot_column_list: List[ColumnOrName], name_column: str, value_column: str) DataFrame#

Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. Note that UNPIVOT is not exactly the reverse of PIVOT as it cannot undo aggregations made by PIVOT.

melt() is an alias of unpivot().

Parameters:
  • ids_column_list – The names of the columns in the source table or subequery that will be used as identifiers.

  • unpivot_column_list – The names of the columns in the source table or subequery that will be narrowed into a single pivot column. The column names will populate name_column, and the column values will populate value_column.

  • name_column – The name to assign to the generated column that will be populated with the names of the columns in the column list.

  • value_column – The name to assign to the generated column that will be populated with the values from the columns in the column list.

Examples

>>> df = session.create_dataframe([
...     (1, 'electronics', 100, 200),
...     (2, 'clothes', 100, 300)
... ], schema=["empid", "dept", "jan", "feb"])
>>> df = df.unpivot(["empid", "dept"], ["jan", "feb"], "month", "sales").sort("empid")
>>> df.show()
---------------------------------------------
|"empid"  |"dept"       |"month"  |"sales"  |
---------------------------------------------
|1        |electronics  |JAN      |100      |
|1        |electronics  |FEB      |200      |
|2        |clothes      |JAN      |100      |
|2        |clothes      |FEB      |300      |
---------------------------------------------
minus(other: DataFrame) DataFrame#

Returns a new DataFrame that contains all the rows from the current DataFrame except for the rows that also appear in the other DataFrame. Duplicate rows are eliminated.

exceptAll(), minus() and subtract() are aliases of except_().

Parameters:

other – The DataFrame that contains the rows to exclude.

Examples

>>> df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[1, 2], [5, 6]], schema=["c", "d"])
>>> df1.subtract(df2).show()
-------------
|"A"  |"B"  |
-------------
|3    |4    |
-------------
property na: DataFrameNaFunctions#

Returns a DataFrameNaFunctions object that provides functions for handling missing values in the DataFrame.

orderBy(*cols: ColumnOrName | Iterable[ColumnOrName], ascending: bool | int | List[bool | int] | None = None) DataFrame#

Sorts a DataFrame by the specified expressions (similar to ORDER BY in SQL).

orderBy() and order_by() are aliases of sort().

Parameters:
  • *cols – A column name as str or Column, or a list of columns to sort by.

  • ascending – A bool or a list of bool for sorting the DataFrame, where True sorts a column in ascending order and False sorts a column in descending order . If you specify a list of multiple sort orders, the length of the list must equal the number of columns.

Examples

>>> from pystarburst.functions import col
>>> df = session.create_dataframe([[1, 2], [3, 4], [1, 4]], schema=["A", "B"])
>>> df.sort(col("A"), col("B").asc()).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|1    |4    |
|3    |4    |
-------------

>>> df.sort(col("a"), ascending=False).show()
-------------
|"A"  |"B"  |
-------------
|3    |4    |
|1    |2    |
|1    |4    |
-------------

>>> # The values from the list overwrite the column ordering.
>>> df.sort(["a", col("b").desc()], ascending=[1, 1]).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|1    |4    |
|3    |4    |
-------------
order_by(*cols: ColumnOrName | Iterable[ColumnOrName], ascending: bool | int | List[bool | int] | None = None) DataFrame#

Sorts a DataFrame by the specified expressions (similar to ORDER BY in SQL).

orderBy() and order_by() are aliases of sort().

Parameters:
  • *cols – A column name as str or Column, or a list of columns to sort by.

  • ascending – A bool or a list of bool for sorting the DataFrame, where True sorts a column in ascending order and False sorts a column in descending order . If you specify a list of multiple sort orders, the length of the list must equal the number of columns.

Examples

>>> from pystarburst.functions import col
>>> df = session.create_dataframe([[1, 2], [3, 4], [1, 4]], schema=["A", "B"])
>>> df.sort(col("A"), col("B").asc()).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|1    |4    |
|3    |4    |
-------------

>>> df.sort(col("a"), ascending=False).show()
-------------
|"A"  |"B"  |
-------------
|3    |4    |
|1    |2    |
|1    |4    |
-------------

>>> # The values from the list overwrite the column ordering.
>>> df.sort(["a", col("b").desc()], ascending=[1, 1]).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|1    |4    |
|3    |4    |
-------------
pivot(pivot_col: ColumnOrName, values: Iterable[LiteralType]) RelationalGroupedDataFrame#

Rotates this DataFrame by turning the unique values from one column in the input expression into multiple columns and aggregating results where required on any remaining column values.

Only one aggregate is supported with pivot.

Parameters:
  • pivot_col – The column or name of the column to use.

  • values – A list of values in the column.

Examples

>>> create_result = session.sql('''create table monthly_sales(empid int, amount int, month varchar)
... as select * from values
... (1, 10000, 'JAN'),
... (1, 400, 'JAN'),
... (2, 4500, 'JAN'),
... (2, 35000, 'JAN'),
... (1, 5000, 'FEB'),
... (1, 3000, 'FEB'),
... (2, 200, 'FEB') ''').collect()
>>> df = session.table("monthly_sales")
>>> df.pivot("month", ['JAN', 'FEB']).sum("amount").show()
-------------------------------
|"EMPID"  |"'JAN'"  |"'FEB'"  |
-------------------------------
|1        |10400    |8000     |
|2        |39500    |200      |
-------------------------------
posexplode(explode_col: ColumnOrName) DataFrame#

Adds new columns to DataFrame with expanded ARRAY or MAP, creating a new row for each element with position in the given array or map. The position starts at 1. Uses the default column name pos for position, col for elements in the array and key and value for elements in the map.

Parameters:

explode_col – target column to work on.

Examples

>>> df = session.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"}),Row(a=2, intlist=[4,5,6], mapfield={"a": "b", "c": "d"})])
------------------------------------------
|"a"  |"intlist"  |"mapfield"            |
------------------------------------------
|1    |[1, 2, 3]  |{'a': 'b'}            |
|2    |[4, 5, 6]  |{'a': 'b', 'c': 'd'}  |
------------------------------------------

>>> df.posexplode(df.intlist)
----------------------------------------------------------
|"a"  |"intlist"  |"mapfield"            |"pos"  |"col"  |
----------------------------------------------------------
|1    |[1, 2, 3]  |{'a': 'b'}            |1      |1      |
|1    |[1, 2, 3]  |{'a': 'b'}            |2      |2      |
|1    |[1, 2, 3]  |{'a': 'b'}            |3      |3      |
|2    |[4, 5, 6]  |{'a': 'b', 'c': 'd'}  |1      |4      |
|2    |[4, 5, 6]  |{'a': 'b', 'c': 'd'}  |2      |5      |
|2    |[4, 5, 6]  |{'a': 'b', 'c': 'd'}  |3      |6      |
----------------------------------------------------------

>>> df.posexplode(df.mapfield)
--------------------------------------------------------------------
|"a"  |"intlist"  |"mapfield"            |"pos"  |"key"  |"value"  |
--------------------------------------------------------------------
|1    |[1, 2, 3]  |{'a': 'b'}            |1      |a      |b        |
|2    |[4, 5, 6]  |{'a': 'b', 'c': 'd'}  |1      |a      |b        |
|2    |[4, 5, 6]  |{'a': 'b', 'c': 'd'}  |2      |c      |d        |
--------------------------------------------------------------------
posexplode_outer(explode_col: ColumnOrName) DataFrame#

Adds new columns to DataFrame with expanded ARRAY or MAP, creating a new row for each element with position in the given array or map. The position starts at 1. Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced. Uses the default column name pos for position, col for elements in the array and key and value for elements in the map.

Parameters:

explode_col – target column to work on.

Examples

>>> df = session.createDataFrame(
>>>     [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],
>>>     ["id", "an_array", "a_map"])
--------------------------------------
|"id"  |"an_array"      |"a_map"     |
--------------------------------------
|1     |['foo', 'bar']  |{'x': 1.0}  |
|2     |[]              |{}          |
|3     |NULL            |NULL        |
--------------------------------------

>>> df.posexplode_outer(df.an_array).show()
------------------------------------------------------
|"id"  |"an_array"      |"a_map"     |"pos"  |"col"  |
------------------------------------------------------
|1     |['foo', 'bar']  |{'x': 1.0}  |1      |foo    |
|1     |['foo', 'bar']  |{'x': 1.0}  |2      |bar    |
|2     |[]              |{}          |NULL   |NULL   |
|3     |NULL            |NULL        |NULL   |NULL   |
------------------------------------------------------

>>> df.posexplode_outer(df.a_map).show()
----------------------------------------------------------------
|"id"  |"an_array"      |"a_map"     |"pos"  |"key"  |"value"  |
----------------------------------------------------------------
|1     |['foo', 'bar']  |{'x': 1.0}  |1      |x      |1.0      |
|2     |[]              |{}          |NULL   |NULL   |NULL     |
|3     |NULL            |NULL        |NULL   |NULL   |NULL     |
----------------------------------------------------------------
property queries: Dict[str, List[str]]#

Returns a dict that contains a list of queries that will be executed to evaluate this DataFrame with the key queries, and a list of post-execution actions (e.g., queries to clean up temporary objects) with the key post_actions.

randomSplit(weights: List[float], seed: int | None = None, *, statement_properties: Dict[str, str] | None = None) List[DataFrame]#

Randomly splits the current DataFrame into separate DataFrames, using the specified weights.

randomSplit() is an alias of random_split().

Parameters:
  • weights – Weights to use for splitting the DataFrame. If the weights don’t add up to 1, the weights will be normalized. Every number in weights has to be positive. If only one weight is specified, the returned DataFrame list only includes the current DataFrame.

  • seed – The seed for sampling.

Examples

>>> df = session.range(10000)
>>> weights = [0.1, 0.2, 0.3]
>>> df_parts = df.random_split(weights)
>>> len(df_parts) == len(weights)
True

Note

1. When multiple weights are specified, the current DataFrame will be cached before being split.

2. When a weight or a normailized weight is less than 1e-6, the corresponding split dataframe will be empty.

random_split(weights: List[float], seed: int | None = None, *, statement_properties: Dict[str, str] | None = None) List[DataFrame]#

Randomly splits the current DataFrame into separate DataFrames, using the specified weights.

randomSplit() is an alias of random_split().

Parameters:
  • weights – Weights to use for splitting the DataFrame. If the weights don’t add up to 1, the weights will be normalized. Every number in weights has to be positive. If only one weight is specified, the returned DataFrame list only includes the current DataFrame.

  • seed – The seed for sampling.

Examples

>>> df = session.range(10000)
>>> weights = [0.1, 0.2, 0.3]
>>> df_parts = df.random_split(weights)
>>> len(df_parts) == len(weights)
True

Note

1. When multiple weights are specified, the current DataFrame will be cached before being split.

2. When a weight or a normailized weight is less than 1e-6, the corresponding split dataframe will be empty.

rename(existing: ColumnOrName, new: str) DataFrame#

Returns a DataFrame with the specified column existing renamed as new.

with_column_renamed() is an alias of rename().

Parameters:
  • existing – The old column instance or column name to be renamed.

  • new – The new column name.

Examples

>>> # This example renames the column `A` as `NEW_A` in the DataFrame.
>>> df = session.sql("select 1 as A, 2 as B")
>>> df_renamed = df.with_column_renamed(col("A"), "NEW_A")
>>> df_renamed.show()
-----------------
|"NEW_A"  |"B"  |
-----------------
|1        |2    |
-----------------
replace(to_replace: LiteralType | Iterable[LiteralType] | Dict[LiteralType, LiteralType], value: Iterable[LiteralType] | None = None, subset: Iterable[str] | None = None) DataFrame#

Returns a new DataFrame that replaces values in the specified columns.

Parameters:
  • to_replace – A scalar value, or a list of values or a dict that associates the original values with the replacement values. If to_replace is a dict, value and subset are ignored. To replace a null value, use None in to_replace. To replace a NaN value, use float("nan") in to_replace. If to_replace is empty, the method returns the original DataFrame.

  • value – A scalar value, or a list of values for the replacement. If value is a list, value should be of the same length as to_replace. If value is a scalar and to_replace is a list, then value is used as a replacement for each item in to_replace.

  • subset – A list of the names of columns in which the values should be replaced. If cols is not provided or None, the replacement will be applied to all columns. If cols is empty, the method returns the original DataFrame.

Examples

>>> df = session.create_dataframe([[1, 1.0, "1.0"], [2, 2.0, "2.0"]], schema=["a", "b", "c"])
>>> # replace 1 with 3 in all columns
>>> df.na.replace(1, 3).show()
-------------------
|"A"  |"B"  |"C"  |
-------------------
|3    |3.0  |1.0  |
|2    |2.0  |2.0  |
-------------------

>>> # replace 1 with 3 and 2 with 4 in all columns
>>> df.na.replace([1, 2], [3, 4]).show()
-------------------
|"A"  |"B"  |"C"  |
-------------------
|3    |3.0  |1.0  |
|4    |4.0  |2.0  |
-------------------

>>> # replace 1 with 3 and 2 with 3 in all columns
>>> df.na.replace([1, 2], 3).show()
-------------------
|"A"  |"B"  |"C"  |
-------------------
|3    |3.0  |1.0  |
|3    |3.0  |2.0  |
-------------------

>>> # the following line intends to replaces 1 with 3 and 2 with 4 in all columns
>>> # and will give [Row(3, 3.0, "1.0"), Row(4, 4.0, "2.0")]
>>> df.na.replace({1: 3, 2: 4}).show()
-------------------
|"A"  |"B"  |"C"  |
-------------------
|3    |3.0  |1.0  |
|4    |4.0  |2.0  |
-------------------

>>> # the following line intends to replace 1 with "3" in column "a",
>>> # but will be ignored since "3" (str) doesn't match the original data type
>>> df.na.replace({1: "3"}, ["a"]).show()
-------------------
|"A"  |"B"  |"C"  |
-------------------
|1    |1.0  |1.0  |
|2    |2.0  |2.0  |
-------------------

Note

If the type of a given value in to_replace or value doesn’t match the column data type (e.g. a float for StringType column), this replacement will be skipped in this column. Especially,

  • int can replace or be replaced in a column with FloatType or DoubleType, but float cannot replace or be replaced in a column with IntegerType or LongType.

  • None can replace or be replaced in a column with any data type.

rollup(*cols: ColumnOrName | Iterable[ColumnOrName]) RelationalGroupedDataFrame#

Performs a SQL GROUP BY ROLLUP. on the DataFrame.

Parameters:

cols – The columns to group by rollup.

sample(frac: float) DataFrame#

Samples rows based on either the number of rows to be returned or a percentage of rows to be returned.

Parameters:

frac – the percentage of rows to be sampled.

Returns:

a DataFrame containing the sample of rows.

sampleBy(col: ColumnOrName, fractions: Dict[LiteralType, float]) DataFrame#

Returns a DataFrame containing a stratified sample without replacement, based on a dict that specifies the fraction for each stratum.

sampleBy() is an alias of sample_by().

Parameters:
  • col – The name of the column that defines the strata.

  • fractions – A dict that specifies the fraction to use for the sample for each stratum. If a stratum is not specified in the dict, the method uses 0 as the fraction.

Examples

>>> df = session.create_dataframe([("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 12)], schema=["name", "age"])
>>> fractions = {"Bob": 0.5, "Nico": 1.0}
>>> sample_df = df.stat.sample_by("name", fractions)  # non-deterministic result
sample_by(col: ColumnOrName, fractions: Dict[LiteralType, float]) DataFrame#

Returns a DataFrame containing a stratified sample without replacement, based on a dict that specifies the fraction for each stratum.

sampleBy() is an alias of sample_by().

Parameters:
  • col – The name of the column that defines the strata.

  • fractions – A dict that specifies the fraction to use for the sample for each stratum. If a stratum is not specified in the dict, the method uses 0 as the fraction.

Examples

>>> df = session.create_dataframe([("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 12)], schema=["name", "age"])
>>> fractions = {"Bob": 0.5, "Nico": 1.0}
>>> sample_df = df.stat.sample_by("name", fractions)  # non-deterministic result
property schema: StructType#

The definition of the columns in this DataFrame (the “relational schema” for the DataFrame).

select(*cols: ColumnOrName | Iterable[ColumnOrName]) DataFrame#

Returns a new DataFrame with the specified Column expressions as output (similar to SELECT in SQL). Only the Columns specified as arguments will be present in the resulting DataFrame.

You can use any Column expression or strings for named columns.

Parameters:

*cols – A Column, str, or a list of those.

Examples

>>> df = session.create_dataframe([[1, "some string value", 3, 4]], schema=["col1", "col2", "col3", "col4"])
>>> df_selected = df.select(col("col1"), col("col2").substr(0, 10), df["col3"] + df["col4"])

>>> df_selected = df.select("col1", "col2", "col3")

>>> df_selected = df.select(["col1", "col2", "col3"])

>>> df_selected = df.select(df["col1"], df.col2, df.col("col3"))
selectExpr(*exprs: str | Iterable[str]) DataFrame#

Projects a set of SQL expressions and returns a new DataFrame. This method is equivalent to select(sql_expr(...)) with select() and functions.sql_expr().

selectExpr() is an alias of select_expr().

Parameters:

exprs – The SQL expressions.

Examples

>>> df = session.create_dataframe([-1, 2, 3], schema=["a"])  # with one pair of [], the dataframe has a single column and 3 rows.
>>> df.select_expr("abs(a)", "a + 2", "cast(a as string)").show()
--------------------------------------------
|"ABS(A)"  |"A + 2"  |"CAST(A AS STRING)"  |
--------------------------------------------
|1         |1        |-1                   |
|2         |4        |2                    |
|3         |5        |3                    |
--------------------------------------------
select_expr(*exprs: str | Iterable[str]) DataFrame#

Projects a set of SQL expressions and returns a new DataFrame. This method is equivalent to select(sql_expr(...)) with select() and functions.sql_expr().

selectExpr() is an alias of select_expr().

Parameters:

exprs – The SQL expressions.

Examples

>>> df = session.create_dataframe([-1, 2, 3], schema=["a"])  # with one pair of [], the dataframe has a single column and 3 rows.
>>> df.select_expr("abs(a)", "a + 2", "cast(a as string)").show()
--------------------------------------------
|"ABS(A)"  |"A + 2"  |"CAST(A AS STRING)"  |
--------------------------------------------
|1         |1        |-1                   |
|2         |4        |2                    |
|3         |5        |3                    |
--------------------------------------------
show(n: int = 10, max_width: int = 50, *, statement_properties: Dict[str, str] | None = None) None#

Evaluates this DataFrame and prints out the first n rows with the specified maximum number of characters per column.

Parameters:
  • n – The number of rows to print out.

  • max_width – The maximum number of characters to print out for each column. If the number of characters exceeds the maximum, the method prints out an ellipsis (…) at the end of the column.

sort(*cols: ColumnOrName | Iterable[ColumnOrName], ascending: bool | int | List[bool | int] | None = None) DataFrame#

Sorts a DataFrame by the specified expressions (similar to ORDER BY in SQL).

orderBy() and order_by() are aliases of sort().

Parameters:
  • *cols – A column name as str or Column, or a list of columns to sort by.

  • ascending – A bool or a list of bool for sorting the DataFrame, where True sorts a column in ascending order and False sorts a column in descending order . If you specify a list of multiple sort orders, the length of the list must equal the number of columns.

Examples

>>> from pystarburst.functions import col
>>> df = session.create_dataframe([[1, 2], [3, 4], [1, 4]], schema=["A", "B"])
>>> df.sort(col("A"), col("B").asc()).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|1    |4    |
|3    |4    |
-------------

>>> df.sort(col("a"), ascending=False).show()
-------------
|"A"  |"B"  |
-------------
|3    |4    |
|1    |2    |
|1    |4    |
-------------

>>> # The values from the list overwrite the column ordering.
>>> df.sort(["a", col("b").desc()], ascending=[1, 1]).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|1    |4    |
|3    |4    |
-------------
stack(row_count: Column, *cols: ColumnOrName, ids_column_list: List[ColumnOrName] = []) DataFrame#

Separates col1, …, colk into n rows. Uses column names _1, _2, etc. by default unless specified otherwise.

Parameters:
  • row_count – number of rows to be separated

  • cols – Input elements to be separated

  • ids_column_list – (Keyword-only argument) The names of the columns in the source table or subequery that will be used as identifiers.

Examples

>>> df = session.createDataFrame([(1, 2, 3)], ["a", "b", "c"])
>>> df.stack(lit(2), df.a, df.b, df.c).show()
---------------
|"_1"  |"_2"  |
---------------
|1     |2     |
|3     |NULL  |
---------------

>>> df.stack(lit(2), df.a, df.b, df.c, ids_column_list=["a", "b", "c"]).show()
---------------------------------
|"a"  |"b"  |"c"  |"_4"  |"_5"  |
---------------------------------
|1    |2    |3    |1     |2     |
|1    |2    |3    |3     |NULL  |
---------------------------------
property stat: DataFrameStatFunctions#
subtract(other: DataFrame) DataFrame#

Returns a new DataFrame that contains all the rows from the current DataFrame except for the rows that also appear in the other DataFrame. Duplicate rows are eliminated.

exceptAll(), minus() and subtract() are aliases of except_().

Parameters:

other – The DataFrame that contains the rows to exclude.

Examples

>>> df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[1, 2], [5, 6]], schema=["c", "d"])
>>> df1.subtract(df2).show()
-------------
|"A"  |"B"  |
-------------
|3    |4    |
-------------
summary(*statistics: str | List[str]) DataFrame#

Computes specified statistics for numeric and string columns. Available statistics are: - count - mean - stddev - min - max - arbitrary approximate percentiles specified as a percentage (e.g., 75%)

If no statistics are given, this function computes count, mean, stddev, min, approximate quartiles (percentiles at 25%, 50%, and 75%), and max.

Parameters:

statistics – The names of statistics whose basic statistics are computed.

take(n: int | None = None, *, statement_properties: Dict[str, str] | None = None) Row | None | List[Row]#

Executes the query representing this DataFrame and returns the first n rows of the results.

Parameters:

n – The number of rows to return.

Returns:

A list of the first n Row objects if n is not None. If n is negative or larger than the number of rows in the result, returns all rows in the results. n is None, it returns the first Row of results, or None if it does not exist.

to(schema: StructType) DataFrame#

Returns a new DataFrame where each row is reconciled to match the specified schema.

Parameters:

schema (StructType) – the new schema

Examples

>>> from pystarburst.types import *
>>> df = session.createDataFrame([("a", 1)], ["i", "j"])
>>> df.schema
StructType([StructField('i', StringType(), True), StructField('j', LongType(), True)])

>>> schema = StructType([StructField("j", StringType()), StructField("i", StringType())])
>>> df2 = df.to(schema)
>>> df2.schema
StructType([StructField('j', StringType(), True), StructField('i', StringType(), True)])

>>> df2.show()
+---+---+
|  j|  i|
+---+---+
|  1|  a|
+---+---+
toDF(*names: str | Iterable[str]) DataFrame#

Creates a new DataFrame containing columns with the specified names.

The number of column names that you pass in must match the number of columns in the existing DataFrame.

toDF() is an alias of to_df().

Parameters:

names – list of new column names

Examples

>>> df1 = session.range(1, 10, 2).to_df("col1")
>>> df2 = session.range(1, 10, 2).to_df(["col1"])
toLocalIterator(*, statement_properties: Dict[str, str] | None = None) Iterator[Row]#

Executes the query representing this DataFrame and returns an iterator of Row objects that you can use to retrieve the results.

Unlike collect(), this method does not load all data into memory at once.

toLocalIterator() is an alias of to_local_iterator().

Examples

>>> df = session.table("prices")
>>> for row in df.to_local_iterator():
...     print(row)
Row(PRODUCT_ID='id1', AMOUNT=Decimal('10.00'))
Row(PRODUCT_ID='id2', AMOUNT=Decimal('20.00'))
to_df(*names: str | Iterable[str]) DataFrame#

Creates a new DataFrame containing columns with the specified names.

The number of column names that you pass in must match the number of columns in the existing DataFrame.

toDF() is an alias of to_df().

Parameters:

names – list of new column names

Examples

>>> df1 = session.range(1, 10, 2).to_df("col1")
>>> df2 = session.range(1, 10, 2).to_df(["col1"])
to_local_iterator(*, statement_properties: Dict[str, str] | None = None) Iterator[Row]#

Executes the query representing this DataFrame and returns an iterator of Row objects that you can use to retrieve the results.

Unlike collect(), this method does not load all data into memory at once.

toLocalIterator() is an alias of to_local_iterator().

Examples

>>> df = session.table("prices")
>>> for row in df.to_local_iterator():
...     print(row)
Row(PRODUCT_ID='id1', AMOUNT=Decimal('10.00'))
Row(PRODUCT_ID='id2', AMOUNT=Decimal('20.00'))
to_pandas()#

Returns a Pandas DataFrame using the results from the PyStarburst DataFrame.

Examples

>>> df = session.create_dataframe([[1, "a", 1.0], [2, "b", 2.0]]).to_df("id", "value1", "value2").to_pandas()
transform(func: Callable[[...], DataFrame], *args: Any, **kwargs: Any) DataFrame#

Returns a new DataFrame. Concise syntax for chaining custom transformations.

Parameters:
  • func (function) – a function that takes and returns a DataFrame.

  • *args – Positional arguments to pass to func.

  • **kwargs – Keyword arguments to pass to func.

Examples

>>> from pystarburst.functions import col
>>> df = session.createDataFrame([(1, 1.0), (2, 2.0)], ["int", "float"])
>>> def cast_all_to_int(input_df):
...     return input_df.select([col(col_name).cast("int") for col_name in input_df.columns])
>>> def sort_columns_asc(input_df):
...     return input_df.select(*sorted(input_df.columns)).toDF("float", "int")
>>> df.transform(cast_all_to_int).transform(sort_columns_asc).show()
+-----+---+
|float|int|
+-----+---+
|    1|  1|
|    2|  2|
+-----+---+

>>> def add_n(input_df, n):
...     return input_df.select([(col(col_name) + n).alias(col_name)
...                             for col_name in input_df.columns])
>>> df.transform(add_n, 1).transform(add_n, n=10).toDF("int", "float").show()
+---+-----+
|int|float|
+---+-----+
| 12| 12.0|
| 13| 13.0|
+---+-----+
union(other: DataFrame) DataFrame#

Returns a new DataFrame that contains all the rows in the current DataFrame and another DataFrame (other), excluding any duplicate rows. Both input DataFrames must contain the same number of columns.

Parameters:

other – the other DataFrame that contains the rows to include.

Examples

>>> df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[0, 1], [3, 4]], schema=["c", "d"])
>>> df1.union(df2).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|3    |4    |
|0    |1    |
-------------
unionAll(other: DataFrame) DataFrame#

Returns a new DataFrame that contains all the rows in the current DataFrame and another DataFrame (other), including any duplicate rows. Both input DataFrames must contain the same number of columns.

unionAll() is an alias of union_all().

Parameters:

other – the other DataFrame that contains the rows to include.

Examples

>>> df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[0, 1], [3, 4]], schema=["c", "d"])
>>> df1.union_all(df2).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|3    |4    |
|0    |1    |
|3    |4    |
-------------
unionAllByName(other: DataFrame) DataFrame#

Returns a new DataFrame that contains all the rows in the current DataFrame and another DataFrame (other), including any duplicate rows.

This method matches the columns in the two DataFrames by their names, not by their positions. The columns in the other DataFrame are rearranged to match the order of columns in the current DataFrame.

unionAllByName() is an alias of union_all_by_name().

Parameters:

other – the other DataFrame that contains the rows to include.

Examples

>>> df1 = session.create_dataframe([[1, 2]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[2, 1]], schema=["b", "a"])
>>> df1.union_all_by_name(df2).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|1    |2    |
-------------
unionByName(other: DataFrame) DataFrame#

Returns a new DataFrame that contains all the rows in the current DataFrame and another DataFrame (other), excluding any duplicate rows.

This method matches the columns in the two DataFrames by their names, not by their positions. The columns in the other DataFrame are rearranged to match the order of columns in the current DataFrame.

unionByName() is an alias of union_by_name().

Parameters:

other – the other DataFrame that contains the rows to include.

Examples

>>> df1 = session.create_dataframe([[1, 2]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[2, 1]], schema=["b", "a"])
>>> df1.union_by_name(df2).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
-------------
union_all(other: DataFrame) DataFrame#

Returns a new DataFrame that contains all the rows in the current DataFrame and another DataFrame (other), including any duplicate rows. Both input DataFrames must contain the same number of columns.

unionAll() is an alias of union_all().

Parameters:

other – the other DataFrame that contains the rows to include.

Examples

>>> df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[0, 1], [3, 4]], schema=["c", "d"])
>>> df1.union_all(df2).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|3    |4    |
|0    |1    |
|3    |4    |
-------------
union_all_by_name(other: DataFrame) DataFrame#

Returns a new DataFrame that contains all the rows in the current DataFrame and another DataFrame (other), including any duplicate rows.

This method matches the columns in the two DataFrames by their names, not by their positions. The columns in the other DataFrame are rearranged to match the order of columns in the current DataFrame.

unionAllByName() is an alias of union_all_by_name().

Parameters:

other – the other DataFrame that contains the rows to include.

Examples

>>> df1 = session.create_dataframe([[1, 2]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[2, 1]], schema=["b", "a"])
>>> df1.union_all_by_name(df2).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|1    |2    |
-------------
union_by_name(other: DataFrame) DataFrame#

Returns a new DataFrame that contains all the rows in the current DataFrame and another DataFrame (other), excluding any duplicate rows.

This method matches the columns in the two DataFrames by their names, not by their positions. The columns in the other DataFrame are rearranged to match the order of columns in the current DataFrame.

unionByName() is an alias of union_by_name().

Parameters:

other – the other DataFrame that contains the rows to include.

Examples

>>> df1 = session.create_dataframe([[1, 2]], schema=["a", "b"])
>>> df2 = session.create_dataframe([[2, 1]], schema=["b", "a"])
>>> df1.union_by_name(df2).show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
-------------
unpivot(ids_column_list: ColumnOrName | Iterable[ColumnOrName], unpivot_column_list: List[ColumnOrName], name_column: str, value_column: str) DataFrame#

Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. Note that UNPIVOT is not exactly the reverse of PIVOT as it cannot undo aggregations made by PIVOT.

melt() is an alias of unpivot().

Parameters:
  • ids_column_list – The names of the columns in the source table or subequery that will be used as identifiers.

  • unpivot_column_list – The names of the columns in the source table or subequery that will be narrowed into a single pivot column. The column names will populate name_column, and the column values will populate value_column.

  • name_column – The name to assign to the generated column that will be populated with the names of the columns in the column list.

  • value_column – The name to assign to the generated column that will be populated with the values from the columns in the column list.

Examples

>>> df = session.create_dataframe([
...     (1, 'electronics', 100, 200),
...     (2, 'clothes', 100, 300)
... ], schema=["empid", "dept", "jan", "feb"])
>>> df = df.unpivot(["empid", "dept"], ["jan", "feb"], "month", "sales").sort("empid")
>>> df.show()
---------------------------------------------
|"empid"  |"dept"       |"month"  |"sales"  |
---------------------------------------------
|1        |electronics  |JAN      |100      |
|1        |electronics  |FEB      |200      |
|2        |clothes      |JAN      |100      |
|2        |clothes      |FEB      |300      |
---------------------------------------------
where(expr: ColumnOrSqlExpr) DataFrame#

Filters rows based on the specified conditional expression (similar to WHERE in SQL).

Parameters:

expr – a Column expression or SQL text.

where() is an alias of filter().

Examples

>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["A", "B"])
>>> df_filtered = df.filter((col("A") > 1) & (col("B") < 100))  # Must use parenthesis before and after operator &.
>>> # The following two result in the same SQL query:
>>> df.filter(col("a") > 1).collect()
[Row(A=3, B=4)]
>>> df.filter("a > 1").collect()  # use SQL expression
[Row(A=3, B=4)]
withColumn(col_name: str, col: Column | TableFunctionCall) DataFrame#

Returns a DataFrame with an additional column with the specified name col_name. The column is computed by using the specified expression col.

If a column with the same name already exists in the DataFrame, that column is replaced by the new column.

withColumn() is an alias of with_column().

Parameters:
  • col_name – The name of the column to add or replace.

  • col – The Column or table_function.TableFunctionCall with single column output to add or replace.

Examples

>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df.with_column("mean", (df["a"] + df["b"]) / 2).show()
------------------------
|"A"  |"B"  |"MEAN"    |
------------------------
|1    |2    |1.500000  |
|3    |4    |3.500000  |
------------------------
withColumnRenamed(existing: ColumnOrName, new: str) DataFrame#

Returns a DataFrame with the specified column existing renamed as new.

with_column_renamed() is an alias of rename().

Parameters:
  • existing – The old column instance or column name to be renamed.

  • new – The new column name.

Examples

>>> # This example renames the column `A` as `NEW_A` in the DataFrame.
>>> df = session.sql("select 1 as A, 2 as B")
>>> df_renamed = df.with_column_renamed(col("A"), "NEW_A")
>>> df_renamed.show()
-----------------
|"NEW_A"  |"B"  |
-----------------
|1        |2    |
-----------------
withColumns(col_names: List[str], values: List[Column]) DataFrame#

Returns a DataFrame with additional columns with the specified names col_names. The columns are computed by using the specified expressions values.

If columns with the same names already exist in the DataFrame, those columns are removed and appended at the end by new columns.

withColumns() is an alias of with_columns().

Parameters:
  • col_names – A list of the names of the columns to add or replace.

  • values – A list of the Column objects to add or replace.

Examples

>>> df = session.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"])
>>> df.with_columns(['age2', 'age3'], [df.age + 2, df.age + 3]).show()
------------------------------------
|"age"  |"name"  |"age2"  |"age3"  |
------------------------------------
|2      |Alice   |4       |5       |
|5      |Bob     |7       |8       |
------------------------------------
withColumnsRenamed(cols_map: dict) DataFrame#

Returns a new DataFrame by renaming multiple columns.

withColumnsRenamed() is an alias of with_columns_renamed().

Parameters:

cols_map – a dict of existing column names and corresponding desired column names.

Examples

>>> # This example renames the columns `A` as `NEW_A` and `B` as `NEW_B`
>>> df = session.sql("select 1 as A, 2 as B")
>>> df_renamed = df.with_columns_renamed({"A": "NEW_A", "B": "NEW_B"})
>>> df_renamed.show()
---------------------
|"NEW_A"  |"NEW_B"  |
---------------------
|1        |2        |
---------------------
with_column(col_name: str, col: Column | TableFunctionCall) DataFrame#

Returns a DataFrame with an additional column with the specified name col_name. The column is computed by using the specified expression col.

If a column with the same name already exists in the DataFrame, that column is replaced by the new column.

withColumn() is an alias of with_column().

Parameters:
  • col_name – The name of the column to add or replace.

  • col – The Column or table_function.TableFunctionCall with single column output to add or replace.

Examples

>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df.with_column("mean", (df["a"] + df["b"]) / 2).show()
------------------------
|"A"  |"B"  |"MEAN"    |
------------------------
|1    |2    |1.500000  |
|3    |4    |3.500000  |
------------------------
with_column_renamed(existing: ColumnOrName, new: str) DataFrame#

Returns a DataFrame with the specified column existing renamed as new.

with_column_renamed() is an alias of rename().

Parameters:
  • existing – The old column instance or column name to be renamed.

  • new – The new column name.

Examples

>>> # This example renames the column `A` as `NEW_A` in the DataFrame.
>>> df = session.sql("select 1 as A, 2 as B")
>>> df_renamed = df.with_column_renamed(col("A"), "NEW_A")
>>> df_renamed.show()
-----------------
|"NEW_A"  |"B"  |
-----------------
|1        |2    |
-----------------
with_columns(col_names: List[str], values: List[Column]) DataFrame#

Returns a DataFrame with additional columns with the specified names col_names. The columns are computed by using the specified expressions values.

If columns with the same names already exist in the DataFrame, those columns are removed and appended at the end by new columns.

withColumns() is an alias of with_columns().

Parameters:
  • col_names – A list of the names of the columns to add or replace.

  • values – A list of the Column objects to add or replace.

Examples

>>> df = session.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"])
>>> df.with_columns(['age2', 'age3'], [df.age + 2, df.age + 3]).show()
------------------------------------
|"age"  |"name"  |"age2"  |"age3"  |
------------------------------------
|2      |Alice   |4       |5       |
|5      |Bob     |7       |8       |
------------------------------------
with_columns_renamed(cols_map: dict) DataFrame#

Returns a new DataFrame by renaming multiple columns.

withColumnsRenamed() is an alias of with_columns_renamed().

Parameters:

cols_map – a dict of existing column names and corresponding desired column names.

Examples

>>> # This example renames the columns `A` as `NEW_A` and `B` as `NEW_B`
>>> df = session.sql("select 1 as A, 2 as B")
>>> df_renamed = df.with_columns_renamed({"A": "NEW_A", "B": "NEW_B"})
>>> df_renamed.show()
---------------------
|"NEW_A"  |"NEW_B"  |
---------------------
|1        |2        |
---------------------
property write: DataFrameWriter#

Returns a new DataFrameWriter object that you can use to write the data in the DataFrame to a Trino cluster

Examples

>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> df.write.mode("overwrite").save_as_table("saved_table")
>>> session.table("saved_table").show()
-------------
|"A"  |"B"  |
-------------
|1    |2    |
|3    |4    |
-------------