Start a Google Cloud ML Engine prediction job.

View on GitHub

Last Updated: Feb. 25, 2023

Access Instructions

Install the Google provider package into your Airflow environment.

Import the module into your DAG file and instantiate it with your desired params.


job_idRequiredA unique id for the prediction job on Google Cloud ML Engine. (templated)
data_formatRequiredThe format of the input data. It will default to ‘DATA_FORMAT_UNSPECIFIED’ if is not provided or is not one of [“TEXT”, “TF_RECORD”, “TF_RECORD_GZIP”].
input_pathsRequiredA list of GCS paths of input data for batch prediction. Accepting wildcard operator *, but only at the end. (templated)
output_pathRequiredThe GCS path where the prediction results are written to. (templated)
regionRequiredThe Google Compute Engine region to run the prediction job in. (templated)
model_nameThe Google Cloud ML Engine model to use for prediction. If version_name is not provided, the default version of this model will be used. Should not be None if version_name is provided. Should be None if uri is provided. (templated)
version_nameThe Google Cloud ML Engine model version to use for prediction. Should be None if uri is provided. (templated)
uriThe GCS path of the saved model to use for prediction. Should be None if model_name is provided. It should be a GCS path pointing to a tensorflow SavedModel. (templated)
max_worker_countThe maximum number of workers to be used for parallel processing. Defaults to 10 if not specified. Should be a string representing the worker count (“10” instead of 10, “50” instead of 50, etc.)
runtime_versionThe Google Cloud ML Engine runtime version to use for batch prediction.
signature_nameThe name of the signature defined in the SavedModel to use for this job.
project_idThe Google Cloud project name where the prediction job is submitted. If set to None or missing, the default project_id from the Google Cloud connection is used. (templated)
gcp_conn_idThe connection ID used for connection to Google Cloud Platform.
delegate_toThe account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled.
labelsa dictionary containing labels for the job; passed to BigQuery
impersonation_chainOptional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. If set as a string, the account must grant the originating account the Service Account Token Creator IAM role. If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated).


Start a Google Cloud ML Engine prediction job.

See also

For more information on how to use this operator, take a look at the guide: Making predictions

NOTE: For model origin, users should consider exactly one from the three options below:

  1. Populate uri field only, which should be a GCS location that points to a tensorflow savedModel directory.

  2. Populate model_name field only, which refers to an existing model, and the default version of the model will be used.

  3. Populate both model_name and version_name fields, which refers to a specific version of a specific model.

In options 2 and 3, both model and version name should contain the minimal identifier. For instance, call:


if the desired model version is projects/my_project/models/my_model/versions/my_version.

See for further documentation on the parameters.

Was this page helpful?