diff --git a/README.md b/README.md index a040faf..15cafb3 100644 --- a/README.md +++ b/README.md @@ -129,12 +129,13 @@ pip install numpy pandas opencv-python pillow openai - `content_temperature`, `content_top_p`, `content_presence_penalty`: 内容生成 API 相关参数 (默认为 0.3, 0.4, 1.5) - `request_timeout`: AI API 请求的超时时间(秒,默认 30) - `max_retries`: 请求超时或可重试网络错误时的最大重试次数(默认 3) -- `camera_image_subdir`: 存放原始照片的子目录名(相对于 `image_base_dir`,默认 "相机") - **注意:此项不再用于查找描述文件。** -- `modify_image_subdir`: 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`,默认 "modify") -- `output_collage_subdir`: 在每个变体输出目录中存放拼贴图的子目录名(默认 "collage_img") -- `output_poster_subdir`: 在每个变体输出目录中存放最终海报的子目录名(默认 "poster") -- `output_poster_filename`: 输出的最终海报文件名(默认 "poster.jpg") -- `poster_target_size`: 海报目标尺寸 `[宽, 高]`(默认 `[900, 1200]`) +- `stream_chunk_timeout`: 处理流式响应时,允许的两个数据块之间的最大等待时间(秒),用于防止流长时间挂起。 +- `camera_image_subdir`: 存放原始照片的子目录名(相对于 `image_base_dir`,默认 "相机") +- `modify_image_subdir`: 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`,默认 "modify") +- `output_collage_subdir`: 在每个变体输出目录中存放拼贴图的子目录名(默认 "collage_img") +- `output_poster_subdir`: 在每个变体输出目录中存放最终海报的子目录名(默认 "poster") +- `output_poster_filename`: 输出的最终海报文件名(默认 "poster.jpg") +- `poster_target_size`: 海报目标尺寸 `[宽, 高]`(默认 `[900, 1200]`) - `text_possibility`: 海报中第二段附加文字出现的概率 (默认 0.3) 项目提供了一个示例配置文件 `example_config.json`,请务必复制并修改: @@ -178,17 +179,20 @@ except Exception as e: print(f"Error loading config: {e}") sys.exit(1) -# 2. 初始化 AI Agent (读取超时/重试配置) +# 2. 初始化 AI Agent (读取超时/重试/流式块超时配置) ai_agent = None try: request_timeout = config.get("request_timeout", 30) max_retries = config.get("max_retries", 3) + stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # 新增:读取流式块超时 + ai_agent = AI_Agent( config["api_url"], config["model"], config["api_key"], timeout=request_timeout, - max_retries=max_retries + max_retries=max_retries, + stream_chunk_timeout=stream_chunk_timeout # 新增:传递流式块超时 ) # 3. 定义提示词和参数 @@ -272,4 +276,68 @@ This refactoring makes it straightforward to add new output handlers in the futu ### 配置文件说明 (Configuration) -主配置文件为 `poster_gen_config.json` (可以复制 `example_config.json` 并修改)。主要包含以下部分: \ No newline at end of file +主配置文件为 `poster_gen_config.json` (可以复制 `example_config.json` 并修改)。主要包含以下部分: + +#### 1. 基本配置 (Basic) + +* `api_url` (必须): 大语言模型 API 地址 (或预设名称如 'vllm', 'ali', 'kimi', 'doubao', 'deepseek') +* `api_key` (必须): API 密钥 +* `model` (必须): 使用的模型名称 +* `topic_system_prompt` (必须): 选题生成系统提示词文件路径 (应为要求JSON输出的版本) +* `topic_user_prompt` (必须): 选题生成基础用户提示词文件路径 +* `content_system_prompt` (必须): 内容生成系统提示词文件路径 +* `resource_dir` (必须): 包含**资源文件信息**的列表。列表中的每个元素是一个字典,包含: + * `type`: 资源类型,目前支持 `"Object"` (景点/对象信息), `"Description"` (对应的描述文件), `"Product"` (关联产品信息)。 + * `file_path`: 一个包含该类型所有资源文件**完整路径**的列表。 +* `prompts_dir` (必须): 存放 Demand/Style/Refer 等提示词片段的目录路径 +* `output_dir` (必须): 输出结果保存目录路径 +* `image_base_dir` (必须): **图片资源根目录绝对路径或相对路径** (用于查找源图片) +* `poster_assets_base_dir` (必须): **海报素材根目录绝对路径或相对路径** (用于查找字体、边框、贴纸、文本背景等) +* `num` (必须): (选题阶段)生成选题数量 +* `variants` (必须): (内容生成阶段)每个选题生成的变体数量 + +#### 2. 可选配置 (Optional) + +* `date` (可选, 默认空): 日期标记(用于选题生成提示词) +* `topic_temperature` (可选, 默认 0.2): 选题生成 API 温度参数 +* `topic_top_p` (可选, 默认 0.5): 选题生成 API top-p 参数 +* `topic_presence_penalty` (可选, 默认 1.5): 选题生成 API presence penalty 参数 +* `content_temperature` (可选, 默认 0.3): 内容生成 API 温度参数 +* `content_top_p` (可选, 默认 0.4): 内容生成 API top-p 参数 +* `content_presence_penalty` (可选, 默认 1.5): 内容生成 API presence penalty 参数 +* `request_timeout` (可选, 默认 30): 单个 HTTP 请求的超时时间(秒)。 +* `max_retries` (可选, 默认 3): API 请求失败时的最大重试次数。 +* `stream_chunk_timeout` (可选, 默认 60): 处理流式响应时,允许的两个数据块之间的最大等待时间(秒),用于防止流长时间挂起。 +* `camera_image_subdir` (可选, 默认 "相机"): 存放原始照片的子目录名(相对于 `image_base_dir`) +* `modify_image_subdir` (可选, 默认 "modify"): 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`) +* `output_collage_subdir` (可选, 默认 "collage_img"): 在每个变体输出目录中存放拼贴图的子目录名 + +#### 3. 选题与内容生成参数 (Topic & Content Generation) + +* `topic_temperature` (可选, 默认 0.2): 选题生成 API 温度参数 +* `topic_top_p` (可选, 默认 0.5): 选题生成 API top-p 参数 +* `topic_presence_penalty` (可选, 默认 1.5): 选题生成 API presence penalty 参数 +* `content_temperature` (可选, 默认 0.3): 内容生成 API 温度参数 +* `content_top_p` (可选, 默认 0.4): 内容生成 API top-p 参数 +* `content_presence_penalty` (可选, 默认 1.5): 内容生成 API presence penalty 参数 + +#### 4. 图片处理参数 (Image Processing) + +* `camera_image_subdir` (可选, 默认 "相机"): 存放原始照片的子目录名(相对于 `image_base_dir`) +* `modify_image_subdir` (可选, 默认 "modify"): 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`) +* `output_collage_subdir` (可选, 默认 "collage_img"): 在每个变体输出目录中存放拼贴图的子目录名 + +#### 5. 海报生成参数 (Poster Generation) + +* `output_poster_subdir` (可选, 默认 "poster"): 在每个变体输出目录中存放最终海报的子目录名 +* `output_poster_filename` (可选, 默认 "poster.jpg"): 输出的最终海报文件名 +* `poster_target_size` (可选, 默认 [900, 1200]): 海报目标尺寸 `[宽, 高]` +* `text_possibility` (可选, 默认 0.3): 海报中第二段附加文字出现的概率 + +#### 6. 其他参数 (Miscellaneous) + +* `request_timeout` (可选, 默认 30): 单个 HTTP 请求的超时时间(秒)。 +* `max_retries` (可选, 默认 3): API 请求失败时的最大重试次数。 +* `stream_chunk_timeout` (可选, 默认 60): 处理流式响应时,允许的两个数据块之间的最大等待时间(秒),用于防止流长时间挂起。 + +项目提供了一个示例配置文件 `example_config.json`,请务必复制并修改: diff --git a/core/__pycache__/ai_agent.cpython-312.pyc b/core/__pycache__/ai_agent.cpython-312.pyc index 54aa19a..6900778 100644 Binary files a/core/__pycache__/ai_agent.cpython-312.pyc and b/core/__pycache__/ai_agent.cpython-312.pyc differ diff --git a/core/ai_agent.py b/core/ai_agent.py index 6f1b897..e3dd8c3 100644 --- a/core/ai_agent.py +++ b/core/ai_agent.py @@ -4,15 +4,34 @@ import time import random import traceback import logging +import tiktoken # Configure basic logging for this module (or rely on root logger config) # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') # logger = logging.getLogger(__name__) # Alternative: use named logger +# Constants +MAX_RETRIES = 3 # Maximum number of retries for API calls +INITIAL_BACKOFF = 1 # Initial backoff time in seconds +MAX_BACKOFF = 16 # Maximum backoff time in seconds +STREAM_CHUNK_TIMEOUT = 10 # Timeout in seconds for receiving a chunk in stream + class AI_Agent(): """AI代理类,负责与AI模型交互生成文本内容""" - def __init__(self, base_url, model_name, api, timeout=30, max_retries=3): + def __init__(self, base_url, model_name, api, timeout=30, max_retries=3, stream_chunk_timeout=10): + """ + 初始化 AI 代理。 + + Args: + base_url (str): 模型服务的基础 URL 或预设名称 ('deepseek', 'vllm')。 + model_name (str): 要使用的模型名称。 + api (str): API 密钥。 + timeout (int, optional): 单次 API 请求的超时时间 (秒)。 Defaults to 30. + max_retries (int, optional): API 请求失败时的最大重试次数。 Defaults to 3. + stream_chunk_timeout (int, optional): 流式响应中两个数据块之间的最大等待时间 (秒)。 Defaults to 10. + """ + logging.info("Initializing AI Agent") self.url_list = { "ali": "https://dashscope.aliyuncs.com/compatible-mode/v1", "kimi": "https://api.moonshot.cn/v1", @@ -26,8 +45,9 @@ class AI_Agent(): self.model_name = model_name self.timeout = timeout self.max_retries = max_retries + self.stream_chunk_timeout = stream_chunk_timeout - print(f"Initializing AI Agent with base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}") + logging.info(f"AI Agent Settings: base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}, stream_chunk_timeout={self.stream_chunk_timeout}s") self.client = OpenAI( api_key=self.api, @@ -35,6 +55,12 @@ class AI_Agent(): timeout=self.timeout ) + try: + self.encoding = tiktoken.encoding_for_model(self.model_name) + except KeyError: + logging.warning(f"Encoding for model '{self.model_name}' not found. Using 'cl100k_base' encoding.") + self.encoding = tiktoken.get_encoding("cl100k_base") + def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty): """生成文本内容,并返回完整响应和token估计值""" logging.info(f"Generating text with model: {self.model_name}, temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}") @@ -176,98 +202,118 @@ class AI_Agent(): # --- Streaming Methods --- def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty): - """生成文本内容,并以生成器方式 yield 文本块""" - logging.info("Streaming Generation Started...") - logging.debug(f"Streaming System Prompt (first 100 chars): {system_prompt[:100]}...") - logging.debug(f"Streaming User Prompt (first 100 chars): {user_prompt[:100]}...") - logging.info(f"Streaming Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}") + """ + Generates text based on prompts using a streaming connection. Handles retries with exponential backoff. - retry_count = 0 - max_retry_wait = 10 + Args: + system_prompt: The system prompt for the AI. + user_prompt: The user prompt for the AI. + temperature: Sampling temperature. + top_p: Nucleus sampling parameter. - while retry_count <= self.max_retries: + Yields: + str: Chunks of the generated text. + + Raises: + Exception: If the API call fails after all retries. + """ + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + logging.info(f"Generating text stream with model: {self.model_name}") + + retries = 0 + backoff_time = INITIAL_BACKOFF + last_exception = None + + while retries < self.max_retries: try: - logging.info(f"Attempting API stream call (try {retry_count + 1}/{self.max_retries + 1})") - response = self.client.chat.completions.create( + logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream.") + stream = self.client.chat.completions.create( model=self.model_name, - messages=[{"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}], + messages=messages, temperature=temperature, top_p=top_p, - presence_penalty=presence_penalty, stream=True, - max_tokens=8192, - timeout=self.timeout, - extra_body={"repetition_penalty": 1.05}, + timeout=self.timeout # Overall request timeout ) - try: - logging.info("Stream connected, receiving content...") - yielded_something = False - for chunk in response: - if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: + chunk_iterator = iter(stream) + last_chunk_time = time.time() + + while True: + try: + # Check for timeout since last received chunk + if time.time() - last_chunk_time > self.stream_chunk_timeout: + raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.") + + chunk = next(chunk_iterator) + last_chunk_time = time.time() # Reset timer on successful chunk receipt + + if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content + # logging.debug(f"Received chunk: {content}") # Potentially very verbose yield content - yielded_something = True + elif chunk.choices and chunk.choices[0].finish_reason == 'stop': + logging.info("Stream finished.") + return # Successful completion + # Handle other finish reasons if needed, e.g., 'length' - if yielded_something: - logging.info("Stream finished successfully.") - else: - logging.warning("Stream finished, but no content was yielded.") - return + except StopIteration: + logging.info("Stream iterator exhausted.") + return # End of stream normally + except Timeout as e: + logging.warning(f"Stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).") + last_exception = e + break # Break inner loop to retry the stream creation + except (APITimeoutError, APIConnectionError, RateLimitError) as e: + logging.warning(f"API error during streaming: {type(e).__name__} - {e}. Retrying if possible ({retries + 1}/{self.max_retries}).") + last_exception = e + break # Break inner loop to retry the stream creation + except Exception as e: + logging.error(f"Unexpected error during streaming: {traceback.format_exc()}") + # Decide if this unexpected error should be retried or raised immediately + last_exception = e + # Option 1: Raise immediately + # raise e + # Option 2: Treat as retryable (use with caution) + break # Break inner loop to retry - except APIConnectionError as stream_err: - logging.warning(f"Stream connection error occurred: {stream_err}") - retry_count += 1 - if retry_count <= self.max_retries: - wait_time = min(2 ** retry_count + random.random(), max_retry_wait) - logging.warning(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...") - time.sleep(wait_time) - continue - else: - logging.error("Max retries reached after stream connection error.") - yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]" - return - - except Exception as stream_err: - logging.exception("Error occurred during stream processing:") - yield f"[STREAM_ERROR: {stream_err}]" - return - - except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e: - logging.warning(f"API Error occurred: {e}") - should_retry = False - if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)): - should_retry = True - elif isinstance(e, APIStatusError) and e.status_code >= 500: - should_retry = True - - if should_retry: - retry_count += 1 - if retry_count <= self.max_retries: - wait_time = min(2 ** retry_count + random.random(), max_retry_wait) - logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...") - time.sleep(wait_time) - continue - else: - logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting stream.") - yield "[API_ERROR: Max retries reached]" - return + # If we broke from the inner loop due to an error that needs retry + retries += 1 + if retries < self.max_retries: + logging.info(f"Retrying stream in {backoff_time} seconds...") + time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter + backoff_time = min(backoff_time * 2, MAX_BACKOFF) else: - logging.error(f"Non-retriable API error: {e}. Aborting stream.") - yield f"[API_ERROR: Non-retriable status {e.status_code if isinstance(e, APIStatusError) else 'Unknown'}]" - return - except Exception as e: - logging.exception("Non-retriable error occurred during API call setup:") - yield f"[FATAL_ERROR: {e}]" - return + logging.error(f"Stream generation failed after {self.max_retries} retries.") + raise last_exception or Exception("Stream generation failed after max retries.") - logging.error("Stream generation failed after exhausting all retries.") - yield "[ERROR: Failed after all retries]" + + except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e: + retries += 1 + last_exception = e + logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}") + if retries < self.max_retries: + logging.info(f"Retrying in {backoff_time} seconds...") + time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter + backoff_time = min(backoff_time * 2, MAX_BACKOFF) + else: + logging.error(f"API call failed after {self.max_retries} retries.") + raise last_exception + except Exception as e: + # Catch unexpected errors during stream setup + logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}") + raise e # Re-raise unexpected errors immediately + + # Should not be reached if logic is correct, but as a safeguard: + logging.error("Exited stream generation loop unexpectedly.") + raise last_exception or Exception("Stream generation failed.") def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty): - """工作流程的流式版本:返回文本生成器""" + """完整的工作流程(流式):读取文件夹(如果提供),然后生成文本流""" logging.info(f"Starting 'work_stream' process. File folder: {file_folder}") if file_folder: logging.info(f"Reading context from folder: {file_folder}") diff --git a/example_config.json b/example_config.json index 98a3b23..98f4e65 100644 --- a/example_config.json +++ b/example_config.json @@ -48,9 +48,11 @@ "content_presence_penalty": 1.5, "request_timeout": 30, "max_retries": 3, + "stream_chunk_timeout": 60, "output_collage_subdir": "collage_img", "output_poster_subdir": "poster", "output_poster_filename": "poster.jpg", "poster_target_size": [900, 1200], - "text_possibility": 0.3 + "text_possibility": 0.3, + "description_filename": "description.txt" } \ No newline at end of file diff --git a/examples/test_pipeline_steps.py b/examples/test_pipeline_steps.py index 0256a8f..5347c90 100644 --- a/examples/test_pipeline_steps.py +++ b/examples/test_pipeline_steps.py @@ -117,6 +117,7 @@ def main_test(): ai_api_key = config.get("api_key") request_timeout = config.get("request_timeout", 30) max_retries = config.get("max_retries", 3) + stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Read stream timeout if not all([ai_api_url, ai_model, ai_api_key]): raise ValueError("Missing required AI configuration (api_url, model, api_key)") logging.info("Initializing AI Agent for content generation test...") @@ -125,7 +126,8 @@ def main_test(): model_name=ai_model, api=ai_api_key, timeout=request_timeout, - max_retries=max_retries + max_retries=max_retries, + stream_chunk_timeout=stream_chunk_timeout # Pass stream timeout ) total_topics = len(topics_list) diff --git a/examples/test_stream.py b/examples/test_stream.py index 1f231ca..6beba94 100644 --- a/examples/test_stream.py +++ b/examples/test_stream.py @@ -66,6 +66,7 @@ def main(): ai_api_key = config.get("api_key") request_timeout = config.get("request_timeout", 30) max_retries = config.get("max_retries", 3) + stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Get stream chunk timeout # Check for required AI params if not all([ai_api_url, ai_model, ai_api_key]): @@ -80,7 +81,8 @@ def main(): model=ai_model, # Use extracted var api_key=ai_api_key, # Use extracted var timeout=request_timeout, - max_retries=max_retries + max_retries=max_retries, + stream_chunk_timeout=stream_chunk_timeout # Pass it here ) # Example call to work_stream diff --git a/main.py b/main.py index ed7fbbc..819748f 100644 --- a/main.py +++ b/main.py @@ -91,14 +91,16 @@ def generate_content_and_posters_step(config, run_id, topics_list, output_handle ai_agent = None try: # --- Initialize AI Agent for Content Generation --- - request_timeout = config.get("request_timeout", 30) # Get timeout from config - max_retries = config.get("max_retries", 3) # Get max_retries from config + request_timeout = config.get("request_timeout", 30) # Default 30 seconds + max_retries = config.get("max_retries", 3) # Default 3 retries + stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Default 60 seconds for stream chunk ai_agent = AI_Agent( config["api_url"], config["model"], config["api_key"], timeout=request_timeout, - max_retries=max_retries + max_retries=max_retries, + stream_chunk_timeout=stream_chunk_timeout ) logging.info("AI Agent for content generation initialized.")