To orchestrate an end-to-end generate-AI workflow (or pipeline) for automating the Chain-of-Thought (CoT) prompting for tabular binary classification by generating intermediate reasoning from global and local feature importances.
Note: The XAI attributes are obtained by training and tuning a tree-based explainable model and then extracting its feature importances (using sklearn's feature_importances_ attribute) and SHAP values (using TreeExplainer)
-
Clone this repository.
git clone https://github.com/Gaurav0502/xai-guided-cot.git
-
Create a virtual environment to isolate all the dependencies from any global dependencies on your local system.
python3 -m venv venv
-
Activate the virtual environment.
source venv/bin/activate # MacOS venv\Scripts\activate # Windows
-
Install all packages in the
requirements.txtfile.pip3 install -r requirements.txt
-
You may have to install
toon_formatfromtoon-pythonseparately. (though it's also included inrequirements.txt)pip3 install git+https://github.com/toon-format/toon-python.git
-
Run the
setup.shto ensure the required directory structure is created.chmod +x setup.sh ./setup.sh
-
After setting up the dependencies, install the google cloud SDK and authenticate your environment to access Google Cloud Storage (GCS) and Vertex AI.
gcloud auth application-default login
-
Get API keys from Together AI (
TOGETHER_API_KEY) and Anthropic Developer Platform (CLAUDE_API_KEY). Add them into the.envfile created bysetup.sh. -
You may require a
WANDB_API_KEYandWANDB_PROJECT_NAMEsince the tree-basedExplanableModelis trained and tuned usingwandb sweep. -
Now, you must setup Google Cloud Provider (GCP) with the following steps:
- Create a Project in GCP and record the
PROJECT_ID. - Ensure you enable Billing on the required APIs (Vertex AI and Google Cloud Storage).
- Inside the Project, create GCP Bucket to store the batch inference job JSONL files (inputs and outputs) for Vertex AI.
- Record your
BUCKET_NAME,LOCATION. - Ensure that the
LOCATIONyou choose has the model you want to use because the code in this repository requires theLOCATIONfor both GCP Bucket and Vertex AI Batch Inference to be same.
- Create a Project in GCP and record the
-
Finally, your
.envfile must look as follows:# wandb config WANDB_API_KEY=<YOUR-API-KEY> WANDB_PROJECT_NAME=<YOUR-WANDB-PROJECT-NAME> # gcp config PROJECT_ID=<YOUR-PROJECT-ID> BUCKET_NAME=<YOUR-GCP-BUCKET-NAME> LOCATION=<YOUR-LOCATION> # same for storage and batch inference # together ai config TOGETHER_API_KEY=<YOUR-TOGETHERAI-API-KEY> # anthropic config CLAUDE_API_KEY=<YOUR-CLAUDE-API-KEY>
Notes:
-
Only the API keys are secrets. Others are just kept inside the
.envsince they define the environment for different SDKs. -
If you setup
WANDB_API_KEYfrom the CLI, you can ignore that variable. However, theWANDB_PROJECT_NAMEis required.
-
If you wish to check if your environment setup is complete, you can run the unit tests inside
test/usingpytest.pytest -v
Notes:
- All tests are expected to be successful if the environment is correctly configured.
- There is a unit test for overall pipeline execution (
tests/test_pipeline.py) which hits the real APIs with a small dataset. Therefore, you will be billed for those tests.
-
The repository has a
tutorial.ipynbfile that explains the environment setup procedure in much more detail. -
It also explains how to use the overall pipeline and its individual components. If you want to run the pipeline on a new dataset, then you can follow this notebook.
-
We tested the pipeline functionality and prediction performance on four datasets and attempted to answer three research questions:
-
Can large language models (LLMs) effectively generate natural language reasoning from numerical XAI attributes? (RQ1)
-
Does providing this natural language reasoning help improve the performance of standard prompt engineering techniques on the tabular binary classification? (RQ2)
-
How does XAI-Guided-CoT perform in comparison to the tree-based explainable model? (RQ3)
-
-
Additionally, we also attempted to perform two ablation studies to further diagnose the improvement in prediction performance:
-
Does the improvement of XAI-Guided-CoT over the zero shot baseline happen to be because of CoT alone? (AB-1)
-
Does the semantic context provided by dataset metadata (dataset name, column name, and class names) drive the performance or is it the XAI attributes? (AB-2)
-
-
To run these experiments, you can use the
main.pyfile in the root of this repository. Preferably, use it inside the terminal because batch inference jobs can take significant amount of time to complete.python3 main.py --dataset <dataset-name> # without masking dataset metadata python3 main.py --dataset <dataset-name> --masked # with masking dataset metadata
Note: The dataset name can be one among this list:
titanic,loan,diabetes, andmushroom. Any other dataset name will raise aValueError. -
The
experiments/contains ametrics.json,results.ipynb, andstat_sig.py. Themetrics.jsonfile has the classification metrics for our experiments along with the objective judge evaluation. Theresults.ipynbfile discusses our results through tabulations and visualizations. Moreover, thestat_sig.pyfile performs McNemar test for statistical significance between the difference in performance of two models. -
All the batch inference jobs were executed through the terminal since they might take a lot of time to run. A notebook kernel is not the ideal choice for it.
Note:
All experimental setup is mentioned inside experiments/results.ipynb. However, even with a proper seed value, it is very difficult for the LLM to produce the same reasoning. However, the overall trend of the results is reproducible.
-
Large language models can effectively generate natural language reasoning from numerical XAI attributes. Moreover, this reasoning is consistent in terms of the numerical values and has an acceptable grammatical structure. (RQ1)
-
XAI-Guided-CoT does show promising results in comparison to that of standard prompt engineering techniques but a more exhaustive evaluation can provide more statistical evidence about the lift in performance. (RQ2)
-
The knowledge learned from the tree-based explainable model can be transferred to the LLM through natural language reasoning and can make the performance more deterministic compared to the standard LLM behaviour. (RQ3)
-
From the results, it is evident that the improvements are not because of CoT alone but are driven by the XAI attributes. Moreover, standard CoT applies generic domain patterns for the prediction while the XAI-Guided-CoT incorporate dataset-specific patterns as well. (AB-1)
-
Being text-based models, the standard approaches suffer to perform when the metadata is masked since they lack context. Interestingly, in some cases, the performance is decent but that is an artifact of the binary classification where there are only two possible options leaving less room for misclassifications. (AB-2)
-
https://docs.cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini (Vertex AI Batch Inference)
-
https://docs.together.ai/docs/batch-inference (Together AI Batch Job)
-
https://platform.claude.com/docs/en/build-with-claude/batch-processing (Anthropic API Batch Processing)
