timeout机制,但无效
This commit is contained in:
parent
4a4b37cba7
commit
0bdc8d9ae9
86
README.md
86
README.md
@ -129,12 +129,13 @@ pip install numpy pandas opencv-python pillow openai
|
||||
- `content_temperature`, `content_top_p`, `content_presence_penalty`: 内容生成 API 相关参数 (默认为 0.3, 0.4, 1.5)
|
||||
- `request_timeout`: AI API 请求的超时时间(秒,默认 30)
|
||||
- `max_retries`: 请求超时或可重试网络错误时的最大重试次数(默认 3)
|
||||
- `camera_image_subdir`: 存放原始照片的子目录名(相对于 `image_base_dir`,默认 "相机") - **注意:此项不再用于查找描述文件。**
|
||||
- `modify_image_subdir`: 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`,默认 "modify")
|
||||
- `output_collage_subdir`: 在每个变体输出目录中存放拼贴图的子目录名(默认 "collage_img")
|
||||
- `output_poster_subdir`: 在每个变体输出目录中存放最终海报的子目录名(默认 "poster")
|
||||
- `output_poster_filename`: 输出的最终海报文件名(默认 "poster.jpg")
|
||||
- `poster_target_size`: 海报目标尺寸 `[宽, 高]`(默认 `[900, 1200]`)
|
||||
- `stream_chunk_timeout`: 处理流式响应时,允许的两个数据块之间的最大等待时间(秒),用于防止流长时间挂起。
|
||||
- `camera_image_subdir`: 存放原始照片的子目录名(相对于 `image_base_dir`,默认 "相机")
|
||||
- `modify_image_subdir`: 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`,默认 "modify")
|
||||
- `output_collage_subdir`: 在每个变体输出目录中存放拼贴图的子目录名(默认 "collage_img")
|
||||
- `output_poster_subdir`: 在每个变体输出目录中存放最终海报的子目录名(默认 "poster")
|
||||
- `output_poster_filename`: 输出的最终海报文件名(默认 "poster.jpg")
|
||||
- `poster_target_size`: 海报目标尺寸 `[宽, 高]`(默认 `[900, 1200]`)
|
||||
- `text_possibility`: 海报中第二段附加文字出现的概率 (默认 0.3)
|
||||
|
||||
项目提供了一个示例配置文件 `example_config.json`,请务必复制并修改:
|
||||
@ -178,17 +179,20 @@ except Exception as e:
|
||||
print(f"Error loading config: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# 2. 初始化 AI Agent (读取超时/重试配置)
|
||||
# 2. 初始化 AI Agent (读取超时/重试/流式块超时配置)
|
||||
ai_agent = None
|
||||
try:
|
||||
request_timeout = config.get("request_timeout", 30)
|
||||
max_retries = config.get("max_retries", 3)
|
||||
stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # 新增:读取流式块超时
|
||||
|
||||
ai_agent = AI_Agent(
|
||||
config["api_url"],
|
||||
config["model"],
|
||||
config["api_key"],
|
||||
timeout=request_timeout,
|
||||
max_retries=max_retries
|
||||
max_retries=max_retries,
|
||||
stream_chunk_timeout=stream_chunk_timeout # 新增:传递流式块超时
|
||||
)
|
||||
|
||||
# 3. 定义提示词和参数
|
||||
@ -272,4 +276,68 @@ This refactoring makes it straightforward to add new output handlers in the futu
|
||||
|
||||
### 配置文件说明 (Configuration)
|
||||
|
||||
主配置文件为 `poster_gen_config.json` (可以复制 `example_config.json` 并修改)。主要包含以下部分:
|
||||
主配置文件为 `poster_gen_config.json` (可以复制 `example_config.json` 并修改)。主要包含以下部分:
|
||||
|
||||
#### 1. 基本配置 (Basic)
|
||||
|
||||
* `api_url` (必须): 大语言模型 API 地址 (或预设名称如 'vllm', 'ali', 'kimi', 'doubao', 'deepseek')
|
||||
* `api_key` (必须): API 密钥
|
||||
* `model` (必须): 使用的模型名称
|
||||
* `topic_system_prompt` (必须): 选题生成系统提示词文件路径 (应为要求JSON输出的版本)
|
||||
* `topic_user_prompt` (必须): 选题生成基础用户提示词文件路径
|
||||
* `content_system_prompt` (必须): 内容生成系统提示词文件路径
|
||||
* `resource_dir` (必须): 包含**资源文件信息**的列表。列表中的每个元素是一个字典,包含:
|
||||
* `type`: 资源类型,目前支持 `"Object"` (景点/对象信息), `"Description"` (对应的描述文件), `"Product"` (关联产品信息)。
|
||||
* `file_path`: 一个包含该类型所有资源文件**完整路径**的列表。
|
||||
* `prompts_dir` (必须): 存放 Demand/Style/Refer 等提示词片段的目录路径
|
||||
* `output_dir` (必须): 输出结果保存目录路径
|
||||
* `image_base_dir` (必须): **图片资源根目录绝对路径或相对路径** (用于查找源图片)
|
||||
* `poster_assets_base_dir` (必须): **海报素材根目录绝对路径或相对路径** (用于查找字体、边框、贴纸、文本背景等)
|
||||
* `num` (必须): (选题阶段)生成选题数量
|
||||
* `variants` (必须): (内容生成阶段)每个选题生成的变体数量
|
||||
|
||||
#### 2. 可选配置 (Optional)
|
||||
|
||||
* `date` (可选, 默认空): 日期标记(用于选题生成提示词)
|
||||
* `topic_temperature` (可选, 默认 0.2): 选题生成 API 温度参数
|
||||
* `topic_top_p` (可选, 默认 0.5): 选题生成 API top-p 参数
|
||||
* `topic_presence_penalty` (可选, 默认 1.5): 选题生成 API presence penalty 参数
|
||||
* `content_temperature` (可选, 默认 0.3): 内容生成 API 温度参数
|
||||
* `content_top_p` (可选, 默认 0.4): 内容生成 API top-p 参数
|
||||
* `content_presence_penalty` (可选, 默认 1.5): 内容生成 API presence penalty 参数
|
||||
* `request_timeout` (可选, 默认 30): 单个 HTTP 请求的超时时间(秒)。
|
||||
* `max_retries` (可选, 默认 3): API 请求失败时的最大重试次数。
|
||||
* `stream_chunk_timeout` (可选, 默认 60): 处理流式响应时,允许的两个数据块之间的最大等待时间(秒),用于防止流长时间挂起。
|
||||
* `camera_image_subdir` (可选, 默认 "相机"): 存放原始照片的子目录名(相对于 `image_base_dir`)
|
||||
* `modify_image_subdir` (可选, 默认 "modify"): 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`)
|
||||
* `output_collage_subdir` (可选, 默认 "collage_img"): 在每个变体输出目录中存放拼贴图的子目录名
|
||||
|
||||
#### 3. 选题与内容生成参数 (Topic & Content Generation)
|
||||
|
||||
* `topic_temperature` (可选, 默认 0.2): 选题生成 API 温度参数
|
||||
* `topic_top_p` (可选, 默认 0.5): 选题生成 API top-p 参数
|
||||
* `topic_presence_penalty` (可选, 默认 1.5): 选题生成 API presence penalty 参数
|
||||
* `content_temperature` (可选, 默认 0.3): 内容生成 API 温度参数
|
||||
* `content_top_p` (可选, 默认 0.4): 内容生成 API top-p 参数
|
||||
* `content_presence_penalty` (可选, 默认 1.5): 内容生成 API presence penalty 参数
|
||||
|
||||
#### 4. 图片处理参数 (Image Processing)
|
||||
|
||||
* `camera_image_subdir` (可选, 默认 "相机"): 存放原始照片的子目录名(相对于 `image_base_dir`)
|
||||
* `modify_image_subdir` (可选, 默认 "modify"): 存放处理后/用于拼贴的图片的子目录名(相对于 `image_base_dir`)
|
||||
* `output_collage_subdir` (可选, 默认 "collage_img"): 在每个变体输出目录中存放拼贴图的子目录名
|
||||
|
||||
#### 5. 海报生成参数 (Poster Generation)
|
||||
|
||||
* `output_poster_subdir` (可选, 默认 "poster"): 在每个变体输出目录中存放最终海报的子目录名
|
||||
* `output_poster_filename` (可选, 默认 "poster.jpg"): 输出的最终海报文件名
|
||||
* `poster_target_size` (可选, 默认 [900, 1200]): 海报目标尺寸 `[宽, 高]`
|
||||
* `text_possibility` (可选, 默认 0.3): 海报中第二段附加文字出现的概率
|
||||
|
||||
#### 6. 其他参数 (Miscellaneous)
|
||||
|
||||
* `request_timeout` (可选, 默认 30): 单个 HTTP 请求的超时时间(秒)。
|
||||
* `max_retries` (可选, 默认 3): API 请求失败时的最大重试次数。
|
||||
* `stream_chunk_timeout` (可选, 默认 60): 处理流式响应时,允许的两个数据块之间的最大等待时间(秒),用于防止流长时间挂起。
|
||||
|
||||
项目提供了一个示例配置文件 `example_config.json`,请务必复制并修改:
|
||||
|
||||
Binary file not shown.
198
core/ai_agent.py
198
core/ai_agent.py
@ -4,15 +4,34 @@ import time
|
||||
import random
|
||||
import traceback
|
||||
import logging
|
||||
import tiktoken
|
||||
|
||||
# Configure basic logging for this module (or rely on root logger config)
|
||||
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
# logger = logging.getLogger(__name__) # Alternative: use named logger
|
||||
|
||||
# Constants
|
||||
MAX_RETRIES = 3 # Maximum number of retries for API calls
|
||||
INITIAL_BACKOFF = 1 # Initial backoff time in seconds
|
||||
MAX_BACKOFF = 16 # Maximum backoff time in seconds
|
||||
STREAM_CHUNK_TIMEOUT = 10 # Timeout in seconds for receiving a chunk in stream
|
||||
|
||||
class AI_Agent():
|
||||
"""AI代理类,负责与AI模型交互生成文本内容"""
|
||||
|
||||
def __init__(self, base_url, model_name, api, timeout=30, max_retries=3):
|
||||
def __init__(self, base_url, model_name, api, timeout=30, max_retries=3, stream_chunk_timeout=10):
|
||||
"""
|
||||
初始化 AI 代理。
|
||||
|
||||
Args:
|
||||
base_url (str): 模型服务的基础 URL 或预设名称 ('deepseek', 'vllm')。
|
||||
model_name (str): 要使用的模型名称。
|
||||
api (str): API 密钥。
|
||||
timeout (int, optional): 单次 API 请求的超时时间 (秒)。 Defaults to 30.
|
||||
max_retries (int, optional): API 请求失败时的最大重试次数。 Defaults to 3.
|
||||
stream_chunk_timeout (int, optional): 流式响应中两个数据块之间的最大等待时间 (秒)。 Defaults to 10.
|
||||
"""
|
||||
logging.info("Initializing AI Agent")
|
||||
self.url_list = {
|
||||
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"kimi": "https://api.moonshot.cn/v1",
|
||||
@ -26,8 +45,9 @@ class AI_Agent():
|
||||
self.model_name = model_name
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.stream_chunk_timeout = stream_chunk_timeout
|
||||
|
||||
print(f"Initializing AI Agent with base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}")
|
||||
logging.info(f"AI Agent Settings: base_url={self.base_url}, model={self.model_name}, timeout={self.timeout}s, max_retries={self.max_retries}, stream_chunk_timeout={self.stream_chunk_timeout}s")
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api,
|
||||
@ -35,6 +55,12 @@ class AI_Agent():
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
try:
|
||||
self.encoding = tiktoken.encoding_for_model(self.model_name)
|
||||
except KeyError:
|
||||
logging.warning(f"Encoding for model '{self.model_name}' not found. Using 'cl100k_base' encoding.")
|
||||
self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
||||
"""生成文本内容,并返回完整响应和token估计值"""
|
||||
logging.info(f"Generating text with model: {self.model_name}, temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
|
||||
@ -176,98 +202,118 @@ class AI_Agent():
|
||||
|
||||
# --- Streaming Methods ---
|
||||
def generate_text_stream(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
||||
"""生成文本内容,并以生成器方式 yield 文本块"""
|
||||
logging.info("Streaming Generation Started...")
|
||||
logging.debug(f"Streaming System Prompt (first 100 chars): {system_prompt[:100]}...")
|
||||
logging.debug(f"Streaming User Prompt (first 100 chars): {user_prompt[:100]}...")
|
||||
logging.info(f"Streaming Params: temp={temperature}, top_p={top_p}, presence_penalty={presence_penalty}")
|
||||
"""
|
||||
Generates text based on prompts using a streaming connection. Handles retries with exponential backoff.
|
||||
|
||||
retry_count = 0
|
||||
max_retry_wait = 10
|
||||
Args:
|
||||
system_prompt: The system prompt for the AI.
|
||||
user_prompt: The user prompt for the AI.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling parameter.
|
||||
|
||||
while retry_count <= self.max_retries:
|
||||
Yields:
|
||||
str: Chunks of the generated text.
|
||||
|
||||
Raises:
|
||||
Exception: If the API call fails after all retries.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
logging.info(f"Generating text stream with model: {self.model_name}")
|
||||
|
||||
retries = 0
|
||||
backoff_time = INITIAL_BACKOFF
|
||||
last_exception = None
|
||||
|
||||
while retries < self.max_retries:
|
||||
try:
|
||||
logging.info(f"Attempting API stream call (try {retry_count + 1}/{self.max_retries + 1})")
|
||||
response = self.client.chat.completions.create(
|
||||
logging.debug(f"Attempt {retries + 1}/{self.max_retries} to generate text stream.")
|
||||
stream = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}],
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
presence_penalty=presence_penalty,
|
||||
stream=True,
|
||||
max_tokens=8192,
|
||||
timeout=self.timeout,
|
||||
extra_body={"repetition_penalty": 1.05},
|
||||
timeout=self.timeout # Overall request timeout
|
||||
)
|
||||
|
||||
try:
|
||||
logging.info("Stream connected, receiving content...")
|
||||
yielded_something = False
|
||||
for chunk in response:
|
||||
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
|
||||
chunk_iterator = iter(stream)
|
||||
last_chunk_time = time.time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Check for timeout since last received chunk
|
||||
if time.time() - last_chunk_time > self.stream_chunk_timeout:
|
||||
raise Timeout(f"No chunk received for {self.stream_chunk_timeout} seconds.")
|
||||
|
||||
chunk = next(chunk_iterator)
|
||||
last_chunk_time = time.time() # Reset timer on successful chunk receipt
|
||||
|
||||
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
# logging.debug(f"Received chunk: {content}") # Potentially very verbose
|
||||
yield content
|
||||
yielded_something = True
|
||||
elif chunk.choices and chunk.choices[0].finish_reason == 'stop':
|
||||
logging.info("Stream finished.")
|
||||
return # Successful completion
|
||||
# Handle other finish reasons if needed, e.g., 'length'
|
||||
|
||||
if yielded_something:
|
||||
logging.info("Stream finished successfully.")
|
||||
else:
|
||||
logging.warning("Stream finished, but no content was yielded.")
|
||||
return
|
||||
except StopIteration:
|
||||
logging.info("Stream iterator exhausted.")
|
||||
return # End of stream normally
|
||||
except Timeout as e:
|
||||
logging.warning(f"Stream chunk timeout: {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
|
||||
last_exception = e
|
||||
break # Break inner loop to retry the stream creation
|
||||
except (APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
logging.warning(f"API error during streaming: {type(e).__name__} - {e}. Retrying if possible ({retries + 1}/{self.max_retries}).")
|
||||
last_exception = e
|
||||
break # Break inner loop to retry the stream creation
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error during streaming: {traceback.format_exc()}")
|
||||
# Decide if this unexpected error should be retried or raised immediately
|
||||
last_exception = e
|
||||
# Option 1: Raise immediately
|
||||
# raise e
|
||||
# Option 2: Treat as retryable (use with caution)
|
||||
break # Break inner loop to retry
|
||||
|
||||
except APIConnectionError as stream_err:
|
||||
logging.warning(f"Stream connection error occurred: {stream_err}")
|
||||
retry_count += 1
|
||||
if retry_count <= self.max_retries:
|
||||
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
||||
logging.warning(f"Retrying connection ({retry_count}/{self.max_retries}), waiting {wait_time:.2f}s...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
logging.error("Max retries reached after stream connection error.")
|
||||
yield f"[STREAM_ERROR: Max retries reached after connection error: {stream_err}]"
|
||||
return
|
||||
|
||||
except Exception as stream_err:
|
||||
logging.exception("Error occurred during stream processing:")
|
||||
yield f"[STREAM_ERROR: {stream_err}]"
|
||||
return
|
||||
|
||||
except (APITimeoutError, APIConnectionError, RateLimitError, APIStatusError) as e:
|
||||
logging.warning(f"API Error occurred: {e}")
|
||||
should_retry = False
|
||||
if isinstance(e, (APITimeoutError, APIConnectionError, RateLimitError)):
|
||||
should_retry = True
|
||||
elif isinstance(e, APIStatusError) and e.status_code >= 500:
|
||||
should_retry = True
|
||||
|
||||
if should_retry:
|
||||
retry_count += 1
|
||||
if retry_count <= self.max_retries:
|
||||
wait_time = min(2 ** retry_count + random.random(), max_retry_wait)
|
||||
logging.warning(f"Retrying API call ({retry_count}/{self.max_retries}) after error, waiting {wait_time:.2f}s...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
logging.error(f"Max retries ({self.max_retries}) reached for API errors. Aborting stream.")
|
||||
yield "[API_ERROR: Max retries reached]"
|
||||
return
|
||||
# If we broke from the inner loop due to an error that needs retry
|
||||
retries += 1
|
||||
if retries < self.max_retries:
|
||||
logging.info(f"Retrying stream in {backoff_time} seconds...")
|
||||
time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
logging.error(f"Non-retriable API error: {e}. Aborting stream.")
|
||||
yield f"[API_ERROR: Non-retriable status {e.status_code if isinstance(e, APIStatusError) else 'Unknown'}]"
|
||||
return
|
||||
except Exception as e:
|
||||
logging.exception("Non-retriable error occurred during API call setup:")
|
||||
yield f"[FATAL_ERROR: {e}]"
|
||||
return
|
||||
logging.error(f"Stream generation failed after {self.max_retries} retries.")
|
||||
raise last_exception or Exception("Stream generation failed after max retries.")
|
||||
|
||||
logging.error("Stream generation failed after exhausting all retries.")
|
||||
yield "[ERROR: Failed after all retries]"
|
||||
|
||||
except (Timeout, APITimeoutError, APIConnectionError, RateLimitError) as e:
|
||||
retries += 1
|
||||
last_exception = e
|
||||
logging.warning(f"Attempt {retries}/{self.max_retries} failed: {type(e).__name__} - {e}")
|
||||
if retries < self.max_retries:
|
||||
logging.info(f"Retrying in {backoff_time} seconds...")
|
||||
time.sleep(backoff_time + random.uniform(0, 1)) # Add jitter
|
||||
backoff_time = min(backoff_time * 2, MAX_BACKOFF)
|
||||
else:
|
||||
logging.error(f"API call failed after {self.max_retries} retries.")
|
||||
raise last_exception
|
||||
except Exception as e:
|
||||
# Catch unexpected errors during stream setup
|
||||
logging.error(f"Unexpected error setting up stream: {traceback.format_exc()}")
|
||||
raise e # Re-raise unexpected errors immediately
|
||||
|
||||
# Should not be reached if logic is correct, but as a safeguard:
|
||||
logging.error("Exited stream generation loop unexpectedly.")
|
||||
raise last_exception or Exception("Stream generation failed.")
|
||||
|
||||
|
||||
def work_stream(self, system_prompt, user_prompt, file_folder, temperature, top_p, presence_penalty):
|
||||
"""工作流程的流式版本:返回文本生成器"""
|
||||
"""完整的工作流程(流式):读取文件夹(如果提供),然后生成文本流"""
|
||||
logging.info(f"Starting 'work_stream' process. File folder: {file_folder}")
|
||||
if file_folder:
|
||||
logging.info(f"Reading context from folder: {file_folder}")
|
||||
|
||||
@ -48,9 +48,11 @@
|
||||
"content_presence_penalty": 1.5,
|
||||
"request_timeout": 30,
|
||||
"max_retries": 3,
|
||||
"stream_chunk_timeout": 60,
|
||||
"output_collage_subdir": "collage_img",
|
||||
"output_poster_subdir": "poster",
|
||||
"output_poster_filename": "poster.jpg",
|
||||
"poster_target_size": [900, 1200],
|
||||
"text_possibility": 0.3
|
||||
"text_possibility": 0.3,
|
||||
"description_filename": "description.txt"
|
||||
}
|
||||
@ -117,6 +117,7 @@ def main_test():
|
||||
ai_api_key = config.get("api_key")
|
||||
request_timeout = config.get("request_timeout", 30)
|
||||
max_retries = config.get("max_retries", 3)
|
||||
stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Read stream timeout
|
||||
if not all([ai_api_url, ai_model, ai_api_key]):
|
||||
raise ValueError("Missing required AI configuration (api_url, model, api_key)")
|
||||
logging.info("Initializing AI Agent for content generation test...")
|
||||
@ -125,7 +126,8 @@ def main_test():
|
||||
model_name=ai_model,
|
||||
api=ai_api_key,
|
||||
timeout=request_timeout,
|
||||
max_retries=max_retries
|
||||
max_retries=max_retries,
|
||||
stream_chunk_timeout=stream_chunk_timeout # Pass stream timeout
|
||||
)
|
||||
|
||||
total_topics = len(topics_list)
|
||||
|
||||
@ -66,6 +66,7 @@ def main():
|
||||
ai_api_key = config.get("api_key")
|
||||
request_timeout = config.get("request_timeout", 30)
|
||||
max_retries = config.get("max_retries", 3)
|
||||
stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Get stream chunk timeout
|
||||
|
||||
# Check for required AI params
|
||||
if not all([ai_api_url, ai_model, ai_api_key]):
|
||||
@ -80,7 +81,8 @@ def main():
|
||||
model=ai_model, # Use extracted var
|
||||
api_key=ai_api_key, # Use extracted var
|
||||
timeout=request_timeout,
|
||||
max_retries=max_retries
|
||||
max_retries=max_retries,
|
||||
stream_chunk_timeout=stream_chunk_timeout # Pass it here
|
||||
)
|
||||
|
||||
# Example call to work_stream
|
||||
|
||||
8
main.py
8
main.py
@ -91,14 +91,16 @@ def generate_content_and_posters_step(config, run_id, topics_list, output_handle
|
||||
ai_agent = None
|
||||
try:
|
||||
# --- Initialize AI Agent for Content Generation ---
|
||||
request_timeout = config.get("request_timeout", 30) # Get timeout from config
|
||||
max_retries = config.get("max_retries", 3) # Get max_retries from config
|
||||
request_timeout = config.get("request_timeout", 30) # Default 30 seconds
|
||||
max_retries = config.get("max_retries", 3) # Default 3 retries
|
||||
stream_chunk_timeout = config.get("stream_chunk_timeout", 60) # Default 60 seconds for stream chunk
|
||||
ai_agent = AI_Agent(
|
||||
config["api_url"],
|
||||
config["model"],
|
||||
config["api_key"],
|
||||
timeout=request_timeout,
|
||||
max_retries=max_retries
|
||||
max_retries=max_retries,
|
||||
stream_chunk_timeout=stream_chunk_timeout
|
||||
)
|
||||
logging.info("AI Agent for content generation initialized.")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user