Model Construction

You can choose language by




This project provides algorithm implementations of three baseline methods, namely A3CLSTM-E2E [1], D-VAT [2], and R-VAT(Ours).

Among these baseline methods, [1] [2] and R-VAT(Ours) respectively call the three environments implemented in this project( Environment Wrapping ), providing templates for different algorithm implementation needs.

3种环境类关系图




Baseline1 A3CLSTM-E2E

This project refers to the paper AD-VAT+ [3] for reproducing the A3CLSTM-E2E method based on the repository rl_a3c_pytorch.

The specific code can be found in the Alg_Base/DAT_Benchmark/models/A3CLSTM_E2E folder

Quick-Start

cd Alg_Base/DAT_Benchmark/
# Test mode
# Testing with Cumulative Reward (CR)
python ./models/A3CLSTM_E2E/main.py --Mode 0 --Scene "citystreet" --Weather "day" --delay 20 --Test_Param "CityStreet-d" --Test_Mode CR
# Testing with Tracking Success Rate (TSR)
python ./models/A3CLSTM_E2E/main.py --Mode 0 --Scene "citystreet" --Weather "day" --delay 20 --Test_Param "CityStreet-d" --Test_Mode TSR
# New training mode
python ./models/A3CLSTM_E2E/main.py --Mode 1 --workers 35 --Scene "citystreet" --Weather "day" --delay 20 --Freq 125 --New_Train
# Resumed training mode
python ./models/A3CLSTM_E2E/main.py --Mode 1 --workers 35 --Scene "citystreet" --Weather "day" --delay 20 --Freq 125

Program entry point

main.py is the entry point of the entire program, primarily providing two operating modes: train and test.

  • In train mode, a testing process is also provided (e.g., if the program runs with 24 processes, the last process with ID 23 will be used for testing). This process is mainly used for tensorboard visualization during the training process.

  • The test mode is mainly used to test the model weights after training is completed. This mode runs in a single process and tests the weights of individual models.

  • The configuration variable for the above modes is the MODE variable.

In addition, the main.py file also accepts user parameter configurations, with the meanings and default values of the parameters as follows:

1. Runtime parameters: Configure system mode, runtime frequency, device settings, etc.

  • --Mode(int)=1 : Configure whether the running mode is training mode (--Mode=1 for training mode, --Mode=0 for testing mode)

  • --workers(int)=35 : The number of parallel training processes in the environment (should be determined based on the actual memory/GPU capacity of the computer)

  • --gpu-ids(list)=[-1] : Used to set the GPU IDs, default is -1, which means no GPU will be used

  • --Freq(int)=125 : The running frequency of the algorithm side (the environment side runs at 500Hz and is unchangeable), so the default --Freq(int)=125 means the environment side transmits data back every 4 steps

  • --delay(int)=20 : Waiting time for Webots map to load (training can only start normally after the Webots map has been fully loaded)

  • --New_Train(bool)=False : Whether to start a new training session, default is --New_Train=False, which will load the pre-trained weights from A3CLSTM_E2E/trained_models/Benchmark.dat (if available) for training, and the tensorboard curves will continue from the last training session

  • --Port(int)=-1 : The communication port between the environment side and the algorithm side, default --Port=-1 will randomly use an available port, and manual modification is not recommended

2. Environment parameters: Configure the selected environment

  • --map(str)="citystreet-day.wbt" : Configure the training/testing environment scene type, choose from [citystreet-day.wbt,downtown-day.wbt,lake-day.wbt,village-day.wbt,desert-day.wbt,farmland-day.wbt,citystreet-night.wbt,downtown-night.wbt,lake-night.wbt,village-night.wbt,desert-night.wbt,farmland-night.wbt,citystreet-foggy.wbt,downtown-foggy.wbt,lake-foggy.wbt,village-foggy.wbt,desert-foggy.wbt,farmland-foggy.wbt,citystreet-snow.wbt,downtown-snow.wbt,lake-snow.wbt,village-snow.wbt,desert-snow.wbt,farmland-snow.wbt]

