From 0bdc8d9ae9403c885b5a7392e231e6fadf2ccdaa Mon Sep 17 00:00:00 2001 From: jinye_huang Date: Tue, 22 Apr 2025 20:12:27 +0800 Subject: [PATCH] =?UTF-8?q?timeout=E6=9C=BA=E5=88=B6=EF=BC=8C=E4=BD=86?= =?UTF-8?q?=E6=97=A0=E6=95=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 86 +++++++++- core/__pycache__/ai_agent.cpython-312.pyc | Bin 16362 -> 18376 bytes core/ai_agent.py | 198 +++++++++++++--------- example_config.json | 4 +- examples/test_pipeline_steps.py | 4 +- examples/test_stream.py | 4 +- main.py | 8 +- 7 files changed, 213 insertions(+), 91 deletions(-) diff --git a/README.md b/README.md index a040faf..15cafb3 100644 --- a/README.md +++ b/README.md @@ -129,12 +129,13 @@ pip install numpy pandas opencv-python pillow openai - `content_temperature`, `content_top_p`, `content_presence_penalty`: 内容生成 API 相关参数 (默认为 0.3, 0.4, 1.5) - `request_timeout`: AI API 请求的超时时间(秒,默认 30) - `max_retries`: 请求超时或可重试网络错误时的最大重试次数(默认 3) -- `camera_image_subdir`: 存放原始照片的子目录名(相对于 `image_base_dir`,默认 "相机") - **注意:此项不再用于查找描述文件。** -- `modify_image_subdir`: 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`,默认 "modify") -- `output_collage_subdir`: 在每个变体输出目录中存放拼贴图的子目录名(默认 "collage_img") -- `output_poster_subdir`: 在每个变体输出目录中存放最终海报的子目录名(默认 "poster") -- `output_poster_filename`: 输出的最终海报文件名(默认 "poster.jpg") -- `poster_target_size`: 海报目标尺寸 `[宽, 高]`(默认 `[900, 1200]`) +- `stream_chunk_timeout`: 处理流式响应时,允许的两个数据块之间的最大等待时间(秒),用于防止流长时间挂起。 +- `camera_image_subdir`: 存放原始照片的子目录名(相对于 `image_base_dir`,默认 "相机") +- `modify_image_subdir`: 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`,默认 "modify") +- `output_collage_subdir`: 在每个变体输出目录中存放拼贴图的子目录名(默认 "collage_img") +- `output_poster_subdir`: 在每个变体输出目录中存放最终海报的子目录名(默认 "poster") +- `output_poster_filename`: 输出的最终海报文件名(默认 "poster.jpg") +- `poster_target_size`: 海报目标尺寸 `[宽, 高]`(默认 `[900, 1200]`) - `text_possibility`: 海报中第二段附加文字出现的概率 (默认 0.3) 项目提供了一个示例配置文件 `example_config.json`,请务必复制并修改: @@ -178,17 +179,20 @@ except Exception as e: print(f"Error loading config: {e}") sys.exit(1) -# 2. 初始化 AI Agent (读取超时/重试配置) +# 2. 初始化 AI Agent (读取超时/重试/流式块超时配置) ai_agent = None try: request_timeout = config.get("request_timeout", 30) max_retries = config.get("max_retries", 3) + stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # 新增:读取流式块超时 + ai_agent = AI_Agent( config["api_url"], config["model"], config["api_key"], timeout=request_timeout, - max_retries=max_retries + max_retries=max_retries, + stream_chunk_timeout=stream_chunk_timeout # 新增:传递流式块超时 ) # 3. 定义提示词和参数 @@ -272,4 +276,68 @@ This refactoring makes it straightforward to add new output handlers in the futu ### 配置文件说明 (Configuration) -主配置文件为 `poster_gen_config.json` (可以复制 `example_config.json` 并修改)。主要包含以下部分: \ No newline at end of file +主配置文件为 `poster_gen_config.json` (可以复制 `example_config.json` 并修改)。主要包含以下部分: + +#### 1. 基本配置 (Basic) + +* `api_url` (必须): 大语言模型 API 地址 (或预设名称如 'vllm', 'ali', 'kimi', 'doubao', 'deepseek') +* `api_key` (必须): API 密钥 +* `model` (必须): 使用的模型名称 +* `topic_system_prompt` (必须): 选题生成系统提示词文件路径 (应为要求JSON输出的版本) +* `topic_user_prompt` (必须): 选题生成基础用户提示词文件路径 +* `content_system_prompt` (必须): 内容生成系统提示词文件路径 +* `resource_dir` (必须): 包含**资源文件信息**的列表。列表中的每个元素是一个字典,包含: + * `type`: 资源类型,目前支持 `"Object"` (景点/对象信息), `"Description"` (对应的描述文件), `"Product"` (关联产品信息)。 + * `file_path`: 一个包含该类型所有资源文件**完整路径**的列表。 +* `prompts_dir` (必须): 存放 Demand/Style/Refer 等提示词片段的目录路径 +* `output_dir` (必须): 输出结果保存目录路径 +* `image_base_dir` (必须): **图片资源根目录绝对路径或相对路径** (用于查找源图片) +* `poster_assets_base_dir` (必须): **海报素材根目录绝对路径或相对路径** (用于查找字体、边框、贴纸、文本背景等) +* `num` (必须): (选题阶段)生成选题数量 +* `variants` (必须): (内容生成阶段)每个选题生成的变体数量 + +#### 2. 可选配置 (Optional) + +* `date` (可选, 默认空): 日期标记(用于选题生成提示词) +* `topic_temperature` (可选, 默认 0.2): 选题生成 API 温度参数 +* `topic_top_p` (可选, 默认 0.5): 选题生成 API top-p 参数 +* `topic_presence_penalty` (可选, 默认 1.5): 选题生成 API presence penalty 参数 +* `content_temperature` (可选, 默认 0.3): 内容生成 API 温度参数 +* `content_top_p` (可选, 默认 0.4): 内容生成 API top-p 参数 +* `content_presence_penalty` (可选, 默认 1.5): 内容生成 API presence penalty 参数 +* `request_timeout` (可选, 默认 30): 单个 HTTP 请求的超时时间(秒)。 +* `max_retries` (可选, 默认 3): API 请求失败时的最大重试次数。 +* `stream_chunk_timeout` (可选, 默认 60): 处理流式响应时,允许的两个数据块之间的最大等待时间(秒),用于防止流长时间挂起。 +* `camera_image_subdir` (可选, 默认 "相机"): 存放原始照片的子目录名(相对于 `image_base_dir`) +* `modify_image_subdir` (可选, 默认 "modify"): 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`) +* `output_collage_subdir` (可选, 默认 "collage_img"): 在每个变体输出目录中存放拼贴图的子目录名 + +#### 3. 选题与内容生成参数 (Topic & Content Generation) + +* `topic_temperature` (可选, 默认 0.2): 选题生成 API 温度参数 +* `topic_top_p` (可选, 默认 0.5): 选题生成 API top-p 参数 +* `topic_presence_penalty` (可选, 默认 1.5): 选题生成 API presence penalty 参数 +* `content_temperature` (可选, 默认 0.3): 内容生成 API 温度参数 +* `content_top_p` (可选, 默认 0.4): 内容生成 API top-p 参数 +* `content_presence_penalty` (可选, 默认 1.5): 内容生成 API presence penalty 参数 + +#### 4. 图片处理参数 (Image Processing) + +* `camera_image_subdir` (可选, 默认 "相机"): 存放原始照片的子目录名(相对于 `image_base_dir`) +* `modify_image_subdir` (可选, 默认 "modify"): 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`) +* `output_collage_subdir` (可选, 默认 "collage_img"): 在每个变体输出目录中存放拼贴图的子目录名 + +#### 5. 海报生成参数 (Poster Generation) + +* `output_poster_subdir` (可选, 默认 "poster"): 在每个变体输出目录中存放最终海报的子目录名 +* `output_poster_filename` (可选, 默认 "poster.jpg"): 输出的最终海报文件名 +* `poster_target_size` (可选, 默认 [900, 1200]): 海报目标尺寸 `[宽, 高]` +* `text_possibility` (可选, 默认 0.3): 海报中第二段附加文字出现的概率 + +#### 6. 其他参数 (Miscellaneous) + +* `request_timeout` (可选, 默认 30): 单个 HTTP 请求的超时时间(秒)。 +* `max_retries` (可选, 默认 3): API 请求失败时的最大重试次数。 +* `stream_chunk_timeout` (可选, 默认 60): 处理流式响应时,允许的两个数据块之间的最大等待时间(秒),用于防止流长时间挂起。 + +项目提供了一个示例配置文件 `example_config.json`,请务必复制并修改: diff --git a/core/__pycache__/ai_agent.cpython-312.pyc b/core/__pycache__/ai_agent.cpython-312.pyc index 54aa19af800b4a38996623e6418fee1b5a99ed7b..6900778c0828a52ce7ebfed2845e1304af20ada9 100644 GIT binary patch delta 7209 zcmc&ZYfu|UnlmE_3D5(OKnR31JOT)VFpgirCYU#vhaW&3KV?NT!U7?Y(GY`HL|&I~ zLpF8{nZ%FS4$d2&?KL^?CC(<7I2&JuY@E%=a0OAgI$s@N`>O7iIIh}MQdf6hk0ijv zxm4Ysn^Jw<{q@)1qx{6e9K!QdM`B;EIS8wQohI6IhAx17Q#+!?zg zmJ;odQ{o*8O0q*qMeR@#SPe!=4`Gz-Jb{K`*p9ef+!Q_7->SJIg616@fj9&Z-@$(# z(n_I>d|!z+mcOkn+XFh=Qb23;xXK}3Y{d@hqcF{ zwnPrK1XNQg&crUL$R=_Vpwbp13FwnjE-23 zicFdT6v_xiZ$c)<8X~PMjM;!=EY77y>!T=@FbYYIJ0n^oS;sZy>-|)G1htsocC0sH zKh`7Kj_t=y>hAThFfnBKo9RC@^Y(KyhmVnEb>#F1uiqN@&A%Nu7&G5r#1Lc{u7_nz#pLYpt23`WH+%ZX%y0d#-g^1a%+Mchy?TIb-P%CT_8+_R+M(OO z`)Fq1$gQ_W$lM%?rn@{e-I0?|<~-$ecIKFnw1_gO>~vGK)8evr(h&i--}?3RKYV=a z#bIC`5*A{&bfmS*?}&^k+gwM^T>8zO7v5Y-#X34^cQ;GsI$Uf%>F#13ZkN>wcFY`k zY4*aaBpP}9(tES-90YN<-~F4}lSYST$#yhuv|0ZxoMx8z|^TykEo;h;+5>jFIrIDXVStJBxz{Jzm zPK#|{x2wal$RPjy7R~Q~A`NNCTS1wLpd1Ly@-vam9BfpzDbA0;!=fJ9Y z@wYBMJM$5;_Ge8kT4bs>vCidS9ag8q>u|Lp-zjUOT{z3diN@9XaFZpQX_ke4Pcir3 zq9VOloBu%7OMKDVu#D_4$8oBEl!7r4T{=ek&`cJU@U&+RgVUyhfTEJ9CH{(v7UJ7b0murX74@|ePE>RfPa={| zqFMwbr~ssss6j=INLh=@q&m?(IQKS+xGyC}|KZptYY=Ok#%4<0ql$EVA)Dig75iDgk$2GKvJ;Q(#(SHR`}lDrfX>y_oh?(MR4m3X-` zaV5By`z9`rU}8CGyn$e%xtw?t?%=k>C&#pXc^!Ror6$f3--v6uPvS>6do^e$3oA*w zZMJTPp{ZihaSztPtKJHZ)r))-(b%j`I4Qka6KS=xG(*y8_4CR7R@gm2;uU+$X0w<3 zI>Cl}FPEqv!R3q&Okh$FFmTUk$S7gi)47WpBS`;Jqr+Elw=`)ICL2gNMPd`4%k4@e zW7nV&YY}*adnU0NsJ=;T#Z6pMQXanH@~))AvcwYT4dRSf(#_fn)|zl83$Y-KPecha zv$=L-j-mmv*bKlVV{*AyjGKVRe;Tv#QaFbns3J;V^lA!_mA6mc422_J#V zHli{Q6y5If4v;Z`GiTv0$GJd;UP(!P zQcC6p$)_+X8fbL!F=r$Tyo6gjAXRg(#i|JiBA|+^xj?L%`xB8=9|bXsm_p%&1~@GF zF38V^mFSCMwUJI=OoVw~EQ_W^_}Y^pph!vRlSd0-#{FoiPws_~%f|)s64f1*Fn>B0 z9WKI0DFplqP+z$aSRBMz3+dWhyn58w+2LqWfw(|>||MIV0*+!;UE>=&o;NqMcYu0U?yv=F3JgGBhmeS2fL4?d%D~%xU3vjzY`eRI^1@<`F~5F zJiQ*4?zDtN7LzUeXflKfYav9wEMUv(?hpFAJv4LQFoak%$WK0K6v*mAHg_|0G1+YG z>~bQr7GdufBE&?+WK*}zNq2ilPh>!sm4S>50+>a%ZFkU4Y9Tp9$&Zkl6q6MwG4YUY zyPy)><-)8}!mKZ1u+{4D&ka z*t`rn@ECNWc{9Tc9x@qLJQT^3!paG;c`?bLaG}X7D`U~5uq1KMPOWK>%7l`^D36(W6w$>s z3&|X5ceos$eKckEri3vLm~JZzc@W*R&)N-M1AR&1z7Xed>Q_v9lg(u5)=1pxu#;VG zkH_(Z6LM%TVKRI5XcC3clcc&CG#w#Bpaz3QWC7{T1B4wOi`&nFfnWwp9s zKf-~85Zg7^t;_%WcNH1XRq0-`#F`(rmV%GI!r}S zRo7J4QdibsDKD$2-&9>4%Gl6l#H1sDZW{(&FiZ^s$WI_Ig=G!}%_Igs6qL4vUMPZM z(1T@a5#>e%Rv=K1#weO$pVkTP0n;4{yTzy%J$o@IS`3QW-Yz;Qvj{IK7Ry4Y9E@fe ztBpoSRWJr})J`j;pFK9F0uA*;R0tY$yBw&8X+*8q#93TGXthW0~H z8eZ?9JbWMuZxlhL)9Mjqg*LdbSAxpWBSzTQ*D^tsKo!Cw|7La|zJl#SbRP#z;iOmu z7}E=@*ZQwRkLNX*huhfq2e7#eEXgoj5Rk8(OiUe)4air{Nkp2Me%U=)Bu&;9Us?Ha z>HDR8aXp{dFfMDnnV{~MPHD6I<&(O!e&tPN;tBCYa&90wcTAZ(M>OIJ-Bg_RxMEOo zRyM31DLR)rtQgrc!cMHI39Rwg@T+V2S}SjOVm!_U42DuqrcUTg0i9_=w<@4pHKAJ{ z(5;_LF$@)+ESyYxXp&qunPQwwHBRcup@x$Ua|$e_upd zq1`X<9`3%>cdqZsmOx?YL}6o~u#w-g?MC6YskE%2T_<<(dF#j1N+#1Xzcpgo)O(c# znta8x3O=b~Tvj=$)Zb86zY)#nRgIC=ysY~Bdo~=?o8Sc?(KwZC9Fm`u^ADAcCzpSn zY?w;TxQmGr4O7W}<}kZ3x9V%`XxXRxCu+6@YPRv!kMoZM z8DOT<2U@wp_)>%;f=N`M9QHDl_}6Z^RnNS@U7)SZ38=!!)pH9BkyX z1V3sCRJQOHt$gbqKEpDu**lq%d&8eXU8v(%+s5)KUPFB=#S+uMzbD5M^`|Z4<_*`( zyn4%BEJk6PjMIj7OwRyEhOMI-RGrNoUNdYODH>^+SXC8RRmHEYzUmm?+`_j$Hnw@& z#OA$$&3pMxR^CeSw4JBgcyix(oZ~(#{ShRo)* zZr-?OJl-;?O&-wxI41?lALh0Zm?|l9dWjOL@>(E=I5tTNZ-iayVj*rDU|TrHqJFD(<- z%S2y@B!K!tCI{3Pn*0Wd=sJcl*Ci6bT#qTwfX1Km31IbSv#vQ+^p&(CVXHxOLni^$ z4TBs|H}dkElSE@l5~z=*CIFthl%90CbDb5(cXLCTDb@X<2j#zf`0QUkeEIgJ4`z-Y zn?3f-^at)`Awr)Q52On-DrpgDZ&(7Q7O&)<4+XtsXOtH*wq2Mk*n&rU$ocHO(zs ztI8TJ6}7ER^_G^p#;Q%NEg{uTA)*5aAa#F`t%snu+etsl3_%kzBg}(@&53Xv{|1Zx x7i`ZzVSBhQO63LLV5&JKriwcdJDH$9pqf;onpVeSn!k!!G2VgS#e_m&{yz=A4XpqG delta 5180 zcmbtXeNbE1m48o9pGX1$5(r74$2a23#$VXj#E}E|gJ5hhHZ~@z(0gEkkjT>`7)79D zC!IF4Ze#bRcAdnvJJZfMb<%O`>D0UHZro1V;5gkpyqYo9(B18fKkA*?F2SB@m+T*V z&J#l5PMY1>H}ih)ymQYv_uO;Oz4zRiH|T$lW!^R#Gbnf_y_yrhZ9J1{rH`Mf+-p@( zyD5%&jpCG76xd8r$Mlo5OLgW_57E&VJJd<0jHu7R`}ibHQRgU*<`f*mDLECVen%Ch zJS(D6YIucC6T=MgGP!+plumU8I?J6>RFr`~v@SnU1^g*j{*>^iTKU7iR{RpPQhZKB zUeOoOTJl*#fkGE0>kJn1uL=$MC6in5Hg!RJ1-! zAXhV-ns4YILr2uL7X8Hx{gSqp6+>g5vF0W6M@uukL~8S@6^}AVZ(fn&5lv3!xk?{V zEDS0~;Kz6oA{GjI#`vH&AWlB~%_sOQ?0lDenAbt)kevLP150N2fxT?2AcO=q9o zOV_?>-9|SiG=b325Hu!~{@`Fpj%h;0@dM+)5&fflLKP11{8&OQc!OMMRF1MxK>n-1 z>6WKV<2w=Z`(V1503Srnfan#fAsPwepg-sjpMd%FhC{)GuJxpkAH%^WG!tGS2z`+g3D63lY`0I9RvUMJKd; z&==y+PHWkK`Z-iSNZv22G#)^*2f(EjY~=T4onWY;yn=r9T5q{eNpB}NT!rnjd<#lW z0>CgK>&AqG$Y?|00bEAGhIB8&Q3L^Dm9zL8CS}#Tp^HGZl|DzlTKyjiq41hmL}}bq zY}DHnr@FwbTC}U>3OIFC-hBQ-judCHRhlRR>e{FhP#0Cf>Q|=?aT!x;PCuoIs*|h# zO~W?w3%!~A-e4t98gj{u*-Xxw%%rcx!cfB|^1b3*;xQIMvI3QD;qv_Sd-+$Ew}&Y$ zVY7Y;UrgTwl3)MePQBxVw48j`CrZ9cVB;o>=pCps~Q)V{fO42qi10y(WU6xq& z`BtV`{R(}m<9{QVH>Hp2&obv|n&NV&bf+eiR8%MAkyjg>1-7({lgvucsbL2WtHAyR ztyS2%aZ|dL+kBQr`;$y5B^IUpETe=QQ--IlY8+7m#9l!batk2Dxy}*DHo-3zr_{K@ zs5%01IigSa_rkmuK0o`?VngqE@e5& zqQ*4qD4cLphH1lT#VP$Nqw>)9mZaslQgoLqyTF`P;t*4MlT1>7f);V%%MT2=a?OBi zF;KDcRn{uL%-RHXk~vCE&=AP_l-+bl#q}#ZT#xcIs4SyS(?a`H=BjZxQmt?@Tq)MY zDx;Yg^i`e|Ol66iQ&Q2)zuwy*lRBOS29hW%SiDAe#g`SS>}cljmX!6R500vl<1o_- z)Q@Hf??%;}F=B-Epk7{d|IbWtJ(UU7|Cdb2oX$MWoa$K>K{*o?HF9FrLk85Y-gmgV z!m2O$%%E|Z^O zISHFLteX0bvTK;qUtlhkgzqyGv}?r(!P&pNf8z~y z($5Dt_QBb&J$UE6`xnkdoZTWEe4}{m?&gEM;Kh@V?S{iqjNklSfLDM%MmINDiNv>Y5)!iwMf<>;L;P;#Q-A}Nf%gOC>Z2@FpgSQ;C;ORB^)I% zMJUl=Xl;?gWx-?xz;kSP+y^y6cyK%rn5>V~KGj^y4vdRzFqEN4cpKCCChMl~ za=4P?L@P15$NI?jzV1U^t?mv_YgboiR}=fiHdeVH>o&1Vbbq9(bxj$^a^q0s4yER2 zObB6^fYTJoTGLZu6D#aM*P9WZLD-4l28dLXCTmG-*+_ka?Sm=*D-QMxyHG2qo^=m| z1UW~N1Fnx0uU6e33f9T-@eTxdu4&!{p$lOTf?NsiMheTxWC3>=DZG0L zc-as%2;B&W09>Z0D@HsJmMdGNjv@3RzzIMlvgLil&2ye`Xp|REz|;#KY{ipQ_zHp- zVF1Ag5HA9>r!fmS2DQ<;$bj zx87{M*(mMnmtF`;{UOOR7T3P`Z)AtVu_1rnlz+}}#vnO&-qg={e~~5aJRlu72z}n| z8kp@Gkodvbu0iSGkTi5cD)h%q!{l{`%TQT&y=JzkZGoYS_K<&c~e(w1*&r8L7Nb4ackck+lLp`J~?bf{ySFxytU}_ zd-?1l#mF(aYc8KPmrE4~;^u>Mj+)E+F7LSdqGa24$Goj*-fa7PK?8E1KgpmhW%2UH z8%3WKNx8j`sXU{d{7ca%u{B*@HeTiKZfTp{(k5-*E$x0$su1GV@LbWecdXAsxA$P+ z9cMbEvgYgF>)q0`T~bMR+2Wm=E+@C@vNy>DmfC5nR@4}6_;MT%ZAW z9ju{$$?Q_JTc{7!RlA+khwH3>ztWZ>*18q&`E@-FQomkj0s14Q9q5moy7nCAqpDq5 z`<%?JMkQ$9+OB|xTg`>Q`Phb>kBefc_p#FevLDwJwi}t-T5Py&K>lr`3i!8k801)N z2dkLd>z>Owu$lRkRzvfrno6YWRY-4a23lH=w6xg(qO)}2!BS?n2syK*D&$l#K>Kjt zNIjZs(k?&Ui1*+7Kl{PGpTG0qr>{P|{I?I^dJPDuE*`%Abto(T^Tu0$_~6a^Z+|yg zWvo;nBkHg^hs@p6pZFS*=;NEoglD|-&?HAh-2 zXd{_!eZQ~}8uq&skMN@Nhyifv6Izca=pE%f9#h`%3rL_R=E}{RH~here0dL z;qglfgKC6MT0$c(gD?yq^2fH4UGgrD`*}hq3SJ*S;Ps6pPYt=dZ$nGkg(dBd(GWKt f;GYq`0}V91;VLq{yD;}ph7Kh?LN8MCN7ncspWjng diff --git a/core/ai_agent.py b/core/ai_agent.py index 6f1b897..e3dd8c3 100644 --- a/core/ai_agent.py +++ b/core/ai_agent.py @@ -4,15 +4,34 @@ import time import random import traceback import logging +import tiktoken # Configure basic logging for this module (or rely on root logger config) # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') # logger = logging.getLogger(__name__) # Alternative: use named logger +# Constants +MAX_RETRIES = 3 # Maximum number of retries for API calls +INITIAL_BACKOFF = 1 # Initial backoff time in seconds +MAX_BACKOFF = 16 # Maximum backoff time in seconds +STREAM_CHUNK_TIMEOUT = 10 # Timeout in seconds for receiving a chunk in stream + class AI_Agent(): """AI代理类,负责与AI模型交互生成文本内容""" - def __init__(self, base_url, model_name, api, timeout=30, max_retries=3): + def __init__(self, base_url, model_name, api, timeout=30, max_retries=3, stream_chunk_timeout=10): + """ + 初始化 AI 代理。 + + Args: + base_url (str): 模型服务的基础 URL 或预设名称 ('deepseek', 'vllm')。 + model_name (str): 要使用的模型名称。 + api (str): API 密钥。 + timeout (int, optional): 单次 API 请求的超时时间 (秒)。 Defaults to 30. + max_retries (int, optional): API 请求失败时的最大重试次数。 Defaults to 3. + stream_chunk_timeout (int, optional): 流式响应中两个数据块之间的最大等待时间 (秒)。 Defaults to 10. + """ + logging.info("Initializing AI Agent") self.url_list = { "ali": "https://dashscope.aliyuncs.com/compatible-mode/v1", "kimi": "https://api.moonshot.cn/v1", @@ -26,8 +45,9 @@ class AI_Agent(): self.model_name = model_name self.timeout = timeout self.max_retries = max_retries + self.stream_chunk_timeout = stream_chunk_timeout - print(f"Initializing AI Agent with base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}") + logging.info(f"AI Agent Settings: base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}, stream_chunk_timeout={self.stream_chunk_timeout}s") self.client = OpenAI( api_key=self.api, @@ -35,6 +55,12 @@ class AI_Agent(): timeout=self.timeout ) + try: + self.encoding = tiktoken.encoding_for_model(self.model_name) + except KeyError: + logging.warning(f"Encoding for model '{self.model_name}' not found. Using 'cl100k_base' encoding.") + self.encoding = tiktoken.get_encoding("cl100k_base") + def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty): """生成文本内容,并返回完整响应和token估计值""" logging.info(f"Generating text with model: {self.model_name}, temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}") @@ -176,98 +202,118 @@ class AI_Agent(): # --- Streaming Methods --- def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty): - """生成文本内容,并以生成器方式 yield 文本块""" - logging.info("Streaming Generation Started...") - logging.debug(f"Streaming System Prompt (first 100 chars): {system_prompt[:100]}...") - logging.debug(f"Streaming User Prompt (first 100 chars): {user_prompt[:100]}...") - logging.info(f"Streaming Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}") + """ + Generates text based on prompts using a streaming connection. Handles retries with exponential backoff. - retry_count = 0 - max_retry_wait = 10 + Args: + system_prompt: The system prompt for the AI. + user_prompt: The user prompt for the AI. + temperature: Sampling temperature. + top_p: Nucleus sampling parameter. - while retry_count <= self.max_retries: + Yields: + str: Chunks of the generated text. + + Raises: + Exception: If the API call fails after all retries. + """ + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + logging.info(f"Generating text stream with model: {self.model_name}") + + retries = 0 + backoff_time = INITIAL_BACKOFF + last_exception = None + + while retries < self.max_retries: try: - logging.info(f"Attempting API stream call (try {retry_count + 1}/{self.max_retries + 1})") - response = self.client.chat.completions.create( + logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream.") + stream = self.client.chat.completions.create( model=self.model_name, - messages=[{"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}], + messages=messages, temperature=temperature, top_p=top_p, - presence_penalty=presence_penalty, stream=True, - max_tokens=8192, - timeout=self.timeout, - extra_body={"repetition_penalty": 1.05}, + timeout=self.timeout # Overall request timeout ) - try: - logging.info("Stream connected, receiving content...") - yielded_something = False - for chunk in response: - if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: + chunk_iterator = iter(stream) + last_chunk_time = time.time() + + while True: + try: + # Check for timeout since last received chunk + if time.time() - last_chunk_time > self.stream_chunk_timeout: + raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.") + + chunk = next(chunk_iterator) + last_chunk_time = time.time() # Reset timer on successful chunk receipt + + if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content + # logging.debug(f"Received chunk: {content}") # Potentially very verbose yield content - yielded_something = True + elif chunk.choices and chunk.choices[0].finish_reason == 'stop': + logging.info("Stream finished.") + return # Successful completion + # Handle other finish reasons if needed, e.g., 'length' - if yielded_something: - logging.info("Stream finished successfully.") - else: - logging.warning("Stream finished, but no content was yielded.") - return + except StopIteration: + logging.info("Stream iterator exhausted.") + return # End of stream normally + except Timeout as e: + logging.warning(f"Stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).") + last_exception = e + break # Break inner loop to retry the stream creation + except (APITimeoutError, APIConnectionError, RateLimitError) as e: + logging.warning(f"API error during streaming: {type(e).__name__} - {e}. Retrying if possible ({retries + 1}/{self.max_retries}).") + last_exception = e + break # Break inner loop to retry the stream creation + except Exception as e: + logging.error(f"Unexpected error during streaming: {traceback.format_exc()}") + # Decide if this unexpected error should be retried or raised immediately + last_exception = e + # Option 1: Raise immediately + # raise e + # Option 2: Treat as retryable (use with caution) + break # Break inner loop to retry - except APIConnectionError as stream_err: - logging.warning(f"Stream connection error occurred: {stream_err}") - retry_count += 1 - if retry_count <= self.max_retries: - wait_time = min(2 ** retry_count + random.random(), max_retry_wait) - logging.warning(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...") - time.sleep(wait_time) - continue - else: - logging.error("Max retries reached after stream connection error.") - yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]" - return - - except Exception as stream_err: - logging.exception("Error occurred during stream processing:") - yield f"[STREAM_ERROR: {stream_err}]" - return - - except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e: - logging.warning(f"API Error occurred: {e}") - should_retry = False - if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)): - should_retry = True - elif isinstance(e, APIStatusError) and e.status_code >= 500: - should_retry = True - - if should_retry: - retry_count += 1 - if retry_count <= self.max_retries: - wait_time = min(2 ** retry_count + random.random(), max_retry_wait) - logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...") - time.sleep(wait_time) - continue - else: - logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting stream.") - yield "[API_ERROR: Max retries reached]" - return + # If we broke from the inner loop due to an error that needs retry + retries += 1 + if retries < self.max_retries: + logging.info(f"Retrying stream in {backoff_time} seconds...") + time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter + backoff_time = min(backoff_time * 2, MAX_BACKOFF) else: - logging.error(f"Non-retriable API error: {e}. Aborting stream.") - yield f"[API_ERROR: Non-retriable status {e.status_code if isinstance(e, APIStatusError) else 'Unknown'}]" - return - except Exception as e: - logging.exception("Non-retriable error occurred during API call setup:") - yield f"[FATAL_ERROR: {e}]" - return + logging.error(f"Stream generation failed after {self.max_retries} retries.") + raise last_exception or Exception("Stream generation failed after max retries.") - logging.error("Stream generation failed after exhausting all retries.") - yield "[ERROR: Failed after all retries]" + + except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e: + retries += 1 + last_exception = e + logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}") + if retries < self.max_retries: + logging.info(f"Retrying in {backoff_time} seconds...") + time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter + backoff_time = min(backoff_time * 2, MAX_BACKOFF) + else: + logging.error(f"API call failed after {self.max_retries} retries.") + raise last_exception + except Exception as e: + # Catch unexpected errors during stream setup + logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}") + raise e # Re-raise unexpected errors immediately + + # Should not be reached if logic is correct, but as a safeguard: + logging.error("Exited stream generation loop unexpectedly.") + raise last_exception or Exception("Stream generation failed.") def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty): - """工作流程的流式版本:返回文本生成器""" + """完整的工作流程(流式):读取文件夹(如果提供),然后生成文本流""" logging.info(f"Starting 'work_stream' process. File folder: {file_folder}") if file_folder: logging.info(f"Reading context from folder: {file_folder}") diff --git a/example_config.json b/example_config.json index 98a3b23..98f4e65 100644 --- a/example_config.json +++ b/example_config.json @@ -48,9 +48,11 @@ "content_presence_penalty": 1.5, "request_timeout": 30, "max_retries": 3, + "stream_chunk_timeout": 60, "output_collage_subdir": "collage_img", "output_poster_subdir": "poster", "output_poster_filename": "poster.jpg", "poster_target_size": [900, 1200], - "text_possibility": 0.3 + "text_possibility": 0.3, + "description_filename": "description.txt" } \ No newline at end of file diff --git a/examples/test_pipeline_steps.py b/examples/test_pipeline_steps.py index 0256a8f..5347c90 100644 --- a/examples/test_pipeline_steps.py +++ b/examples/test_pipeline_steps.py @@ -117,6 +117,7 @@ def main_test(): ai_api_key = config.get("api_key") request_timeout = config.get("request_timeout", 30) max_retries = config.get("max_retries", 3) + stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Read stream timeout if not all([ai_api_url, ai_model, ai_api_key]): raise ValueError("Missing required AI configuration (api_url, model, api_key)") logging.info("Initializing AI Agent for content generation test...") @@ -125,7 +126,8 @@ def main_test(): model_name=ai_model, api=ai_api_key, timeout=request_timeout, - max_retries=max_retries + max_retries=max_retries, + stream_chunk_timeout=stream_chunk_timeout # Pass stream timeout ) total_topics = len(topics_list) diff --git a/examples/test_stream.py b/examples/test_stream.py index 1f231ca..6beba94 100644 --- a/examples/test_stream.py +++ b/examples/test_stream.py @@ -66,6 +66,7 @@ def main(): ai_api_key = config.get("api_key") request_timeout = config.get("request_timeout", 30) max_retries = config.get("max_retries", 3) + stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Get stream chunk timeout # Check for required AI params if not all([ai_api_url, ai_model, ai_api_key]): @@ -80,7 +81,8 @@ def main(): model=ai_model, # Use extracted var api_key=ai_api_key, # Use extracted var timeout=request_timeout, - max_retries=max_retries + max_retries=max_retries, + stream_chunk_timeout=stream_chunk_timeout # Pass it here ) # Example call to work_stream diff --git a/main.py b/main.py index ed7fbbc..819748f 100644 --- a/main.py +++ b/main.py @@ -91,14 +91,16 @@ def generate_content_and_posters_step(config, run_id, topics_list, output_handle ai_agent = None try: # --- Initialize AI Agent for Content Generation --- - request_timeout = config.get("request_timeout", 30) # Get timeout from config - max_retries = config.get("max_retries", 3) # Get max_retries from config + request_timeout = config.get("request_timeout", 30) # Default 30 seconds + max_retries = config.get("max_retries", 3) # Default 3 retries + stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Default 60 seconds for stream chunk ai_agent = AI_Agent( config["api_url"], config["model"], config["api_key"], timeout=request_timeout, - max_retries=max_retries + max_retries=max_retries, + stream_chunk_timeout=stream_chunk_timeout ) logging.info("AI Agent for content generation initialized.")