You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

37 lines
1.1 KiB

#!/bin/bash
# lr 1e-5 to 5e-5
# mc_loss_efficient 0.1 to 1
while getopts d: flag
do
case "${flag}" in
d) data_split=${OPTARG};;
esac
done
# Check whether the required environment vars are set
if [[ -z "${SAVED_MODELS_BASELINE}" ]] || [[ -z "${PRE_TRAINED_SOLOIST}" ]]; then
echo "Required Environment variables not set. First run set_env_var.sh"
fi
datetime_now=$(date +"%Y%m%d")
experiment_folder=experiment-${datetime_now}/"${data_split}"
python soloist_train.py \
--output_dir="${SAVED_MODELS_BASELINE}"/"${experiment_folder}" \
--model_type=gpt2 \
--model_name_or_path="${PRE_TRAINED_SOLOIST}" \
--do_train \
--train_data_file=../data/baseline/"${data_split}"/train.soloist.json \
--eval_data_file=../data/baseline/valid/valid.soloist.json \
--add_special_action_tokens=../data/resource/special_tokens.txt \
--per_gpu_train_batch_size 1 \
--num_train_epochs 25 \
--learning_rate 5e-5 \
--overwrite_cache \
--save_steps 5000 \
--max_seq 100 \
--overwrite_output_dir \
--max_turn 15 \
--num_candidates 1 \
--mc_loss_efficient 0.33 \
--add_belief_prediction