3. Model parameters: Configure parameters related to model weight loading and saving

  • --load(bool)=True : Whether to load an existing model for further training (only when --load=True and the model weights A3CLSTM_E2E/trained_models/Benchmark.dat exist will the model be loaded)

  • --save-max(bool)=False : Whether to save the model weights when the testing process reaches the highest reward, default is False, meaning only the model weights at the last moment will be saved

  • --model_type(str)="E2E" : Specify the type of model currently being used, if the user implements their own model, they can add the configuration here

  • --save-model-dir(str)="./models/A3CLSTM_E2E/trained_models/" : The path for saving the model

  • --Test_Param(str)="CityStreet-d" : Specifies which weights to load for testing(only enabled when --Mode(int)=0), default is --Test_Param="CityStreet-d", which will load the weights A3CLSTM_E2E/trained_models/CityStreet-d.dat for testing

4. Visualization parameters: Configure parameters related to visualization files

  • --tensorboard-logger(bool)=True : Whether to enable tensorboard for model visualization

  • --log-dir(str)="./models/A3CLSTM_E2E/logs/" : If tensorboard is enabled, the location to store the log files

log files

This project provides two types of log recording methods. The first is to directly output log files (configured through config.json in ["Benchmark"]["verbose"]), and the second is to use TensorBoard to record performance changes during the training process.

Mode 1: Directly output log files

Log files can be found in the folder Alg_Base/DAT_Benchmark/logs

  • This mode is mainly used for program debugging and data transmission verification, essentially replacing the terminal print function.

  • Agent${n}.log is mainly used to save the data transmitted from the environment. For example, if you want to observe the custom reward parameters RewardParams obtained from the environment, you can view them in this file.

Mode 2: Directly output tensorboard-logger files

TensorBoard-logger files can be found in the folder Alg_Base/DAT_Benchmark/models/A3CLSTM_E2E/runs/Benchmark_training

  • TensorBoard is a commonly used visualization platform during the training process of neural networks. Therefore, this project also provides corresponding support.

  • During multi-process training, this project reserves one process for testing. For example, if the user selects 24 parallel agents, then during the actual training, there will be 23 training processes and 1 testing process. The testing process synchronizes the weights of the shared_model at the beginning of each episode and conducts tests.

  • The data recorded in the tensorboard-logger file is from the testing process, mainly the value during training, which is used to evaluate the training status of the agents.

  • Additionally, if the program is interrupted due to external reasons, this project provides an uninterrupted tensorboard-logger function. It uses num_test.txt to store the current data entry count and continues recording when a new training session starts.

  • Note: If the user wishes to re-record the training curve, they need to add -- New_Train to the command that starts the training.

  • After completing all configurations, simply run the following code to view the TensorBoard visualization records on localhost:xxxx.

cd Alg_Base/DAT_Benchmark/models/A3CLSTM_E2E/runs
tensorboard --logdir Benchmark_training




Baseline2 D-VAT

The specific code can be found in the Alg_Base/DAT_Benchmark/models/D_VAT folder.

Quick-Start

cd Alg_Base/DAT_Benchmark/
# Test mode
# Test using Cumulative Reward (CR)
python ./models/D_VAT/DVAT_main.py -w 1 -m citystreet-day.wbt --train_mode 0 --Test_Mode AR
# Test using Tracking Success Rate (TSR)
python ./models/D_VAT/DVAT_main.py -w 1 -m citystreet-day.wbt --train_mode 0 --Test_Mode TSR
# New training mode
python ./models/D_VAT/DVAT_main.py -w 35 -m citystreet-day.wbt --train_mode 1 --New_Train
# Resume training mode after interruption
python ./models/D_VAT/DVAT_main.py -w 35 -m citystreet-day.wbt --train_mode 1

Program Entry

DVAT_main.py is the entry point of the entire program, providing both training and testing modes (specific configurations can be found in the parameters section below).

1. Run Parameters: Configure system mode, run frequency, device, and other parameters.

  • --workers(int)=35 : The number of parallel training environments (this should be decided based on the actual memory/GPU of the computer).

  • --train_mode(int)=1 : Configure whether the operation mode is training mode (--train_mode=1 is training mode, --train_mode=0 is testing mode).

  • --port(int)=-1 : Communication port between the environment and algorithm end. By default, --port=-1 uses a randomly available port. Manual modification is not recommended.

  • --New_Train(bool)=False : Whether to start a new training session. The default is --New_Train=False, which will load params.pth pretrained weights (if available) for training, and the tensorboard curve will also load from the last training session.

