diff --git a/README.md b/README.md index b78cdf3..e634a23 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,20 @@ - **模块化设计**:核心功能(配置加载、提示词管理、AI交互、选题、内容生成、海报制作)分离,方便维护和扩展 - **配置驱动**:通过配置文件集中管理所有运行参数 +## 新功能: 流式输出处理 + +TravelContentCreator 现已支持三种流式输出处理方法,提供了更灵活的 AI 文本生成体验: + +- **同步流式响应**: 使用流式 API 但返回完整响应 +- **回调式流式响应**: 通过回调函数处理每个文本块 +- **异步流式响应**: 使用异步生成器返回文本流 + +这些功能大大提升了长文本生成的用户体验和系统响应性。 + +详细文档请参阅: +- [流式处理文档](docs/streaming.md) +- [流式处理演示](examples/test_stream.py) + ## 快速开始 ### 1. 环境准备 diff --git a/core/__pycache__/ai_agent.cpython-312.pyc b/core/__pycache__/ai_agent.cpython-312.pyc index d9f9d54..ce80f7e 100644 Binary files a/core/__pycache__/ai_agent.cpython-312.pyc and b/core/__pycache__/ai_agent.cpython-312.pyc differ diff --git a/core/__pycache__/contentGen.cpython-312.pyc b/core/__pycache__/contentGen.cpython-312.pyc index 53c9a07..62f1d66 100644 Binary files a/core/__pycache__/contentGen.cpython-312.pyc and b/core/__pycache__/contentGen.cpython-312.pyc differ diff --git a/core/ai_agent.py b/core/ai_agent.py index cac6a3d..fc73d21 100644 --- a/core/ai_agent.py +++ b/core/ai_agent.py @@ -5,6 +5,8 @@ import random import traceback import logging import tiktoken +import asyncio +from asyncio import TimeoutError as AsyncTimeoutError # Configure basic logging for this module (or rely on root logger config) # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -16,6 +18,16 @@ INITIAL_BACKOFF = 1 # Initial backoff time in seconds MAX_BACKOFF = 16 # Maximum backoff time in seconds STREAM_CHUNK_TIMEOUT = 10 # Timeout in seconds for receiving a chunk in stream +# 自定义超时异常 +class Timeout(Exception): + """Raised when a stream chunk timeout occurs.""" + def __init__(self, message="Stream chunk timeout occurred."): + self.message = message + super().__init__(self.message) + + def __str__(self): + return self.message + class AI_Agent(): """AI代理类,负责与AI模型交互生成文本内容""" @@ -198,7 +210,7 @@ class AI_Agent(): time_end = time.time() time_cost = time_end - time_start - tokens = None + tokens = len(result) * 1.3 logging.info(f"'work' completed in {time_cost:.2f}s. Estimated tokens: {tokens}") return result, tokens, time_cost @@ -220,9 +232,10 @@ class AI_Agent(): user_prompt: The user prompt for the AI. temperature: Sampling temperature. top_p: Nucleus sampling parameter. + presence_penalty: Presence penalty parameter. - Yields: - str: Chunks of the generated text. + Returns: + str: The complete generated text. Raises: Exception: If the API call fails after all retries. @@ -236,6 +249,7 @@ class AI_Agent(): retries = 0 backoff_time = INITIAL_BACKOFF last_exception = None + full_response = "" while retries < self.max_retries: try: @@ -245,6 +259,7 @@ class AI_Agent(): messages=messages, temperature=temperature, top_p=top_p, + presence_penalty=presence_penalty, stream=True, timeout=self.timeout # Overall request timeout ) @@ -263,16 +278,16 @@ class AI_Agent(): if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content + full_response += content # logging.debug(f"Received chunk: {content}") # Potentially very verbose - yield content elif chunk.choices and chunk.choices[0].finish_reason == 'stop': logging.info("Stream finished.") - return # Successful completion + return full_response # Return complete response # Handle other finish reasons if needed, e.g., 'length' except StopIteration: logging.info("Stream iterator exhausted.") - return # End of stream normally + return full_response # Return complete response except Timeout as e: logging.warning(f"Stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).") last_exception = e @@ -300,7 +315,6 @@ class AI_Agent(): logging.error(f"Stream generation failed after {self.max_retries} retries.") raise last_exception or Exception("Stream generation failed after max retries.") - except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e: retries += 1 last_exception = e @@ -311,16 +325,246 @@ class AI_Agent(): backoff_time = min(backoff_time * 2, MAX_BACKOFF) else: logging.error(f"API call failed after {self.max_retries} retries.") - raise last_exception + return f"[生成内容失败: {str(last_exception)}]" # Return error message as string except Exception as e: # Catch unexpected errors during stream setup logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}") - raise e # Re-raise unexpected errors immediately + return f"[生成内容失败: {str(e)}]" # Return error message as string # Should not be reached if logic is correct, but as a safeguard: logging.error("Exited stream generation loop unexpectedly.") - raise last_exception or Exception("Stream generation failed.") + return f"[生成内容失败: 超过最大重试次数]" # Return error message as string + def generate_text_stream_with_callback(self, system_prompt, user_prompt, + callback_fn, temperature=0.7, top_p=0.9, + presence_penalty=0.0): + """生成文本流并通过回调函数处理每个块 + + Args: + system_prompt: 系统提示词 + user_prompt: 用户提示词 + callback_fn: 处理每个文本块的回调函数,接收(content, is_last, is_timeout, is_error, error)参数 + temperature: 温度参数 + top_p: 核采样参数 + presence_penalty: 存在惩罚参数 + + Returns: + str: 完整的响应文本 + """ + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + logging.info(f"Generating text stream with callback using model: {self.model_name}") + + retries = 0 + backoff_time = INITIAL_BACKOFF + last_exception = None + full_response = "" + + while retries < self.max_retries: + try: + logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream with callback.") + stream = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + temperature=temperature, + top_p=top_p, + presence_penalty=presence_penalty, + stream=True, + timeout=self.timeout + ) + + chunk_iterator = iter(stream) + last_chunk_time = time.time() + + try: + while True: + try: + # 检查上次接收块以来的超时 + if time.time() - last_chunk_time > self.stream_chunk_timeout: + callback_fn("", is_last=True, is_timeout=True, is_error=False, error=None) + raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.") + + chunk = next(chunk_iterator) + last_chunk_time = time.time() # 成功接收块后重置计时器 + + content = "" + is_last = False + + if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + full_response += content + + if chunk.choices and chunk.choices[0].finish_reason == 'stop': + is_last = True + + # 调用回调函数处理块 + callback_fn(content, is_last=is_last, is_timeout=False, is_error=False, error=None) + + if is_last: + logging.info("Stream with callback finished normally.") + return full_response # 成功完成,返回完整响应 + + except StopIteration: + # 正常结束流 + callback_fn("", is_last=True, is_timeout=False, is_error=False, error=None) + logging.info("Stream iterator exhausted normally.") + return full_response + + except Timeout as e: + logging.warning(f"Stream chunk timeout: {e}") + last_exception = e + # 超时信息已通过回调传递,此处不需要再次调用回调 + except (APITimeoutError, APIConnectionError, RateLimitError) as e: + logging.warning(f"API error during streaming with callback: {type(e).__name__} - {e}") + callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e)) + last_exception = e + except Exception as e: + logging.error(f"Unexpected error during streaming with callback: {traceback.format_exc()}") + callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e)) + last_exception = e + + # 执行重试逻辑 + retries += 1 + if retries < self.max_retries: + retry_msg = f"将在 {backoff_time:.2f} 秒后重试..." + logging.info(f"Retrying stream in {backoff_time} seconds...") + callback_fn(f"\n[{retry_msg}]\n", is_last=False, is_timeout=False, is_error=False, error=None) + time.sleep(backoff_time + random.uniform(0, 1)) # 添加随机抖动 + backoff_time = min(backoff_time * 2, MAX_BACKOFF) + else: + error_msg = f"API call failed after {self.max_retries} retries: {str(last_exception)}" + logging.error(error_msg) + callback_fn(f"\n[{error_msg}]\n", is_last=True, is_timeout=False, is_error=True, error=str(last_exception)) + return full_response # 返回已收集的部分响应 + + except Exception as e: + logging.error(f"Error setting up stream with callback: {traceback.format_exc()}") + callback_fn("", is_last=True, is_timeout=False, is_error=True, error=str(e)) + return f"[生成内容失败: {str(e)}]" + + # 作为安全措施 + error_msg = "超出最大重试次数" + logging.error(error_msg) + callback_fn(f"\n[{error_msg}]\n", is_last=True, is_timeout=False, is_error=True, error=error_msg) + return full_response + + async def async_generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0): + """异步生成文本流 + + Args: + system_prompt: 系统提示词 + user_prompt: 用户提示词 + temperature: 温度参数 + top_p: 核采样参数 + presence_penalty: 存在惩罚参数 + + Yields: + str: 生成的文本块 + + Raises: + Exception: 如果API调用在所有重试后失败 + """ + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + logging.info(f"Asynchronously generating text stream with model: {self.model_name}") + + retries = 0 + backoff_time = INITIAL_BACKOFF + last_exception = None + + while retries < self.max_retries: + try: + logging.debug(f"Async attempt {retries + 1}/{self.max_retries} to generate text stream.") + # 创建新的客户端用于异步操作 + async_client = OpenAI( + api_key=self.api, + base_url=self.base_url, + timeout=self.timeout + ) + + stream = await async_client.chat.completions.create( + model=self.model_name, + messages=messages, + temperature=temperature, + top_p=top_p, + presence_penalty=presence_penalty, + stream=True, + timeout=self.timeout + ) + + last_chunk_time = time.time() + try: + async for chunk in stream: + # 检查上次接收块以来的超时 + current_time = time.time() + if current_time - last_chunk_time > self.stream_chunk_timeout: + raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.") + + last_chunk_time = current_time # 成功接收块后重置计时器 + + if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + yield content + + logging.info("Async stream finished normally.") + return # 成功完成 + + except AsyncTimeoutError as e: + logging.warning(f"Async stream timeout: {e}") + last_exception = e + except Timeout as e: + logging.warning(f"Async stream chunk timeout: {e}") + last_exception = e + except Exception as e: + logging.error(f"Error during async streaming: {traceback.format_exc()}") + last_exception = e + + # 执行重试逻辑 + retries += 1 + if retries < self.max_retries: + logging.info(f"Retrying async stream in {backoff_time} seconds...") + await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 使用异步睡眠 + backoff_time = min(backoff_time * 2, MAX_BACKOFF) + else: + logging.error(f"Async API call failed after {self.max_retries} retries.") + raise last_exception or Exception("Async stream generation failed after max retries.") + + except (Timeout, AsyncTimeoutError, APITimeoutError, APIConnectionError, RateLimitError) as e: + retries += 1 + last_exception = e + logging.warning(f"Async attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}") + if retries < self.max_retries: + logging.info(f"Retrying async stream in {backoff_time} seconds...") + await asyncio.sleep(backoff_time + random.uniform(0, 1)) # 使用异步睡眠 + backoff_time = min(backoff_time * 2, MAX_BACKOFF) + else: + logging.error(f"Async API call failed after {self.max_retries} retries.") + raise last_exception + except Exception as e: + logging.error(f"Unexpected error setting up async stream: {traceback.format_exc()}") + raise e # 立即重新引发意外错误 + + # 作为安全措施 + logging.error("Exited async stream generation loop unexpectedly.") + raise last_exception or Exception("Async stream generation failed.") + + async def async_work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty): + """异步完整工作流程:读取文件夹(如果提供),然后生成文本流""" + logging.info(f"Starting async 'work_stream' process. File folder: {file_folder}") + if file_folder: + logging.info(f"Reading context from folder: {file_folder}") + context = self.read_folder(file_folder) + if context: + user_prompt = f"{user_prompt.strip()}\n\n--- 参考资料 ---\n{context.strip()}" + else: + logging.warning(f"Folder {file_folder} provided but no content read.") + + logging.info("Calling async_generate_text_stream...") + return self.async_generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty) def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty): """完整的工作流程(流式):读取文件夹(如果提供),然后生成文本流""" @@ -333,6 +577,6 @@ class AI_Agent(): else: logging.warning(f"Folder {file_folder} provided but no content read.") - logging.info("Calling generate_text_stream...") - return self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty) - # --- End Added Streaming Methods --- \ No newline at end of file + full_response = self.generate_text_stream(system_prompt, user_prompt, temperature, top_p, presence_penalty) + return full_response + # --- End Streaming Methods --- \ No newline at end of file diff --git a/core/contentGen.py b/core/contentGen.py index dbd2f75..7c17e00 100644 --- a/core/contentGen.py +++ b/core/contentGen.py @@ -45,6 +45,11 @@ class ContentGenerator: self.top_p = 0.8 self.presence_penalty = 1.2 + # 设置日志 + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + self.logger = logging.getLogger(__name__) + def load_infomation(self, info_directory_path): """ @@ -193,142 +198,149 @@ class ContentGenerator: 返回: 生成的海报内容 """ - if system_prompt is None: - system_prompt = """ - 你是一名资深海报设计师,有丰富的爆款海报设计经验,你现在要为旅游景点做宣传,在小红书上发布大量宣传海报。你的主要工作目标有2个: - 1、你要根据我给你的图片描述和笔记推文内容,设计图文匹配的海报。 - 2、为海报设计文案,文案的<第一个小标题>和<第二个小标题>之间你需要检查是否逻辑关系合理,你将通过先去生成<第二个小标题>关于景区亮点的部分,再去综合判断<第一个小标题>应该如何搭配组合更符合两个小标题的逻辑再生成<第一个小标题>。 - - 其中,生成三类标题文案的通用性要求如下: - 1、生成的<大标题>字数必须小于8个字符 - 2、生成的<第一个小标题>字数和<第二个小标题>字数,两者都必须小8个字符 - 3、标题和文案都应符合中国社会主义核心价值观 - - 接下来先开始生成<大标题>部分,由于海报是用来宣传旅游景点,生成的海报<大标题>必须使用以下8种格式之一: - ①地名+景点名(例如福建厦门鼓浪屿/厦门鼓浪屿); - ②地名+景点名+plog; - ③拿捏+地名+景点名; - ④地名+景点名+攻略; - ⑤速通+地名+景点名 - ⑥推荐!+地名+景点名 - ⑦勇闯!+地名+景点名 - ⑧收藏!+地名+景点名 - 你需要随机挑选一种格式生成对应景点的文案,但是格式除了上面8种不可以有其他任何格式;同时尽量保证每一种格式出现的频率均衡。 - 接下来先去生成<第二个小标题>,<第二个小标题>文案的创作必须遵循以下原则: - 请根据笔记内容和图片识别,用极简的文字概括这篇笔记和图片中景点的特色亮点,其中你可以参考以下词汇进行创作,这段文案字数控制6-8字符以内; - - 特色亮点可能会出现的词汇不完全举例:非遗、古建、绝佳山水、祈福圣地、研学圣地、解压天堂、中国小瑞士、秘境竹筏游等等类型词汇 - - 接下来再去生成<第一个小标题>,<第一个小标题>文案的创作必须遵循以下原则: - 这部分文案创作公式有5种,分别为: - ①<受众人群画像>+<痛点词> - ②<受众人群画像> - ③<痛点词> - ④<受众人群画像>+ | +<痛点词> - ⑤<痛点词>+ | +<受众人群画像> - 请你根据实际笔记内容,结合这部分文案创作公式,需要结合<受众人群画像>和<痛点词>时,必须根据<第二个小标题>的景点特征和所对应的完整笔记推文内容主旨,特征挑选对应<受众人群画像>和<痛点词>。 - - 我给你提供受众人群画像库和痛点词库如下: - 1、受众人群画像库:情侣党、亲子游、合家游、银发族、亲子研学、学生党、打工人、周边游、本地人、穷游党、性价比、户外人、美食党、出片 - 2、痛点词库:3天2夜、必去、看了都哭了、不能错过、一定要来、问爆了、超全攻略、必打卡、强推、懒人攻略、必游榜、小众打卡、狂喜等等。 - - 你需要为每个请求至少生成{poster_num}个海报设计。请使用JSON格式输出结果,结构如下: - - ```json + full_response = "" + timeout = 60 # 请求超时时间(秒) + + if not system_prompt: + # 使用默认系统提示词 + system_prompt = f""" + 你是一个专业的文案处理专家,擅长从文章中提取关键信息并生成吸引人的标题和简短描述。 + 现在,我需要你根据提供的文章内容,生成{poster_num}个海报的文案配置。 + + 每个配置包含: + 1. main_title:主标题,简短有力,突出景点特点 + 2. texts:两句简短文本,每句不超过15字,描述景点特色或游玩体验 + + 以JSON数组格式返回配置,示例: [ - { - "index": 1, - "main_title": "主标题内容", - "texts": ["第一个小标题", "第二个小标题"] - }, - { - "index": 2, - "main_title": "主标题内容", - "texts": ["第一个小标题", "第二个小标题"] - } - // ... 更多海报 + {{ + "main_title": "泰宁古城", + "texts": ["千年古韵","匠心独运"] + }}, + ... ] - ``` - - 确保生成的数量与用户要求的数量一致。只生成上述JSON格式内容,不要有其他任何额外内容。 + + 仅返回JSON数据,不需要任何额外解释。确保生成的标题和文本能够准确反映文章提到的景点特色。 """ - user_content = f""" - 海报数量:{poster_num}; - 景区介绍:{self.add_description}; - 推文内容:{tweet_content}; - """ - - # 最终响应内容 - full_response = "" + if self.add_description: + # 创建用户内容,包括info信息和tweet_content + user_content = f""" + 以下是需要你处理的信息: + + 关于景点的描述: + {self.add_description} + + 推文内容: + {tweet_content} + + 请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。 + 确保主标题(main_title)简短有力,每个text不超过15字,并能准确反映景点特色。 + """ + else: + # 仅使用tweet_content + user_content = f""" + 以下是需要你处理的推文内容: + {tweet_content} + + 请根据这些信息,生成{poster_num}个海报文案配置,以JSON数组格式返回。 + 确保主标题(main_title)简短有力,每个text不超过15字,并能准确反映景点特色。 + """ + + self.logger.info(f"正在生成{poster_num}个海报文案配置") # 创建临时客户端 temp_client = self._create_temp_client() - # 如果创建客户端失败,直接使用备用方案 - if temp_client is None: - print("创建OpenAI客户端失败,使用备用方案生成内容") - return 404 - else: - # 添加重试机制 + if temp_client: + # 重试逻辑 for retry in range(max_retries): try: - print(f"尝试连接API (尝试 {retry+1}/{max_retries})...") + self.logger.info(f"尝试生成内容 (尝试 {retry+1}/{max_retries})") - # 计算退避时间(指数退避策略):0, 2, 4, 8, 16...秒 - if retry > 0: - backoff_time = min(2 ** (retry - 1) * 2, 30) # 最大等待30秒 - print(f"等待 {backoff_time} 秒后重试...") - time.sleep(backoff_time) + # 定义流式响应处理回调函数 + def handle_stream_chunk(chunk, is_last=False, is_timeout=False, is_error=False, error=None): + nonlocal full_response + + if chunk: + full_response += chunk + # 实时输出到控制台 + print(chunk, end="", flush=True) + + if is_last: + print("\n") # 输出完成后换行 + if is_timeout: + print("警告: 响应流超时") + if is_error: + print(f"错误: {error}") - # 设置超时时间随重试次数递增 - timeout = 30 + (retry * 30) # 30, 60, 90, ...秒 - - chat_response = temp_client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_content} - ], - stream=True, - temperature=self.temperature, - top_p=self.top_p, - presence_penalty=self.presence_penalty, - timeout=timeout # 设置请求超时时间 + # 使用AI_Agent的新回调方式 + from core.ai_agent import AI_Agent + ai_agent = AI_Agent( + self.api_base_url, + self.model_name, + self.api_key, + timeout=timeout, + max_retries=max_retries, + stream_chunk_timeout=30 # 流式块超时时间 ) - # 获取响应内容 - for chunk in chat_response: - if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: - content = chunk.choices[0].delta.content - full_response += content - print(content, end="", flush=True) - if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop": - break - - print("\n") # 输出完成后换行 - - # 成功获取响应,跳出重试循环 - break + # 使用回调方式处理流式响应 + try: + full_response = ai_agent.generate_text_stream_with_callback( + system_prompt, + user_content, + handle_stream_chunk, + temperature=self.temperature, + top_p=self.top_p, + presence_penalty=self.presence_penalty + ) + + # 如果成功生成内容,跳出重试循环 + ai_agent.close() + break + + except Exception as e: + error_msg = str(e) + self.logger.error(f"AI生成错误: {error_msg}") + ai_agent.close() + + # 继续重试逻辑 + if retry + 1 >= max_retries: + self.logger.warning("已达到最大重试次数,使用备用方案...") + # 生成备用内容 + full_response = self._generate_fallback_content(poster_num) + else: + self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会") except Exception as e: error_msg = str(e) - print(f"API连接错误 (尝试 {retry+1}/{max_retries}): {error_msg}") + self.logger.error(f"API连接错误 (尝试 {retry+1}/{max_retries}): {error_msg}") # 如果已经达到最大重试次数 if retry + 1 >= max_retries: - print("已达到最大重试次数,使用备用方案...") + self.logger.warning("已达到最大重试次数,使用备用方案...") # 生成备用内容(简单模板) full_response = self._generate_fallback_content(poster_num) else: - print(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会") + self.logger.info(f"将在稍后重试,还剩 {max_retries - retry - 1} 次重试机会") # 关闭临时客户端 self._close_client(temp_client) - # 生成时间戳 return full_response + def _generate_fallback_content(self, poster_num): + """生成备用内容,当API调用失败时使用""" + self.logger.info("生成备用内容") + default_configs = [] + for i in range(poster_num): + default_configs.append({ + "main_title": f"景点风光 {i+1}", + "texts": ["自然美景", "人文体验"] + }) + return json.dumps(default_configs, ensure_ascii=False) + def save_result(self, full_response): """ 保存生成结果到文件 diff --git a/docs/streaming.md b/docs/streaming.md new file mode 100644 index 0000000..7f1b1f2 --- /dev/null +++ b/docs/streaming.md @@ -0,0 +1,209 @@ +# 流式输出处理功能 + +TravelContentCreator 现在支持三种不同的流式输出处理方法,让您能够更灵活地处理 AI 模型生成的文本内容。这些方法都在 `AI_Agent` 类中实现,可以根据不同的使用场景进行选择。 + +## 为什么需要流式处理? + +流式处理(Streaming)相比于传统的一次性返回完整响应的方式有以下优势: + +1. **实时性**:内容生成的同时即可开始处理,无需等待完整响应 +2. **用户体验更好**:可以实现"打字机效果",让用户看到文本逐步生成的过程 +3. **更早检测错误**:可以在响应生成过程中及早发现问题 +4. **长文本处理更高效**:特别适合生成较长的内容,避免长时间等待 + +## 流式处理方法 + +`AI_Agent` 类提供了三种不同模式的流式处理方法: + +### 1. 同步流式响应 (generate_text_stream) + +这种方法虽然使用了流式 API 连接,但会将所有的输出整合后一次性返回,适合简单的 API 调用。 + +```python +def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty): + """ + 生成文本内容(使用流式API但返回完整响应) + + Args: + system_prompt: 系统提示词 + user_prompt: 用户提示词 + temperature: 温度参数 + top_p: 核采样参数 + presence_penalty: 存在惩罚参数 + + Returns: + str: 完整的生成文本 + """ +``` + +使用示例: + +```python +agent = AI_Agent(base_url, model_name, api_key, timeout=30, stream_chunk_timeout=10) +result = agent.generate_text_stream(system_prompt, user_prompt, 0.7, 0.9, 0.0) +print(result) # 输出完整的生成结果 +``` + +### 2. 回调式流式响应 (generate_text_stream_with_callback) + +这种方法使用回调函数来处理流中的每个文本块,非常适合实时显示、分析或保存过程数据,更加灵活。 + +```python +def generate_text_stream_with_callback(self, system_prompt, user_prompt, + callback_fn, temperature=0.7, top_p=0.9, + presence_penalty=0.0): + """ + 生成文本流并通过回调函数处理每个块 + + Args: + system_prompt: 系统提示词 + user_prompt: 用户提示词 + callback_fn: 处理每个文本块的回调函数,接收(content, is_last, is_timeout, is_error, error)参数 + temperature: 温度参数 + top_p: 核采样参数 + presence_penalty: 存在惩罚参数 + + Returns: + str: 完整的响应文本 + """ +``` + +回调函数应符合以下格式: + +```python +def my_callback(content, is_last=False, is_timeout=False, is_error=False, error=None): + """ + 处理流式响应的回调函数 + + Args: + content: 文本块内容 + is_last: 是否为最后一个块 + is_timeout: 是否发生超时 + is_error: 是否发生错误 + error: 错误信息 + """ + if content: + print(content, end="", flush=True) # 实时打印 + + # 处理特殊情况 + if is_last: + print("\n完成生成") + if is_timeout: + print("警告: 响应流超时") + if is_error: + print(f"错误: {error}") +``` + +使用示例: + +```python +agent = AI_Agent(base_url, model_name, api_key, timeout=30, stream_chunk_timeout=10) +result = agent.generate_text_stream_with_callback( + system_prompt, + user_prompt, + my_callback, # 传入回调函数 + temperature=0.7, + top_p=0.9, + presence_penalty=0.0 +) +``` + +### 3. 异步流式响应 (async_generate_text_stream) + +这种方法基于 `asyncio`,返回一个异步生成器,非常适合与其他异步操作集成,例如在异步网络应用中使用。 + +```python +async def async_generate_text_stream(self, system_prompt, user_prompt, temperature=0.7, top_p=0.9, presence_penalty=0.0): + """ + 异步生成文本流 + + Args: + system_prompt: 系统提示词 + user_prompt: 用户提示词 + temperature: 温度参数 + top_p: 核采样参数 + presence_penalty: 存在惩罚参数 + + Yields: + str: 生成的文本块 + + Raises: + Exception: 如果API调用在所有重试后失败 + """ +``` + +使用示例: + +```python +async def demo_async_stream(): + agent = AI_Agent(base_url, model_name, api_key, timeout=30, stream_chunk_timeout=10) + + full_response = "" + try: + # 使用异步生成器 + async for chunk in agent.async_generate_text_stream( + system_prompt, + user_prompt, + temperature=0.7, + top_p=0.9, + presence_penalty=0.0 + ): + print(chunk, end="", flush=True) # 实时显示 + full_response += chunk + + # 这里可以同时执行其他异步操作 + # await some_other_async_operation() + + except Exception as e: + print(f"错误: {e}") + finally: + agent.close() + +# 在异步环境中运行 +asyncio.run(demo_async_stream()) +``` + +## 超时处理 + +所有流式处理方法都支持两种超时设置: + +1. **全局请求超时**:控制整个API请求的最大持续时间 +2. **流块超时**:控制接收连续两个数据块之间的最大等待时间 + +在创建 `AI_Agent` 实例时设置: + +```python +agent = AI_Agent( + base_url="your_api_base_url", + model_name="your_model_name", + api="your_api_key", + timeout=60, # 整体请求超时(秒) + max_retries=3, # 最大重试次数 + stream_chunk_timeout=10 # 流块超时(秒) +) +``` + +## 完整示例 + +我们提供了一个完整的示例脚本,演示所有三种流式处理方法的使用: + +``` +examples/test_stream.py +``` + +运行方式: + +```bash +cd TravelContentCreator +python examples/test_stream.py +``` + +这将依次演示三种流式处理方法,并展示它们的输出和性能差异。 + +## 在WebUI中的应用 + +TravelContentCreator的WebUI已经集成了基于回调的流式处理,实现了生成内容的实时显示,大大提升了用户体验,特别是在生成长篇内容时。 + +## 在自定义项目中使用 + +如果您想在自己的项目中使用这些流式处理功能,只需导入 `AI_Agent` 类并按照上述示例使用相应的方法即可。所有流式处理方法都内置了完善的错误处理和重试机制,提高了生产环境中的稳定性。 \ No newline at end of file diff --git a/examples/README.md b/examples/README.md index 0cfa948..d7a8354 100644 --- a/examples/README.md +++ b/examples/README.md @@ -156,6 +156,31 @@ python examples/test_stream.py - 演示实时输出文本块的处理方式 - 说明如何收集完整响应和估计token数 +## 流式处理演示 (test_stream.py) + +`test_stream.py` 是一个演示脚本,展示了AI_Agent类中新增的三种不同流式输出处理方法: + +1. **同步流式响应** (generate_text_stream) - 使用流式API但返回完整响应 +2. **回调式流式响应** (generate_text_stream_with_callback) - 通过回调函数处理每个文本块 +3. **异步流式响应** (async_generate_text_stream) - 使用异步生成器返回文本流 + +### 运行演示 + +```bash +cd TravelContentCreator +python examples/test_stream.py +``` + +演示脚本会依次执行这三种方法,并展示它们的输出和性能差异。 + +### 更多信息 + +关于流式处理功能的详细文档,请参阅: + +``` +docs/streaming.md +``` + ## 其他示例 `generate_poster.py` 是一个简化的海报生成示例,主要用于快速测试特定图片的海报效果。 diff --git a/examples/README_streaming.md b/examples/README_streaming.md new file mode 100644 index 0000000..9953513 --- /dev/null +++ b/examples/README_streaming.md @@ -0,0 +1,103 @@ +# 流式处理功能说明 + +本文档介绍了Travel Content Creator中新增的三种流式输出处理方式,以及如何使用它们来优化内容生成体验。 + +## 新增的流式处理方式 + +我们在`AI_Agent`类中实现了三种不同的流式处理方式: + +1. **同步流式响应** (`generate_text_stream`) + - 已修改为返回完整响应,不再是生成器 + - 内部仍使用流式请求以获得更好的超时控制 + - 适用于简单的API调用场景 + +2. **基于回调的流式响应** (`generate_text_stream_with_callback`) + - 通过回调函数处理每个文本块 + - 可在回调中实现实时显示、分析或保存 + - 最灵活的选项,适合需要自定义处理流程的场景 + +3. **异步流式响应** (`async_generate_text_stream`) + - 基于`asyncio`的异步生成器 + - 适用于需要与其他异步操作集成的场景 + - 可在保持响应性的同时处理长时间运行的请求 + +## 演示脚本 + +我们提供了一个示例脚本`test_stream.py`,演示了如何使用这三种流式处理方式: + +```bash +# 从项目根目录运行 +python examples/test_stream.py +``` + +## 回调函数示例 + +以下是使用回调函数处理流式响应的例子: + +```python +def handle_chunk(chunk, is_last=False, is_timeout=False, is_error=False, error=None): + # 处理文本块 + if chunk: + print(chunk, end="", flush=True) # 实时显示 + + # 处理结束状态 + if is_last: + if is_timeout: + print("\n流式响应超时") + if is_error: + print(f"\n发生错误: {error}") + +# 使用回调函数进行流式处理 +response = agent.generate_text_stream_with_callback( + system_prompt, + user_prompt, + handle_chunk, # 传入回调函数 + temperature=0.7 +) +``` + +## 异步流式处理示例 + +以下是使用异步方式处理流式响应的例子: + +```python +async def process_stream(): + async for chunk in agent.async_generate_text_stream( + system_prompt, + user_prompt, + temperature=0.7 + ): + print(chunk, end="", flush=True) # 实时显示 + + # 可以同时执行其他异步操作 + await other_async_task() + +# 在异步环境中运行 +asyncio.run(process_stream()) +``` + +## 超时处理 + +所有流式处理方法都支持超时控制: + +1. **全局请求超时**:控制整个请求的最长等待时间 +2. **流式块超时**:控制两个连续文本块之间的最长等待时间 + +这些参数可以在创建`AI_Agent`实例时设置: + +```python +agent = AI_Agent( + api_url="http://localhost:8000/v1/", + model="qwen", + api_key="EMPTY", + timeout=30, # 全局请求超时(秒) + stream_chunk_timeout=10 # 流式块超时(秒) +) +``` + +## 注意事项 + +1. 新的流式处理方法内置了重试机制和错误处理 +2. 回调方式提供了最佳的灵活性和控制 +3. 异步方式最适合需要保持UI响应性的应用 +4. 所有方法都会在完成时自动关闭流式请求 \ No newline at end of file diff --git a/examples/test_stream.py b/examples/test_stream.py index 6beba94..8489f62 100644 --- a/examples/test_stream.py +++ b/examples/test_stream.py @@ -1,129 +1,175 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +""" +测试AI_Agent的流式处理方法 + +此脚本演示TravelContentCreator中AI_Agent类的三种流式输出处理方法: +1. 同步流式响应 (generate_text_stream) +2. 回调式流式响应 (generate_text_stream_with_callback) +3. 异步流式响应 (async_generate_text_stream) +""" + import os import sys -import json +import asyncio import time -import logging +from pathlib import Path -# Determine the project root directory (assuming examples/ is one level down) -PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -if PROJECT_ROOT not in sys.path: - sys.path.append(PROJECT_ROOT) +# 添加项目根目录到Python路径 +project_root = str(Path(__file__).parent.parent) +if project_root not in sys.path: + sys.path.insert(0, project_root) -# Now import from core -try: - from core.ai_agent import AI_Agent -except ImportError as e: - logging.critical(f"Failed to import AI_Agent. Ensure '{PROJECT_ROOT}' is in sys.path and core/ai_agent.py exists. Error: {e}") - sys.exit(1) +from core.ai_agent import AI_Agent -def load_config(config_path): - """Loads configuration from a JSON file.""" - try: - with open(config_path, 'r', encoding='utf-8') as f: - config = json.load(f) - logging.info(f"Config loaded successfully from {config_path}") - return config - except FileNotFoundError: - logging.error(f"Error: Configuration file not found at {config_path}") - return None - except json.JSONDecodeError: - logging.error(f"Error: Could not decode JSON from {config_path}") - return None - except Exception as e: - logging.exception(f"An unexpected error occurred loading config {config_path}:") - return None +# 示例提示词 +SYSTEM_PROMPT = """你是一个专业的旅游内容创作助手,请根据用户的提示生成相关内容。""" +USER_PROMPT = """请为我生成一篇关于福建泰宁古城的旅游攻略,包括著名景点、美食推荐和最佳游玩季节。字数控制在300字以内。""" -def main(): - # --- Basic Logging Setup --- - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' +def print_separator(title): + """打印分隔线和标题""" + print("\n" + "="*50) + print(f" {title} ".center(50, "=")) + print("="*50 + "\n") + +def demo_sync_stream(): + """演示同步流式响应方法""" + print_separator("同步流式响应 (generate_text_stream)") + + # 创建AI_Agent实例 + agent = AI_Agent( + base_url="vllm", # 使用本地vLLM服务 + model_name="qwen2-7b-instruct", # 或其他您配置的模型名称 + api="EMPTY", # vLLM不需要API key + timeout=60, # 整体请求超时时间(秒) + stream_chunk_timeout=10 # 流式块超时时间(秒) ) - # --- End Logging Setup --- + + print("开始生成内容...") + start_time = time.time() + + # 使用同步流式方法 + result = agent.generate_text_stream( + SYSTEM_PROMPT, + USER_PROMPT, + temperature=0.7, + top_p=0.9, + presence_penalty=0.0 + ) + + end_time = time.time() + + print(f"\n\n完整生成内容:\n{result}") + print(f"\n生成完成! 耗时: {end_time - start_time:.2f}秒") + + # 关闭agent + agent.close() - logging.info("Starting AI Agent Stream Test...") - - # Load configuration (adjust path relative to this script) - config_path = os.path.join(PROJECT_ROOT, "poster_gen_config.json") - config = load_config(config_path) - if config is None: - logging.critical("Failed to load configuration. Exiting test.") - sys.exit(1) - - # Example Prompts - system_prompt = "你是一个乐于助人的AI助手,擅长写短篇故事。" - user_prompt = "请写一个关于旅行机器人的短篇故事,它在一个充满异国情调的星球上发现了新的生命形式。" - - ai_agent = None - try: - # --- Extract AI Agent parameters from config --- - ai_api_url = config.get("api_url") - ai_model = config.get("model") - ai_api_key = config.get("api_key") - request_timeout = config.get("request_timeout", 30) - max_retries = config.get("max_retries", 3) - stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Get stream chunk timeout +def demo_callback_stream(): + """演示回调式流式响应方法""" + print_separator("回调式流式响应 (generate_text_stream_with_callback)") + + # 创建AI_Agent实例 + agent = AI_Agent( + base_url="vllm", + model_name="qwen2-7b-instruct", + api="EMPTY", + timeout=60, + stream_chunk_timeout=10 + ) + + # 定义回调函数 + def my_callback(content, is_last=False, is_timeout=False, is_error=False, error=None): + """处理流式响应的回调函数""" + if content: + # 实时打印内容,不换行 + print(content, end="", flush=True) - # Check for required AI params - if not all([ai_api_url, ai_model, ai_api_key]): - logging.critical("Missing required AI configuration (api_url, model, api_key) in config. Exiting test.") - sys.exit(1) - # --- End Extract AI Agent params --- + if is_last: + print("\n") + if is_timeout: + print("警告: 响应流超时") + if is_error: + print(f"错误: {error}") + + print("开始生成内容...") + start_time = time.time() + + # 使用回调式流式方法 + result = agent.generate_text_stream_with_callback( + SYSTEM_PROMPT, + USER_PROMPT, + my_callback, + temperature=0.7, + top_p=0.9, + presence_penalty=0.0 + ) + + end_time = time.time() + print(f"\n生成完成! 耗时: {end_time - start_time:.2f}秒") + + # 关闭agent + agent.close() - logging.info("Initializing AI Agent for stream test...") - # Initialize AI_Agent using extracted parameters - ai_agent = AI_Agent( - api_url=ai_api_url, # Use extracted var - model=ai_model, # Use extracted var - api_key=ai_api_key, # Use extracted var - timeout=request_timeout, - max_retries=max_retries, - stream_chunk_timeout=stream_chunk_timeout # Pass it here +async def demo_async_stream(): + """演示异步流式响应方法""" + print_separator("异步流式响应 (async_generate_text_stream)") + + # 创建AI_Agent实例 + agent = AI_Agent( + base_url="vllm", + model_name="qwen2-7b-instruct", + api="EMPTY", + timeout=60, + stream_chunk_timeout=10 + ) + + print("开始生成内容...") + start_time = time.time() + full_response = "" + + # 使用异步流式方法 + try: + async_stream = agent.async_generate_text_stream( + SYSTEM_PROMPT, + USER_PROMPT, + temperature=0.7, + top_p=0.9, + presence_penalty=0.0 ) - - # Example call to work_stream - logging.info("Calling ai_agent.work_stream...") - # Extract generation parameters from config - temperature = config.get("content_temperature", 0.7) # Use a relevant temperature setting - top_p = config.get("content_top_p", 0.9) - presence_penalty = config.get("content_presence_penalty", 0.0) - - start_time = time.time() - stream_generator = ai_agent.work_stream( - system_prompt=system_prompt, - user_prompt=user_prompt, - info_directory=None, # No extra context folder for this test - temperature=temperature, - top_p=top_p, - presence_penalty=presence_penalty - ) - - # Process the stream - logging.info("Processing stream response:") - full_response = "" - for chunk in stream_generator: - print(chunk, end="", flush=True) # Keep print for stream output - full_response += chunk - - end_time = time.time() - logging.info(f"\n--- Stream Finished ---") - logging.info(f"Total time: {end_time - start_time:.2f} seconds") - logging.info(f"Total characters received: {len(full_response)}") - - except KeyError as e: - logging.error(f"Configuration error: Missing key '{e}'. Please check '{config_path}'.") + + # 异步迭代流 + async for content in async_stream: + # 累积完整响应 + full_response += content + # 实时打印内容 + print(content, end="", flush=True) + except Exception as e: - logging.exception("An error occurred during the stream test:") - finally: - # Ensure the agent is closed - if ai_agent: - logging.info("Closing AI Agent...") - ai_agent.close() - logging.info("AI Agent closed.") + print(f"\n生成过程中出错: {e}") + + end_time = time.time() + print(f"\n\n生成完成! 耗时: {end_time - start_time:.2f}秒") + + # 关闭agent + agent.close() + +async def main(): + """主函数""" + print("Testing AI_Agent streaming methods...") + + # 1. 测试同步流式响应 + demo_sync_stream() + + # 2. 测试回调式流式响应 + demo_callback_stream() + + # 3. 测试异步流式响应 + await demo_async_stream() + + print("\n所有测试完成!") if __name__ == "__main__": - main() \ No newline at end of file + # 运行异步主函数 + asyncio.run(main()) \ No newline at end of file