Qwen3_8B_sort/sql_vllm.py
2025-05-22 11:38:05 +08:00

157 lines
5.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from openai import OpenAI
import os
import json
import random
import pandas as pd
import re
import sys
import sqlite3
import time
from sql_prompt import create_prompt_json
from sql_prompt import read_json_file
# 自动下载模型时指定使用modelscope; 否则会从HuggingFace下载
os.environ['VLLM_USE_MODELSCOPE']='True'
if __name__ == "__main__":
sys.stdout.reconfigure(encoding='utf-8')
# 连接到数据库
sort_json_path = "sort.json" # 分类信息文件
category_data = read_json_file(sort_json_path)
# 文件路径
table_name = 'data'
conn = sqlite3.connect(f'{table_name}.db') # 替换为您的数据库文件名
cursor = conn.cursor()
sort_name = '产品类型'
sort_value = ('住宿', '门票', '抢购')
num_limit = "500"
# SQL查询 - 随机选择500个指定产品类型的记录
query = f"""
SELECT * FROM data
WHERE {sort_name} IN {sort_value}
ORDER BY RANDOM()
LIMIT {num_limit}
"""
cursor.execute(query)
result = cursor.fetchall()
# 修改 OpenAI 的 API 密钥和 API 基础 URL 以使用 vLLM 的 API 服务器。
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
#work_dir = './prompt'
results = []
for i, row in enumerate(result):
# 记录开始时间
loop_start_time = time.time()
prompt = create_prompt_json(row ,category_data) # 您的prompt文件路径
prompt_str = json.dumps(prompt, ensure_ascii=False) # 将字典转换为JSON字符串保留中文字符
# 使用聊天接口
messages = [
{"role": "user", "content": prompt_str}
]
completion = client.chat.completions.create(
model="Qwen3-8B",
messages=messages
)
# 获取生成的文本
for choice in completion.choices:
generated_text = choice.message.content
# 确保文本使用UTF-8编码
if isinstance(generated_text, bytes):
generated_text = generated_text.decode('utf-8')
elif isinstance(generated_text, str):
# 如果是字符串但编码可能不是UTF-8先转为bytes再解码
try:
generated_text = generated_text.encode('latin-1').decode('utf-8')
except UnicodeError:
# 如果上面的转换失败,保持原样
pass
result_entry = []
result_entry.extend(row) # 将row的所有元素分别添加到result_entry
# 提取</think>后的内容作为最终回答
think_parts = generated_text.split("</think>")
if len(think_parts) > 1:
final_answer = think_parts[1].strip()
# 继续处理的文本改为最终回答部分
processing_text = final_answer
else:
# 如果没有</think>标记,使用原始文本
processing_text = generated_text
# 处理提取出的文本内容
try:
# JSON解析失败使用正则表达式
patterns = {
"primary_category": r'primary_category["\s]*[:]\s*["]*([^",\n}]+)',
"secondary_category": r'secondary_category["\s]*[:]\s*["]*([^",\n}]+)',
"tertiary_category": r'tertiary_category["\s]*[:]\s*["]*([^",\n}]+)',
"confidence": r'confidence["\s]*[:]\s*([0-9.]+)',
"reasoning": r'reasoning["\s]*[:]\s*["]*([^"}\n]+(?:\n[^"}\n]+)*)'
}
# 按顺序提取各个字段的值并追加到result_entry列表
for field, pattern in patterns.items():
match = re.search(pattern, processing_text, re.IGNORECASE)
if match:
# 确保匹配结果使用UTF-8编码
match_text = match.group(1).strip()
if isinstance(match_text, bytes):
match_text = match_text.decode('utf-8')
elif isinstance(match_text, str):
try:
match_text = match_text.encode('latin-1').decode('utf-8', errors='replace')
except UnicodeError:
pass
result_entry.append(match_text)
else:
result_entry.append("")
except Exception as e:
print(f"处理样本 {i} 时出错: {e}")
# 添加处理后的结果
results.append(result_entry)
# 计算并打印本次循环的执行时间
loop_end_time = time.time()
loop_duration = loop_end_time - loop_start_time
print(f"{i+1} 次模型推理时间: {loop_duration:.2f}")
df = pd.DataFrame(results)
excel_file = '500_Qwen3_8B_sort.xlsx'
# 获取数据库表的列名作为DataFrame的表头
columns = [description[0] for description in cursor.description]
columns.append('primary_category')
columns.append('secondary_category')
columns.append('tertiary_category')
columns.append('confidence')
columns.append('reasoning')
# 添加index列
df.columns = columns
# 将第三列移到第一列
cols = df.columns.tolist()
third_col = cols.pop(2) # 第三列索引为2
cols.insert(0, third_col)
df = df[cols]
df.to_excel(excel_file, index=False)
print(f"\n所有生成文本已保存到 {excel_file}")