2. Environment Parameters: Configure the selected environment.

  • --map(str)="citystreet-day.wbt" : Configure the training/testing environment scene type, selected from ./Webots_Simulation/traffic_project/worlds/*.wbt.

3. Model Parameters: Configure parameters related to model weight import and saving.

  • --savepath(str)="params.pth" : Path where the model will be saved.

4. Visualization Parameters: Configure parameters related to visualization files.

  • --tensorboard_port(int)=1 : Whether to use tensorboard-logger. If --tensorboard_port(int)!=-1, a random available port will be assigned; otherwise, it will not be enabled (not enabled in testing mode).

Code Encapsulation and Modification Details

D-VAT adopts an Actor-Critic asymmetric framework. The schematic diagrams of the Actor-Critic symmetric and asymmetric architectures are as follows:

Asymmetric_and_Symmetrical_Structure

Schematic diagrams of the Actor-Critic symmetric and asymmetric architectures

The D-VAT code mainly implements the following:

  1. Custom Environment Class

The environment class is modeled after the Gym environment class to independently implement the DVAT_ENV environment class, with the specific code seen in Alg_Base/DAT_Benchmark/models/D_VAT/DVAT_envs.py
To support the Actor-Critic asymmetric architecture, the state space of the DVAT_ENV environment class includes two parts: actor_obs and critic_obs:

self.observation_space = gymnasium.spaces.Dict({
    "actor_obs": gymnasium.spaces.Box(low=0, high=1.0, shape=(obs_buffer_len,)+image_shape, dtype=np.float32),
    "critic_obs": gymnasium.spaces.Box(low=-np.inf, high=np.inf, shape=(9,), dtype=np.float32),
})

In the original gymnasium environment, the actor and critic share the same state space:

if obs_buffer_len == 0:
    self.observation_space = spaces.Box(low=0, high=1.0, shape=(env_conf["State_channel"], env_conf["State_size"], env_conf["State_size"]), dtype=np.float32)
else:
    self.observation_space = spaces.Box(low=0, high=1.0, shape=(obs_buffer_len, env_conf["State_channel"], env_conf["State_size"], env_conf["State_size"]), dtype=np.float32)
  1. Custom Implemented Policy

To support the Actor-Critic asymmetric architecture, the DVAT_SACDPolicy is implemented by inheriting from Tianshou’s DiscreteSACPolicy.

  1. Custom Implemented Parallel Class and Collector

Custom parallel environment classes and collectors are implemented by referring to Async_SubprocVecEnv & SubprocVecEnv_TS.

log file

See the tensorboard-logger file in the folder Alg_Base/DAT_Benchmark/models/D_VAT/DVAT_logs

  • During the model training process, as long as the parameter --tensorboard_port(int)!=-1 is set, the tensorboard-logger will automatically start when the program is launched.

  • However, if you want to manually start tensorboard, you can also use the following command:

cd Alg_Base/DAT_Benchmark/models/D_VAT/
tensorboard --logdir DVAT_logs




Baseline3 R-VAT

For specific code, see Alg_Base/DAT_Benchmark/models/R_VAT folder

Quick-Start

cd Alg_Base/DAT_Benchmark/
# Test mode
# Test with Cumulative Reward (CR)
python ./models/R_VAT/RVAT.py -w 1 -m citystreet-day.wbt --train_mode 0 --Test_Mode AR
# Test with Tracking Success Rate (TSR)
python ./models/R_VAT/RVAT.py -w 1 -m citystreet-day.wbt --train_mode 0 --Test_Mode TSR
# New training mode
python ./models/R_VAT/RVAT.py -w 35 -m citystreet-day.wbt --train_mode 1 --New_Train
# Resume training mode
python ./models/R_VAT/RVAT.py -w 35 -m citystreet-day.wbt --train_mode 1

Program Entry

RVAT.py is the entry point for the entire program, providing both training and testing modes (for specific configurations, refer to the parameter settings below).

1. Runtime Parameters: Configure system mode, execution frequency, device, and other parameters

  • --workers(int)=35 : The number of parallel environments for training (should be determined by the actual memory/VRAM of the computer).

  • --train_mode(int)=1 : Configure whether the mode is training mode ( --train_mode=1 is training mode, --train_mode=0 is testing mode).

  • --port(int)=-1 : Communication port between environment and algorithm, default --port=-1 will randomly use an available port, manual modification is not recommended.

  • --New_Train(bool)=False : Whether to start a completely new training, default is --New_Train=False, which loads params.pth pre-trained weights (if available) for training, and tensorboard curves will be loaded from the last training session.

2. Environment Parameters: Configure the selected environment

  • --map(str)="citystreet-day.wbt" : Configure the training/testing environment scene type, selected from ./Webots_Simulation/traffic_project/worlds/*.wbt.

3. Model Parameters: Configure model weight import and save related parameters

  • --savepath(str)="./models/R_VAT/params.pth" : Path where the model will be saved.

4. Visualization Parameters: Configure parameters related to visualization files

  • --tensorboard_port(int)=1 : Whether to use tensorboard-logger, --tensorboard_port(int)!=-1 will randomly assign an available port, otherwise it will not be enabled (not enabled in testing mode).

Curriculum Learning

  • Our method adds curriculum learning based on R-VAT, and in the first and second stages of curriculum learning, the intelligent experience tracks monochrome and colored cars respectively under simple settings (unobstructed situation, car only travels straight)

  • For convenience, we extracted the simple settings as maps, i.e. ./Webots_Simulation/traffic_project/worlds/simpleway-*.wbt maps.

  • Therefore, the third stage is trained on the corresponding simpleway-*.wbt.

  • After completing task understanding, the third stage of training can be conducted on visually challenging maps to learn complex visual features.

  • For example, for the map citystreet-day.wbt, the commands for the third stages are as follows:

cd Alg_Base/DAT_Benchmark/
# Stage2: Task Understanding
python ./models/R_VAT/RVAT.py -w 35 -m simpleway-grass.wbt --train_mode 1 --New_Train
# Stage3: Visual Generalization
python ./models/R_VAT/RVAT.py -w 35 -m citystreet-day.wbt --train_mode 1
# Testing Mode
# Test with Cumulative Reward (CR)
python ./models/R_VAT/RVAT.py -w 1 -m citystreet-night.wbt --train_mode 0 --Test_Mode AR
# Test with Tracking Success Rate (TSR)
python ./models/R_VAT/RVAT.py -w 1 -m citystreet-night.wbt --train_mode 0 --Test_Mode TSR
  • The comparison table of the two-stage maps is as follows:

    Origin Map

    Simple Map

    citystreet-day

    simpleway-grass_day

    citystreet-night

    simpleway-grass_night

    citystreet-foggy

    simpleway-grass_foggy

    citystreet-snow

    simpleway-city_snow

    desert-day

    simpleway-desert_day

    desert-night

    simpleway-desert_night

    desert-foggy

    simpleway-desert_foggy

    desert-snow

    simpleway-desert_snow

    downtown-day

    simpleway-clinker_day

    downtown-night

    simpleway-clinker_night

    downtown-foggy

    simpleway-clinker_foggy

    downtown-snow

    simpleway-city_snow

    farmland-day

    simpleway-farm_day

    farmland-night

    simpleway-farm_night

    farmland-foggy

    simpleway-farm_foggy

    farmland-snow

    simpleway-farm_snow

    lake-day

    simpleway-lake_day

    lake-night

    simpleway-lake_night

    lake-foggy

    simpleway-lake_foggy

    lake-snow

    simpleway-lake_snow

    village-day

    simpleway-village_day

    village-night

    simpleway-village_night

    village-foggy

    simpleway-village_foggy

    village-snow

    simpleway-village_snow

log file

The tensorboard-logger file can be found in the folder Alg_Base/DAT_Benchmark/models/${model_name}_logs

  • During the model training process, as long as the parameter --tensorboard_port(int)!=-1 is set, the tensorboard-logger will automatically start when the program is launched.

  • However, if you wish to start tensorboard manually, you can use the following command:

cd Alg_Base/DAT_Benchmark/
tensorboard --logdir models/${model_name}_logs




Reference: