You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
37 lines
1.0 KiB
37 lines
1.0 KiB
#!/bin/bash
|
|
# temp 0.7 - 1.5
|
|
# top_p 0.2 - 0.8
|
|
|
|
# Check whether the required environment vars are set
|
|
if [[ -z "${SAVED_MODELS_BASELINE}" ]]; then
|
|
echo "SAVED_MODELS_BASELINE Environment variable not set. Run set_env.sh bash script"
|
|
exit 1
|
|
fi
|
|
|
|
# Check whether the MODEL_CHECKPOINT env var set
|
|
if [[ -z "${MODEL_CHECKPOINT}" ]]; then
|
|
echo "MODEL_CHECKPOINT Environment variable not set. Run \"export MODEL_CHECKPOINT=<path>\""
|
|
exit 1
|
|
fi
|
|
|
|
if [ ! -d "${SAVED_MODELS_BASELINE}/${MODEL_CHECKPOINT}" ]; then
|
|
echo "Directory ${MODEL_CHECKPOINT} doesn't exist! Provide a valid Saved Model Checkpoint dir."
|
|
exit 1
|
|
fi
|
|
|
|
NS=5
|
|
TEMP=1
|
|
TOP_P=0.5
|
|
OUTPUT_DIR=${OUTPUTS_DIR_BASELINE}/${MODEL_CHECKPOINT}
|
|
mkdir -p "${OUTPUT_DIR}"
|
|
|
|
python soloist_decode.py \
|
|
--model_type=gpt2 \
|
|
--model_name_or_path="${SAVED_MODELS_BASELINE}"/"${MODEL_CHECKPOINT}" \
|
|
--num_samples $NS \
|
|
--input_file=../data/baseline/test/test.soloist.json \
|
|
--top_p $TOP_P \
|
|
--temperature $TEMP \
|
|
--output_file="${OUTPUT_DIR}"/output_test.json \
|
|
--max_turn 15
|