From 59312ed23ad0a134091d7bb1f50571f9029f1275 Mon Sep 17 00:00:00 2001 From: jinye_huang Date: Tue, 22 Apr 2025 16:29:51 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E8=BE=93=E5=87=BA=E7=9A=84ai=5Fagent=20work=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- core/__pycache__/ai_agent.cpython-312.pyc | Bin 7254 -> 13042 bytes core/ai_agent.py | 125 +++++++++++++++++++++- 2 files changed, 123 insertions(+), 2 deletions(-) diff --git a/core/__pycache__/ai_agent.cpython-312.pyc b/core/__pycache__/ai_agent.cpython-312.pyc index 858b74277d321fe5977ff8981f7d2a619d013de1..8b29a681aa83bda17b807a0897e499ffb52ff1f2 100644 GIT binary patch delta 5902 zcmb_g3v3(5nV#j66yFjhQX(meTt$*4>J`6KSr5ym--;zmk>nVWA~Q61WmDoqU9Mt_ zp+Yy!-MKj9>fGl1h~q}Up#XDfuPTDx$%u<6XoJW|n=EBs#A6hkaggl<1zJ@$&?0Gp z`)5f?v`)QVfeyfbhX4HYKfd{Q_n-N6*D9m?NUK#LXbpXNXPBl}b*AV5{JVl^L=!$V zhvMBV zm)Gyj2GsN@!w`q*b)!6=G0|JsXtKQlN5L z1~iYB169y@K$Wyo0i9}6=}`ap=XWI~deD<1-Q(Vr${(Z|pDW<@Gi)LMF4$nNe0zR9Vd1^`l)4V9s>h^(e=Xlh zRPdkX?;=Y1{DNBIiFaBHcp0&c|JdjQ@wVc{#WK7 z5@q}~%kT~{R++~a3^5KuR30K=9mV3+>^P9HZf&tB_B!t|9%f1-8rphYN|)WpTlbW+^DBAQXhgHldKTW&Q|g!0 zmxzm(a-6lTOVD=gI8EIkFH3Ql*%g!0^vbgY_e9n-1ikaIpv%Reb7JPIP$s+3=oTDJ z6tek=KWE`AI!8W>&JqA{P1e>r)Ld_oD}iiJ!~b};=jPgz8*>e^J=4~ceJd6cAEymoW({PM{fFdP!m(ole_9{B_U8Qy zm+oKwyZdka;NIW<&AorU`I}o8@7?^j<#)fIp1OA}cK`Ctd$+zzO?nw0O=U!HTnXFz zIVb}B-oR<IK4hZr_H z5!Uv)S+_sbN`dtF?y#bs;)3HY7!}o1<17s*mGcVxEKJJY^sdlh&SL3ondGyG~t1=CNwhP z^G!B}E3)$$ML1Xn7nfAf^HF8W^C8k$8hXq#z>~_q-xj}Jg0WlQ)k^! zvw;ijF3gdJun*h2-RHnNj`cDj=-nRdhI<4mEec8?Ksy`7n(@YP+fe^NU-Y1}+jX$7 zucxn-`qMf#dfrNXMGS{)Rbsh|!w{pO=m{1(ma#a_1~Cf2X{EyX*jSFOuw8}2^hVvM zd<+Hk^ExCMM8ZXDIN)iI+vlTd!WEl2IRaR&sR%WOsUi6He?&DLE>dB6NVzB!?UWb> zI7w0sl=l%-JYkz7iyp(i03eF{wGMr_dNVgNc;~EW4Da|eu0%JnLxt_ff`JBcJJ;X| ziRX*7BaPwmT(yr2X9NM@$N6h1m&I2%bjUdnb#`Te3O8pH70TS~Q#s=H`qom3Id{fA z5fYOR??xuyus4|uI+QD&BQoqTv*H25iL-p|guw1bQ2nqSn7IcN@#rF+TK3|<`!K;3 zF)PCa$~45{V-$glyV$WkaIETJS;9lSX^JTd00rD2FB%q zLlb)xI1U{Y>)8JKQ6T3OfY)_m@)Ramp6$j2yOUC-Q``lZvd1vD2S`WK4G}gEeV)z36{ybV%Xa(Jq3B}eWJh0sg&F#Lk=V@W! z#KN8v)7lsv3klSY8D@cM6NX061F46U8*i=u_vpVlBM>fy&`F)-t4;8m85pg z(1N{f#fYqX9wN21@RDN1O6g0N3}(TyFJai9EO)FT(pY;{#{Y}HiudbPM4UgTFFW#t z<9hc>9?U!{L`I8XX-gQ|lLpJ9a|BZA1l^H@qBE)6c1N-8_m9pJ$W*(EGg#E{bzk@;s_B-5so`@IwQMRAY<&q+f07y!o*NN{P79~c2sZD6$-88-eZ5>lts-eU z8tdVC)m%id3?vN4mkec#hU&PXI$>}mt#z^Hr(*3h-xMq@3x<{@gXQZLB`o^-kp>wn z6IIQ(Y(mkIRixIIE)^IB)2>V zT9@?2g1I5){*nJJf6O(ze~$Q2^}cF0B9wF`^xa8=@k;lr-P8RGhMJ_2icUM@#;Qw0 zNmIo%w^-dAuWnv2H7`}w#PqSTxTEFPq4&CPcE8&bcO0ERd8g{RtL7!MO(@?vGc?mZ zr=Q!oxW7NXzhBrlAPlfVD6)dkON3JbU5MC+yM2;FxT6kB9wB5oB3Uw*y=l2-neI%O z8$UOfFGoR}{Yb*xnJga=jz6<_{8aq-DPh1Pcp$Nu1v9f`wl711l(r{~yOU)dv)_Dg z>gLpJaQAITNd?|aeZY%za?p|juF$L zSh-MAzo4&Q(wE?E(wd@w2>mxajW0W!dv_7pbEjp%23>%pz)XN@MHKhLYcv~DNR z5A1pn`e3WFvyl9t#o2bGko<6`6c&8AO9B%gwiW?rUW+;NdIC7}g&Gi@x7BuP$=eFd zzpcUi+gcg$Zx@oC#wfaNH1sx+w;j&5qaEZ&m2#N>s5T#uTV#0Lp@HdNsP>ijG?5?I z;pvaJmBIKE9iINANE~k!$4yvNAZmNJl0qHk2*3wU6A5Dvep6-}9^3dId#D zb|9y2CGd5?W5*aYO5&q^5t*%M;W{9Xt&B)UW%yjBfFe^8$xHvS!dh}l63EZdNQxH> z(i!*xx^R{S_EMmHFU17oZaT z&8^qJc<;4)*T0vpHrFcAa9JPYrZXiKehE_}P+C$WK_7ezTB!;AMN+968XBm3-+lRu z^WXmBC$HSUe2oHzDqL{F%?7-IC{!_rM0qM~gW@vyf|rI5jDzO*DBd>P5LvV#uroB zuZv;r#wmvjH>4=^9awlC{Z@`NB}*E;VAwWOH&Z4Q?n!9&CW|eL#arUVTc%&UQ@rEC z;s4ApzFS_uSiU`8zCC)Ud?&m^g{B^%_rzlFaJ+Z;POnSod0w!&UsW&T_lYA`7b_FK zF(8y2Pw1a{Bp2V(_@kA~FIvp6iRahE%4XGbLqgvP;aTsU{82$M%HpepV-CreHt4^^ z>E=+U6fW^n)8$I3TrRleobWL&K3Hn*u)@4}iOM9vaNUMG(7LP8HhtVDyF ziEUgim^4aWsIyveu8i@&v0q8{CV<%>j9AnmmgG?IRUK(vNVHFBq|9`4jDC{}8dODc zsqRK?Ro`KrT`N@?*Tt|h){QydiO_+dBXo(1Ql@RqH;S#EHl;}6zQ>|e*{_AH7Dmcy z!1$s|2EnzHvpNgi(DiFJ?6c&Bjpg_6aO}*Q!qMibE?#L(x+W_6X712IpH^FOK5Hbq zkn9#K*#W4%pEYQ|c$OU<%>x~2K_duR)vj~DS`VWPcHR&85!4MM92I}EMUc(rrf5Oj z%#F~Y=98RAP*a@Bdo(Ye=FdRlNB$VS-fS;CkSs<+a%{oz8u 100 else system_prompt) # Print truncated prompts for logs + print("User Prompt:", user_prompt[:100] + "..." if len(user_prompt) > 100 else user_prompt) + print(f"Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}") + + retry_count = 0 + max_retry_wait = 10 # Max backoff wait time + + while retry_count <= self.max_retries: + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=[{"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}], + temperature=temperature, + top_p=top_p, + presence_penalty=presence_penalty, + stream=True, + max_tokens=8192, # Or make configurable? + timeout=self.timeout, # Use configured timeout for initial connect/response + extra_body={"repetition_penalty": 1.05}, # Keep if needed + ) + + # Inner try-except specifically for handling errors during the stream + try: + print("Stream connected, receiving content...") + for chunk in response: + if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + yield content # Yield the content chunk + # Check for finish reason if needed, but loop termination handles it + # if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop": + # print("\nStream finished (stop reason).") + # break # Exit inner loop + + print("\nStream finished successfully.") + return # Generator successfully exhausted + + except APIConnectionError as stream_err: # Catch connection errors during stream + print(f"\nStream connection error occurred: {stream_err}") + # Decide if retryable based on type or context + retry_count += 1 + if retry_count <= self.max_retries: + wait_time = min(2 ** retry_count + random.random(), max_retry_wait) + print(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...") + time.sleep(wait_time) + continue # Continue outer loop to retry the whole API call + else: + print("Max retries reached after stream connection error.") + yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]" # Yield error info + return # Stop generator + + except Exception as stream_err: # Catch other errors during stream processing + print(f"\nError occurred during stream processing: {stream_err}") + traceback.print_exc() + yield f"[STREAM_ERROR: {stream_err}]" # Yield error info + return # Stop generator + + except (APITimeoutError, APIConnectionError, RateLimitError) as e: # Catch specific retriable API errors + print(f"\nRetriable API error occurred: {e}") + retry_count += 1 + if retry_count <= self.max_retries: + wait_time = min(2 ** retry_count + random.random(), max_retry_wait) + print(f"Retrying API call ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...") + time.sleep(wait_time) + continue # Continue outer loop + else: + print("Max retries reached for API errors.") + yield "[API_ERROR: Max retries reached]" + return # Stop generator + + except APIStatusError as e: # Handle 5xx server errors specifically if possible + print(f"\nAPI Status Error: {e.status_code} - {e.response}") + if e.status_code >= 500: # Typically retry on 5xx + retry_count += 1 + if retry_count <= self.max_retries: + wait_time = min(2 ** retry_count + random.random(), max_retry_wait) + print(f"Retrying API call ({retry_count}/{self.max_retries}) after server error, waiting {wait_time:.2f}s...") + time.sleep(wait_time) + continue + else: + print("Max retries reached after server error.") + yield f"[API_ERROR: Max retries reached after server error {e.status_code}]" + return + else: # Don't retry on non-5xx status errors (like 4xx) + print("Non-retriable API status error.") + yield f"[API_ERROR: Non-retriable status {e.status_code}]" + return + + except Exception as e: # Catch other non-retriable errors during setup/call + print(f"\nNon-retriable error occurred: {e}") + traceback.print_exc() + yield f"[FATAL_ERROR: {e}]" + return # Stop generator + + # This part is reached only if all retries failed without returning/yielding error + print("\nStream generation failed after exhausting all retries.") + yield "[ERROR: Failed after all retries]" + + + def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty): + """工作流程的流式版本:返回文本生成器""" + # 如果提供了参考文件夹,则读取其内容 + if file_folder: + print(f"Reading context from folder: {file_folder}") + context = self.read_folder(file_folder) + if context: + # Append context carefully + user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}" + else: + print(f"Warning: Folder {file_folder} provided but no content read.") + + # 直接返回 generate_text_stream 的生成器 + print("Calling generate_text_stream...") + return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty) + # --- End Added Streaming Methods --- \ No newline at end of file