commit 8bfc191771eb115f8042ff6e2254b4d61076a508 Author: root <374191531@qq.com> Date: Thu May 22 11:38:05 2025 +0800 初始提交 diff --git a/sort.json b/sort.json new file mode 100644 index 0000000..90f6a47 --- /dev/null +++ b/sort.json @@ -0,0 +1,20 @@ +{ + "游玩类": { + "景点": ["红色景区", "观光街区", "地标建筑", "自然景观", "人文古迹", "古镇/古村"], + "动植物园": ["动物园", "水族馆", "植物园"], + "旅游项目": ["水上体验", "城市/高空观光", "温泉", "滑雪场", "露营地"], + "泛主题乐园": ["游乐园", "影视基地", "水上乐园"], + "展馆/展览": ["博物馆", "展览馆", "科技馆", "美术馆", "艺术馆", "纪念馆"], + "公园广场": ["公园广场"], + "其他游玩": ["其他游玩"] + }, + "住宿类": { + "酒店宾馆": ["经济型酒店", "舒适性酒店", "高档型酒店", "豪华型酒店"], + "客栈民宿": ["客栈民宿"], + "其他住宿": ["其他住宿"] + }, + "旅行社类": { + "境外旅行社": ["境外旅行社"], + "境内旅行社": ["境内旅行社"] + } + } \ No newline at end of file diff --git a/sql_prompt.py b/sql_prompt.py new file mode 100644 index 0000000..f5ec201 --- /dev/null +++ b/sql_prompt.py @@ -0,0 +1,90 @@ +import json +import os +import pandas as pd +import sqlite3 + +# 读取1.json文件 +def read_json_file(file_path): + try: + with open(file_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + print(f"读取文件 {file_path} 时出错: {e}") + return {} + +# 主要处理函数 +def create_prompt_json(data, category_data): + + item = data + category = category_data + # 创建包含提示词和两个文件内容的JSON结构 + prompt_data = { + "instruction": "将基本信息根据分类体系进行分类判断", + "task_description": "你是一个精准的内容分类专家。请阅读提供的基本信息(item_info),并根据分类体系(category_system)为该项目确定最合适的分类。", + "input": { + "item_info": item, + "category_system": category + }, + "output_requirements": { + "fields": { + "primary_category": "一级分类名称", + "secondary_category": "二级分类名称(如果有)", + "tertiary_category": "三级分类名称(如果有)", + "confidence": "分类置信度(0-1)", + "reasoning": "简要说明分类理由" + } + } + } + + return prompt_data + +# 使用示例 +if __name__ == "__main__": + sort_json_path = "sort.json" # 分类信息文件 + category_data = read_json_file(sort_json_path) + + # 文件路径 + table_name = 'data' + conn = sqlite3.connect(f'{table_name}.db') # 替换为您的数据库文件名 + cursor = conn.cursor() + sort_name = '产品类型' + sort_value = ('住宿', '门票', '抢购') + num_limit = "500" + # SQL查询 - 随机选择500个指定产品类型的记录 + query = f""" + SELECT * FROM data + WHERE {sort_name} IN {sort_value} + ORDER BY RANDOM() + LIMIT {num_limit} + """ + cursor.execute(query) + result = cursor.fetchall() + results = [] + for i, row in enumerate(result): + prompt_data = create_prompt_json(row ,category_data) + #print(prompt_data) + result_entry = [] + result_entry.extend(row) # 将row的所有元素分别添加到result_entry + result_entry.append(i) + if i == 5: + break + + results.append(result_entry) + + df = pd.DataFrame(results) + excel_file = '111.xlsx' + # 获取数据库表的列名作为DataFrame的表头 + columns = [description[0] for description in cursor.description] + columns.append('index') # 添加index列 + df.columns = columns + + # 将第三列移到第一列 + cols = df.columns.tolist() + third_col = cols.pop(2) # 第三列索引为2 + cols.insert(0, third_col) + df = df[cols] + + df.to_excel(excel_file, index=False) + print(f"\n所有生成文本已保存到 {excel_file}") + print(results) + # 将结果保存为JSON文件 \ No newline at end of file diff --git a/sql_vllm.py b/sql_vllm.py new file mode 100644 index 0000000..da63342 --- /dev/null +++ b/sql_vllm.py @@ -0,0 +1,157 @@ +from openai import OpenAI +import os +import json +import random +import pandas as pd +import re +import sys +import sqlite3 + +import time + +from sql_prompt import create_prompt_json +from sql_prompt import read_json_file + +# 自动下载模型时,指定使用modelscope; 否则,会从HuggingFace下载 +os.environ['VLLM_USE_MODELSCOPE']='True' +if __name__ == "__main__": + + sys.stdout.reconfigure(encoding='utf-8') + + # 连接到数据库 + sort_json_path = "sort.json" # 分类信息文件 + category_data = read_json_file(sort_json_path) + + # 文件路径 + table_name = 'data' + conn = sqlite3.connect(f'{table_name}.db') # 替换为您的数据库文件名 + cursor = conn.cursor() + sort_name = '产品类型' + sort_value = ('住宿', '门票', '抢购') + num_limit = "500" + # SQL查询 - 随机选择500个指定产品类型的记录 + query = f""" + SELECT * FROM data + WHERE {sort_name} IN {sort_value} + ORDER BY RANDOM() + LIMIT {num_limit} + """ + cursor.execute(query) + result = cursor.fetchall() + + # 修改 OpenAI 的 API 密钥和 API 基础 URL 以使用 vLLM 的 API 服务器。 + openai_api_key = "EMPTY" + openai_api_base = "http://localhost:8000/v1" + client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + + #work_dir = './prompt' + + results = [] + + for i, row in enumerate(result): + # 记录开始时间 + loop_start_time = time.time() + + prompt = create_prompt_json(row ,category_data) # 您的prompt文件路径 + prompt_str = json.dumps(prompt, ensure_ascii=False) # 将字典转换为JSON字符串,保留中文字符 + + # 使用聊天接口 + messages = [ + {"role": "user", "content": prompt_str} + ] + + completion = client.chat.completions.create( + model="Qwen3-8B", + messages=messages + ) + + # 获取生成的文本 + for choice in completion.choices: + generated_text = choice.message.content + + # 确保文本使用UTF-8编码 + if isinstance(generated_text, bytes): + generated_text = generated_text.decode('utf-8') + elif isinstance(generated_text, str): + # 如果是字符串,但编码可能不是UTF-8,先转为bytes再解码 + try: + generated_text = generated_text.encode('latin-1').decode('utf-8') + except UnicodeError: + # 如果上面的转换失败,保持原样 + pass + + result_entry = [] + result_entry.extend(row) # 将row的所有元素分别添加到result_entry + + # 提取后的内容作为最终回答 + think_parts = generated_text.split("") + if len(think_parts) > 1: + final_answer = think_parts[1].strip() + # 继续处理的文本改为最终回答部分 + processing_text = final_answer + else: + # 如果没有标记,使用原始文本 + processing_text = generated_text + + # 处理提取出的文本内容 + try: + # JSON解析失败,使用正则表达式 + patterns = { + "primary_category": r'primary_category["\s]*[::]\s*["]*([^",\n}]+)', + "secondary_category": r'secondary_category["\s]*[::]\s*["]*([^",\n}]+)', + "tertiary_category": r'tertiary_category["\s]*[::]\s*["]*([^",\n}]+)', + "confidence": r'confidence["\s]*[::]\s*([0-9.]+)', + "reasoning": r'reasoning["\s]*[::]\s*["]*([^"}\n]+(?:\n[^"}\n]+)*)' + } + + # 按顺序提取各个字段的值并追加到result_entry列表 + for field, pattern in patterns.items(): + match = re.search(pattern, processing_text, re.IGNORECASE) + if match: + # 确保匹配结果使用UTF-8编码 + match_text = match.group(1).strip() + if isinstance(match_text, bytes): + match_text = match_text.decode('utf-8') + elif isinstance(match_text, str): + try: + match_text = match_text.encode('latin-1').decode('utf-8', errors='replace') + except UnicodeError: + pass + result_entry.append(match_text) + else: + result_entry.append("") + except Exception as e: + print(f"处理样本 {i} 时出错: {e}") + + # 添加处理后的结果 + results.append(result_entry) + + # 计算并打印本次循环的执行时间 + loop_end_time = time.time() + loop_duration = loop_end_time - loop_start_time + print(f"第 {i+1} 次模型推理时间: {loop_duration:.2f} 秒") + + df = pd.DataFrame(results) + excel_file = '500_Qwen3_8B_sort.xlsx' + # 获取数据库表的列名作为DataFrame的表头 + columns = [description[0] for description in cursor.description] + columns.append('primary_category') + columns.append('secondary_category') + columns.append('tertiary_category') + columns.append('confidence') + columns.append('reasoning') + # 添加index列 + df.columns = columns + + # 将第三列移到第一列 + cols = df.columns.tolist() + third_col = cols.pop(2) # 第三列索引为2 + cols.insert(0, third_col) + df = df[cols] + + + df.to_excel(excel_file, index=False) + print(f"\n所有生成文本已保存到 {excel_file}") \ No newline at end of file diff --git a/vllm_model.py b/vllm_model.py new file mode 100644 index 0000000..5904914 --- /dev/null +++ b/vllm_model.py @@ -0,0 +1,124 @@ +from vllm import LLM, SamplingParams +from transformers import AutoTokenizer +import os +import json +import random +import pandas as pd +import re + +# 读取您的prompt.json文件 +def load_from_json(json_path): + with open(json_path, 'r', encoding='utf-8') as f: + data = json.load(f) + return data + +# 从JSON文件构建提示词 +def build_prompt_from_json(prompt_data): + # 根据您的JSON结构提取和组装内容 + instruction = prompt_data.get("instruction", "") + task_description = prompt_data.get("task_description", "") + input_data = prompt_data.get("input", {}) + output_requirements = prompt_data.get("output_requirements", {}) + + # 组装完整提示词 + prompt = f"{instruction}\n\n{task_description}\n\n" + prompt += f"基本信息:\n{json.dumps(input_data.get('item_info', {}), ensure_ascii=False, indent=2)}\n\n" + prompt += f"分类体系:\n{json.dumps(input_data.get('category_system', {}), ensure_ascii=False, indent=2)}\n\n" + prompt += f"输出要求:\n{json.dumps(output_requirements, ensure_ascii=False, indent=2)}" + + return prompt + +def get_completion(prompts, model, tokenizer=None, temperature=0.6, top_p=0.95, top_k=20, min_p=0, max_tokens=4096, max_model_len=8192): + stop_token_ids = [151645, 151643] + # 创建采样参数。temperature 控制生成文本的多样性,top_p 控制核心采样的概率,top_k 通过限制候选词的数量来控制生成文本的质量和多样性, min_p 通过设置概率阈值来筛选候选词,从而在保证文本质量的同时增加多样性 + sampling_params = SamplingParams(temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, max_tokens=max_tokens, stop_token_ids=stop_token_ids) # max_tokens 用于限制模型在推理过程中生成的最大输出长度 + # 初始化 vLLM 推理引擎 + llm = LLM(model=model, tokenizer=tokenizer, max_model_len=max_model_len,trust_remote_code=True) # max_model_len 用于限制模型在推理过程中可以处理的最大输入和输出长度之和。 + outputs = llm.generate(prompts, sampling_params) + return outputs + +# 自动下载模型时,指定使用modelscope; 否则,会从HuggingFace下载 +os.environ['VLLM_USE_MODELSCOPE']='True' +if __name__ == "__main__": + random.seed(114514) + random_numbers = [random.randint(0, 61929) for _ in range(20)] + + # 初始化 vLLM 推理引擎 + model='/root/autodl-tmp/Qwen/Qwen3-8B' # 指定模型路径 + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) # 加载分词器 + + work_dir = './prompt' + data_dir = './data' + + results = [] + + for i in random_numbers: + + prompt_json_path = f"{work_dir}/{i}_prompt.json" # 您的prompt文件路径 + prompt_data = load_from_json(prompt_json_path) + prompt = build_prompt_from_json(prompt_data) + + data_json_path = f"{data_dir}/{i}.json" + data_json = load_from_json(data_json_path) + + messages = [ + {"role": "user", "content": prompt} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True # 是否开启思考模式,默认为 True + ) + + outputs = get_completion(text, model, tokenizer=None, temperature=0.6, top_p = 0.95, top_k=20, min_p=0) # 对于思考模式,官方建议使用以下参数:temperature = 0.6,TopP = 0.95,TopK = 20,MinP = 0。 + + for output in outputs: + generated_text = output.outputs[0].text + + result_entry = {} + + # 首先添加样本ID字段 + result_entry["样本ID"] = i + + # 添加原始数据字段 + if isinstance(data_json, dict): + for key, value in data_json.items(): + result_entry[key] = value + + # 提取后的内容作为最终回答 + think_parts = generated_text.split("") + if len(think_parts) > 1: + final_answer = think_parts[1].strip() + # 继续处理的文本改为最终回答部分 + processing_text = final_answer + else: + # 如果没有标记,使用原始文本 + processing_text = generated_text + + # 处理提取出的文本内容 + try: + # JSON解析失败,使用正则表达式 + patterns = { + "primary_category": r'primary_category["\s]*[::]\s*["]*([^",\n}]+)', + "secondary_category": r'secondary_category["\s]*[::]\s*["]*([^",\n}]+)', + "tertiary_category": r'tertiary_category["\s]*[::]\s*["]*([^",\n}]+)', + "confidence": r'confidence["\s]*[::]\s*([0-9.]+)', + "reasoning": r'reasoning["\s]*[::]\s*["]*([^"}\n]+(?:\n[^"}\n]+)*)' + } + + for field, pattern in patterns.items(): + match = re.search(pattern, processing_text, re.IGNORECASE) + if match: + result_entry[field] = match.group(1).strip() + except Exception as e: + print(f"处理样本 {i} 时出错: {e}") + + # 添加处理后的结果 + results.append(result_entry) + + # 创建DataFrame并保存到Excel + df = pd.DataFrame(results) + excel_file = '20_random_Qwen3_8B_responses.xlsx' + df.to_excel(excel_file, index=False) + print(f"\n所有生成文本已保存到 {excel_file}") \ No newline at end of file