6.8 KiB
Prompt-based methods for Dialog State Tracking
Repository for my master thesis at the University of Stuttgart (IMS).
Refer to this thesis proposal document for detailed explanation about thesis experiments.
Dataset
MultiWOZ 2.1 dataset is used for training and evaluation of the baseline/prompt-based methods. MultiWOZ is a fully-labeled dataset with a collection of human-human written conversations spanning over multiple domains and topics. Only single-domain dialogues are used in this setup for training and testing. Each dialogue contains multiple turns and may also contain a sub-domain booking. Five domains - Hotel, Train, Restaurant, Attraction, Taxi are used in the experiments and excluded the other two domains as they only appear in the training set. Under few-shot settings, only a portion of the training data is utilized to measure the performance of the DST task in a low-resource scenario. Dialogues are randomly picked for each domain. The below table contains some statistics of the dataset and data splits for the few-shot experiments.
| Data Split | # Dialogues | # Total Turns |
|---|---|---|
| 50-dpd | 250 | 1114 |
| 100-dpd | 500 | 2292 |
| 125-dpd | 625 | 2831 |
| 250-dpd | 1125 | 5187 |
| valid | 190 | 900 |
| test | 193 | 894 |
In the above table, term "dpd" refers to "dialogues per domain". For example, 50-dpd means 50 dialogues per each domain.
All the training and testing data can be found under /data/baseline/ folder.
Environment Setup
Python 3.6 is required for training the baseline model. conda is used for creating environments.
Create conda environment (for baseline model)
Create an environment for baseline training with a specific python version (Python 3.6).
conda create -n <baseline-env-name> python=3.6
Create conda environment (for prompt learning)
Create an environment for prompt-based methods
# TODO
Activate the conda environment
To activate the conda environment, run:
conda activate <env-name>
Deactivating the conda evironment
To deactivate the conda environment, run: (Only after running all the experiments)
conda deactivate
Download and extract SOLOIST pre-trained model
Download and unzip the pretrained model, this is used for finetuning the baseline and prompt-based methods. For more details about the pre-trained SOLOIST model, refer to the GitHub repo.
Download the zip file, replace the /path/to/folder from the below command to a folder of your choice.
wget https://bapengstorage.blob.core.windows.net/soloist/gtg_pretrained.tar.gz \
-P /path/to/folder/
Extract the downloaded pretrained model zip file.
tar -xvf /path/to/folder/gtg_pretrained.tar.gz
Clone the repository
Clone the repository for source code
git clone https://git.pavanmandava.com/pavan/master-thesis.git
Pull the changes from remote (if local is behind the remote)
git pull
Change directory
cd master-thesis
Set Environment variables
Next step is to set environment variables that contains path to pre-trained model, saved models and output dirs.
Edit the set_env.sh file and set the paths for: (nano or vim can be used)
PRE_TRAINED_SOLOIST - Path to the extracted pre-trained SOLOIST model
SAVED_MODELS_BASELINE - Path for saving the trained models at checkpoints
OUTPUTS_DIR_BASELINE - Path for storing the outputs of belief state predictions.
nano set_env.sh
Save the edited file and source it
source set_env.sh
Run the below line to unset the environment variables
sh unset_env.sh
Baseline Experiments
SOLOIST (Peng et al., 2021), the baseline model for this thesis, is a task-oriented dialog system that uses transfer learning and machine teaching to build task bots at scale. SOLOIST uses the pre-train, fine-tune paradigm for building end-to-end dialog systems using a transformer-based auto-regressive language model GPT-2. In the pre-training stage, SOLOIST is initialized with 12-layer GPT-2 (117M parameters) and further trained on two task-oriented dialog corpora for solving belief state prediction task. In the fine-tuning stage, the pre-trained SOLOIST is fine-tuned on MultiWOZ 2.1 dataset to perform belief prediction task.
Install the requirements
After following the environment setup steps in the previous section, install the required python modules for baseline model training.
Change directory to baseline and install the requirements. Make sure the correct baseline conda environment is activated before installing the requirements.
cd baseline
pip install requirements.txt
Train the baseline model
Train a separate model for each data split. Edit the train_baseline.sh file to modify the hyperparameters while training (learning rate, epochs). Use CUDA_VISIBLE_DEVICES to specify a CUDA device (GPU) for training the model.
sh train_baseline.sh -d <data-split-name>
Pass the data split name to -d flag. Possible values are: 50-dpd, 100-dpd, 125-dpd, 250-dpd
Example training command: sh train_baseline.sh -d 50-dpd
Belief State Prediction
Choose a checkpoint of the saved baseline model to generate belief state predictions.
Set the MODEL_CHECKPOINT environment variable with the path to the chosen model checkpoint. It should only contain the path from the "experiment-{datetime}" folder.
export MODEL_CHECKPOINT=<experiment-folder>/<data-split-name>/<checkpoint-folder>
Example: export MODEL_CHECKPOINT=experiment-20220831/100-dpd/checkpoint-90000
Generate belief states by running decode script
sh decode_baseline.sh
The generated predictions are saved under OUTPUTS_DIR_BASELINE folder. Some of the generated belief state predictions are uploaded to this repository and can found under outputs folder.
Baseline Evaluation
The standard Joint Goal Accuracy (JGA) is used to evaluate the belief predictions. This metric compares all the predicted belief states to the ground-truth states for each turn. The prediction is considered correct only if all the predicted belief states match with the ground-truth states. Both slots and values must match for the prediction to be correct.
Edit the evaluate.py to set the predictions output file before running the evaluation
python evaluate.py
Preliminary results of baseline evaluation
| data-split | JGA |
|---|---|
| 50-dpd | 28.64 |
| 100-dpd | 33.11 |
| 125-dpd | 35.79 |
| 250-dpd | 40.38 |
Note: The above preliminary results will change based on further experiments