强化学习是一种基于用户自定义反馈信号,对模型进行效果优化的技术。与有监督微调类似,该技术可针对特定任务定制模型能力;二者的核心区别在于,强化学习无需依赖固定的 “标准答案” 开展训练,而是通过可编程评分器对模型生成的所有候选回复进行打分,再由训练算法调整模型权重,最终实现高评分输出的概率提升、低评分输出的概率降低。
火山方舟强化学习将以上环节进行了一定程度的封装以降低强化学习复杂度,推出 精调 SDK 为 AgentRL 场景提供一站式支持,开发者只需聚焦核心业务逻辑开发 —— 包括 Rollout 流程中的 Agent 交互逻辑设计、奖励函数构建、训练样本筛选与格式转换,即可高效完成定制化强化学习精调任务。
使用精调SDK进行强化学习流程分为以下9步:
意义
一句话总结:对效果和泛化能力要求高的场景,RL相比PE、SFT等有更高的上限。
日益复杂的业务场景对模型的效果提出了越来越高的要求。同时,真实业务对模型的诉求不仅是在评估集的效果提升,更需要端到端服务于业务指标的提升。相比其他效果调优方式,RL有以下优势:
时机
当任务有以下特点时,强化学习(RL)能起到优化作用,建议优先考虑使用:
强化学习是机器学习的一个分支,模型通过 "行动 - 反馈 - 调整" 的循环过程,最大化未来获得的奖励信号。与记忆单个样本的 "正确答案" 不同,模型会探索多个可能答案,接收每个答案的数值奖励,并逐步调整行为模式,使高奖励答案的生成概率提升,低奖励答案逐渐淘汰。经过多轮迭代后,模型会收敛至最优策略 —— 即符合奖励信号定义的输出选择规则。
在强化学习中,奖励信号由用户为特定任务自定义的评分器提供。对于数据集中的每个提示词,平台会采样多个候选答案,通过评分器完成量化评分,并执行策略梯度更新,使模型向高分答案方向优化。这种 "采样 - 评分 - 更新" 的循环会在数据集上反复执行,直至模型能够稳定输出符合评分器质量标准的结果。评分器可编码用户关注的任意指标(准确性、风格、安全性等),因此最终的微调模型会体现这些核心诉求,且无需用户管理强化学习底层基础设施。
目前精调SDK支持的模型及训练方式如下:
foundation_model | customization_type(训练方式) | ||||||
|---|---|---|---|---|---|---|---|
name
| model_version
| FinetuneSft
| FinetuneLoRA
| DPO
| DPOLoRA
| GRPO
| GRPOLoRA
|
doubao-seed-1-6-flash | 250828【推荐】 | ✅已支持 | ✅已支持 | ✅已支持 | ✅已支持 | ✅已支持 | ✅已支持 |
doubao-seed-1-6 | 250615【推荐】 | ✅已支持 | ✅已支持 | ✅已支持 | ✅已支持 | ✅已支持 | ✅已支持 |
创建强化学习精调任务前请备好精调数据集,任务参数配置时,需在data字段中指定训练集与验证集。
采用JSONL 格式,训练数据文件的每一行包含一个完整的 JSON 结构。数据集格式要求详细说明请参见模型精调数据集格式说明
建议从小规模数据集起步,先使用几十到几百条样本验证强化学习精调方案的有效性,再投入资源扩充数据规模。只要保证数据质量,几十条样本即可产生有效的精调效果。在维持高质量的前提下,数据量越大精调效果越优;且大规模数据集支持配置更大的批次大小(batch size),有助于提升训练稳定性。
目前提供以下两种示例,你可以根据需求来初始化对应的示例。
template 模版名 | 简介 |
|---|---|
rl_search_mcp_demo | 模板通过强化学习微调大型语言模型,优化其在深度搜索(Deep Search)场景下的性能。经训练后,模型增强了对复杂搜索意图的理解能力,可高效准确调用MCP/外部搜索API,进而生成高质量且精准的搜索结果与答案。 |
rl_demo | 该模板通过强化学习,使模型精准掌握自定义函数调用天气工具的时机与方式,实现更精准流畅的天气问答功能。可按需扩展至强化学习微调场景:通过强化学习微调大型语言模型,使其通过对话(Chat)API智能结合自定义工具完成特定功能。 |
本 Demo 核心目标:通过强化学习精调,使模型掌握智能调用外部搜索工具(API)的能力 —— 面对自身无法直接解答的复杂问题时,可连续调用工具获取信息并整合,最终生成完整答案。
开通模型服务、精调算力计费项 | 开通云产品及授权 |
|---|---|
您可以通过以下 pip 命令安装精调 SDK。
运行环境:python>=3.10
pip install https://ark-public-example-cn-beijing.tos-cn-beijing.volces.com/ark-sdk/ark_sdk-0.2.11.tar.gz
您需要授权将SDK终端关联到指定的账号和项目,具体操作如下:
ark login命令开启授权过程通过以下命令,可在指定文件夹下快速初始化一个包含基础结构和配置、可立即使用的精调项目。ark init workspace <文件夹名> --template <模版名>
说明:<文件夹名> 替换为实际项目文件夹名称,例如:
ark_rl_project,本demo模版名为rl_search_mcp_demo。
ark init workspace ark_rl_project --template rl_search_mcp_demo #初始化项目命令执行成功后工作区结构如下: #<ark_rl_project> #├── data #│ └── search_dataset_dev_100.jsonl #│ └── search_dataset.jsonl #├── plugins #│ ├── draft_rollout_arkitect.py #│ └── rollout.py #│ └── llm_grader.py #│ └── test_utils.py #├── job.py #├── README.md #└── requirements.txt #└── test_faas.py #└── ......
成功执行后将显示:
初始化项目后,通过cd ark_rl_project命令进入项目。为了让模型学会使用外部工具,我们引入了“插件(Plugins)”机制主要包含两类核心组件:Rollout Plugin(执行者)与Grader Plugin(评分者)。本章节将结合rl_search_mcp_demo中的示例,详解两类插件的实现逻辑、修改方法及开发注意事项。
强化学习的本质是“模型通过反馈持续优化”,而Plugin函数正是实现这一闭环的关键:
在 rl_search_mcp_demo/plugins/ 目录下,我们为您提供了这两个核心组件的示例实现。接下来,我们将深入解析本demo中的这两个插件函数。为进一步规范开发,可按需查看Rollout签名要求和Grader签名要求。同时方舟还提供了多种Rollout函数模版和常用Grader函数,能帮您拓展思路。
Demo中plugins/rollout.py已实现“复杂网络搜索场景”的完整交互逻辑。
Rollout函数实现代码如下:
MAX_STEPS = 6 TOOL_CALL_MAX_RETRY = 10 TOOL_CALL_TIMEOUT = 20.0 # 单次工具调用最大时长(秒 rate_limiter = AsyncLimiter(max_rate=20, time_period=1) LOCAL_TEST_MODEL = "doubao-seed-1-6-flash-250615" search_tool = { "type": "function", "function": { "name": "web_search", "description": "互联网搜索", "parameters": { "type": "object", "properties": { "Query": { "type": "string", "description": "搜索的关键词,1~100个字符(过长会截断),不支持多词搜索", }, "Count": { "type": "integer", "description": "搜索结果数量", "default": 10, }, }, "required": ["Query"], }, }, } @rollout( name="code_sandbox_grader_single_grader", description="code_sandbox_grader", runtime=Runtime( instance=PluginInstance.CPU2MEM4, timeout=900, ), ) @coze_monitor async def demo_rollout( context: PluginContext, proxy: RolloutInferenceProxy, sample: ChatCompletionSample, ) -> Optional[RolloutResult]: """ 训练时,每个样本会调用 n_sample 次rollout函数,每次调用rollout函数当前仅支持返回一条轨迹(wrap_client_inplace的client最后一次调用上下文) """ # NOTE: 由于search mcp的封装,无法修改tool定义,使用直接调用api的方式 # from arkitect.core.component.tool.tool_pool import ToolPool, build_tool_pool # from arkitect.core.component.tool.builder import build_mcp_clients_from_config # mcp_clients, cleanup = build_mcp_clients_from_config(CONFIG_FILE_PATH) # tool_pool = build_tool_pool(list(mcp_clients.values())) # tools = [tool.model_dump() for tool in await tool_pool.list_tools()] tools = [search_tool] schema = search_tool["function"]["parameters"] search_api_key = os.getenv("SEARCH_API_KEY") assert search_api_key, "env SEARCH_API_KEY not set" async with httpx.AsyncClient( base_url="https://open.feedcoopapi.com", headers={ "Authorization": f"Bearer {search_api_key}", "Content-Type": "application/json", }, ) as tool_client: req = sample.model_dump() messages = req.pop("messages") client = AsyncArk(base_url=proxy.url, api_key=proxy.jwt_token) wrap_inplace_trace(client) wrap_async_client_inplace(client, proxy=proxy) step = 0 while True: completion = await client.chat.completions.create( # model 字段仅在本地测试时生效 model=LOCAL_TEST_MODEL, messages=messages, tools=tools, # NOTE: 为了避免训练效果受影响,被wrap后的client无法在此处指定采样参数或thinking type, 需要在样本或任务超参中指定 ) messages.append(completion.choices[0].message.model_dump()) step += 1 if step >= MAX_STEPS: break if completion.choices[0].finish_reason != "tool_calls": break tool_calls = completion.choices[0].message.tool_calls for tool_call in tool_calls or []: try: name = tool_call.function.name params = json.loads(tool_call.function.arguments) assert name == "web_search", f"tool name not web_search: {name}" jsonschema.validate(instance=params, schema=schema) except ( json.JSONDecodeError, AssertionError, jsonschema.ValidationError, ) as e: # NOTE: extra字段用于传递额外的信息给reward函数,此处通知reward需要惩罚为-1分 logger.error(f"tool call format check failed: {repr(e)}") return RolloutResult( status=PluginStatus.SUCCESS, extra={ "reward": -1, }, ) res = await exec_tool_call( tool_client=tool_client, name=name, arguments=params, ) messages.append( { "role": "tool", "content": res, "tool_call_id": tool_call.id, } ) return RolloutResult( status=PluginStatus.SUCCESS, ) @backoff.on_exception( backoff.expo, RolloutRetryException, max_time=60, jitter=backoff.full_jitter, # full jitter base=2, factor=0.4, max_value=20, ) @observe( process_inputs=lambda d: {key: d["kwargs"][key] for key in ["name", "arguments"]}, ) # type: ignore async def exec_tool_call( tool_client: httpx.AsyncClient, name: str, arguments: dict, ) -> str | list[ChatCompletionContentPartParam]: # from arkitect.core.component.tool.utils import convert_to_chat_completion_content_part_param # if isinstance(tool_client, ToolPool): # res = await tool_client.execute_tool(name, arguments) # content = convert_to_chat_completion_content_part_param(res) # return content async with rate_limiter: try: resp = await asyncio.wait_for( tool_client.post( "search_api/web_search", json={**arguments, "SearchType": "web"}, ), timeout=TOOL_CALL_TIMEOUT, ) data = resp.json() resp.raise_for_status() if data.get("Result"): return json.dumps(data, ensure_ascii=False) # NOTE: 可以选择一些错误返回RETRY进行rollout重试,重试达到上线后导致任务失败,默认exception为Failure会重试后拉挂任务 raise RolloutRetryException( f"web_search tool call got resp without Result: {data}" ) except asyncio.TimeoutError: raise RolloutRetryException( f"tool call timeout after {TOOL_CALL_TIMEOUT}s" ) except httpx.HTTPStatusError as e: raise RolloutRetryException( f"tool call HTTPStatusError: {e.response.status_code} {e.response.text}" )
Demo中plugins/llm_grader.py已实现LLM-as-a-Judge。
如果您觉得当前的 llm_grader 提示词不完全符合您的评分标准,您可以修改其内部的 Prompt,以更精确地定义您的业务场景下“好”与“坏”的标准。同时方舟提供多种常用Grader函数,您可以根据任务的复杂度和对精度的要求来选择或修改评分逻辑。
Grader函数实现代码如下:
# NOTE: 避免thinking模型因为重复输出等问题影响训练 MAX_COMPLETION_TOKENS = 4096 SYSTEM_PROMPT = """你是一位严谨的内容评估专家。 你的核心任务是:依据提供的【问题】和一份【正确答案列表】(列表中的每个答案都被视为该问题的有效解答),来判断我提供的【预测答案】是否正确回答了该【问题】。你应该首先给出你的判断依据,然后给出你的判断结果(即“正确”或“错误”)。 判断标准如下: 1. 【预测答案】不需要与【正确答案列表】中的任何答案在字面上完全相同,但应该在语义上相同。 2. 【正确答案列表】中的每个答案都可以被视为问题的正确答案,【预测答案】应该至少与其中之一在语义上是相同的。 3. 你必须仔细阅读【问题】和【预测答案】,并仔细考虑【预测答案】是否真的正确回答了【问题】,不可以因为【预测答案】中包含了正确答案的字眼而认为【预测答案】正确。 4. 对于模型没有认真回答问题,只是通过列举很多答案骗分的情况,请你判断为错误。 """ USER_PROMPT_TEMPLATE = """ ###问题### {} ###正确答案列表### {} ###预测答案### {} 现在请开始你的判断: """ @group_grader( name="llm_group_grader", description="llm_group_grader", runtime=Runtime( instance=PluginInstance.CPU1MEM2, max_concurrency=32, timeout=900, ), ) async def llm_grader( context: PluginContext, sample: ChatCompletionSample, trajectories: List[Trajectory], ): if ( not sample.extra or "answer" not in sample.extra or sample.messages[-1].role != "user" ): return GroupGraderResult( rewards=[], metrics={}, status=PluginStatus.DISCARD, error="sample is not valid", ) ark = AsyncArk() reward_list = [] tasks = [] @backoff.on_exception( wait_gen=backoff.expo, exception=( ArkAPIConnectionError, ArkRateLimitError, ArkInternalServerError, json.JSONDecodeError, ), max_time=100, jitter=backoff.full_jitter, # full jitter base=2, factor=0.4, max_value=20, # 退避重试最大时长(秒) ) async def single_grader(traj: Trajectory) -> float: # 返回rollout过程中提前设置的reward if traj.extra and "reward" in traj.extra: return traj.extra["reward"] format_score = validate_tool_calls(traj) if format_score is not None: logger.info( f"prompt_id: {sample.extra.get('prompt_id')} eval_res: {format_score} format check failed. {traj.messages}" ) return format_score resp = await ark.chat.completions.create( model="doubao-seed-1-6-250615", max_completion_tokens=MAX_COMPLETION_TOKENS, temperature=0, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, { "role": "user", # NOTE: 仅reward最后的总结,并未考虑历史tool_calls "content": USER_PROMPT_TEMPLATE.format( sample.messages[-1].content, sample.extra["answer"], traj.messages[-1].content, ), }, ], response_format={ "type": "json_schema", "json_schema": { "name": "评估结果", "description": "根据你作为内容评估专家的角色和评估标准进行评估的结果", "schema": { "type": "object", "properties": { "rationale": { "type": "string", "description": "详细判断理由", }, "judgement": { "type": "string", "enum": ["正确", "错误"], "description": "判断结果", }, }, "required": ["rationale", "judgement"], }, "strict": True, }, }, ) eval_res = json.loads(resp.choices[0].message.content) logger.info(f"prompt_id: {sample.extra.get('prompt_id')} eval_res: {eval_res}") if eval_res["judgement"] == "正确": return 1.0 else: return 0.0 for traj in trajectories: tasks.append(asyncio.create_task(single_grader(traj))) results = await asyncio.gather(*tasks, return_exceptions=True) for res in results: if isinstance(res, Exception): return GroupGraderResult( rewards=[], metrics={}, status=PluginStatus.FAILURE, error=str(res) ) reward_list.append(res) return GroupGraderResult( rewards=reward_list, metrics={}, status=PluginStatus.SUCCESS, error="" ) def validate_tool_calls(traj: Trajectory) -> Optional[float]: # 如果最后一轮是tool_calls则视为参数不对或者超出轮次,0分惩罚 if traj.messages[-1].tool_calls: for tool_call in traj.messages[-1].tool_calls: try: params = json.loads(tool_call.function.arguments) assert tool_call.function.name == "web_search", ( f"tool name not web_search: {tool_call.function.name}" ) jsonschema.validate( instance=params, schema={ "type": "object", "properties": { "Query": { "type": "string", }, "Count": { "type": "integer", }, }, "required": ["Query"], }, ) except Exception as e: return -1.0 return 0.0 if not traj.messages[-1].content: return -1.0 # 如果模型没有执行工具直接回答,-1分惩罚 if all([not msg.tool_calls for msg in traj.messages]): return -1.0 return None def main(): import asyncio sample = ChatCompletionSample( **{ "messages": [ { "role": "user", "content": "通过景栗科技的私域运营服务和与薪勤科技的产品共创,哪两个公司在各自的领域实现了用户增长或应用上架?", } ], "thinking": {"type": "enabled"}, "extra": {"answer": "景栗科技和薪勤科技"}, } ) trajectories = [ Trajectory( **{ "messages": [ { "role": "assistant", "content": "景栗科技的私域运营服务和薪勤科技的产品共创实现了用户增长或应用上架。", } ], "finish_reason": "stop", "usage": {}, } ), Trajectory( **{ "messages": [ { "role": "assistant", "content": "景栗科技", } ], "finish_reason": "stop", "usage": {}, } ), ] res = asyncio.run(llm_grader({}, sample, trajectories)) logger.info(f"grader done with result: {res}") if __name__ == "__main__": main()
async定义协程并发函数时,需避免同步长阻塞函数调用,防止并发效率下降。Ark(),应该使用AsyncArk()。若使用@rollout()装饰器定义异步函数,示例如下:@rollout() async def demo_rollout(): # 函数逻辑(需调用AsyncArk()等异步函数)
async关键字,此时会自动启用线程并发,定义示例如下:@rollout() def demo_rollout(): # 函数逻辑(可调用同步长耗时函数)
在强化学习微调任务中,可观测性配置至关重要,其能支撑训练过程分析、效果量化及问题定位排查。因此,在测试自定义 plugin 函数或提交强化学习精调任务前,需优先完成可观测性配置。关于以下三种可观测性功能的详细开启方法和具体配置步骤详见可观测性配置。
通过在方舟控制台将模型的完整决策轨迹进行可视化,让您能直观地分析模型的每一步行为以及得分。具体效果如下图所示。
允许您使用 logger 在自己的 Rollout 和 Reward 函数中记录关键信息,用于代码级别的精细化调试。
系统内置效果指标:平台默认提供强化学习核心评估指标,无需额外配置,系统会自动采集并计算,直观反映模型的基础训练效果与泛化能力。
用户自定义指标:支持用户在 Grader 函数返回RewardFunctionResult对象时,在其metrics字段中补充自定义评估指标。
两类指标均会统一展示在方舟控制台模型精调的训练观测功能模块中。具体效果如下图所示。
利用专业的追踪系统,端到端地追踪从模型推理到工具调用的每一环节的输入、输出及耗时。以Cozeloop Tracing系统为例,效果如下图所示。
在提交资源消耗较高的强化学习精调任务前,需依次完成本地测试与在线 FaaS 测试,验证自定义 Plugin 函数(Rollout、Grader)的功能、交互逻辑及性能,这是避免代码错误导致训练失败、节约时间与计算资源的关键步骤。
支持使用单样本调试进行功能验证与多样本批量测试进行并发压测,确保函数在真实训练场景的并发压力下可稳定运行。
示例代码中plugins/rollout.py末尾提供了开箱即用的main测试函数,支持单样本、多样本调试模式。测试时可按需选择调试模式(单样本、多样本或两者并行),通过注释代码灵活控制,按照下述操作步骤执行测试并记录调用过程。
test_with_dataset函数中的max_concurrent参数,设置为与训练配置的batchsize一致,模拟真实训练场景的并发压力;LOCAL_TEST_MODEL变量,更换成被训练的基础模型名称;完成本地测试后,建议进行在线 FaaS 测试,模拟训练任务的真实运行环境,步骤如下:
requirements.txt:
uv pip freeze > requirements.txt固定间接依赖版本,过滤冗余依赖,精简依赖列表。test_faas.py文件,拉起在线运行环境对 Plugin 函数进行测试。在rl_search_mcp_demo示例中,job.py是强化学习精调任务的核心配置入口,可集中设置模型选择、精调类型、算法超参、数据及Plugin定义等关键参数,直接控制训练全流程。
提交训练前,需根据实际需求调整该文件配置,修改前请仔细阅读完整参数说明精调参数的配置。部分重要配置项简述如下:
enable_trajectory=True,该配置对强化学习至关重要,强烈建议开启。ark命令行工具查询指定模型版本的超参信息:
#命令语法 ark get foundation-model --model <模型名> --version <模型版本> --fields hyperparameters <指定查询超参维度> ark get foundation-model --model doubao-seed-1-6 --version 250615 --fields hyperparameters
确认以上各项无误后,执行以下命令即可提交任务:
python job.py
提交后的精调任务,您可通过 精调控制台 或 CLI命令 两种方式查看精调任务详情,包括任务概览、训练观测、轨迹分析、日志、时间线、模型产出、精调安全审计信息。也可以对处于不同阶段的精调任务执行终止、停止、复制、删除等操作。
对于强化学习精调任务,核心监控指标为奖励得分,该指标反映模型reward得分随训练步数的变化,由任务配置中定义的评分器(Grader)计算得出。主要包含两个奖励指标:
train/reward->final_reward.mean:当前训练步骤中,最终奖励平均值。由于每个训练步骤的批次数据会动态变化,因此不同步骤间的 train/reward->final_reward.mean不具备直接可比性,数值可能出现大幅波动。test/reward->final_reward.mean:配置验证集后将展示验证数据样本的平均奖励值,该指标的数值表现更稳定,是评估模型泛化能力的核心依据。若要查看所有训练指标的说明和展示,在 模型精调 页面点击精调任务名称,进入精调任务详情页,在 训练观测 页签查看。
奖励指标趋势图
若发现异常,可在指标上手动选中step区间,查看该区间的均值、峰值、谷值。点击 “轨迹分析” ,可跳转至 轨迹分析 页签查看选中步骤范围的轨迹详情。
可自由选择step范围进行轨迹分析,Step列表展开状态时,Reward 分布图可直观了解各step的Reward分布情况, 点击具体step右侧列表会展示该step中所有轨迹。点击轨迹条目即可查看该轨迹的详情。
轨迹分析功能效果图
精调产物是训练过程中生成的可调用模型版本,通过实际使用产物,能更直观地验证模型在业务场景中的表现。
选择精调产物导出至模型仓库 ,进行精调后模型的使用,支持量化、体验、评测、在线推理、批量推理,操作见管理自定义模型。
本 Demo 核心目标:通过强化学习微调 doubao-seed-1-6-flash 模型,使其能精准调用天气工具回答用户问题。
若需基于本 Demo 完成完整的强化学习精调任务,整体流程可完全参考前文 Demo1 Deep Search RL 精调(基于 rl_search_mcp_demo 模板)的步骤(包括 SDK 安装与环境准备、项目初始化、可观测性配置、Plugin 函数测试、训练参数配置、任务提交、监控评估及产物使用等全环节)。
本 Demo 与 Deep Search Demo 的核心差异在于 plugins 目录下Rollout 函数与 Grader 函数的功能实现逻辑,以下重点介绍本 Demo 中 Plugin 函数的核心功能与实现特点:
本 Demo 提供 weather_rollout.py 和 raw_rollout.py 两个模版,均用于实现模型与天气工具的交互流程,生成训练轨迹。
利用精调SDK提供的封装,通过 Ark() 获取客户端,训练逻辑(包括状态更新和完成处理)由SDK自动管理。以下是示例代码:
@rollout( name="demo_rollout", runtime=Runtime( instance=PluginInstance.CPU1MEM2, max_concurrency=100, min_replicas=1, max_replicas=10, timeout=900, ), ) @coze_monitor def demo_rollout( # sync的函数,会由@rollout转换成线程并发的async函数来提升运行效率 context: PluginContext, proxy: RolloutInferenceProxy, sample: ChatCompletionSample, ) -> Optional[RolloutResult]: """ 训练时,每个样本会调用 n_sample 次rollout函数,每次调用rollout函数当前仅支持返回一条轨迹(wrap_client_inplace的client最后一次调用上下文) """ # 使用训练感知客户端 - 仅此一行不同! client = Ark(base_url=proxy.url, api_key=proxy.jwt_token) # 可以使用openai的client # from openai import OpenAI # from ark_sdk.core.plugin.rollout.proxy import wrap_client_inplace # client = OpenAI(base_url=proxy.url, api_key=proxy.jwt_token) wrap_inplace_trace(client) wrap_client_inplace(client, proxy=proxy) req = sample.model_dump() messages = req.pop("messages") tools = (req.pop("tools") or []) + rollout_tools step = 0 while True: # 步骤2: 发起模型请求,由于模型在收到工具执行结果后仍然可能有工具调用意愿,因此需要多次请求 completion = client.chat.completions.create( # model 字段仅在本地测试时生效 model=LOCAL_TEST_MODEL, messages=messages, tools=tools, ) # 注意:训练逻辑已经自动处理,无需手动判断模式或调用process_completion! messages.append(completion.choices[0].message.model_dump()) step += 1 logger.info(f"completion: {completion.choices[0].message.model_dump()}") if step >= 30: break if completion.choices[0].finish_reason != "tool_calls": # 模型最终总结,没有调用工具意愿 break tool_calls = completion.choices[0].message.tool_calls for tool_call in tool_calls or []: tool_name = tool_call.function.name if tool_name == "get_current_weather": # 步骤 3:调用外部工具 try: args = json.loads(tool_call.function.arguments) tool_result = get_current_weather(**args) except Exception as e: logger.error(f"get_current_weather error: {e}") # 重试,rollout请求,超过重试次数会DISCARD掉本条数据(包裹所有n smaple) # return RolloutResult(status=PluginStatus.RETRY, error=str(e)) # 丢弃,丢弃本样本(包裹所有n smaple) # return RolloutResult(status=PluginStatus.DISCARD, error="") # 失败,返回FAILURE重试本样本,超过重试次数会拉挂任务。如果不处理exception,默认会返回FAILURE # return RolloutResult(status=PluginStatus.FAILURE, error=str(e)) # extra字段用于传递额外的信息给reward函数,此处通知reward需要惩罚为-1分 return RolloutResult( status=PluginStatus.SUCCESS, extra={ "reward": -1, }, ) # 步骤 4:回填工具结果,并获取模型总结回复 messages.append( { "role": "tool", "content": tool_result, "tool_call_id": tool_call.id, } ) # 失败,丢弃本样本(包裹所有n smaple) # return RolloutResult(status=PluginStatus.DISCARD, error="") # 默认return None则视为rollout成功 return None
通过手动初始化 Ark 客户端,并需要显式调用 proxy.update_state_from_messages 和 proxy.process_completion 来处理训练过程中的状态更新和完成回调,其实现是同步的,这赋予了开发者更底层的控制能力。以下是示例代码:
@rollout( runtime=Runtime( instance=PluginInstance.CPU1MEM2, max_concurrency=100, min_replicas=1, max_replicas=10, ), ) @coze_monitor def demo_rollout( context: PluginContext, proxy: RolloutInferenceProxy, sample: ChatCompletionSample, ) -> Optional[RolloutResult]: client = Ark(base_url=proxy.url, api_key=proxy.jwt_token) # 可选:添加cozeloop trace,用于debug监控模型调用输入输出 wrap_inplace_trace(client) req = sample.model_dump() messages = req.pop("messages") tools = (req.pop("tools") or []) + rollout_tools while True: # NOTE:强化学习场景特殊逻辑 # 请求前发给 proxy 检查历史信息、更新内部状态(当 message 更改历史时,proxy 会自动截断内部维护的 token ids)。 proxy.update_state_from_messages([ChatCompletionMessage(**msg) for msg in messages]) # 步骤2: 发起模型请求,由于模型在收到工具执行结果后仍然可能有工具调用意愿,因此需要多次请求 completion = client.chat.completions.create( # model 字段仅在本地测试时生效 model=LOCAL_TEST_MODEL, messages=messages, tools=tools, # NOTE:强化学习场景特殊逻辑 extra_body=proxy.get_extra_body(), extra_headers=proxy.headers, ) assert isinstance(completion, ChatCompletion) # NOTE:强化学习场景特殊逻辑 proxy.process_completion(ChatCompletionResponse(**completion.model_dump())) logger.info(f"completion: {completion.choices[0].message.model_dump()}") messages.append(completion.choices[0].message.model_dump()) if completion.choices[0].finish_reason != "tool_calls": # 模型最终总结,没有调用工具意愿 break tool_calls = completion.choices[0].message.tool_calls for tool_call in tool_calls or []: tool_name = tool_call.function.name if tool_name == "get_current_weather": # 步骤 3:调用外部工具 args = json.loads(tool_call.function.arguments) tool_result = get_current_weather(**args) # 步骤 4:回填工具结果,并获取模型总结回复 messages.append( { "role": "tool", "content": tool_result, "tool_call_id": tool_call.id, } ) return
比较rollout结果是否包含数据集中给定的标准答案。以下是示例代码:
@group_grader( name="randaom_reward", runtime=Runtime( instance=PluginInstance.CPU1MEM2, max_concurrency=100, timeout=300, ), ) def random_reward_fn( context: PluginContext, sample: ChatCompletionSample, trajectories: List[Trajectory], ) -> GroupGraderResult: """ 奖励函数:返回随机奖励 参数: - trajectories: 完整的对话历史 - sample: 样本数据,包含标准答案的字典 返回: - list[float]: 奖励分数列表,每个分数对应一个候选回复(1.0表示完全匹配,0.0表示不匹配) 依赖: - 数据集里的字典字段 extra 内需要携带 answer 字段。 """ rewards = [ t.extra["reward"] if (t.extra and "reward" in t.extra) else random.random() for t in trajectories ] return GroupGraderResult( rewards=rewards, status=PluginStatus.SUCCESS, error="", metrics={} )