361 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from dataclasses import dataclass, field
from typing import Optional, ClassVar, List, Tuple
import numpy as np
from numpy.typing import NDArray
from PIL import Image
import librosa
import librosa.util
import cv2
import os
from openai import OpenAI
import base64
# import ray # 注释掉Ray导入
# # 使用本地模式初始化Ray避免分布式通信问题
# ray.init(local_mode=True, ignore_reinit_error=True)
# 设置环境变量,禁用在线检查
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
@dataclass(frozen=True)
class VideoAsset:
name: str
num_frames: int = -1
_NAME_TO_FILE: ClassVar[dict[str, str]] = {
"baby_reading": "sample_demo_1.mp4",
}
@property
def filename(self) -> str:
return self._NAME_TO_FILE[self.name]
@property
def pil_images(self) -> list[Image.Image]:
video_path = download_video_asset(self.filename)
return video_to_pil_images_list(video_path, self.num_frames)
@property
def np_ndarrays(self) -> NDArray:
video_path = download_video_asset(self.filename)
return video_to_ndarrays(video_path, self.num_frames)
def get_audio(self, sampling_rate: Optional[float] = None) -> NDArray:
video_path = download_video_asset(self.filename)
return librosa.load(video_path, sr=sampling_rate)[0]
@dataclass(frozen=True)
class LocalVideoAsset:
local_path: str
name: str = "local_video"
num_frames: int = -1
@property
def filename(self) -> str:
return self.local_path
@property
def pil_images(self) -> list[Image.Image]:
return video_to_pil_images_list(self.filename, self.num_frames)
@property
def np_ndarrays(self) -> NDArray:
return video_to_ndarrays(self.filename, self.num_frames)
def get_audio(self, sampling_rate: Optional[float] = None) -> NDArray:
try:
if not os.path.exists(self.filename):
print(f"音频文件不存在: {self.filename}")
return np.zeros(1) # 返回空数组
return librosa.load(self.filename, sr=sampling_rate)[0]
except Exception as e:
print(f"加载音频时出错: {e}")
return np.zeros(1) # 出错时返回空数组
# 辅助函数实现
def download_video_asset(filename: str) -> str:
# 如果路径是绝对路径或相对路径,直接返回
if filename.startswith("/") or filename.startswith("./"):
return filename
# 否则执行下载逻辑(原实现)
return f"/path/to/downloaded/{filename}"
def video_to_pil_images_list(video_path: str, num_frames: int) -> list[Image.Image]:
"""将视频转换为PIL图像列表"""
if not os.path.exists(video_path):
print(f"视频文件不存在: {video_path}")
return []
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"无法打开视频: {video_path}")
return []
# 获取视频帧数
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames / fps if fps > 0 else 0
print(f"视频信息: 总帧数={total_frames}, FPS={fps:.2f}, 时长={duration:.2f}")
# 如果指定了帧数,设置采样间隔;否则读取所有帧
if num_frames > 0 and num_frames < total_frames:
frame_interval = total_frames / num_frames
print(f"将提取 {num_frames} 帧,采样间隔为每 {frame_interval:.2f}")
else:
frame_interval = 1
num_frames = total_frames
print(f"将提取所有 {total_frames}")
pil_images = []
frame_count = 0
success = True
last_progress = -1
while success and len(pil_images) < num_frames:
# 读取下一帧
success, frame = cap.read()
if not success:
break
# 按间隔采样帧
if frame_count % max(1, int(frame_interval)) == 0:
# OpenCV使用BGR转为RGB
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转为PIL图像
pil_image = Image.fromarray(rgb_frame)
pil_images.append(pil_image)
# 显示进度每10%显示一次)
progress = int(len(pil_images) / num_frames * 10)
if progress > last_progress:
print(f"提取进度: {len(pil_images)}/{num_frames} ({len(pil_images)/num_frames*100:.1f}%)")
last_progress = progress
frame_count += 1
cap.release()
print(f"从视频中共提取了 {len(pil_images)}")
return pil_images
def video_to_ndarrays(video_path: str, num_frames: int) -> NDArray:
"""将视频转换为NumPy数组"""
pil_images = video_to_pil_images_list(video_path, num_frames)
if not pil_images:
print(f"未能从视频中提取帧: {video_path}")
return np.zeros((1, 224, 224, 3))
# 将PIL图像列表转换为NumPy数组
arrays = []
for img in pil_images:
# 调整图像大小为统一尺寸
img_resized = img.resize((224, 224))
# 转换为NumPy数组
arr = np.array(img_resized)
arrays.append(arr)
# 堆叠为单个NumPy数组形状为[num_frames, height, width, channels]
stacked_array = np.stack(arrays, axis=0)
print(f"NumPy数组形状: {stacked_array.shape}")
return stacked_array
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def encode_audio(audio_path):
with open(audio_path, "rb") as audio_file:
return base64.b64encode(audio_file.read()).decode("utf-8")
# 使用示例
if __name__ == "__main__":
# 使用本地视频资源
# local_video = LocalVideoAsset(
# local_path="/root/autodl-tmp/hot_video_analyse/source/sample_demo_1.mp4",
# num_frames = -1 # 限制帧数以加快测试速度
# )
# print("本地资源:", local_video.filename)
# 获取PIL图像列表实际会调用download_video_asset和转换函数
# pil_images = local_video.pil_images
# print("PIL图像数量:", len(pil_images))
# print(pil_images[0])
# 获取NumPy数组
# print("\n=== 加载视频帧 ===")
# np_arrays = local_video.np_ndarrays
# print(f"视频数组形状: {np_arrays.shape}")
# scene_change_arrays = np.load('/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_arrays.npy')
# print(f"场景变化帧数组形状: {scene_change_arrays.shape}")
# 获取音频数据
# audio = local_video.get_audio(sampling_rate=16000)
# print("音频数据形状:", audio.shape)
# print(type(audio))
base64_audio = encode_audio("/root/autodl-tmp/hot_video_analyse/source/transcription/sample_demo_1_audio.wav")
client = OpenAI(
# 若没有配置环境变量请用阿里云百炼API Key将下行替换为api_key="sk-xxx",
api_key="EMPTY",
base_url="http://localhost:8000/v1",
)
base64_image_0 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000034_t1.13s_ssim0.5673.jpg")
base64_image_1 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000117_t3.90s_ssim0.1989.jpg")
base64_image_2 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000119_t3.97s_ssim0.2138.jpg")
base64_image_3 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000140_t4.67s_ssim0.5160.jpg")
base64_image_4 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000160_t5.33s_ssim0.4934.jpg")
base64_image_5 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000180_t6.00s_ssim0.3577.jpg")
base64_image_6 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000201_t6.70s_ssim0.3738.jpg")
base64_image_7 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000222_t7.40s_ssim0.6104.jpg")
base64_image_8 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000243_t8.10s_ssim0.5099.jpg")
base64_image_9 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000261_t8.70s_ssim0.4735.jpg")
base64_image_10 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000281_t9.37s_ssim0.2703.jpg")
base64_image_11 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000301_t10.03s_ssim0.2772.jpg")
base64_image_12 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000321_t10.70s_ssim0.3721.jpg")
base64_image_13 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000341_t11.37s_ssim0.3700.jpg")
base64_image_14 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000361_t12.03s_ssim0.3494.jpg")
base64_image_15 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000382_t12.73s_ssim0.3423.jpg")
base64_image_16 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000402_t13.40s_ssim0.3252.jpg")
base64_image_17 = encode_image("/root/autodl-tmp/hot_video_analyse/source/scene_change/scene_change_frames/scene_change_000421_t14.03s_ssim0.2086.jpg")
completion = client.chat.completions.create(
model="/root/autodl-tmp/llm", # 使用vLLM服务中实际注册的模型ID
messages=[
{
"role": "user",
"content": [
# 将视频帧作为多个图像处理
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_1}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_2}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_3}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_4}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_5}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_6}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_7}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_8}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_9}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_10}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_11}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_12}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_13}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_14}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_15}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_16}"}
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image_17}"}
},
# 音频输入
{
"type": "input_audio",
"input_audio": {
"data": base64_audio,
"format": "wav"
}
},
# 文本提示
{
"type": "text",
"text": "描述一下关于这个视频的浅层特征"
},
],
}
],
# 设置输出数据的模态,当前支持两种:["text","audio"]、["text"]
modalities=["text"],
#audio={"voice": "Cherry", "format": "wav"},
# stream 必须设置为 True否则会报错
stream=True,
stream_options={"include_usage": True},
)
#"形容一下这段音频。这些图像是一个视频的连续帧,请结合音频描述这个视频的具体过程和内容。并解释一下这个视频吸引人眼球的原因"
#"识别这段音频,并描述一下这段音频的内容。"
# 收集完整的响应内容
full_response = ""
usage_info = None
print("正在生成回复...")
for chunk in completion:
if chunk.choices:
delta = chunk.choices[0].delta
if delta.content:
# 拼接内容
full_response += delta.content
# 实时显示(可选)
print(delta.content, end='', flush=True)
else:
# 保存使用情况信息
usage_info = chunk.usage
# 输出完整的响应
print("\n" + "="*50)
print("完整回复:")
print("="*50)
print(full_response)
# 输出使用情况信息
if usage_info:
print("\n" + "="*50)
print("使用情况:")
print("="*50)
print(usage_info)