You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
1.9 KiB
63 lines
1.9 KiB
#!/bin/bash
|
|
|
|
usage="$(basename "$0") [-d <data-split-name>]
|
|
Argument -d takes (few-shot) data split names.
|
|
Possible data-split names : 50-dpd|100-dpd|125-dpd|250-dpd"
|
|
|
|
while getopts :d: flag
|
|
do
|
|
case "${flag}" in
|
|
d) data_split=${OPTARG};;
|
|
:) printf "missing argument for -%s\n" "$OPTARG" >&2; echo "$usage" >&2; exit 1;;
|
|
esac
|
|
done
|
|
|
|
# check for mandatory/required -d argument
|
|
# mandatory arguments
|
|
if [ ! "$data_split" ]; then
|
|
echo "arguments -d must be provided"
|
|
echo "$usage" >&2; exit 1
|
|
fi
|
|
|
|
# Check whether the required environment vars are set
|
|
if [ -z "${SAVED_MODELS_PROMPT}" ]; then
|
|
echo "Must set SAVED_MODELS_PROMPT in environment, run set_env.sh first!";
|
|
exit 1
|
|
fi
|
|
|
|
# Check whether the required environment vars are set
|
|
if [ -z "${PRE_TRAINED_SOLOIST}" ]; then
|
|
echo "Pre-trained SOLOIST Model path must be provided!";
|
|
echo "Must set PRE_TRAINED_SOLOIST in environment, run set_env.sh first!";
|
|
exit 1
|
|
fi
|
|
|
|
# check if the training data file exists
|
|
TRAIN_DATA_FILE=../data/prompt-learning/"${data_split}"/train.soloist.json
|
|
if [ -f "$TRAIN_DATA_FILE" ]; then
|
|
echo "Selected Training set :: ${data_split}/train.soloist.json"
|
|
else
|
|
echo "Training File with set ${data_split} does not exist."
|
|
exit 1
|
|
fi
|
|
|
|
# create experiment folder for storing saved models
|
|
datetime_now=$(date +"%Y%m%dT%H%M%S")
|
|
experiment_folder="${data_split}"/experiment-${datetime_now}
|
|
SAVE_DIR="${SAVED_MODELS_PROMPT}"/"${experiment_folder}"
|
|
|
|
echo "Trained Models (epochs) will be saved in ${SAVE_DIR}"
|
|
|
|
# when using prompt ensemble for training, preferably use more number of epochs.
|
|
epochs=5
|
|
|
|
python prompt_train.py \
|
|
--save_model_dir="${SAVE_DIR}" \
|
|
--pretrained_model_path="${PRE_TRAINED_SOLOIST}" \
|
|
--train_data_file="${TRAIN_DATA_FILE}" \
|
|
--validation_file=../data/prompt-learning/valid/valid.soloist.json \
|
|
--num_epochs $epochs \
|
|
--learning_rate 5e-5 \
|
|
--with_prompt_ensemble \
|
|
--with_inverse_prompt \
|
|
--inverse_prompt_weight 0.3 |