littlebird13 commited on
Commit
502c533
·
verified ·
1 Parent(s): 973f869

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +767 -0
app.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+
4
+ import torch
5
+
6
+ os.environ['VLLM_USE_V1'] = '0'
7
+ os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
8
+ from argparse import ArgumentParser
9
+
10
+ import gradio as gr
11
+ import gradio.processing_utils as processing_utils
12
+ import modelscope_studio.components.antd as antd
13
+ import modelscope_studio.components.base as ms
14
+ import numpy as np
15
+ import soundfile as sf
16
+ from gradio_client import utils as client_utils
17
+ from qwen_omni_utils import process_mm_info
18
+
19
+ import base64
20
+ import numpy as np
21
+ from scipy.io import wavfile # 使用 scipy 保存 wav 文件,更简单支持 int16
22
+
23
+ import soundfile as sf
24
+ from openai import OpenAI
25
+
26
+ import base64
27
+
28
+ import os
29
+ import oss2
30
+ import json
31
+ import time
32
+ import subprocess
33
+ import numpy as np
34
+
35
+ OSS_RETRY = 10
36
+ OSS_RETRY_DELAY = 3
37
+ WAV_BIT_RATE = 16
38
+ WAV_SAMPLE_RATE = os.environ.get("WAV_SAMPLE_RATE", 16000)
39
+
40
+ # OSS_CONFIG_PATH = "/mnt/workspace/feizi.wx/.oss_config.json"
41
+
42
+ endpoint = os.getenv("OSS_ENDPOINT")
43
+ region = os.getenv("OSS_REGION")
44
+ bucket_name = os.getenv("OSS_BUCKET_NAME")
45
+ API_KEY = os.environ['API_KEY']
46
+ OSS_ACCESS_KEY_ID = os.environ['OSS_ACCESS_KEY_ID']
47
+ OSS_ACCESS_KEY_SECRET = os.environ['OSS_ACCESS_KEY_SECRET']
48
+ OSS_CONFIG_PATH = {}
49
+
50
+ class OSSReader:
51
+ def __init__(self):
52
+ # 初始化OSS配置
53
+ self.bucket2object = {
54
+ bucket_name: oss2.Bucket(oss2.Auth(OSS_ACCESS_KEY_ID, OSS_ACCESS_KEY_SECRET), endpoint, bucket_name),
55
+ }
56
+ print(f"Loaded OSS config from: {OSS_CONFIG_PATH}\nSupported buckets: {list(self.bucket2object.keys())}")
57
+
58
+ def _parse_oss_path(self, oss_path):
59
+ """解析oss路径,返回bucket名称和实际路径"""
60
+ assert oss_path.startswith("oss://"), f"Invalid oss path {oss_path}"
61
+ bucket_name, object_key = oss_path.split("oss://")[-1].split("/", 1)
62
+ object_key = f"studio-temp/Qwen3-Omni-Demo/{object_key}"
63
+ return bucket_name, object_key
64
+
65
+ def _retry_operation(self, func, *args, retries=OSS_RETRY, delay=OSS_RETRY, **kwargs):
66
+ """通用的重试机制"""
67
+ for _ in range(retries):
68
+ try:
69
+ return func(*args, **kwargs)
70
+ except Exception as e:
71
+ print(f"Retry: {_} Error: {str(e)}")
72
+ if _ == retries - 1:
73
+ raise e
74
+ time.sleep(delay)
75
+
76
+ def get_public_url(self, oss_path):
77
+ bucket_name, object_key = self._parse_oss_path(oss_path)
78
+ url = self._retry_operation(self.bucket2object[bucket_name].sign_url, 'GET', object_key, 600,
79
+ slash_safe=True).replace('http://', 'https://')
80
+ return url.replace("-internal", '')
81
+
82
+ def file_exists(self, oss_path):
83
+ """判断文件是否存在"""
84
+ bucket_name, object_key = self._parse_oss_path(oss_path)
85
+ return self._retry_operation(self.bucket2object[bucket_name].object_exists, object_key)
86
+
87
+ def download_file(self, oss_path, local_path):
88
+ """下载OSS上的文件到本地"""
89
+ bucket_name, object_key = self._parse_oss_path(oss_path)
90
+ self._retry_operation(self.bucket2object[bucket_name].get_object_to_file, object_key, local_path)
91
+
92
+ def upload_file(self, local_path, oss_path, overwrite=True):
93
+ """上传本地文件到OSS"""
94
+ bucket_name, object_key = self._parse_oss_path(oss_path)
95
+ # 检查文件是否存在
96
+ if not os.path.exists(local_path):
97
+ raise FileNotFoundError(f"Local file {local_path} does not exist")
98
+ # 检查目标文件是否存在(当overwrite=False时)
99
+ if not overwrite and self.file_exists(oss_path):
100
+ print(f"File {oss_path} already exists, skip upload")
101
+ return False
102
+ # 执行上传操作
103
+ try:
104
+ self._retry_operation(
105
+ self.bucket2object[bucket_name].put_object_from_file,
106
+ object_key,
107
+ local_path
108
+ )
109
+ return True
110
+ except Exception as e:
111
+ print(f"Upload failed: {str(e)}")
112
+ return False
113
+
114
+ def upload_audio_from_array(self, data, sample_rate, oss_path, overwrite=True):
115
+ """将音频数据保存为WAV格式并上传到OSS"""
116
+ bucket_name, object_key = self._parse_oss_path(oss_path)
117
+
118
+ # 检查目标文件是否存在(当overwrite=False时)
119
+ if not overwrite and self.file_exists(oss_path):
120
+ print(f"File {oss_path} already exists, skip upload")
121
+ return False
122
+
123
+ try:
124
+ # 使用 BytesIO 在内存中生成 WAV 格式数据
125
+ import wave
126
+ from io import BytesIO
127
+
128
+ byte_io = BytesIO()
129
+ with wave.open(byte_io, 'wb') as wf:
130
+ wf.setnchannels(1) # 单声道
131
+ wf.setsampwidth(2) # 16-bit PCM
132
+ wf.setframerate(sample_rate) # 设置采样率
133
+ # 将 float32 数据转换为 int16 并写入 WAV
134
+ data_int16 = np.clip(data, -1, 1) * 32767
135
+ data_int16 = data_int16.astype(np.int16)
136
+ wf.writeframes(data_int16.tobytes())
137
+
138
+ # 上传到 OSS
139
+ self._retry_operation(
140
+ self.bucket2object[bucket_name].put_object,
141
+ object_key,
142
+ byte_io.getvalue()
143
+ )
144
+ return True
145
+ except Exception as e:
146
+ print(f"Upload failed: {str(e)}")
147
+ return False
148
+
149
+ def get_object(self, oss_path):
150
+ """读取OSS上的音频文件,返回音频数据和采样率"""
151
+ bucket_name, object_key = self._parse_oss_path(oss_path)
152
+ return self._retry_operation(self.bucket2object[bucket_name].get_object, object_key)
153
+
154
+ def read_text_file(self, oss_path):
155
+ """读取OSS上的文本文件"""
156
+ bucket_name, object_key = self._parse_oss_path(oss_path)
157
+ result = self._retry_operation(self.bucket2object[bucket_name].get_object, object_key)
158
+ return result.read().decode('utf-8')
159
+
160
+ def read_audio_file(self, oss_path):
161
+ """读取OSS上的音频文件,返回音频数据和采样率"""
162
+ bucket_name, object_key = self._parse_oss_path(oss_path)
163
+ result = self._retry_operation(self.bucket2object[bucket_name].get_object, object_key)
164
+ # ffmpeg 命令:从标准输入读取音频并输出PCM浮点数据
165
+ command = [
166
+ 'ffmpeg',
167
+ '-i', '-', # 输入来自管道
168
+ '-ar', str(WAV_SAMPLE_RATE), # 输出采样率
169
+ '-ac', '1', # 单声道
170
+ '-f', 'f32le', # 指定输出格式
171
+ '-' # 输出到管道
172
+ ]
173
+ # 启动ffmpeg子进程
174
+ process = subprocess.Popen(
175
+ command,
176
+ stdin=subprocess.PIPE,
177
+ stdout=subprocess.PIPE,
178
+ stderr=subprocess.PIPE
179
+ )
180
+ # 写入音频字节并获取输出
181
+ stdout_data, stderr_data = process.communicate(input=result.read())
182
+ if process.returncode != 0:
183
+ raise RuntimeError(f"FFmpeg error: {stderr_data.decode('utf-8')}")
184
+ # 将PCM数据转换为numpy数组
185
+ wav_data = np.frombuffer(stdout_data, dtype=np.float32)
186
+ return wav_data, WAV_SAMPLE_RATE
187
+
188
+ def get_wav_duration_from_bin(self, oss_path):
189
+ oss_bin_path = oss_path + ".ar16k.bin"
190
+ bucket_name, object_key = self._parse_oss_path(oss_bin_path)
191
+ metadata = self._retry_operation(self.bucket2object[bucket_name].get_object_meta, object_key)
192
+ duration = float(metadata.headers['Content-Length']) / (16000 * 2)
193
+ return duration
194
+
195
+ def read_wavdata_from_oss(self, oss_path, start=None, end=None, force_bin=False):
196
+ bucket_name, object_key = self._parse_oss_path(oss_path)
197
+ oss_bin_key = object_key + ".ar16k.bin"
198
+ if start is None or end is None:
199
+ if self.bucket2object[bucket_name].object_exists(oss_bin_key):
200
+ wav_data = self._retry_operation(self.bucket2object[bucket_name].get_object, oss_bin_key).read()
201
+ elif not force_bin:
202
+ wav_data, _ = self.read_audio_file(oss_path)
203
+ else:
204
+ raise ValueError(f"Cannot find bin file for {oss_path}")
205
+ else:
206
+ bytes_per_second = WAV_SAMPLE_RATE * (WAV_BIT_RATE // 8)
207
+ # 计算字节偏移量
208
+ start_offset = round(start * bytes_per_second)
209
+ end_offset = round(end * bytes_per_second)
210
+ if not (end_offset - start_offset) % 2:
211
+ end_offset -= 1
212
+ # 使用范围请求只获取指定字节范围的数据
213
+ wav_data = self._retry_operation(self.bucket2object[bucket_name].get_object,
214
+ oss_bin_key,
215
+ byte_range=(start_offset, end_offset),
216
+ headers={'x-oss-range-behavior': 'standard'}).read()
217
+ if not isinstance(wav_data, np.ndarray):
218
+ wav_data = np.frombuffer(wav_data, np.int16).flatten() / 32768.0
219
+ return wav_data.astype(np.float32)
220
+
221
+ def _list_files_by_suffix(self, oss_dir, suffix):
222
+ """递归搜索以某个后缀结尾的所有文件,返回所有文件的OSS路径列表"""
223
+ bucket_name, dir_key = self._parse_oss_path(oss_dir)
224
+ file_list = []
225
+
226
+ def _recursive_list(prefix):
227
+ for obj in oss2.ObjectIterator(self.bucket2object[bucket_name], prefix=prefix, delimiter='/'):
228
+ if obj.is_prefix(): # 如果是目录,递归搜索
229
+ _recursive_list(obj.key)
230
+ elif obj.key.endswith(suffix):
231
+ file_list.append(f"oss://{bucket_name}/{obj.key}")
232
+
233
+ _recursive_list(dir_key)
234
+ return file_list
235
+
236
+ def list_files_by_suffix(self, oss_dir, suffix):
237
+ return self._retry_operation(self._list_files_by_suffix, oss_dir, suffix)
238
+
239
+ def _list_files_by_prefix(self, oss_dir, file_prefix):
240
+ """递归搜索以某个后缀结尾的所有文件,返回所有文件的OSS路径列表"""
241
+ bucket_name, dir_key = self._parse_oss_path(oss_dir)
242
+ file_list = []
243
+
244
+ def _recursive_list(prefix):
245
+ for obj in oss2.ObjectIterator(self.bucket2object[bucket_name], prefix=prefix, delimiter='/'):
246
+ if obj.is_prefix(): # 如果是目录,递归搜索
247
+ _recursive_list(obj.key)
248
+ elif os.path.basename(obj.key).startswith(file_prefix):
249
+ file_list.append(f"oss://{bucket_name}/{obj.key}")
250
+
251
+ _recursive_list(dir_key)
252
+ return file_list
253
+
254
+ def list_files_by_prefix(self, oss_dir, file_prefix):
255
+ return self._retry_operation(self._list_files_by_prefix, oss_dir, file_prefix)
256
+
257
+
258
+ def encode_base64(base64_path):
259
+ with open(base64_path, "rb") as base64_file:
260
+ return base64.b64encode(base64_file.read()).decode("utf-8")
261
+
262
+
263
+ def _load_model_processor(args):
264
+ if args.cpu_only:
265
+ device_map = 'cpu'
266
+ else:
267
+ device_map = 'auto'
268
+
269
+ model = OpenAI(
270
+ # 若没有配置环境变量,请用阿里云百炼API Key将下行替换为:api_key="sk-xxx",
271
+ api_key=API_KEY,
272
+ base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
273
+ )
274
+
275
+ return model, None
276
+
277
+
278
+ oss_reader = OSSReader()
279
+
280
+
281
+ def _launch_demo(args, model, processor):
282
+ # Voice settings
283
+ VOICE_OPTIONS = {
284
+ "芊悦 Cherry": "Cherry",
285
+ "晨煦 Ethan": "Ethan",
286
+ "詹妮弗 Jennifer": "Jennifer",
287
+ "甜茶 Ryan": "Ryan",
288
+ "卡捷琳娜 Katerina": "Katerina",
289
+ "不吃鱼 Nofish": "Nofish",
290
+ "墨讲师 Elias": "Elias",
291
+ "南京-老李 Li": "Li",
292
+ "陕西-秦川 Marcus": "Marcus",
293
+ "闽南-阿杰 Roy": "Roy",
294
+ "天津-李彼得 Peter": "Peter",
295
+ "四川-程川 Eric": "Eric",
296
+ "粤语-阿强 Rocky": "Rocky",
297
+ "粤语-阿清 Kiki": "Kiki",
298
+ "四川-晴儿 Sunny": "Sunny",
299
+ "上海-阿珍 Jada": "Jada",
300
+ "北京-晓东 Dylan": "Dylan",
301
+ }
302
+ DEFAULT_VOICE = '芊悦 Cherry'
303
+
304
+ default_system_prompt = ''
305
+
306
+ language = args.ui_language
307
+
308
+ def get_text(text: str, cn_text: str):
309
+ if language == 'en':
310
+ return text
311
+ if language == 'zh':
312
+ return cn_text
313
+ return text
314
+
315
+ def to_mp4(path):
316
+ import subprocess
317
+ if path and path.endswith(".webm"):
318
+ mp4_path = path.replace(".webm", ".mp4")
319
+ subprocess.run([
320
+ "ffmpeg", "-y",
321
+ "-i", path,
322
+ "-c:v", "libx264", # 使用 H.264
323
+ "-preset", "ultrafast", # 最快速度!
324
+ "-tune", "fastdecode", # 优化快速解码(利于后续处理)
325
+ "-pix_fmt", "yuv420p", # 兼容性像素格式
326
+ "-c:a", "aac", # 音频编码
327
+ "-b:a", "128k", # 可选:限制音频比特率加速
328
+ "-threads", "0", # 使用所有线程
329
+ "-f", "mp4",
330
+ mp4_path
331
+ ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
332
+ return mp4_path
333
+ return path # 已经是 mp4 或 None
334
+
335
+ def format_history(history: list, system_prompt: str):
336
+ print(history)
337
+ messages = []
338
+ if system_prompt != "":
339
+ messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
340
+
341
+ current_user_content = []
342
+
343
+ for item in history:
344
+ role = item['role']
345
+ content = item['content']
346
+
347
+ if role != "user":
348
+ if current_user_content:
349
+ messages.append({"role": "user", "content": current_user_content})
350
+ current_user_content = []
351
+
352
+ if isinstance(content, str):
353
+ messages.append({
354
+ "role": role,
355
+ "content": [{"type": "text", "text": content}]
356
+ })
357
+ else:
358
+ pass
359
+ continue
360
+
361
+ if isinstance(content, str):
362
+ current_user_content.append({"type": "text", "text": content})
363
+ elif isinstance(content, (list, tuple)):
364
+ for file_path in content:
365
+ mime_type = client_utils.get_mimetype(file_path)
366
+ media_type = None
367
+
368
+ if mime_type.startswith("image"):
369
+ media_type = "image_url"
370
+ elif mime_type.startswith("video"):
371
+ media_type = "video_url"
372
+ file_path = to_mp4(file_path)
373
+ elif mime_type.startswith("audio"):
374
+ media_type = "input_audio"
375
+
376
+ if media_type:
377
+ # base64_media = encode_base64(file_path)
378
+ import uuid
379
+ request_id = str(uuid.uuid4())
380
+ oss_path = f"oss://{bucket_name}//studio-temp/Qwen3-Omni-Demo/" + request_id
381
+ oss_reader.upload_file(file_path, oss_path)
382
+ media_url = oss_reader.get_public_url(oss_path)
383
+ if media_type == "input_audio":
384
+ current_user_content.append({
385
+ "type": "input_audio",
386
+ "input_audio": {
387
+ "data": media_url,
388
+ "format": "wav",
389
+ },
390
+ })
391
+ if media_type == "image_url":
392
+ current_user_content.append({
393
+ "type": "image_url",
394
+ "image_url": {
395
+ "url": media_url
396
+ },
397
+ })
398
+ if media_type == "video_url":
399
+ current_user_content.append({
400
+ "type": "video_url",
401
+ "video_url": {
402
+ "url": media_url
403
+ },
404
+ })
405
+ else:
406
+ current_user_content.append({
407
+ "type": "text",
408
+ "text": file_path
409
+ })
410
+
411
+ if current_user_content:
412
+ media_items = []
413
+ text_items = []
414
+
415
+ for item in current_user_content:
416
+ if item["type"] == "text":
417
+ text_items.append(item)
418
+ else:
419
+ media_items.append(item)
420
+
421
+ messages.append({
422
+ "role": "user",
423
+ "content": media_items + text_items
424
+ })
425
+
426
+ return messages
427
+
428
+ def predict(messages, voice_choice=DEFAULT_VOICE, temperature=0.7, top_p=0.8, top_k=20, return_audio=False,
429
+ enable_thinking=False):
430
+ # print('predict history: ', messages)
431
+ if enable_thinking:
432
+ return_audio=False
433
+ if return_audio:
434
+ completion = model.chat.completions.create(
435
+ model="qwen3-omni-flash",
436
+ messages=messages,
437
+ modalities=["text", "audio"],
438
+ audio={"voice": VOICE_OPTIONS[voice_choice], "format": "wav"},
439
+ extra_body={'enable_thinking': False, "top_k": top_k},
440
+ stream_options={"include_usage": True},
441
+ stream=True,
442
+ temperature=temperature,
443
+ top_p=top_p,
444
+ )
445
+ else:
446
+ completion = model.chat.completions.create(
447
+ model="qwen3-omni-flash",
448
+ messages=messages,
449
+ modalities=["text"],
450
+ extra_body={'enable_thinking': enable_thinking, "top_k": top_k},
451
+ stream_options={"include_usage": True},
452
+ stream=True,
453
+ temperature=temperature,
454
+ top_p=top_p,
455
+ )
456
+ audio_string = ""
457
+ output_text = ""
458
+ reasoning_content = "<think>\n\n" # 完整思考过程
459
+ answer_content = "" # 完整回复
460
+ is_answering = False # 是否进入回复阶段
461
+ print(return_audio, enable_thinking)
462
+ for chunk in completion:
463
+ if chunk.choices:
464
+ if hasattr(chunk.choices[0].delta, "audio"):
465
+ try:
466
+ audio_string += chunk.choices[0].delta.audio["data"]
467
+ except Exception as e:
468
+ output_text += chunk.choices[0].delta.audio["transcript"]
469
+ yield {"type": "text", "data": output_text}
470
+ else:
471
+ delta = chunk.choices[0].delta
472
+ if enable_thinking:
473
+ if hasattr(delta, "reasoning_content") and delta.reasoning_content is not None:
474
+ if not is_answering:
475
+ print(delta.reasoning_content, end="", flush=True)
476
+ reasoning_content += delta.reasoning_content
477
+ yield {"type": "text", "data": reasoning_content}
478
+ if hasattr(delta, "content") and delta.content:
479
+ if not is_answering:
480
+ reasoning_content += "\n\n</think>\n\n"
481
+ is_answering = True
482
+ answer_content += delta.content
483
+ yield {"type": "text", "data": reasoning_content + answer_content}
484
+ else:
485
+ if hasattr(delta, "content") and delta.content:
486
+ output_text += chunk.choices[0].delta.content
487
+ yield {"type": "text", "data": output_text}
488
+ else:
489
+ print(chunk.usage)
490
+
491
+ wav_bytes = base64.b64decode(audio_string)
492
+ audio_np = np.frombuffer(wav_bytes, dtype=np.int16)
493
+
494
+ if audio_string != "":
495
+ wav_io = io.BytesIO()
496
+ sf.write(wav_io, audio_np, samplerate=24000, format="WAV")
497
+ wav_io.seek(0)
498
+ wav_bytes = wav_io.getvalue()
499
+ audio_path = processing_utils.save_bytes_to_cache(
500
+ wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE)
501
+ yield {"type": "audio", "data": audio_path}
502
+
503
+ def media_predict(audio, video, history, system_prompt, voice_choice, temperature, top_p, top_k, return_audio=False,
504
+ enable_thinking=False):
505
+ # First yield
506
+ yield (
507
+ None, # microphone
508
+ None, # webcam
509
+ history, # media_chatbot
510
+ gr.update(visible=False), # submit_btn
511
+ gr.update(visible=True), # stop_btn
512
+ )
513
+
514
+ files = [audio, video]
515
+
516
+ for f in files:
517
+ if f:
518
+ history.append({"role": "user", "content": (f,)})
519
+
520
+ yield (
521
+ None, # microphone
522
+ None, # webcam
523
+ history, # media_chatbot
524
+ gr.update(visible=True), # submit_btn
525
+ gr.update(visible=False), # stop_btn
526
+ )
527
+
528
+ formatted_history = format_history(history=history,
529
+ system_prompt=system_prompt, )
530
+
531
+ history.append({"role": "assistant", "content": ""})
532
+
533
+ for chunk in predict(formatted_history, voice_choice, temperature, top_p, top_k, return_audio, enable_thinking):
534
+ print('chunk', chunk)
535
+ if chunk["type"] == "text":
536
+ history[-1]["content"] = chunk["data"]
537
+ yield (
538
+ None, # microphone
539
+ None, # webcam
540
+ history, # media_chatbot
541
+ gr.update(visible=False), # submit_btn
542
+ gr.update(visible=True), # stop_btn
543
+ )
544
+ if chunk["type"] == "audio":
545
+ history.append({
546
+ "role": "assistant",
547
+ "content": gr.Audio(chunk["data"])
548
+ })
549
+
550
+ # Final yield
551
+ yield (
552
+ None, # microphone
553
+ None, # webcam
554
+ history, # media_chatbot
555
+ gr.update(visible=True), # submit_btn
556
+ gr.update(visible=False), # stop_btn
557
+ )
558
+
559
+ def chat_predict(text, audio, image, video, history, system_prompt, voice_choice, temperature, top_p, top_k,
560
+ return_audio=False, enable_thinking=False):
561
+
562
+ # Process audio input
563
+ if audio:
564
+ history.append({"role": "user", "content": (audio,)})
565
+
566
+ # Process text input
567
+ if text:
568
+ history.append({"role": "user", "content": text})
569
+
570
+ # Process image input
571
+ if image:
572
+ history.append({"role": "user", "content": (image,)})
573
+
574
+ # Process video input
575
+ if video:
576
+ history.append({"role": "user", "content": (video,)})
577
+
578
+ formatted_history = format_history(history=history,
579
+ system_prompt=system_prompt)
580
+
581
+ yield None, None, None, None, history
582
+
583
+ history.append({"role": "assistant", "content": ""})
584
+ for chunk in predict(formatted_history, voice_choice, temperature, top_p, top_k, return_audio, enable_thinking):
585
+ print('chat_predict chunk', chunk)
586
+
587
+ if chunk["type"] == "text":
588
+ history[-1]["content"] = chunk["data"]
589
+ yield gr.skip(), gr.skip(), gr.skip(), gr.skip(
590
+ ), history
591
+ if chunk["type"] == "audio":
592
+ history.append({
593
+ "role": "assistant",
594
+ "content": gr.Audio(chunk["data"])
595
+ })
596
+ yield gr.skip(), gr.skip(), gr.skip(), gr.skip(), history
597
+
598
+ # --- CORRECTED UI LAYOUT ---
599
+ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]),
600
+ css=".gradio-container {max-width: none !important;}") as demo:
601
+ gr.Markdown("# Qwen3-Omni Demo")
602
+ gr.Markdown(
603
+ "**Instructions**: Interact with the model through text, audio, images, or video. Use the tabs to switch between Online and Offline chat modes.")
604
+ gr.Markdown(
605
+ "**使用说明**:1️⃣ 点击音频录制按钮,或摄像头-录制按钮 2️⃣ 输入音频或者视频 3️⃣ 点击提交并等待模型的回答")
606
+
607
+ with gr.Row(equal_height=False):
608
+ with gr.Column(scale=1):
609
+ gr.Markdown("### ⚙️ Parameters (参数)")
610
+ system_prompt_textbox = gr.Textbox(label="System Prompt", value=default_system_prompt, lines=4,
611
+ max_lines=8)
612
+ voice_choice = gr.Dropdown(label="Voice Choice", choices=VOICE_OPTIONS, value=DEFAULT_VOICE,
613
+ visible=True)
614
+ return_audio = gr.Checkbox(
615
+ label="Return Audio (返回语音)",
616
+ value=True,
617
+ interactive=True,
618
+ elem_classes="checkbox-large"
619
+ )
620
+ enable_thinking = gr.Checkbox(
621
+ label="Enable Thinking (启用思维链)",
622
+ value=False,
623
+ interactive=True,
624
+ elem_classes="checkbox-large"
625
+ )
626
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.6, step=0.1)
627
+ top_p = gr.Slider(label="Top P", minimum=0.05, maximum=1.0, value=0.95, step=0.05)
628
+ top_k = gr.Slider(label="Top K", minimum=1, maximum=100, value=20, step=1)
629
+
630
+ with gr.Column(scale=3):
631
+ with gr.Tabs():
632
+ with gr.TabItem("Online"):
633
+ with gr.Row():
634
+ with gr.Column(scale=1):
635
+ gr.Markdown("### Audio-Video Input (音视频输入)")
636
+ microphone = gr.Audio(sources=['microphone'], type="filepath",
637
+ label="Record Audio (录制音频)")
638
+ webcam = gr.Video(sources=['webcam', "upload"],
639
+ label="Record/Upload Video (录制/上传视频)",
640
+ elem_classes="media-upload")
641
+ with gr.Row():
642
+ submit_btn_online = gr.Button("Submit (提交)", variant="primary", scale=2)
643
+ stop_btn_online = gr.Button("Stop (停止)", visible=False, scale=1)
644
+ clear_btn_online = gr.Button("Clear History (清除历史)")
645
+ with gr.Column(scale=2):
646
+ # FIX: Re-added type="messages"
647
+ media_chatbot = gr.Chatbot(label="Chat History (对话历史)", type="messages", height=650,
648
+ layout="panel", bubble_full_width=False,
649
+ allow_tags=["think"], render=False)
650
+ media_chatbot.render()
651
+
652
+ def clear_history_online():
653
+ return [], None, None
654
+
655
+ submit_event_online = submit_btn_online.click(
656
+ fn=media_predict,
657
+ inputs=[microphone, webcam, media_chatbot, system_prompt_textbox, voice_choice, temperature,
658
+ top_p, top_k, return_audio, enable_thinking],
659
+ outputs=[microphone, webcam, media_chatbot, submit_btn_online, stop_btn_online]
660
+ )
661
+ stop_btn_online.click(fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
662
+ outputs=[submit_btn_online, stop_btn_online],
663
+ cancels=[submit_event_online], queue=False)
664
+ clear_btn_online.click(fn=clear_history_online, outputs=[media_chatbot, microphone, webcam])
665
+
666
+ with gr.TabItem("Offline"):
667
+ # FIX: Re-added type="messages"
668
+ chatbot = gr.Chatbot(label="Chat History (对话历史)", type="messages", height=550,
669
+ layout="panel", bubble_full_width=False, allow_tags=["think"],
670
+ render=False)
671
+ chatbot.render()
672
+
673
+ with gr.Accordion("📎 Click to upload multimodal files (点击上传多模态文件)", open=False):
674
+ with gr.Row():
675
+ audio_input = gr.Audio(sources=["upload", 'microphone'], type="filepath", label="Audio",
676
+ elem_classes="media-upload")
677
+ image_input = gr.Image(sources=["upload", 'webcam'], type="filepath", label="Image",
678
+ elem_classes="media-upload")
679
+ video_input = gr.Video(sources=["upload", 'webcam'], label="Video",
680
+ elem_classes="media-upload")
681
+
682
+ with gr.Row():
683
+ text_input = gr.Textbox(show_label=False,
684
+ placeholder="Enter text or upload files and press Submit... (输入文本或者上传文件并点击提交)",
685
+ scale=7)
686
+ submit_btn_offline = gr.Button("Submit (提交)", variant="primary", scale=1)
687
+ stop_btn_offline = gr.Button("Stop (停止)", visible=False, scale=1)
688
+ clear_btn_offline = gr.Button("Clear (清空) ", scale=1)
689
+
690
+ def clear_history_offline():
691
+ return [], None, None, None, None
692
+
693
+ submit_event_offline = gr.on(
694
+ triggers=[submit_btn_offline.click, text_input.submit],
695
+ fn=chat_predict,
696
+ inputs=[text_input, audio_input, image_input, video_input, chatbot, system_prompt_textbox,
697
+ voice_choice, temperature, top_p, top_k, return_audio, enable_thinking],
698
+ outputs=[text_input, audio_input, image_input, video_input, chatbot]
699
+ )
700
+ stop_btn_offline.click(fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
701
+ outputs=[submit_btn_offline, stop_btn_offline],
702
+ cancels=[submit_event_offline], queue=False)
703
+ clear_btn_offline.click(fn=clear_history_offline,
704
+ outputs=[chatbot, text_input, audio_input, image_input, video_input])
705
+
706
+ gr.HTML("""
707
+ <style>
708
+ .media-upload { min-height: 160px; border: 2px dashed #ccc; border-radius: 8px; display: flex; align-items: center; justify-content: center; }
709
+ .media-upload:hover { border-color: #666; }
710
+ </style>
711
+ """)
712
+
713
+ demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100,
714
+ ssr_mode=False,
715
+ share=args.share,
716
+ inbrowser=args.inbrowser,
717
+ # ssl_certfile="examples/offline_inference/qwen3_omni_moe/cert.pem",
718
+ # ssl_keyfile="examples/offline_inference/qwen3_omni_moe/key.pem",
719
+ # ssl_verify=False,
720
+ server_port=args.server_port,
721
+ server_name=args.server_name, )
722
+
723
+
724
+ DEFAULT_CKPT_PATH = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
725
+
726
+
727
+ def _get_args():
728
+ parser = ArgumentParser()
729
+
730
+ parser.add_argument('-c',
731
+ '--checkpoint-path',
732
+ type=str,
733
+ default=DEFAULT_CKPT_PATH,
734
+ help='Checkpoint name or path, default to %(default)r')
735
+ parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
736
+
737
+ parser.add_argument('--flash-attn2',
738
+ action='store_true',
739
+ default=False,
740
+ help='Enable flash_attention_2 when loading the model.')
741
+ parser.add_argument('--use-transformers',
742
+ action='store_true',
743
+ default=False,
744
+ help='Use transformers for inference.')
745
+ parser.add_argument('--share',
746
+ action='store_true',
747
+ default=False,
748
+ help='Create a publicly shareable link for the interface.')
749
+ parser.add_argument('--inbrowser',
750
+ action='store_true',
751
+ default=False,
752
+ help='Automatically launch the interface in a new tab on the default browser.')
753
+ parser.add_argument('--server-port', type=int, default=8905, help='Demo server port.')
754
+ parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Demo server name.')
755
+ parser.add_argument('--ui-language', type=str, choices=['en', 'zh'], default='zh',
756
+ help='Display language for the UI.')
757
+
758
+ args = parser.parse_args()
759
+ return args
760
+
761
+
762
+ if __name__ == "__main__":
763
+ args = _get_args()
764
+ model, processor = _load_model_processor(args)
765
+ _launch_demo(args, model, processor)
766
+
767
+