You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
1.8 KiB
60 lines
1.8 KiB
#!/bin/bash
|
|
# lr 1e-5 to 5e-5
|
|
# mc_loss_efficient 0.1 to 1
|
|
|
|
usage="$(basename "$0") [-d <data-split-name>]
|
|
Argument -d takes (few-shot) data split names.
|
|
Possible valid names : 50-dpd|100-dpd|125-dpd|250-dpd"
|
|
|
|
while getopts :d: flag
|
|
do
|
|
case "${flag}" in
|
|
d) data_split=${OPTARG};;
|
|
:) printf "missing argument for -%s\n" "$OPTARG" >&2; echo "$usage" >&2; exit 1;;
|
|
esac
|
|
done
|
|
|
|
# check for mandatory/required -d argument
|
|
# mandatory arguments
|
|
if [ ! "$data_split" ]; then
|
|
echo "arguments -d must be provided"
|
|
echo "$usage" >&2; exit 1
|
|
fi
|
|
|
|
# Check whether the required environment vars are set
|
|
if [[ -z "${SAVED_MODELS_BASELINE}" ]] || [[ -z "${PRE_TRAINED_SOLOIST}" ]]; then
|
|
echo "Required Environment variables not set. First run set_env.sh"
|
|
exit 1
|
|
fi
|
|
|
|
datetime_now=$(date +"%Y%m%dT%H%M%S")
|
|
experiment_folder=experiment-${datetime_now}/"${data_split}"
|
|
SAVE_DIR="${SAVED_MODELS_BASELINE}"/"${experiment_folder}"
|
|
# create the dirs if not exist
|
|
mkdir -p "${SAVE_DIR}"
|
|
|
|
echo "Trained Model checkpoints are saved in ${SAVE_DIR}"
|
|
echo "Cleaning the experiment directory (${experiment_folder}) to make sure it's empty!"
|
|
|
|
# delete the contents of the dir (where model checkpoints will be saved!)
|
|
rm -rf "${SAVE_DIR}"
|
|
|
|
python soloist_train.py \
|
|
--output_dir="${SAVE_DIR}" \
|
|
--model_type=gpt2 \
|
|
--model_name_or_path="${PRE_TRAINED_SOLOIST}" \
|
|
--do_train \
|
|
--train_data_file=../data/baseline/"${data_split}"/train.soloist.json \
|
|
--eval_data_file=../data/baseline/valid/valid.soloist.json \
|
|
--add_special_action_tokens=../data/resource/special_tokens.txt \
|
|
--per_gpu_train_batch_size 1 \
|
|
--num_train_epochs 50 \
|
|
--learning_rate 5e-5 \
|
|
--overwrite_cache \
|
|
--save_steps 5000 \
|
|
--max_seq 100 \
|
|
--overwrite_output_dir \
|
|
--max_turn 15 \
|
|
--num_candidates 1 \
|
|
--mc_loss_efficient 0.33 \
|
|
--add_belief_prediction |