timeout机制,但无效
This commit is contained in:
parent
4a4b37cba7
commit
0bdc8d9ae9
84
README.md
84
README.md
@ -129,12 +129,13 @@ pip install numpy pandas opencv-python pillow openai
|
|||||||
- `content_temperature`, `content_top_p`, `content_presence_penalty`: 内容生成 API 相关参数 (默认为 0.3, 0.4, 1.5)
|
- `content_temperature`, `content_top_p`, `content_presence_penalty`: 内容生成 API 相关参数 (默认为 0.3, 0.4, 1.5)
|
||||||
- `request_timeout`: AI API 请求的超时时间(秒,默认 30)
|
- `request_timeout`: AI API 请求的超时时间(秒,默认 30)
|
||||||
- `max_retries`: 请求超时或可重试网络错误时的最大重试次数(默认 3)
|
- `max_retries`: 请求超时或可重试网络错误时的最大重试次数(默认 3)
|
||||||
- `camera_image_subdir`: 存放原始照片的子目录名(相对于 `image_base_dir`,默认 "相机") - **注意:此项不再用于查找描述文件。**
|
- `stream_chunk_timeout`: 处理流式响应时,允许的两个数据块之间的最大等待时间(秒),用于防止流长时间挂起。
|
||||||
- `modify_image_subdir`: 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`,默认 "modify")
|
- `camera_image_subdir`: 存放原始照片的子目录名(相对于 `image_base_dir`,默认 "相机")
|
||||||
- `output_collage_subdir`: 在每个变体输出目录中存放拼贴图的子目录名(默认 "collage_img")
|
- `modify_image_subdir`: 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`,默认 "modify")
|
||||||
- `output_poster_subdir`: 在每个变体输出目录中存放最终海报的子目录名(默认 "poster")
|
- `output_collage_subdir`: 在每个变体输出目录中存放拼贴图的子目录名(默认 "collage_img")
|
||||||
- `output_poster_filename`: 输出的最终海报文件名(默认 "poster.jpg")
|
- `output_poster_subdir`: 在每个变体输出目录中存放最终海报的子目录名(默认 "poster")
|
||||||
- `poster_target_size`: 海报目标尺寸 `[宽, 高]`(默认 `[900, 1200]`)
|
- `output_poster_filename`: 输出的最终海报文件名(默认 "poster.jpg")
|
||||||
|
- `poster_target_size`: 海报目标尺寸 `[宽, 高]`(默认 `[900, 1200]`)
|
||||||
- `text_possibility`: 海报中第二段附加文字出现的概率 (默认 0.3)
|
- `text_possibility`: 海报中第二段附加文字出现的概率 (默认 0.3)
|
||||||
|
|
||||||
项目提供了一个示例配置文件 `example_config.json`,请务必复制并修改:
|
项目提供了一个示例配置文件 `example_config.json`,请务必复制并修改:
|
||||||
@ -178,17 +179,20 @@ except Exception as e:
|
|||||||
print(f"Error loading config: {e}")
|
print(f"Error loading config: {e}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# 2. 初始化 AI Agent (读取超时/重试配置)
|
# 2. 初始化 AI Agent (读取超时/重试/流式块超时配置)
|
||||||
ai_agent = None
|
ai_agent = None
|
||||||
try:
|
try:
|
||||||
request_timeout = config.get("request_timeout", 30)
|
request_timeout = config.get("request_timeout", 30)
|
||||||
max_retries = config.get("max_retries", 3)
|
max_retries = config.get("max_retries", 3)
|
||||||
|
stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # 新增:读取流式块超时
|
||||||
|
|
||||||
ai_agent = AI_Agent(
|
ai_agent = AI_Agent(
|
||||||
config["api_url"],
|
config["api_url"],
|
||||||
config["model"],
|
config["model"],
|
||||||
config["api_key"],
|
config["api_key"],
|
||||||
timeout=request_timeout,
|
timeout=request_timeout,
|
||||||
max_retries=max_retries
|
max_retries=max_retries,
|
||||||
|
stream_chunk_timeout=stream_chunk_timeout # 新增:传递流式块超时
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 定义提示词和参数
|
# 3. 定义提示词和参数
|
||||||
@ -273,3 +277,67 @@ This refactoring makes it straightforward to add new output handlers in the futu
|
|||||||
### 配置文件说明 (Configuration)
|
### 配置文件说明 (Configuration)
|
||||||
|
|
||||||
主配置文件为 `poster_gen_config.json` (可以复制 `example_config.json` 并修改)。主要包含以下部分:
|
主配置文件为 `poster_gen_config.json` (可以复制 `example_config.json` 并修改)。主要包含以下部分:
|
||||||
|
|
||||||
|
#### 1. 基本配置 (Basic)
|
||||||
|
|
||||||
|
* `api_url` (必须): 大语言模型 API 地址 (或预设名称如 'vllm', 'ali', 'kimi', 'doubao', 'deepseek')
|
||||||
|
* `api_key` (必须): API 密钥
|
||||||
|
* `model` (必须): 使用的模型名称
|
||||||
|
* `topic_system_prompt` (必须): 选题生成系统提示词文件路径 (应为要求JSON输出的版本)
|
||||||
|
* `topic_user_prompt` (必须): 选题生成基础用户提示词文件路径
|
||||||
|
* `content_system_prompt` (必须): 内容生成系统提示词文件路径
|
||||||
|
* `resource_dir` (必须): 包含**资源文件信息**的列表。列表中的每个元素是一个字典,包含:
|
||||||
|
* `type`: 资源类型,目前支持 `"Object"` (景点/对象信息), `"Description"` (对应的描述文件), `"Product"` (关联产品信息)。
|
||||||
|
* `file_path`: 一个包含该类型所有资源文件**完整路径**的列表。
|
||||||
|
* `prompts_dir` (必须): 存放 Demand/Style/Refer 等提示词片段的目录路径
|
||||||
|
* `output_dir` (必须): 输出结果保存目录路径
|
||||||
|
* `image_base_dir` (必须): **图片资源根目录绝对路径或相对路径** (用于查找源图片)
|
||||||
|
* `poster_assets_base_dir` (必须): **海报素材根目录绝对路径或相对路径** (用于查找字体、边框、贴纸、文本背景等)
|
||||||
|
* `num` (必须): (选题阶段)生成选题数量
|
||||||
|
* `variants` (必须): (内容生成阶段)每个选题生成的变体数量
|
||||||
|
|
||||||
|
#### 2. 可选配置 (Optional)
|
||||||
|
|
||||||
|
* `date` (可选, 默认空): 日期标记(用于选题生成提示词)
|
||||||
|
* `topic_temperature` (可选, 默认 0.2): 选题生成 API 温度参数
|
||||||
|
* `topic_top_p` (可选, 默认 0.5): 选题生成 API top-p 参数
|
||||||
|
* `topic_presence_penalty` (可选, 默认 1.5): 选题生成 API presence penalty 参数
|
||||||
|
* `content_temperature` (可选, 默认 0.3): 内容生成 API 温度参数
|
||||||
|
* `content_top_p` (可选, 默认 0.4): 内容生成 API top-p 参数
|
||||||
|
* `content_presence_penalty` (可选, 默认 1.5): 内容生成 API presence penalty 参数
|
||||||
|
* `request_timeout` (可选, 默认 30): 单个 HTTP 请求的超时时间(秒)。
|
||||||
|
* `max_retries` (可选, 默认 3): API 请求失败时的最大重试次数。
|
||||||
|
* `stream_chunk_timeout` (可选, 默认 60): 处理流式响应时,允许的两个数据块之间的最大等待时间(秒),用于防止流长时间挂起。
|
||||||
|
* `camera_image_subdir` (可选, 默认 "相机"): 存放原始照片的子目录名(相对于 `image_base_dir`)
|
||||||
|
* `modify_image_subdir` (可选, 默认 "modify"): 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`)
|
||||||
|
* `output_collage_subdir` (可选, 默认 "collage_img"): 在每个变体输出目录中存放拼贴图的子目录名
|
||||||
|
|
||||||
|
#### 3. 选题与内容生成参数 (Topic & Content Generation)
|
||||||
|
|
||||||
|
* `topic_temperature` (可选, 默认 0.2): 选题生成 API 温度参数
|
||||||
|
* `topic_top_p` (可选, 默认 0.5): 选题生成 API top-p 参数
|
||||||
|
* `topic_presence_penalty` (可选, 默认 1.5): 选题生成 API presence penalty 参数
|
||||||
|
* `content_temperature` (可选, 默认 0.3): 内容生成 API 温度参数
|
||||||
|
* `content_top_p` (可选, 默认 0.4): 内容生成 API top-p 参数
|
||||||
|
* `content_presence_penalty` (可选, 默认 1.5): 内容生成 API presence penalty 参数
|
||||||
|
|
||||||
|
#### 4. 图片处理参数 (Image Processing)
|
||||||
|
|
||||||
|
* `camera_image_subdir` (可选, 默认 "相机"): 存放原始照片的子目录名(相对于 `image_base_dir`)
|
||||||
|
* `modify_image_subdir` (可选, 默认 "modify"): 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`)
|
||||||
|
* `output_collage_subdir` (可选, 默认 "collage_img"): 在每个变体输出目录中存放拼贴图的子目录名
|
||||||
|
|
||||||
|
#### 5. 海报生成参数 (Poster Generation)
|
||||||
|
|
||||||
|
* `output_poster_subdir` (可选, 默认 "poster"): 在每个变体输出目录中存放最终海报的子目录名
|
||||||
|
* `output_poster_filename` (可选, 默认 "poster.jpg"): 输出的最终海报文件名
|
||||||
|
* `poster_target_size` (可选, 默认 [900, 1200]): 海报目标尺寸 `[宽, 高]`
|
||||||
|
* `text_possibility` (可选, 默认 0.3): 海报中第二段附加文字出现的概率
|
||||||
|
|
||||||
|
#### 6. 其他参数 (Miscellaneous)
|
||||||
|
|
||||||
|
* `request_timeout` (可选, 默认 30): 单个 HTTP 请求的超时时间(秒)。
|
||||||
|
* `max_retries` (可选, 默认 3): API 请求失败时的最大重试次数。
|
||||||
|
* `stream_chunk_timeout` (可选, 默认 60): 处理流式响应时,允许的两个数据块之间的最大等待时间(秒),用于防止流长时间挂起。
|
||||||
|
|
||||||
|
项目提供了一个示例配置文件 `example_config.json`,请务必复制并修改:
|
||||||
|
|||||||
Binary file not shown.
198
core/ai_agent.py
198
core/ai_agent.py
@ -4,15 +4,34 @@ import time
|
|||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
# Configure basic logging for this module (or rely on root logger config)
|
# Configure basic logging for this module (or rely on root logger config)
|
||||||
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
# logger = logging.getLogger(__name__) # Alternative: use named logger
|
# logger = logging.getLogger(__name__) # Alternative: use named logger
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
MAX_RETRIES = 3 # Maximum number of retries for API calls
|
||||||
|
INITIAL_BACKOFF = 1 # Initial backoff time in seconds
|
||||||
|
MAX_BACKOFF = 16 # Maximum backoff time in seconds
|
||||||
|
STREAM_CHUNK_TIMEOUT = 10 # Timeout in seconds for receiving a chunk in stream
|
||||||
|
|
||||||
class AI_Agent():
|
class AI_Agent():
|
||||||
"""AI代理类,负责与AI模型交互生成文本内容"""
|
"""AI代理类,负责与AI模型交互生成文本内容"""
|
||||||
|
|
||||||
def __init__(self, base_url, model_name, api, timeout=30, max_retries=3):
|
def __init__(self, base_url, model_name, api, timeout=30, max_retries=3, stream_chunk_timeout=10):
|
||||||
|
"""
|
||||||
|
初始化 AI 代理。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url (str): 模型服务的基础 URL 或预设名称 ('deepseek', 'vllm')。
|
||||||
|
model_name (str): 要使用的模型名称。
|
||||||
|
api (str): API 密钥。
|
||||||
|
timeout (int, optional): 单次 API 请求的超时时间 (秒)。 Defaults to 30.
|
||||||
|
max_retries (int, optional): API 请求失败时的最大重试次数。 Defaults to 3.
|
||||||
|
stream_chunk_timeout (int, optional): 流式响应中两个数据块之间的最大等待时间 (秒)。 Defaults to 10.
|
||||||
|
"""
|
||||||
|
logging.info("Initializing AI Agent")
|
||||||
self.url_list = {
|
self.url_list = {
|
||||||
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
"kimi": "https://api.moonshot.cn/v1",
|
"kimi": "https://api.moonshot.cn/v1",
|
||||||
@ -26,8 +45,9 @@ class AI_Agent():
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
|
self.stream_chunk_timeout = stream_chunk_timeout
|
||||||
|
|
||||||
print(f"Initializing AI Agent with base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}")
|
logging.info(f"AI Agent Settings: base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}, stream_chunk_timeout={self.stream_chunk_timeout}s")
|
||||||
|
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=self.api,
|
api_key=self.api,
|
||||||
@ -35,6 +55,12 @@ class AI_Agent():
|
|||||||
timeout=self.timeout
|
timeout=self.timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.encoding = tiktoken.encoding_for_model(self.model_name)
|
||||||
|
except KeyError:
|
||||||
|
logging.warning(f"Encoding for model '{self.model_name}' not found. Using 'cl100k_base' encoding.")
|
||||||
|
self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
||||||
"""生成文本内容,并返回完整响应和token估计值"""
|
"""生成文本内容,并返回完整响应和token估计值"""
|
||||||
logging.info(f"Generating text with model: {self.model_name}, temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
|
logging.info(f"Generating text with model: {self.model_name}, temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
|
||||||
@ -176,98 +202,118 @@ class AI_Agent():
|
|||||||
|
|
||||||
# --- Streaming Methods ---
|
# --- Streaming Methods ---
|
||||||
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
||||||
"""生成文本内容,并以生成器方式 yield 文本块"""
|
"""
|
||||||
logging.info("Streaming Generation Started...")
|
Generates text based on prompts using a streaming connection. Handles retries with exponential backoff.
|
||||||
logging.debug(f"Streaming System Prompt (first 100 chars): {system_prompt[:100]}...")
|
|
||||||
logging.debug(f"Streaming User Prompt (first 100 chars): {user_prompt[:100]}...")
|
|
||||||
logging.info(f"Streaming Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
|
|
||||||
|
|
||||||
retry_count = 0
|
Args:
|
||||||
max_retry_wait = 10
|
system_prompt: The system prompt for the AI.
|
||||||
|
user_prompt: The user prompt for the AI.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
top_p: Nucleus sampling parameter.
|
||||||
|
|
||||||
while retry_count <= self.max_retries:
|
Yields:
|
||||||
|
str: Chunks of the generated text.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the API call fails after all retries.
|
||||||
|
"""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
logging.info(f"Generating text stream with model: {self.model_name}")
|
||||||
|
|
||||||
|
retries = 0
|
||||||
|
backoff_time = INITIAL_BACKOFF
|
||||||
|
last_exception = None
|
||||||
|
|
||||||
|
while retries < self.max_retries:
|
||||||
try:
|
try:
|
||||||
logging.info(f"Attempting API stream call (try {retry_count + 1}/{self.max_retries + 1})")
|
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream.")
|
||||||
response = self.client.chat.completions.create(
|
stream = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=[{"role": "system", "content": system_prompt},
|
messages=messages,
|
||||||
{"role": "user", "content": user_prompt}],
|
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
stream=True,
|
stream=True,
|
||||||
max_tokens=8192,
|
timeout=self.timeout # Overall request timeout
|
||||||
timeout=self.timeout,
|
|
||||||
extra_body={"repetition_penalty": 1.05},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
chunk_iterator = iter(stream)
|
||||||
|
last_chunk_time = time.time()
|
||||||
|
|
||||||
|
while True:
|
||||||
try:
|
try:
|
||||||
logging.info("Stream connected, receiving content...")
|
# Check for timeout since last received chunk
|
||||||
yielded_something = False
|
if time.time() - last_chunk_time > self.stream_chunk_timeout:
|
||||||
for chunk in response:
|
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||||
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
|
|
||||||
|
chunk = next(chunk_iterator)
|
||||||
|
last_chunk_time = time.time() # Reset timer on successful chunk receipt
|
||||||
|
|
||||||
|
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||||
content = chunk.choices[0].delta.content
|
content = chunk.choices[0].delta.content
|
||||||
|
# logging.debug(f"Received chunk: {content}") # Potentially very verbose
|
||||||
yield content
|
yield content
|
||||||
yielded_something = True
|
elif chunk.choices and chunk.choices[0].finish_reason == 'stop':
|
||||||
|
logging.info("Stream finished.")
|
||||||
|
return # Successful completion
|
||||||
|
# Handle other finish reasons if needed, e.g., 'length'
|
||||||
|
|
||||||
if yielded_something:
|
except StopIteration:
|
||||||
logging.info("Stream finished successfully.")
|
logging.info("Stream iterator exhausted.")
|
||||||
else:
|
return # End of stream normally
|
||||||
logging.warning("Stream finished, but no content was yielded.")
|
except Timeout as e:
|
||||||
return
|
logging.warning(f"Stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
|
||||||
|
last_exception = e
|
||||||
except APIConnectionError as stream_err:
|
break # Break inner loop to retry the stream creation
|
||||||
logging.warning(f"Stream connection error occurred: {stream_err}")
|
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||||
retry_count += 1
|
logging.warning(f"API error during streaming: {type(e).__name__} - {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
|
||||||
if retry_count <= self.max_retries:
|
last_exception = e
|
||||||
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
break # Break inner loop to retry the stream creation
|
||||||
logging.warning(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logging.error("Max retries reached after stream connection error.")
|
|
||||||
yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]"
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception as stream_err:
|
|
||||||
logging.exception("Error occurred during stream processing:")
|
|
||||||
yield f"[STREAM_ERROR: {stream_err}]"
|
|
||||||
return
|
|
||||||
|
|
||||||
except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e:
|
|
||||||
logging.warning(f"API Error occurred: {e}")
|
|
||||||
should_retry = False
|
|
||||||
if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)):
|
|
||||||
should_retry = True
|
|
||||||
elif isinstance(e, APIStatusError) and e.status_code >= 500:
|
|
||||||
should_retry = True
|
|
||||||
|
|
||||||
if should_retry:
|
|
||||||
retry_count += 1
|
|
||||||
if retry_count <= self.max_retries:
|
|
||||||
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
|
||||||
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting stream.")
|
|
||||||
yield "[API_ERROR: Max retries reached]"
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
logging.error(f"Non-retriable API error: {e}. Aborting stream.")
|
|
||||||
yield f"[API_ERROR: Non-retriable status {e.status_code if isinstance(e, APIStatusError) else 'Unknown'}]"
|
|
||||||
return
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("Non-retriable error occurred during API call setup:")
|
logging.error(f"Unexpected error during streaming: {traceback.format_exc()}")
|
||||||
yield f"[FATAL_ERROR: {e}]"
|
# Decide if this unexpected error should be retried or raised immediately
|
||||||
return
|
last_exception = e
|
||||||
|
# Option 1: Raise immediately
|
||||||
|
# raise e
|
||||||
|
# Option 2: Treat as retryable (use with caution)
|
||||||
|
break # Break inner loop to retry
|
||||||
|
|
||||||
logging.error("Stream generation failed after exhausting all retries.")
|
# If we broke from the inner loop due to an error that needs retry
|
||||||
yield "[ERROR: Failed after all retries]"
|
retries += 1
|
||||||
|
if retries < self.max_retries:
|
||||||
|
logging.info(f"Retrying stream in {backoff_time} seconds...")
|
||||||
|
time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter
|
||||||
|
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||||
|
else:
|
||||||
|
logging.error(f"Stream generation failed after {self.max_retries} retries.")
|
||||||
|
raise last_exception or Exception("Stream generation failed after max retries.")
|
||||||
|
|
||||||
|
|
||||||
|
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||||
|
retries += 1
|
||||||
|
last_exception = e
|
||||||
|
logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
|
||||||
|
if retries < self.max_retries:
|
||||||
|
logging.info(f"Retrying in {backoff_time} seconds...")
|
||||||
|
time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter
|
||||||
|
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||||
|
else:
|
||||||
|
logging.error(f"API call failed after {self.max_retries} retries.")
|
||||||
|
raise last_exception
|
||||||
|
except Exception as e:
|
||||||
|
# Catch unexpected errors during stream setup
|
||||||
|
logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}")
|
||||||
|
raise e # Re-raise unexpected errors immediately
|
||||||
|
|
||||||
|
# Should not be reached if logic is correct, but as a safeguard:
|
||||||
|
logging.error("Exited stream generation loop unexpectedly.")
|
||||||
|
raise last_exception or Exception("Stream generation failed.")
|
||||||
|
|
||||||
|
|
||||||
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
|
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
|
||||||
"""工作流程的流式版本:返回文本生成器"""
|
"""完整的工作流程(流式):读取文件夹(如果提供),然后生成文本流"""
|
||||||
logging.info(f"Starting 'work_stream' process. File folder: {file_folder}")
|
logging.info(f"Starting 'work_stream' process. File folder: {file_folder}")
|
||||||
if file_folder:
|
if file_folder:
|
||||||
logging.info(f"Reading context from folder: {file_folder}")
|
logging.info(f"Reading context from folder: {file_folder}")
|
||||||
|
|||||||
@ -48,9 +48,11 @@
|
|||||||
"content_presence_penalty": 1.5,
|
"content_presence_penalty": 1.5,
|
||||||
"request_timeout": 30,
|
"request_timeout": 30,
|
||||||
"max_retries": 3,
|
"max_retries": 3,
|
||||||
|
"stream_chunk_timeout": 60,
|
||||||
"output_collage_subdir": "collage_img",
|
"output_collage_subdir": "collage_img",
|
||||||
"output_poster_subdir": "poster",
|
"output_poster_subdir": "poster",
|
||||||
"output_poster_filename": "poster.jpg",
|
"output_poster_filename": "poster.jpg",
|
||||||
"poster_target_size": [900, 1200],
|
"poster_target_size": [900, 1200],
|
||||||
"text_possibility": 0.3
|
"text_possibility": 0.3,
|
||||||
|
"description_filename": "description.txt"
|
||||||
}
|
}
|
||||||
@ -117,6 +117,7 @@ def main_test():
|
|||||||
ai_api_key = config.get("api_key")
|
ai_api_key = config.get("api_key")
|
||||||
request_timeout = config.get("request_timeout", 30)
|
request_timeout = config.get("request_timeout", 30)
|
||||||
max_retries = config.get("max_retries", 3)
|
max_retries = config.get("max_retries", 3)
|
||||||
|
stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Read stream timeout
|
||||||
if not all([ai_api_url, ai_model, ai_api_key]):
|
if not all([ai_api_url, ai_model, ai_api_key]):
|
||||||
raise ValueError("Missing required AI configuration (api_url, model, api_key)")
|
raise ValueError("Missing required AI configuration (api_url, model, api_key)")
|
||||||
logging.info("Initializing AI Agent for content generation test...")
|
logging.info("Initializing AI Agent for content generation test...")
|
||||||
@ -125,7 +126,8 @@ def main_test():
|
|||||||
model_name=ai_model,
|
model_name=ai_model,
|
||||||
api=ai_api_key,
|
api=ai_api_key,
|
||||||
timeout=request_timeout,
|
timeout=request_timeout,
|
||||||
max_retries=max_retries
|
max_retries=max_retries,
|
||||||
|
stream_chunk_timeout=stream_chunk_timeout # Pass stream timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
total_topics = len(topics_list)
|
total_topics = len(topics_list)
|
||||||
|
|||||||
@ -66,6 +66,7 @@ def main():
|
|||||||
ai_api_key = config.get("api_key")
|
ai_api_key = config.get("api_key")
|
||||||
request_timeout = config.get("request_timeout", 30)
|
request_timeout = config.get("request_timeout", 30)
|
||||||
max_retries = config.get("max_retries", 3)
|
max_retries = config.get("max_retries", 3)
|
||||||
|
stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Get stream chunk timeout
|
||||||
|
|
||||||
# Check for required AI params
|
# Check for required AI params
|
||||||
if not all([ai_api_url, ai_model, ai_api_key]):
|
if not all([ai_api_url, ai_model, ai_api_key]):
|
||||||
@ -80,7 +81,8 @@ def main():
|
|||||||
model=ai_model, # Use extracted var
|
model=ai_model, # Use extracted var
|
||||||
api_key=ai_api_key, # Use extracted var
|
api_key=ai_api_key, # Use extracted var
|
||||||
timeout=request_timeout,
|
timeout=request_timeout,
|
||||||
max_retries=max_retries
|
max_retries=max_retries,
|
||||||
|
stream_chunk_timeout=stream_chunk_timeout # Pass it here
|
||||||
)
|
)
|
||||||
|
|
||||||
# Example call to work_stream
|
# Example call to work_stream
|
||||||
|
|||||||
8
main.py
8
main.py
@ -91,14 +91,16 @@ def generate_content_and_posters_step(config, run_id, topics_list, output_handle
|
|||||||
ai_agent = None
|
ai_agent = None
|
||||||
try:
|
try:
|
||||||
# --- Initialize AI Agent for Content Generation ---
|
# --- Initialize AI Agent for Content Generation ---
|
||||||
request_timeout = config.get("request_timeout", 30) # Get timeout from config
|
request_timeout = config.get("request_timeout", 30) # Default 30 seconds
|
||||||
max_retries = config.get("max_retries", 3) # Get max_retries from config
|
max_retries = config.get("max_retries", 3) # Default 3 retries
|
||||||
|
stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Default 60 seconds for stream chunk
|
||||||
ai_agent = AI_Agent(
|
ai_agent = AI_Agent(
|
||||||
config["api_url"],
|
config["api_url"],
|
||||||
config["model"],
|
config["model"],
|
||||||
config["api_key"],
|
config["api_key"],
|
||||||
timeout=request_timeout,
|
timeout=request_timeout,
|
||||||
max_retries=max_retries
|
max_retries=max_retries,
|
||||||
|
stream_chunk_timeout=stream_chunk_timeout
|
||||||
)
|
)
|
||||||
logging.info("AI Agent for content generation initialized.")
|
logging.info("AI Agent for content generation initialized.")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user