Model Construction
A3CLSTM-E2E
[1], D-VAT
[2], and R-VAT(Ours)
.
Among these baseline methods, [1] [2] and R-VAT(Ours)
respectively call the three environments implemented in this project( Environment Wrapping ), providing templates for different algorithm implementation needs.
The
A3CLSTM-E2E
algorithm calls the underlying environment class ( 1 Base Environment Class )The
D-VAT
algorithm calls the Gym environment class ( 2 Gym Environment class )The
R-VAT
algorithm calls the parallel environment class ( 3 Parallel Environment Classes )
Baseline1 A3CLSTM-E2E
This project refers to the paper AD-VAT+
[3] for reproducing the A3CLSTM-E2E
method based on the repository rl_a3c_pytorch.
Alg_Base/DAT_Benchmark/models/A3CLSTM_E2E
folder
Quick-Start
cd Alg_Base/DAT_Benchmark/
# Test mode
# Testing with Cumulative Reward (CR)
python ./models/A3CLSTM_E2E/main.py --Mode 0 --Scene "citystreet" --Weather "day" --delay 20 --Test_Param "CityStreet-d" --Test_Mode CR
# Testing with Tracking Success Rate (TSR)
python ./models/A3CLSTM_E2E/main.py --Mode 0 --Scene "citystreet" --Weather "day" --delay 20 --Test_Param "CityStreet-d" --Test_Mode TSR
# New training mode
python ./models/A3CLSTM_E2E/main.py --Mode 1 --workers 35 --Scene "citystreet" --Weather "day" --delay 20 --Freq 125 --New_Train
# Resumed training mode
python ./models/A3CLSTM_E2E/main.py --Mode 1 --workers 35 --Scene "citystreet" --Weather "day" --delay 20 --Freq 125
Program entry point
main.py
is the entry point of the entire program, primarily providing two operating modes:train
andtest
.
In
train
mode, a testing process is also provided (e.g., if the program runs with 24 processes, the last process with ID 23 will be used for testing). This process is mainly used fortensorboard
visualization during the training process.The
test
mode is mainly used to test the model weights after training is completed. This mode runs in a single process and tests the weights of individual models.The configuration variable for the above modes is the
MODE
variable.In addition, the
main.py
file also accepts user parameter configurations, with the meanings and default values of the parameters as follows:1. Runtime parameters: Configure system mode, runtime frequency, device settings, etc.
--Mode(int)=1
: Configure whether the running mode is training mode (--Mode=1
for training mode,--Mode=0
for testing mode)
--workers(int)=35
: The number of parallel training processes in the environment (should be determined based on the actual memory/GPU capacity of the computer)
--gpu-ids(list)=[-1]
: Used to set the GPU IDs, default is -1, which means no GPU will be used
--Freq(int)=125
: The running frequency of the algorithm side (the environment side runs at 500Hz and is unchangeable), so the default--Freq(int)=125
means the environment side transmits data back every 4 steps
--delay(int)=20
: Waiting time for Webots map to load (training can only start normally after the Webots map has been fully loaded)
--New_Train(bool)=False
: Whether to start a new training session, default is--New_Train=False
, which will load the pre-trained weights fromA3CLSTM_E2E/trained_models/Benchmark.dat
(if available) for training, and the tensorboard curves will continue from the last training session
--Port(int)=-1
: The communication port between the environment side and the algorithm side, default--Port=-1
will randomly use an available port, and manual modification is not recommended2. Environment parameters: Configure the selected environment
--map(str)="citystreet-day.wbt"
: Configure the training/testing environment scene type, choose from[citystreet-day.wbt,downtown-day.wbt,lake-day.wbt,village-day.wbt,desert-day.wbt,farmland-day.wbt,citystreet-night.wbt,downtown-night.wbt,lake-night.wbt,village-night.wbt,desert-night.wbt,farmland-night.wbt,citystreet-foggy.wbt,downtown-foggy.wbt,lake-foggy.wbt,village-foggy.wbt,desert-foggy.wbt,farmland-foggy.wbt,citystreet-snow.wbt,downtown-snow.wbt,lake-snow.wbt,village-snow.wbt,desert-snow.wbt,farmland-snow.wbt]
3. Model parameters: Configure parameters related to model weight loading and saving
--load(bool)=True
: Whether to load an existing model for further training (only when--load=True
and the model weightsA3CLSTM_E2E/trained_models/Benchmark.dat
exist will the model be loaded)
--save-max(bool)=False
: Whether to save the model weights when the testing process reaches the highest reward, default isFalse
, meaning only the model weights at the last moment will be saved
--model_type(str)="E2E"
: Specify the type of model currently being used, if the user implements their own model, they can add the configuration here
--save-model-dir(str)="./models/A3CLSTM_E2E/trained_models/"
: The path for saving the model
--Test_Param(str)="CityStreet-d"
: Specifies which weights to load for testing(only enabled when--Mode(int)=0
), default is--Test_Param="CityStreet-d"
, which will load the weightsA3CLSTM_E2E/trained_models/CityStreet-d.dat
for testing4. Visualization parameters: Configure parameters related to visualization files
--tensorboard-logger(bool)=True
: Whether to enabletensorboard
for model visualization
--log-dir(str)="./models/A3CLSTM_E2E/logs/"
: Iftensorboard
is enabled, the location to store the log files
log files
This project provides two types of log recording methods. The first is to directly output log files (configured through
config.json
in["Benchmark"]["verbose"]
), and the second is to use TensorBoard to record performance changes during the training process.Mode 1: Directly output log files
Log files can be found in the folderAlg_Base/DAT_Benchmark/logs
This mode is mainly used for program debugging and data transmission verification, essentially replacing the terminal
Agent${n}.log
is mainly used to save the data transmitted from the environment. For example, if you want to observe the custom reward parametersRewardParams
obtained from the environment, you can view them in this file.Mode 2: Directly output tensorboard-logger files
TensorBoard-logger files can be found in the folderAlg_Base/DAT_Benchmark/models/A3CLSTM_E2E/runs/Benchmark_training
TensorBoard
is a commonly used visualization platform during the training process of neural networks. Therefore, this project also provides corresponding support.During multi-process training, this project reserves one process for testing. For example, if the user selects 24 parallel agents, then during the actual training, there will be 23 training processes and 1 testing process. The testing process synchronizes the weights of the shared_model at the beginning of each episode and conducts tests.
The data recorded in the
tensorboard-logger
file is from the testing process, mainly thevalue
during training, which is used to evaluate the training status of the agents.Additionally, if the program is interrupted due to external reasons, this project provides an uninterrupted
tensorboard-logger
function. It usesnum_test.txt
to store the current data entry count and continues recording when a new training session starts.Note: If the user wishes to re-record the training curve, they need to add
-- New_Train
to the command that starts the training.After completing all configurations, simply run the following code to view the
TensorBoard
visualization records onlocalhost:xxxx
.cd Alg_Base/DAT_Benchmark/models/A3CLSTM_E2E/runs tensorboard --logdir Benchmark_training
Baseline2 D-VAT
Alg_Base/DAT_Benchmark/models/D_VAT
folder.
Quick-Start
cd Alg_Base/DAT_Benchmark/
# Test mode
# Test using Cumulative Reward (CR)
python ./models/D_VAT/DVAT_main.py -w 1 -m citystreet-day.wbt --train_mode 0 --Test_Mode AR
# Test using Tracking Success Rate (TSR)
python ./models/D_VAT/DVAT_main.py -w 1 -m citystreet-day.wbt --train_mode 0 --Test_Mode TSR
# New training mode
python ./models/D_VAT/DVAT_main.py -w 35 -m citystreet-day.wbt --train_mode 1 --New_Train
# Resume training mode after interruption
python ./models/D_VAT/DVAT_main.py -w 35 -m citystreet-day.wbt --train_mode 1
Program Entry
DVAT_main.py
is the entry point of the entire program, providing both training and testing modes (specific configurations can be found in the parameters section below).1. Run Parameters: Configure system mode, run frequency, device, and other parameters.
--workers(int)=35
: The number of parallel training environments (this should be decided based on the actual memory/GPU of the computer).
--train_mode(int)=1
: Configure whether the operation mode is training mode (--train_mode=1
is training mode,--train_mode=0
is testing mode).
--port(int)=-1
: Communication port between the environment and algorithm end. By default,--port=-1
uses a randomly available port. Manual modification is not recommended.
--New_Train(bool)=False
: Whether to start a new training session. The default is--New_Train=False
, which will loadparams.pth
pretrained weights (if available) for training, and the tensorboard curve will also load from the last training session.2. Environment Parameters: Configure the selected environment.
--map(str)="citystreet-day.wbt"
: Configure the training/testing environment scene type, selected from./Webots_Simulation/traffic_project/worlds/*.wbt
.3. Model Parameters: Configure parameters related to model weight import and saving.
--savepath(str)="params.pth"
: Path where the model will be saved.4. Visualization Parameters: Configure parameters related to visualization files.
--tensorboard_port(int)=1
: Whether to use tensorboard-logger. If--tensorboard_port(int)!=-1
, a random available port will be assigned; otherwise, it will not be enabled (not enabled in testing mode).
Code Encapsulation and Modification Details
D-VAT adopts an Actor-Critic asymmetric framework. The schematic diagrams of the Actor-Critic symmetric and asymmetric architectures are as follows:
The D-VAT code mainly implements the following:
Custom Environment Class
The environment class is modeled after the Gym environment class to independently implement theDVAT_ENV
environment class, with the specific code seen inAlg_Base/DAT_Benchmark/models/D_VAT/DVAT_envs.py
To support the Actor-Critic asymmetric architecture, the state space of theDVAT_ENV
environment class includes two parts:actor_obs
andcritic_obs
:self.observation_space = gymnasium.spaces.Dict({ "actor_obs": gymnasium.spaces.Box(low=0, high=1.0, shape=(obs_buffer_len,)+image_shape, dtype=np.float32), "critic_obs": gymnasium.spaces.Box(low=-np.inf, high=np.inf, shape=(9,), dtype=np.float32), })In the original gymnasium environment, the actor and critic share the same state space:
if obs_buffer_len == 0: self.observation_space = spaces.Box(low=0, high=1.0, shape=(env_conf["State_channel"], env_conf["State_size"], env_conf["State_size"]), dtype=np.float32) else: self.observation_space = spaces.Box(low=0, high=1.0, shape=(obs_buffer_len, env_conf["State_channel"], env_conf["State_size"], env_conf["State_size"]), dtype=np.float32)
Custom Implemented Policy
To support the Actor-Critic asymmetric architecture, the
DVAT_SACDPolicy
is implemented by inheriting from Tianshou’sDiscreteSACPolicy
.
Custom Implemented Parallel Class and Collector
Custom parallel environment classes and collectors are implemented by referring to Async_SubprocVecEnv & SubprocVecEnv_TS.
log file
See the tensorboard-logger file in the folderAlg_Base/DAT_Benchmark/models/D_VAT/DVAT_logs
During the model training process, as long as the parameter
--tensorboard_port(int)!=-1
is set, the tensorboard-logger will automatically start when the program is launched.However, if you want to manually start tensorboard, you can also use the following command:
cd Alg_Base/DAT_Benchmark/models/D_VAT/ tensorboard --logdir DVAT_logs
Baseline3 R-VAT
Alg_Base/DAT_Benchmark/models/R_VAT
folder
Quick-Start
cd Alg_Base/DAT_Benchmark/
# Test mode
# Test with Cumulative Reward (CR)
python ./models/R_VAT/RVAT.py -w 1 -m citystreet-day.wbt --train_mode 0 --Test_Mode AR
# Test with Tracking Success Rate (TSR)
python ./models/R_VAT/RVAT.py -w 1 -m citystreet-day.wbt --train_mode 0 --Test_Mode TSR
# New training mode
python ./models/R_VAT/RVAT.py -w 35 -m citystreet-day.wbt --train_mode 1 --New_Train
# Resume training mode
python ./models/R_VAT/RVAT.py -w 35 -m citystreet-day.wbt --train_mode 1
Program Entry
RVAT.py
is the entry point for the entire program, providing both training and testing modes (for specific configurations, refer to the parameter settings below).1. Runtime Parameters: Configure system mode, execution frequency, device, and other parameters
--workers(int)=35
: The number of parallel environments for training (should be determined by the actual memory/VRAM of the computer).
--train_mode(int)=1
: Configure whether the mode is training mode (--train_mode=1
is training mode,--train_mode=0
is testing mode).
--port(int)=-1
: Communication port between environment and algorithm, default--port=-1
will randomly use an available port, manual modification is not recommended.
--New_Train(bool)=False
: Whether to start a completely new training, default is--New_Train=False
, which loadsparams.pth
pre-trained weights (if available) for training, and tensorboard curves will be loaded from the last training session.2. Environment Parameters: Configure the selected environment
--map(str)="citystreet-day.wbt"
: Configure the training/testing environment scene type, selected from./Webots_Simulation/traffic_project/worlds/*.wbt
.3. Model Parameters: Configure model weight import and save related parameters
--savepath(str)="./models/R_VAT/params.pth"
: Path where the model will be saved.4. Visualization Parameters: Configure parameters related to visualization files
--tensorboard_port(int)=1
: Whether to use tensorboard-logger,--tensorboard_port(int)!=-1
will randomly assign an available port, otherwise it will not be enabled (not enabled in testing mode).
Curriculum Learning
Our method adds curriculum learning based on R-VAT, and in the first and second stages of curriculum learning, the intelligent experience tracks monochrome and colored cars respectively under simple settings (unobstructed situation, car only travels straight)
For convenience, we extracted the simple settings as maps, i.e.
./Webots_Simulation/traffic_project/worlds/simpleway-*.wbt
maps.Therefore, the third stage is trained on the corresponding
simpleway-*.wbt
.After completing task understanding, the third stage of training can be conducted on visually challenging maps to learn complex visual features.
For example, for the map
citystreet-day.wbt
, the commands for the third stages are as follows:cd Alg_Base/DAT_Benchmark/ # Stage2: Task Understanding python ./models/R_VAT/RVAT.py -w 35 -m simpleway-grass.wbt --train_mode 1 --New_Train # Stage3: Visual Generalization python ./models/R_VAT/RVAT.py -w 35 -m citystreet-day.wbt --train_mode 1 # Testing Mode # Test with Cumulative Reward (CR) python ./models/R_VAT/RVAT.py -w 1 -m citystreet-night.wbt --train_mode 0 --Test_Mode AR # Test with Tracking Success Rate (TSR) python ./models/R_VAT/RVAT.py -w 1 -m citystreet-night.wbt --train_mode 0 --Test_Mode TSR
The comparison table of the two-stage maps is as follows:
Origin Map
Simple Map
citystreet-day
simpleway-grass_day
citystreet-night
simpleway-grass_night
citystreet-foggy
simpleway-grass_foggy
citystreet-snow
simpleway-city_snow
desert-day
simpleway-desert_day
desert-night
simpleway-desert_night
desert-foggy
simpleway-desert_foggy
desert-snow
simpleway-desert_snow
downtown-day
simpleway-clinker_day
downtown-night
simpleway-clinker_night
downtown-foggy
simpleway-clinker_foggy
downtown-snow
simpleway-city_snow
farmland-day
simpleway-farm_day
farmland-night
simpleway-farm_night
farmland-foggy
simpleway-farm_foggy
farmland-snow
simpleway-farm_snow
lake-day
simpleway-lake_day
lake-night
simpleway-lake_night
lake-foggy
simpleway-lake_foggy
lake-snow
simpleway-lake_snow
village-day
simpleway-village_day
village-night
simpleway-village_night
village-foggy
simpleway-village_foggy
village-snow
simpleway-village_snow
log file
The tensorboard-logger file can be found in the folderAlg_Base/DAT_Benchmark/models/${model_name}_logs
During the model training process, as long as the parameter
--tensorboard_port(int)!=-1
is set, the tensorboard-logger will automatically start when the program is launched.However, if you wish to start tensorboard manually, you can use the following command:
cd Alg_Base/DAT_Benchmark/ tensorboard --logdir models/${model_name}_logs
Reference: