-
Notifications
You must be signed in to change notification settings - Fork 952
Description
Describe the bug
使用grpo async mode微调模型出错 出现了数据的格式问题:dacite.exceptions.WrongTypeError: wrong value type for field "tools" - should be "typing.Optional[typing.List[typing.Dict[str, typing.Union[str, typing.Dict]]]]" instead of value "<generator object _build_value_for_collection.. at 0x7f88ab487740>" of type "str"
colocate mode没有问题
参照文档里agent微调设置的候选工具列表,是一个json str解析后是list[dict]
debug了下是因为tools字段dumps时候先dump成json str 又嵌套在内层dump了一次,导致又转义符,被RequestData解析的时候就出现的错误
我尝试tools 不dump成json str和整个数据一起dump, 在dataset预处理阶段就会报错
tool str如下:
···
[{"type": "function", "function": {"name": "currency.converter.get_rates", "description": "Retrieve current exchange rates for multiple currencies relative to a base currency", "parameters": {"type": "object", "properties": {"base_currency": {"type": "string", "description": "The 3-letter currency code to use as the reference (e.g., USD)"}, "target_currencies": {"type": "array", "items": {"type": "string"}, "description": "List of 3-letter currency codes to get rates for (e.g., [\\"EUR\\", \\"JPY\\"])"}, "use_offline_data": {"type": "boolean", "description": "Whether to use offline exchange rate data when internet is unavailable"}}, "required": ["base_currency", "target_currencies"]}}}, {"type": "function", "function": {"name": "currency.converter.download_rates", "description": "Download exchange rate data for offline use during travel", "parameters": {"type": "object", "properties": {"currencies": {"type": "array", "items": {"type": "string"}, "description": "List of 3-letter currency codes to download rates for"}}, "required": ["currencies"]}}}, {"type": "function", "function": {"name": "currency.converter.generate_chart", "description": "Generate a chart image URL showing historical exchange rate trends", "parameters": {"type": "object", "properties": {"from_currency": {"type": "string", "description": "The 3-letter currency code of the original currency"}, "to_currency": {"type": "string", "description": "The 3-letter currency code of the target currency"}, "period": {"type": "string", "enum": ["7d", "30d", "90d", "1y"], "description": "Time period to display in the chart"}}, "required": ["from_currency", "to_currency", "period"]}}}, {"type": "function", "function": {"name": "currency.converter.get_history", "description": "Retrieve historical exchange rates for a currency pair over a specified period", "parameters": {"type": "object", "properties": {"from_currency": {"type": "string", "description": "The 3-letter currency code of the original currency"}, "to_currency": {"type": "string", "description": "The 3-letter currency code of the target currency"}, "start_date": {"type": "string", "description": "Start date for the historical data range"}, "end_date": {"type": "string", "description": "End date for the historical data range"}, "interval": {"type": "string", "enum": ["daily", "weekly", "monthly"], "description": "Time interval between data points"}}, "required": ["from_currency", "to_currency", "start_date", "end_date"]}}}, {"type": "function", "function": {"name": "price.tracker.get_history", "description": "Retrieve the historical price data for a tracked product within a specified time period", "parameters": {"type": "object", "properties": {"product_url": {"type": "string", "description": "URL of the tracked product"}, "time_period": {"type": "string", "description": "Time range to retrieve price history (e.g., \'past 30 days\', \'last month\')"}}, "required": ["product_url", "time_period"]}}}]
···
Your hardware and system info
8 x H20
Additional context
Add any other context about the problem here(在这里补充其他信息)
实验命令
···
half_gpu=$((n_gpu / 2))
rollout_gpus=$(seq -s, 0 $((half_gpu - 1)))
training_gpus=$(seq -s,
echo "Detected $n_gpu GPUs."
echo "Rollout GPUs: $rollout_gpus"
echo "Training GPUs: $training_gpus"
CUDA_VISIBLE_DEVICES=$rollout_gpus
swift rollout
--model $pretrain_model_path --use_hf false
--vllm_tensor_parallel_size $half_gpu
--vllm_data_parallel_size 1 --vllm_mm_processor_cache_gb 0 & > rollout.log 2>&1
export NPROC_PER_NODE=$half_gpu
CUDA_VISIBLE_DEVICES=$training_gpus
swift rlhf
--rlhf_type grpo --use_vllm true
--vllm_mode server
--vllm_server_host 127.0.0.1
--vllm_server_port 8000
--train_type full
--torch_dtype bfloat16
--vllm_server_timeout 360 --async_generate true
--model $pretrain_model_path --use_hf false
--reward_funcs st_fc
--reward_weights 1 --response_prefix "\n\n\n\n"
--train_type full --gradient_checkpointing true --deepspeed zero2
--torch_dtype bfloat16
--dataset /mnt/bn/llm3d/agent_data/data_synth/gen_data/weisiyuan.buaa/20251103.zzs_llm_synth_tool_metas.20251015_clustered_0.95.grpo.jsonl
--load_from_cache_file false
--external_plugins ./rewards.py
--max_completion_length 1024
--num_train_epochs 1
--per_device_train_batch_size $bs
--per_device_eval_batch_size $bs
--learning_rate 5e-6 --beta 0.001 --temperature 0.3
--gradient_accumulation_steps $grad_acc_step
--eval_steps 100
--save_steps 5000
--save_total_limit 2
--logging_steps 2
--max_length 16384
--output_dir $model_dir --add_version false
--warmup_ratio 0.05
--dataloader_num_workers 4
--dataset_num_proc 4
--num_generations 8
--log_completions true --report_to wandb --run_name $exp_name
···