littlebird13 commited on
Commit
ac25892
·
verified ·
1 Parent(s): d0e09ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +499 -264
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import io
2
  import os
3
 
 
 
4
  os.environ['VLLM_USE_V1'] = '0'
5
  os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
6
  from argparse import ArgumentParser
@@ -45,22 +47,33 @@ OSS_ACCESS_KEY_ID = os.environ['OSS_ACCESS_KEY_ID']
45
  OSS_ACCESS_KEY_SECRET = os.environ['OSS_ACCESS_KEY_SECRET']
46
  OSS_CONFIG_PATH = {}
47
 
 
48
  class OSSReader:
 
49
  def __init__(self):
50
  # 初始化OSS配置
51
  self.bucket2object = {
52
- bucket_name: oss2.Bucket(oss2.Auth(OSS_ACCESS_KEY_ID, OSS_ACCESS_KEY_SECRET), endpoint, bucket_name),
 
 
53
  }
54
- print(f"Loaded OSS config from: {OSS_CONFIG_PATH}\nSupported buckets: {list(self.bucket2object.keys())}")
55
-
 
 
56
  def _parse_oss_path(self, oss_path):
57
  """解析oss路径,返回bucket名称和实际路径"""
58
  assert oss_path.startswith("oss://"), f"Invalid oss path {oss_path}"
59
  bucket_name, object_key = oss_path.split("oss://")[-1].split("/", 1)
60
  object_key = f"studio-temp/Qwen3-Omni-Demo/{object_key}"
61
  return bucket_name, object_key
62
-
63
- def _retry_operation(self, func, *args, retries=OSS_RETRY, delay=OSS_RETRY, **kwargs):
 
 
 
 
 
64
  """通用的重试机制"""
65
  for _ in range(retries):
66
  try:
@@ -70,23 +83,30 @@ class OSSReader:
70
  if _ == retries - 1:
71
  raise e
72
  time.sleep(delay)
73
-
74
  def get_public_url(self, oss_path):
75
  bucket_name, object_key = self._parse_oss_path(oss_path)
76
- url = self._retry_operation(self.bucket2object[bucket_name].sign_url, 'GET', object_key, 600,
77
- slash_safe=True).replace('http://', 'https://')
 
 
 
 
78
  return url.replace("-internal", '')
79
-
80
  def file_exists(self, oss_path):
81
  """判断文件是否存在"""
82
  bucket_name, object_key = self._parse_oss_path(oss_path)
83
- return self._retry_operation(self.bucket2object[bucket_name].object_exists, object_key)
84
-
 
85
  def download_file(self, oss_path, local_path):
86
  """下载OSS上的文件到本地"""
87
  bucket_name, object_key = self._parse_oss_path(oss_path)
88
- self._retry_operation(self.bucket2object[bucket_name].get_object_to_file, object_key, local_path)
89
-
 
 
90
  def upload_file(self, local_path, oss_path, overwrite=True):
91
  """上传本地文件到OSS"""
92
  bucket_name, object_key = self._parse_oss_path(oss_path)
@@ -101,28 +121,30 @@ class OSSReader:
101
  try:
102
  self._retry_operation(
103
  self.bucket2object[bucket_name].put_object_from_file,
104
- object_key,
105
- local_path
106
- )
107
  return True
108
  except Exception as e:
109
  print(f"Upload failed: {str(e)}")
110
  return False
111
-
112
- def upload_audio_from_array(self, data, sample_rate, oss_path, overwrite=True):
 
 
 
 
113
  """将音频数据保存为WAV格式并上传到OSS"""
114
  bucket_name, object_key = self._parse_oss_path(oss_path)
115
-
116
  # 检查目标文件是否存在(当overwrite=False时)
117
  if not overwrite and self.file_exists(oss_path):
118
  print(f"File {oss_path} already exists, skip upload")
119
  return False
120
-
121
  try:
122
  # 使用 BytesIO 在内存中生成 WAV 格式数据
123
  import wave
124
  from io import BytesIO
125
-
126
  byte_io = BytesIO()
127
  with wave.open(byte_io, 'wb') as wf:
128
  wf.setnchannels(1) # 单声道
@@ -132,49 +154,51 @@ class OSSReader:
132
  data_int16 = np.clip(data, -1, 1) * 32767
133
  data_int16 = data_int16.astype(np.int16)
134
  wf.writeframes(data_int16.tobytes())
135
-
136
  # 上传到 OSS
137
- self._retry_operation(
138
- self.bucket2object[bucket_name].put_object,
139
- object_key,
140
- byte_io.getvalue()
141
- )
142
  return True
143
  except Exception as e:
144
  print(f"Upload failed: {str(e)}")
145
  return False
146
-
147
  def get_object(self, oss_path):
148
  """读取OSS上的音频文件,返回音频数据和采样率"""
149
  bucket_name, object_key = self._parse_oss_path(oss_path)
150
- return self._retry_operation(self.bucket2object[bucket_name].get_object, object_key)
151
-
 
152
  def read_text_file(self, oss_path):
153
  """读取OSS上的文本文件"""
154
  bucket_name, object_key = self._parse_oss_path(oss_path)
155
- result = self._retry_operation(self.bucket2object[bucket_name].get_object, object_key)
 
156
  return result.read().decode('utf-8')
157
-
158
  def read_audio_file(self, oss_path):
159
  """读取OSS上的音频文件,返回音频数据和采样率"""
160
  bucket_name, object_key = self._parse_oss_path(oss_path)
161
- result = self._retry_operation(self.bucket2object[bucket_name].get_object, object_key)
 
162
  # ffmpeg 命令:从标准输入读取音频并输出PCM浮点数据
163
  command = [
164
  'ffmpeg',
165
- '-i', '-', # 输入来自管道
166
- '-ar', str(WAV_SAMPLE_RATE), # 输出采样率
167
- '-ac', '1', # 单声道
168
- '-f', 'f32le', # 指定输出格式
 
 
 
 
169
  '-' # 输出到管道
170
  ]
171
  # 启动ffmpeg子进程
172
- process = subprocess.Popen(
173
- command,
174
- stdin=subprocess.PIPE,
175
- stdout=subprocess.PIPE,
176
- stderr=subprocess.PIPE
177
- )
178
  # 写入音频字节并获取输出
179
  stdout_data, stderr_data = process.communicate(input=result.read())
180
  if process.returncode != 0:
@@ -182,20 +206,27 @@ class OSSReader:
182
  # 将PCM数据转换为numpy数组
183
  wav_data = np.frombuffer(stdout_data, dtype=np.float32)
184
  return wav_data, WAV_SAMPLE_RATE
185
-
186
  def get_wav_duration_from_bin(self, oss_path):
187
  oss_bin_path = oss_path + ".ar16k.bin"
188
  bucket_name, object_key = self._parse_oss_path(oss_bin_path)
189
- metadata = self._retry_operation(self.bucket2object[bucket_name].get_object_meta, object_key)
 
190
  duration = float(metadata.headers['Content-Length']) / (16000 * 2)
191
  return duration
192
-
193
- def read_wavdata_from_oss(self, oss_path, start=None, end=None, force_bin=False):
 
 
 
 
194
  bucket_name, object_key = self._parse_oss_path(oss_path)
195
  oss_bin_key = object_key + ".ar16k.bin"
196
  if start is None or end is None:
197
  if self.bucket2object[bucket_name].object_exists(oss_bin_key):
198
- wav_data = self._retry_operation(self.bucket2object[bucket_name].get_object, oss_bin_key).read()
 
 
199
  elif not force_bin:
200
  wav_data, _ = self.read_audio_file(oss_path)
201
  else:
@@ -208,49 +239,58 @@ class OSSReader:
208
  if not (end_offset - start_offset) % 2:
209
  end_offset -= 1
210
  # 使用范围请求只获取指定字节范围的数据
211
- wav_data = self._retry_operation(self.bucket2object[bucket_name].get_object,
212
- oss_bin_key,
213
- byte_range=(start_offset, end_offset),
214
- headers={'x-oss-range-behavior': 'standard'}).read()
 
 
 
215
  if not isinstance(wav_data, np.ndarray):
216
  wav_data = np.frombuffer(wav_data, np.int16).flatten() / 32768.0
217
  return wav_data.astype(np.float32)
218
-
219
  def _list_files_by_suffix(self, oss_dir, suffix):
220
  """递归搜索以某个后缀结尾的所有文件,返回所有文件的OSS路径列表"""
221
  bucket_name, dir_key = self._parse_oss_path(oss_dir)
222
  file_list = []
223
-
224
  def _recursive_list(prefix):
225
- for obj in oss2.ObjectIterator(self.bucket2object[bucket_name], prefix=prefix, delimiter='/'):
 
 
226
  if obj.is_prefix(): # 如果是目录,递归搜索
227
  _recursive_list(obj.key)
228
  elif obj.key.endswith(suffix):
229
  file_list.append(f"oss://{bucket_name}/{obj.key}")
230
-
231
  _recursive_list(dir_key)
232
  return file_list
233
-
234
  def list_files_by_suffix(self, oss_dir, suffix):
235
- return self._retry_operation(self._list_files_by_suffix, oss_dir, suffix)
236
-
 
237
  def _list_files_by_prefix(self, oss_dir, file_prefix):
238
  """递归搜索以某个后缀结尾的所有文件,返回所有文件的OSS路径列表"""
239
  bucket_name, dir_key = self._parse_oss_path(oss_dir)
240
  file_list = []
241
-
242
  def _recursive_list(prefix):
243
- for obj in oss2.ObjectIterator(self.bucket2object[bucket_name], prefix=prefix, delimiter='/'):
 
 
244
  if obj.is_prefix(): # 如果是目录,递归搜索
245
  _recursive_list(obj.key)
246
  elif os.path.basename(obj.key).startswith(file_prefix):
247
  file_list.append(f"oss://{bucket_name}/{obj.key}")
248
-
249
  _recursive_list(dir_key)
250
  return file_list
251
-
252
  def list_files_by_prefix(self, oss_dir, file_prefix):
253
- return self._retry_operation(self._list_files_by_prefix, oss_dir, file_prefix)
 
254
 
255
 
256
  def encode_base64(base64_path):
@@ -263,13 +303,13 @@ def _load_model_processor(args):
263
  device_map = 'cpu'
264
  else:
265
  device_map = 'auto'
266
-
267
  model = OpenAI(
268
  # 若没有配置环境变量,请用阿里云百炼API Key将下行替换为:api_key="sk-xxx",
269
  api_key=API_KEY,
270
  base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
271
  )
272
-
273
  return model, None
274
 
275
 
@@ -279,90 +319,150 @@ oss_reader = OSSReader()
279
  def _launch_demo(args, model, processor):
280
  # Voice settings
281
  VOICE_OPTIONS = {
282
- "芊悦 Cherry": "Cherry",
283
- "晨煦 Ethan": "Ethan",
284
- "詹妮弗 Jennifer": "Jennifer",
285
- "甜茶 Ryan": "Ryan",
286
- "卡捷琳娜 Katerina": "Katerina",
287
- "不吃鱼 Nofish": "Nofish",
288
- "墨讲师 Elias": "Elias",
289
- "南京-老李 Li": "Li",
290
- "陕西-秦川 Marcus": "Marcus",
291
- "闽南-阿杰 Roy": "Roy",
292
- "天津-李彼得 Peter": "Peter",
293
- "四川-程川 Eric": "Eric",
294
- "粤语-阿强 Rocky": "Rocky",
295
- "粤语-阿清 Kiki": "Kiki",
296
- "四川-晴儿 Sunny": "Sunny",
297
- "上海-阿珍 Jada": "Jada",
298
- "北京-晓东 Dylan": "Dylan",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  }
300
  DEFAULT_VOICE = '芊悦 Cherry'
301
-
302
  default_system_prompt = ''
303
-
304
  language = args.ui_language
305
-
306
  def get_text(text: str, cn_text: str):
307
  if language == 'en':
308
  return text
309
  if language == 'zh':
310
  return cn_text
311
  return text
312
-
313
  def to_mp4(path):
314
  import subprocess
315
  if path and path.endswith(".webm"):
316
  mp4_path = path.replace(".webm", ".mp4")
317
- subprocess.run([
318
- "ffmpeg", "-y",
319
- "-i", path,
320
- "-c:v", "libx264", # 使用 H.264
321
- "-preset", "ultrafast", # 最快速度!
322
- "-tune", "fastdecode", # 优化快速解码(利于后续处理)
323
- "-pix_fmt", "yuv420p", # 兼容性像素格式
324
- "-c:a", "aac", # 音频编码
325
- "-b:a", "128k", # 可选:限制音频比特率加速
326
- "-threads", "0", # 使用所有线程
327
- "-f", "mp4",
328
- mp4_path
329
- ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  return mp4_path
331
  return path # 已经是 mp4 或 None
332
-
333
  def format_history(history: list, system_prompt: str):
334
  print(history)
335
  messages = []
336
  if system_prompt != "":
337
- messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
338
-
 
 
 
 
 
 
 
339
  current_user_content = []
340
-
341
  for item in history:
342
  role = item['role']
343
  content = item['content']
344
-
345
  if role != "user":
346
  if current_user_content:
347
- messages.append({"role": "user", "content": current_user_content})
 
 
 
348
  current_user_content = []
349
-
350
  if isinstance(content, str):
351
  messages.append({
352
- "role": role,
353
- "content": [{"type": "text", "text": content}]
 
 
 
 
354
  })
355
  else:
356
  pass
357
  continue
358
-
359
  if isinstance(content, str):
360
  current_user_content.append({"type": "text", "text": content})
361
  elif isinstance(content, (list, tuple)):
362
  for file_path in content:
363
  mime_type = client_utils.get_mimetype(file_path)
364
  media_type = None
365
-
366
  if mime_type.startswith("image"):
367
  media_type = "image_url"
368
  elif mime_type.startswith("video"):
@@ -370,7 +470,7 @@ def _launch_demo(args, model, processor):
370
  file_path = to_mp4(file_path)
371
  elif mime_type.startswith("audio"):
372
  media_type = "input_audio"
373
-
374
  if media_type:
375
  # base64_media = encode_base64(file_path)
376
  import uuid
@@ -405,36 +505,47 @@ def _launch_demo(args, model, processor):
405
  "type": "text",
406
  "text": file_path
407
  })
408
-
409
  if current_user_content:
410
  media_items = []
411
  text_items = []
412
-
413
  for item in current_user_content:
414
  if item["type"] == "text":
415
  text_items.append(item)
416
  else:
417
  media_items.append(item)
418
-
419
  messages.append({
420
  "role": "user",
421
  "content": media_items + text_items
422
  })
423
-
424
  return messages
425
-
426
- def predict(messages, voice_choice=DEFAULT_VOICE, temperature=0.7, top_p=0.8, top_k=20, return_audio=False,
 
 
 
 
 
427
  enable_thinking=False):
428
  # print('predict history: ', messages)
429
  if enable_thinking:
430
- return_audio=False
431
  if return_audio:
432
  completion = model.chat.completions.create(
433
- model="qwen3-omni-flash",
434
  messages=messages,
435
  modalities=["text", "audio"],
436
- audio={"voice": VOICE_OPTIONS[voice_choice], "format": "wav"},
437
- extra_body={'enable_thinking': False, "top_k": top_k},
 
 
 
 
 
 
438
  stream_options={"include_usage": True},
439
  stream=True,
440
  temperature=temperature,
@@ -442,10 +553,13 @@ def _launch_demo(args, model, processor):
442
  )
443
  else:
444
  completion = model.chat.completions.create(
445
- model="qwen3-omni-flash",
446
  messages=messages,
447
  modalities=["text"],
448
- extra_body={'enable_thinking': enable_thinking, "top_k": top_k},
 
 
 
449
  stream_options={"include_usage": True},
450
  stream=True,
451
  temperature=temperature,
@@ -463,14 +577,18 @@ def _launch_demo(args, model, processor):
463
  try:
464
  audio_string += chunk.choices[0].delta.audio["data"]
465
  except Exception as e:
466
- output_text += chunk.choices[0].delta.audio["transcript"]
 
467
  yield {"type": "text", "data": output_text}
468
  else:
469
  delta = chunk.choices[0].delta
470
  if enable_thinking:
471
- if hasattr(delta, "reasoning_content") and delta.reasoning_content is not None:
 
472
  if not is_answering:
473
- print(delta.reasoning_content, end="", flush=True)
 
 
474
  reasoning_content += delta.reasoning_content
475
  yield {"type": "text", "data": reasoning_content}
476
  if hasattr(delta, "content") and delta.content:
@@ -478,17 +596,20 @@ def _launch_demo(args, model, processor):
478
  reasoning_content += "\n\n</think>\n\n"
479
  is_answering = True
480
  answer_content += delta.content
481
- yield {"type": "text", "data": reasoning_content + answer_content}
 
 
 
482
  else:
483
  if hasattr(delta, "content") and delta.content:
484
  output_text += chunk.choices[0].delta.content
485
  yield {"type": "text", "data": output_text}
486
  else:
487
  print(chunk.usage)
488
-
489
  wav_bytes = base64.b64decode(audio_string)
490
  audio_np = np.frombuffer(wav_bytes, dtype=np.int16)
491
-
492
  if audio_string != "":
493
  wav_io = io.BytesIO()
494
  sf.write(wav_io, audio_np, samplerate=24000, format="WAV")
@@ -497,8 +618,16 @@ def _launch_demo(args, model, processor):
497
  audio_path = processing_utils.save_bytes_to_cache(
498
  wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE)
499
  yield {"type": "audio", "data": audio_path}
500
-
501
- def media_predict(audio, video, history, system_prompt, voice_choice, temperature, top_p, top_k, return_audio=False,
 
 
 
 
 
 
 
 
502
  enable_thinking=False):
503
  # First yield
504
  yield (
@@ -508,13 +637,13 @@ def _launch_demo(args, model, processor):
508
  gr.update(visible=False), # submit_btn
509
  gr.update(visible=True), # stop_btn
510
  )
511
-
512
  files = [audio, video]
513
-
514
  for f in files:
515
  if f:
516
- history.append({"role": "user", "content": (f,)})
517
-
518
  yield (
519
  None, # microphone
520
  None, # webcam
@@ -522,13 +651,16 @@ def _launch_demo(args, model, processor):
522
  gr.update(visible=True), # submit_btn
523
  gr.update(visible=False), # stop_btn
524
  )
525
-
526
- formatted_history = format_history(history=history,
527
- system_prompt=system_prompt, )
528
-
 
 
529
  history.append({"role": "assistant", "content": ""})
530
-
531
- for chunk in predict(formatted_history, voice_choice, temperature, top_p, top_k, return_audio, enable_thinking):
 
532
  print('chunk', chunk)
533
  if chunk["type"] == "text":
534
  history[-1]["content"] = chunk["data"]
@@ -544,7 +676,7 @@ def _launch_demo(args, model, processor):
544
  "role": "assistant",
545
  "content": gr.Audio(chunk["data"])
546
  })
547
-
548
  # Final yield
549
  yield (
550
  None, # microphone
@@ -553,170 +685,259 @@ def _launch_demo(args, model, processor):
553
  gr.update(visible=True), # submit_btn
554
  gr.update(visible=False), # stop_btn
555
  )
556
-
557
- def chat_predict(text, audio, image, video, history, system_prompt, voice_choice, temperature, top_p, top_k,
558
- return_audio=False, enable_thinking=False):
559
-
 
 
 
 
 
 
 
 
 
 
560
  # Process audio input
561
  if audio:
562
- history.append({"role": "user", "content": (audio,)})
563
-
564
  # Process text input
565
  if text:
566
  history.append({"role": "user", "content": text})
567
-
568
  # Process image input
569
  if image:
570
- history.append({"role": "user", "content": (image,)})
571
-
572
  # Process video input
573
  if video:
574
- history.append({"role": "user", "content": (video,)})
575
-
576
  formatted_history = format_history(history=history,
577
  system_prompt=system_prompt)
578
-
579
  yield None, None, None, None, history
580
-
581
  history.append({"role": "assistant", "content": ""})
582
- for chunk in predict(formatted_history, voice_choice, temperature, top_p, top_k, return_audio, enable_thinking):
 
583
  print('chat_predict chunk', chunk)
584
-
585
  if chunk["type"] == "text":
586
  history[-1]["content"] = chunk["data"]
587
- yield gr.skip(), gr.skip(), gr.skip(), gr.skip(
588
- ), history
589
  if chunk["type"] == "audio":
590
  history.append({
591
  "role": "assistant",
592
  "content": gr.Audio(chunk["data"])
593
  })
594
  yield gr.skip(), gr.skip(), gr.skip(), gr.skip(), history
595
-
596
  # --- CORRECTED UI LAYOUT ---
597
- with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]),
598
- css=".gradio-container {max-width: none !important;}") as demo:
 
 
 
599
  gr.Markdown("# Qwen3-Omni Demo")
600
  gr.Markdown(
601
- "**Instructions**: Interact with the model through text, audio, images, or video. Use the tabs to switch between Online and Offline chat modes.")
 
602
  gr.Markdown(
603
  "**使用说明**:1️⃣ 点击音频录制按钮,或摄像头-录制按钮 2️⃣ 输入音频或者视频 3️⃣ 点击提交并等待模型的回答")
604
-
605
  with gr.Row(equal_height=False):
606
  with gr.Column(scale=1):
607
  gr.Markdown("### ⚙️ Parameters (参数)")
608
- system_prompt_textbox = gr.Textbox(label="System Prompt", value=default_system_prompt, lines=4,
 
 
609
  max_lines=8)
610
- voice_choice = gr.Dropdown(label="Voice Choice", choices=VOICE_OPTIONS, value=DEFAULT_VOICE,
 
 
611
  visible=True)
612
- return_audio = gr.Checkbox(
613
- label="Return Audio (返回语音)",
614
- value=True,
615
- interactive=True,
616
- elem_classes="checkbox-large"
617
- )
618
- enable_thinking = gr.Checkbox(
619
- label="Enable Thinking (启用思维链)",
620
- value=False,
621
- interactive=True,
622
- elem_classes="checkbox-large"
623
- )
624
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.6, step=0.1)
625
- top_p = gr.Slider(label="Top P", minimum=0.05, maximum=1.0, value=0.95, step=0.05)
626
- top_k = gr.Slider(label="Top K", minimum=1, maximum=100, value=20, step=1)
627
-
 
 
 
 
 
 
 
 
628
  with gr.Column(scale=3):
629
  with gr.Tabs():
630
  with gr.TabItem("Online"):
631
  with gr.Row():
632
  with gr.Column(scale=1):
633
  gr.Markdown("### Audio-Video Input (音视频输入)")
634
- microphone = gr.Audio(sources=['microphone'], type="filepath",
635
- label="Record Audio (录制音频)")
636
- webcam = gr.Video(sources=['webcam', "upload"],
637
- label="Record/Upload Video (录制/上传视频)",
638
- elem_classes="media-upload")
 
 
 
639
  with gr.Row():
640
- submit_btn_online = gr.Button("Submit (提交)", variant="primary", scale=2)
641
- stop_btn_online = gr.Button("Stop (停止)", visible=False, scale=1)
642
- clear_btn_online = gr.Button("Clear History (清除历史)")
 
 
 
 
 
 
643
  with gr.Column(scale=2):
644
  # FIX: Re-added type="messages"
645
- media_chatbot = gr.Chatbot(label="Chat History (对话历史)", type="messages", height=650,
646
- layout="panel", bubble_full_width=False,
647
- allow_tags=["think"], render=False)
 
 
 
 
 
648
  media_chatbot.render()
649
-
650
  def clear_history_online():
651
  return [], None, None
652
-
653
  submit_event_online = submit_btn_online.click(
654
  fn=media_predict,
655
- inputs=[microphone, webcam, media_chatbot, system_prompt_textbox, voice_choice, temperature,
656
- top_p, top_k, return_audio, enable_thinking],
657
- outputs=[microphone, webcam, media_chatbot, submit_btn_online, stop_btn_online]
658
- )
659
- stop_btn_online.click(fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
660
- outputs=[submit_btn_online, stop_btn_online],
661
- cancels=[submit_event_online], queue=False)
662
- clear_btn_online.click(fn=clear_history_online, outputs=[media_chatbot, microphone, webcam])
663
-
 
 
 
 
 
 
 
 
 
 
 
664
  with gr.TabItem("Offline"):
665
  # FIX: Re-added type="messages"
666
- chatbot = gr.Chatbot(label="Chat History (对话历史)", type="messages", height=550,
667
- layout="panel", bubble_full_width=False, allow_tags=["think"],
 
 
 
 
668
  render=False)
669
  chatbot.render()
670
-
671
- with gr.Accordion("📎 Click to upload multimodal files (点击上传多模态文件)", open=False):
 
 
672
  with gr.Row():
673
- audio_input = gr.Audio(sources=["upload", 'microphone'], type="filepath", label="Audio",
674
- elem_classes="media-upload")
675
- image_input = gr.Image(sources=["upload", 'webcam'], type="filepath", label="Image",
676
- elem_classes="media-upload")
677
- video_input = gr.Video(sources=["upload", 'webcam'], label="Video",
678
- elem_classes="media-upload")
679
-
 
 
 
 
 
 
 
 
680
  with gr.Row():
681
- text_input = gr.Textbox(show_label=False,
682
- placeholder="Enter text or upload files and press Submit... (输入文本或者上传文件并点击提交)",
683
- scale=7)
684
- submit_btn_offline = gr.Button("Submit (提交)", variant="primary", scale=1)
685
- stop_btn_offline = gr.Button("Stop (停止)", visible=False, scale=1)
686
- clear_btn_offline = gr.Button("Clear (清空) ", scale=1)
687
-
 
 
 
 
 
 
 
688
  def clear_history_offline():
689
  return [], None, None, None, None
690
-
691
  submit_event_offline = gr.on(
692
- triggers=[submit_btn_offline.click, text_input.submit],
 
 
693
  fn=chat_predict,
694
- inputs=[text_input, audio_input, image_input, video_input, chatbot, system_prompt_textbox,
695
- voice_choice, temperature, top_p, top_k, return_audio, enable_thinking],
696
- outputs=[text_input, audio_input, image_input, video_input, chatbot]
697
- )
698
- stop_btn_offline.click(fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
699
- outputs=[submit_btn_offline, stop_btn_offline],
700
- cancels=[submit_event_offline], queue=False)
 
 
 
 
 
 
 
 
 
701
  clear_btn_offline.click(fn=clear_history_offline,
702
- outputs=[chatbot, text_input, audio_input, image_input, video_input])
703
-
 
 
 
 
704
  gr.HTML("""
705
  <style>
706
  .media-upload { min-height: 160px; border: 2px dashed #ccc; border-radius: 8px; display: flex; align-items: center; justify-content: center; }
707
  .media-upload:hover { border-color: #666; }
708
  </style>
709
  """)
710
-
711
- demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100,
712
- ssr_mode=False,
713
- share=args.share,
714
- inbrowser=args.inbrowser,
715
- # ssl_certfile="examples/offline_inference/qwen3_omni_moe/cert.pem",
716
- # ssl_keyfile="examples/offline_inference/qwen3_omni_moe/key.pem",
717
- # ssl_verify=False,
718
- server_port=args.server_port,
719
- server_name=args.server_name, )
 
 
720
 
721
 
722
  DEFAULT_CKPT_PATH = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
@@ -724,35 +945,51 @@ DEFAULT_CKPT_PATH = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
724
 
725
  def _get_args():
726
  parser = ArgumentParser()
727
-
728
  parser.add_argument('-c',
729
  '--checkpoint-path',
730
  type=str,
731
  default=DEFAULT_CKPT_PATH,
732
  help='Checkpoint name or path, default to %(default)r')
733
- parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
734
-
735
- parser.add_argument('--flash-attn2',
736
  action='store_true',
737
- default=False,
738
- help='Enable flash_attention_2 when loading the model.')
 
 
 
 
 
739
  parser.add_argument('--use-transformers',
740
  action='store_true',
741
  default=False,
742
  help='Use transformers for inference.')
743
- parser.add_argument('--share',
744
- action='store_true',
745
- default=False,
746
- help='Create a publicly shareable link for the interface.')
747
- parser.add_argument('--inbrowser',
748
- action='store_true',
749
- default=False,
750
- help='Automatically launch the interface in a new tab on the default browser.')
751
- parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
752
- parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Demo server name.')
753
- parser.add_argument('--ui-language', type=str, choices=['en', 'zh'], default='zh',
 
 
 
 
 
 
 
 
 
 
 
 
 
754
  help='Display language for the UI.')
755
-
756
  args = parser.parse_args()
757
  return args
758
 
@@ -761,5 +998,3 @@ if __name__ == "__main__":
761
  args = _get_args()
762
  model, processor = _load_model_processor(args)
763
  _launch_demo(args, model, processor)
764
-
765
-
 
1
  import io
2
  import os
3
 
4
+ import torch
5
+
6
  os.environ['VLLM_USE_V1'] = '0'
7
  os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
8
  from argparse import ArgumentParser
 
47
  OSS_ACCESS_KEY_SECRET = os.environ['OSS_ACCESS_KEY_SECRET']
48
  OSS_CONFIG_PATH = {}
49
 
50
+
51
  class OSSReader:
52
+
53
  def __init__(self):
54
  # 初始化OSS配置
55
  self.bucket2object = {
56
+ bucket_name:
57
+ oss2.Bucket(oss2.Auth(OSS_ACCESS_KEY_ID, OSS_ACCESS_KEY_SECRET),
58
+ endpoint, bucket_name),
59
  }
60
+ print(
61
+ f"Loaded OSS config from: {OSS_CONFIG_PATH}\nSupported buckets: {list(self.bucket2object.keys())}"
62
+ )
63
+
64
  def _parse_oss_path(self, oss_path):
65
  """解析oss路径,返回bucket名称和实际路径"""
66
  assert oss_path.startswith("oss://"), f"Invalid oss path {oss_path}"
67
  bucket_name, object_key = oss_path.split("oss://")[-1].split("/", 1)
68
  object_key = f"studio-temp/Qwen3-Omni-Demo/{object_key}"
69
  return bucket_name, object_key
70
+
71
+ def _retry_operation(self,
72
+ func,
73
+ *args,
74
+ retries=OSS_RETRY,
75
+ delay=OSS_RETRY,
76
+ **kwargs):
77
  """通用的重试机制"""
78
  for _ in range(retries):
79
  try:
 
83
  if _ == retries - 1:
84
  raise e
85
  time.sleep(delay)
86
+
87
  def get_public_url(self, oss_path):
88
  bucket_name, object_key = self._parse_oss_path(oss_path)
89
+ url = self._retry_operation(self.bucket2object[bucket_name].sign_url,
90
+ 'GET',
91
+ object_key,
92
+ 600,
93
+ slash_safe=True).replace(
94
+ 'http://', 'https://')
95
  return url.replace("-internal", '')
96
+
97
  def file_exists(self, oss_path):
98
  """判断文件是否存在"""
99
  bucket_name, object_key = self._parse_oss_path(oss_path)
100
+ return self._retry_operation(
101
+ self.bucket2object[bucket_name].object_exists, object_key)
102
+
103
  def download_file(self, oss_path, local_path):
104
  """下载OSS上的文件到本地"""
105
  bucket_name, object_key = self._parse_oss_path(oss_path)
106
+ self._retry_operation(
107
+ self.bucket2object[bucket_name].get_object_to_file, object_key,
108
+ local_path)
109
+
110
  def upload_file(self, local_path, oss_path, overwrite=True):
111
  """上传本地文件到OSS"""
112
  bucket_name, object_key = self._parse_oss_path(oss_path)
 
121
  try:
122
  self._retry_operation(
123
  self.bucket2object[bucket_name].put_object_from_file,
124
+ object_key, local_path)
 
 
125
  return True
126
  except Exception as e:
127
  print(f"Upload failed: {str(e)}")
128
  return False
129
+
130
+ def upload_audio_from_array(self,
131
+ data,
132
+ sample_rate,
133
+ oss_path,
134
+ overwrite=True):
135
  """将音频数据保存为WAV格式并上传到OSS"""
136
  bucket_name, object_key = self._parse_oss_path(oss_path)
137
+
138
  # 检查目标文件是否存在(当overwrite=False时)
139
  if not overwrite and self.file_exists(oss_path):
140
  print(f"File {oss_path} already exists, skip upload")
141
  return False
142
+
143
  try:
144
  # 使用 BytesIO 在内存中生成 WAV 格式数据
145
  import wave
146
  from io import BytesIO
147
+
148
  byte_io = BytesIO()
149
  with wave.open(byte_io, 'wb') as wf:
150
  wf.setnchannels(1) # 单声道
 
154
  data_int16 = np.clip(data, -1, 1) * 32767
155
  data_int16 = data_int16.astype(np.int16)
156
  wf.writeframes(data_int16.tobytes())
157
+
158
  # 上传到 OSS
159
+ self._retry_operation(self.bucket2object[bucket_name].put_object,
160
+ object_key, byte_io.getvalue())
 
 
 
161
  return True
162
  except Exception as e:
163
  print(f"Upload failed: {str(e)}")
164
  return False
165
+
166
  def get_object(self, oss_path):
167
  """读取OSS上的音频文件,返回音频数据和采样率"""
168
  bucket_name, object_key = self._parse_oss_path(oss_path)
169
+ return self._retry_operation(
170
+ self.bucket2object[bucket_name].get_object, object_key)
171
+
172
  def read_text_file(self, oss_path):
173
  """读取OSS上的文本文件"""
174
  bucket_name, object_key = self._parse_oss_path(oss_path)
175
+ result = self._retry_operation(
176
+ self.bucket2object[bucket_name].get_object, object_key)
177
  return result.read().decode('utf-8')
178
+
179
  def read_audio_file(self, oss_path):
180
  """读取OSS上的音频文件,返回音频数据和采样率"""
181
  bucket_name, object_key = self._parse_oss_path(oss_path)
182
+ result = self._retry_operation(
183
+ self.bucket2object[bucket_name].get_object, object_key)
184
  # ffmpeg 命令:从标准输入读取音频并输出PCM浮点数据
185
  command = [
186
  'ffmpeg',
187
+ '-i',
188
+ '-', # 输入来自管道
189
+ '-ar',
190
+ str(WAV_SAMPLE_RATE), # 输出采样率
191
+ '-ac',
192
+ '1', # 单声道
193
+ '-f',
194
+ 'f32le', # 指定输出格式
195
  '-' # 输出到管道
196
  ]
197
  # 启动ffmpeg子进程
198
+ process = subprocess.Popen(command,
199
+ stdin=subprocess.PIPE,
200
+ stdout=subprocess.PIPE,
201
+ stderr=subprocess.PIPE)
 
 
202
  # 写入音频字节并获取输出
203
  stdout_data, stderr_data = process.communicate(input=result.read())
204
  if process.returncode != 0:
 
206
  # 将PCM数据转换为numpy数组
207
  wav_data = np.frombuffer(stdout_data, dtype=np.float32)
208
  return wav_data, WAV_SAMPLE_RATE
209
+
210
  def get_wav_duration_from_bin(self, oss_path):
211
  oss_bin_path = oss_path + ".ar16k.bin"
212
  bucket_name, object_key = self._parse_oss_path(oss_bin_path)
213
+ metadata = self._retry_operation(
214
+ self.bucket2object[bucket_name].get_object_meta, object_key)
215
  duration = float(metadata.headers['Content-Length']) / (16000 * 2)
216
  return duration
217
+
218
+ def read_wavdata_from_oss(self,
219
+ oss_path,
220
+ start=None,
221
+ end=None,
222
+ force_bin=False):
223
  bucket_name, object_key = self._parse_oss_path(oss_path)
224
  oss_bin_key = object_key + ".ar16k.bin"
225
  if start is None or end is None:
226
  if self.bucket2object[bucket_name].object_exists(oss_bin_key):
227
+ wav_data = self._retry_operation(
228
+ self.bucket2object[bucket_name].get_object,
229
+ oss_bin_key).read()
230
  elif not force_bin:
231
  wav_data, _ = self.read_audio_file(oss_path)
232
  else:
 
239
  if not (end_offset - start_offset) % 2:
240
  end_offset -= 1
241
  # 使用范围请求只获取指定字节范围的数据
242
+ wav_data = self._retry_operation(
243
+ self.bucket2object[bucket_name].get_object,
244
+ oss_bin_key,
245
+ byte_range=(start_offset, end_offset),
246
+ headers={
247
+ 'x-oss-range-behavior': 'standard'
248
+ }).read()
249
  if not isinstance(wav_data, np.ndarray):
250
  wav_data = np.frombuffer(wav_data, np.int16).flatten() / 32768.0
251
  return wav_data.astype(np.float32)
252
+
253
  def _list_files_by_suffix(self, oss_dir, suffix):
254
  """递归搜索以某个后缀结尾的所有文件,返回所有文件的OSS路径列表"""
255
  bucket_name, dir_key = self._parse_oss_path(oss_dir)
256
  file_list = []
257
+
258
  def _recursive_list(prefix):
259
+ for obj in oss2.ObjectIterator(self.bucket2object[bucket_name],
260
+ prefix=prefix,
261
+ delimiter='/'):
262
  if obj.is_prefix(): # 如果是目录,递归搜索
263
  _recursive_list(obj.key)
264
  elif obj.key.endswith(suffix):
265
  file_list.append(f"oss://{bucket_name}/{obj.key}")
266
+
267
  _recursive_list(dir_key)
268
  return file_list
269
+
270
  def list_files_by_suffix(self, oss_dir, suffix):
271
+ return self._retry_operation(self._list_files_by_suffix, oss_dir,
272
+ suffix)
273
+
274
  def _list_files_by_prefix(self, oss_dir, file_prefix):
275
  """递归搜索以某个后缀结尾的所有文件,返回所有文件的OSS路径列表"""
276
  bucket_name, dir_key = self._parse_oss_path(oss_dir)
277
  file_list = []
278
+
279
  def _recursive_list(prefix):
280
+ for obj in oss2.ObjectIterator(self.bucket2object[bucket_name],
281
+ prefix=prefix,
282
+ delimiter='/'):
283
  if obj.is_prefix(): # 如果是目录,递归搜索
284
  _recursive_list(obj.key)
285
  elif os.path.basename(obj.key).startswith(file_prefix):
286
  file_list.append(f"oss://{bucket_name}/{obj.key}")
287
+
288
  _recursive_list(dir_key)
289
  return file_list
290
+
291
  def list_files_by_prefix(self, oss_dir, file_prefix):
292
+ return self._retry_operation(self._list_files_by_prefix, oss_dir,
293
+ file_prefix)
294
 
295
 
296
  def encode_base64(base64_path):
 
303
  device_map = 'cpu'
304
  else:
305
  device_map = 'auto'
306
+
307
  model = OpenAI(
308
  # 若没有配置环境变量,请用阿里云百炼API Key将下行替换为:api_key="sk-xxx",
309
  api_key=API_KEY,
310
  base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
311
  )
312
+
313
  return model, None
314
 
315
 
 
319
  def _launch_demo(args, model, processor):
320
  # Voice settings
321
  VOICE_OPTIONS = {
322
+ "Cherry / 芊悦": "Cherry",
323
+ "Serena / 苏瑶": "Serena",
324
+ "Ethan / 晨煦": "Ethan",
325
+ "Chelsie / 千雪": "Chelsie",
326
+ "Momo / 茉兔": "Momo",
327
+ "Vivian / 十三": "Vivian",
328
+ "Moon / 月白": "Moon",
329
+ "Maia / 四月": "Maia",
330
+ "Kai / 凯": "Kai",
331
+ "Nofish / 不吃鱼": "Nofish",
332
+ "Bella / 萌宝": "Bella",
333
+ "Jennifer / 詹妮弗": "Jennifer",
334
+ "Ryan / 甜茶": "Ryan",
335
+ "Katerina / 卡捷琳娜": "Katerina",
336
+ "Aiden / 艾登": "Aiden",
337
+ "Bodega / 西班牙语-博德加": "Bodega",
338
+ "Alek / 俄语-阿列克": "Alek",
339
+ "Dolce / 意大利语-多尔切": "Dolce",
340
+ "Sohee / 韩语-素熙": "Sohee",
341
+ "Ono Anna / 日语-小野杏": "Ono Anna",
342
+ "Lenn / 德语-莱恩": "Lenn",
343
+ "Sonrisa / 西班牙语拉美-索尼莎": "Sonrisa",
344
+ "Emilien / 法语-埃米尔安": "Emilien",
345
+ "Andre / 葡萄牙语欧-安德雷": "Andre",
346
+ "Radio Gol / 葡萄牙语巴-拉���奥·戈尔": "Radio Gol",
347
+ "Eldric Sage / 精品百人-沧明子": "Eldric Sage",
348
+ "Mia / 精品百人-乖小妹": "Mia",
349
+ "Mochi / 精品百人-沙小弥": "Mochi",
350
+ "Bellona / 精品百人-燕铮莺": "Bellona",
351
+ "Vincent / 精品百人-田叔": "Vincent",
352
+ "Bunny / 精品百人-萌小姬": "Bunny",
353
+ "Neil / 精品百人-阿闻": "Neil",
354
+ "Elias / 墨讲师": "Elias",
355
+ "Arthur / 精品百人-徐大爷": "Arthur",
356
+ "Nini / 精品百人-邻家妹妹": "Nini",
357
+ "Ebona / 精品百人-诡婆婆": "Ebona",
358
+ "Seren / 精品百人-小婉": "Seren",
359
+ "Pip / 精品百人-调皮小新": "Pip",
360
+ "Stella / 精品百人-美少女阿月": "Stella",
361
+ "Li / 南京-老李": "Li",
362
+ "Marcus / 陕西-秦川": "Marcus",
363
+ "Roy / 闽南-阿杰": "Roy",
364
+ "Peter / 天津-李彼得": "Peter",
365
+ "Eric / 四川-程川": "Eric",
366
+ "Rocky / 粤语-阿强": "Rocky",
367
+ "Kiki / 粤语-阿清": "Kiki",
368
+ "Sunny / 四川-晴儿": "Sunny",
369
+ "Jada / 上海-阿珍": "Jada",
370
+ "Dylan / 北京-晓东": "Dylan",
371
  }
372
  DEFAULT_VOICE = '芊悦 Cherry'
373
+
374
  default_system_prompt = ''
375
+
376
  language = args.ui_language
377
+
378
  def get_text(text: str, cn_text: str):
379
  if language == 'en':
380
  return text
381
  if language == 'zh':
382
  return cn_text
383
  return text
384
+
385
  def to_mp4(path):
386
  import subprocess
387
  if path and path.endswith(".webm"):
388
  mp4_path = path.replace(".webm", ".mp4")
389
+ subprocess.run(
390
+ [
391
+ "ffmpeg",
392
+ "-y",
393
+ "-i",
394
+ path,
395
+ "-c:v",
396
+ "libx264", # 使用 H.264
397
+ "-preset",
398
+ "ultrafast", # 最快速度!
399
+ "-tune",
400
+ "fastdecode", # 优化快速解码(利于后续处理)
401
+ "-pix_fmt",
402
+ "yuv420p", # 兼容性像素格式
403
+ "-c:a",
404
+ "aac", # 音频编码
405
+ "-b:a",
406
+ "128k", # 可选:限制音频比特率加速
407
+ "-threads",
408
+ "0", # 使用所有线程
409
+ "-f",
410
+ "mp4",
411
+ mp4_path
412
+ ],
413
+ check=True,
414
+ stdout=subprocess.DEVNULL,
415
+ stderr=subprocess.DEVNULL)
416
  return mp4_path
417
  return path # 已经是 mp4 或 None
418
+
419
  def format_history(history: list, system_prompt: str):
420
  print(history)
421
  messages = []
422
  if system_prompt != "":
423
+ messages.append({
424
+ "role":
425
+ "system",
426
+ "content": [{
427
+ "type": "text",
428
+ "text": system_prompt
429
+ }]
430
+ })
431
+
432
  current_user_content = []
433
+
434
  for item in history:
435
  role = item['role']
436
  content = item['content']
437
+
438
  if role != "user":
439
  if current_user_content:
440
+ messages.append({
441
+ "role": "user",
442
+ "content": current_user_content
443
+ })
444
  current_user_content = []
445
+
446
  if isinstance(content, str):
447
  messages.append({
448
+ "role":
449
+ role,
450
+ "content": [{
451
+ "type": "text",
452
+ "text": content
453
+ }]
454
  })
455
  else:
456
  pass
457
  continue
458
+
459
  if isinstance(content, str):
460
  current_user_content.append({"type": "text", "text": content})
461
  elif isinstance(content, (list, tuple)):
462
  for file_path in content:
463
  mime_type = client_utils.get_mimetype(file_path)
464
  media_type = None
465
+
466
  if mime_type.startswith("image"):
467
  media_type = "image_url"
468
  elif mime_type.startswith("video"):
 
470
  file_path = to_mp4(file_path)
471
  elif mime_type.startswith("audio"):
472
  media_type = "input_audio"
473
+
474
  if media_type:
475
  # base64_media = encode_base64(file_path)
476
  import uuid
 
505
  "type": "text",
506
  "text": file_path
507
  })
508
+
509
  if current_user_content:
510
  media_items = []
511
  text_items = []
512
+
513
  for item in current_user_content:
514
  if item["type"] == "text":
515
  text_items.append(item)
516
  else:
517
  media_items.append(item)
518
+
519
  messages.append({
520
  "role": "user",
521
  "content": media_items + text_items
522
  })
523
+
524
  return messages
525
+
526
+ def predict(messages,
527
+ voice_choice=DEFAULT_VOICE,
528
+ temperature=0.7,
529
+ top_p=0.8,
530
+ top_k=20,
531
+ return_audio=False,
532
  enable_thinking=False):
533
  # print('predict history: ', messages)
534
  if enable_thinking:
535
+ return_audio = False
536
  if return_audio:
537
  completion = model.chat.completions.create(
538
+ model="qwen3-omni-flash-2025-12-01",
539
  messages=messages,
540
  modalities=["text", "audio"],
541
+ audio={
542
+ "voice": VOICE_OPTIONS[voice_choice],
543
+ "format": "wav"
544
+ },
545
+ extra_body={
546
+ 'enable_thinking': False,
547
+ "top_k": top_k
548
+ },
549
  stream_options={"include_usage": True},
550
  stream=True,
551
  temperature=temperature,
 
553
  )
554
  else:
555
  completion = model.chat.completions.create(
556
+ model="qwen3-omni-flash-2025-12-01",
557
  messages=messages,
558
  modalities=["text"],
559
+ extra_body={
560
+ 'enable_thinking': enable_thinking,
561
+ "top_k": top_k
562
+ },
563
  stream_options={"include_usage": True},
564
  stream=True,
565
  temperature=temperature,
 
577
  try:
578
  audio_string += chunk.choices[0].delta.audio["data"]
579
  except Exception as e:
580
+ output_text += chunk.choices[0].delta.audio[
581
+ "transcript"]
582
  yield {"type": "text", "data": output_text}
583
  else:
584
  delta = chunk.choices[0].delta
585
  if enable_thinking:
586
+ if hasattr(delta, "reasoning_content"
587
+ ) and delta.reasoning_content is not None:
588
  if not is_answering:
589
+ print(delta.reasoning_content,
590
+ end="",
591
+ flush=True)
592
  reasoning_content += delta.reasoning_content
593
  yield {"type": "text", "data": reasoning_content}
594
  if hasattr(delta, "content") and delta.content:
 
596
  reasoning_content += "\n\n</think>\n\n"
597
  is_answering = True
598
  answer_content += delta.content
599
+ yield {
600
+ "type": "text",
601
+ "data": reasoning_content + answer_content
602
+ }
603
  else:
604
  if hasattr(delta, "content") and delta.content:
605
  output_text += chunk.choices[0].delta.content
606
  yield {"type": "text", "data": output_text}
607
  else:
608
  print(chunk.usage)
609
+
610
  wav_bytes = base64.b64decode(audio_string)
611
  audio_np = np.frombuffer(wav_bytes, dtype=np.int16)
612
+
613
  if audio_string != "":
614
  wav_io = io.BytesIO()
615
  sf.write(wav_io, audio_np, samplerate=24000, format="WAV")
 
618
  audio_path = processing_utils.save_bytes_to_cache(
619
  wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE)
620
  yield {"type": "audio", "data": audio_path}
621
+
622
+ def media_predict(audio,
623
+ video,
624
+ history,
625
+ system_prompt,
626
+ voice_choice,
627
+ temperature,
628
+ top_p,
629
+ top_k,
630
+ return_audio=False,
631
  enable_thinking=False):
632
  # First yield
633
  yield (
 
637
  gr.update(visible=False), # submit_btn
638
  gr.update(visible=True), # stop_btn
639
  )
640
+
641
  files = [audio, video]
642
+
643
  for f in files:
644
  if f:
645
+ history.append({"role": "user", "content": (f, )})
646
+
647
  yield (
648
  None, # microphone
649
  None, # webcam
 
651
  gr.update(visible=True), # submit_btn
652
  gr.update(visible=False), # stop_btn
653
  )
654
+
655
+ formatted_history = format_history(
656
+ history=history,
657
+ system_prompt=system_prompt,
658
+ )
659
+
660
  history.append({"role": "assistant", "content": ""})
661
+
662
+ for chunk in predict(formatted_history, voice_choice, temperature,
663
+ top_p, top_k, return_audio, enable_thinking):
664
  print('chunk', chunk)
665
  if chunk["type"] == "text":
666
  history[-1]["content"] = chunk["data"]
 
676
  "role": "assistant",
677
  "content": gr.Audio(chunk["data"])
678
  })
679
+
680
  # Final yield
681
  yield (
682
  None, # microphone
 
685
  gr.update(visible=True), # submit_btn
686
  gr.update(visible=False), # stop_btn
687
  )
688
+
689
+ def chat_predict(text,
690
+ audio,
691
+ image,
692
+ video,
693
+ history,
694
+ system_prompt,
695
+ voice_choice,
696
+ temperature,
697
+ top_p,
698
+ top_k,
699
+ return_audio=False,
700
+ enable_thinking=False):
701
+
702
  # Process audio input
703
  if audio:
704
+ history.append({"role": "user", "content": (audio, )})
705
+
706
  # Process text input
707
  if text:
708
  history.append({"role": "user", "content": text})
709
+
710
  # Process image input
711
  if image:
712
+ history.append({"role": "user", "content": (image, )})
713
+
714
  # Process video input
715
  if video:
716
+ history.append({"role": "user", "content": (video, )})
717
+
718
  formatted_history = format_history(history=history,
719
  system_prompt=system_prompt)
720
+
721
  yield None, None, None, None, history
722
+
723
  history.append({"role": "assistant", "content": ""})
724
+ for chunk in predict(formatted_history, voice_choice, temperature,
725
+ top_p, top_k, return_audio, enable_thinking):
726
  print('chat_predict chunk', chunk)
727
+
728
  if chunk["type"] == "text":
729
  history[-1]["content"] = chunk["data"]
730
+ yield gr.skip(), gr.skip(), gr.skip(), gr.skip(), history
 
731
  if chunk["type"] == "audio":
732
  history.append({
733
  "role": "assistant",
734
  "content": gr.Audio(chunk["data"])
735
  })
736
  yield gr.skip(), gr.skip(), gr.skip(), gr.skip(), history
737
+
738
  # --- CORRECTED UI LAYOUT ---
739
+ with gr.Blocks(
740
+ theme=gr.themes.Soft(font=[
741
+ gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"
742
+ ]),
743
+ css=".gradio-container {max-width: none !important;}") as demo:
744
  gr.Markdown("# Qwen3-Omni Demo")
745
  gr.Markdown(
746
+ "**Instructions**: Interact with the model through text, audio, images, or video. Use the tabs to switch between Online and Offline chat modes."
747
+ )
748
  gr.Markdown(
749
  "**使用说明**:1️⃣ 点击音频录制按钮,或摄像头-录制按钮 2️⃣ 输入音频或者视频 3️⃣ 点击提交并等待模型的回答")
750
+
751
  with gr.Row(equal_height=False):
752
  with gr.Column(scale=1):
753
  gr.Markdown("### ⚙️ Parameters (参数)")
754
+ system_prompt_textbox = gr.Textbox(label="System Prompt",
755
+ value=default_system_prompt,
756
+ lines=4,
757
  max_lines=8)
758
+ voice_choice = gr.Dropdown(label="Voice Choice",
759
+ choices=VOICE_OPTIONS,
760
+ value=DEFAULT_VOICE,
761
  visible=True)
762
+ return_audio = gr.Checkbox(label="Return Audio (返回语音)",
763
+ value=True,
764
+ interactive=True,
765
+ elem_classes="checkbox-large")
766
+ enable_thinking = gr.Checkbox(label="Enable Thinking (启用思维链)",
767
+ value=False,
768
+ interactive=True,
769
+ elem_classes="checkbox-large")
770
+ temperature = gr.Slider(label="Temperature",
771
+ minimum=0.1,
772
+ maximum=2.0,
773
+ value=0.6,
774
+ step=0.1)
775
+ top_p = gr.Slider(label="Top P",
776
+ minimum=0.05,
777
+ maximum=1.0,
778
+ value=0.95,
779
+ step=0.05)
780
+ top_k = gr.Slider(label="Top K",
781
+ minimum=1,
782
+ maximum=100,
783
+ value=20,
784
+ step=1)
785
+
786
  with gr.Column(scale=3):
787
  with gr.Tabs():
788
  with gr.TabItem("Online"):
789
  with gr.Row():
790
  with gr.Column(scale=1):
791
  gr.Markdown("### Audio-Video Input (音视频输入)")
792
+ microphone = gr.Audio(
793
+ sources=['microphone'],
794
+ type="filepath",
795
+ label="Record Audio (录制音频)")
796
+ webcam = gr.Video(
797
+ sources=['webcam', "upload"],
798
+ label="Record/Upload Video (录制/上传视频)",
799
+ elem_classes="media-upload")
800
  with gr.Row():
801
+ submit_btn_online = gr.Button(
802
+ "Submit (提交)",
803
+ variant="primary",
804
+ scale=2)
805
+ stop_btn_online = gr.Button("Stop (停止)",
806
+ visible=False,
807
+ scale=1)
808
+ clear_btn_online = gr.Button(
809
+ "Clear History (清除历史)")
810
  with gr.Column(scale=2):
811
  # FIX: Re-added type="messages"
812
+ media_chatbot = gr.Chatbot(
813
+ label="Chat History (对话历史)",
814
+ type="messages",
815
+ height=650,
816
+ layout="panel",
817
+ bubble_full_width=False,
818
+ allow_tags=["think"],
819
+ render=False)
820
  media_chatbot.render()
821
+
822
  def clear_history_online():
823
  return [], None, None
824
+
825
  submit_event_online = submit_btn_online.click(
826
  fn=media_predict,
827
+ inputs=[
828
+ microphone, webcam, media_chatbot,
829
+ system_prompt_textbox, voice_choice,
830
+ temperature, top_p, top_k, return_audio,
831
+ enable_thinking
832
+ ],
833
+ outputs=[
834
+ microphone, webcam, media_chatbot,
835
+ submit_btn_online, stop_btn_online
836
+ ])
837
+ stop_btn_online.click(
838
+ fn=lambda: (gr.update(visible=True),
839
+ gr.update(visible=False)),
840
+ outputs=[submit_btn_online, stop_btn_online],
841
+ cancels=[submit_event_online],
842
+ queue=False)
843
+ clear_btn_online.click(
844
+ fn=clear_history_online,
845
+ outputs=[media_chatbot, microphone, webcam])
846
+
847
  with gr.TabItem("Offline"):
848
  # FIX: Re-added type="messages"
849
+ chatbot = gr.Chatbot(label="Chat History (对话历史)",
850
+ type="messages",
851
+ height=550,
852
+ layout="panel",
853
+ bubble_full_width=False,
854
+ allow_tags=["think"],
855
  render=False)
856
  chatbot.render()
857
+
858
+ with gr.Accordion(
859
+ "📎 Click to upload multimodal files (点击上传多模态文件)",
860
+ open=False):
861
  with gr.Row():
862
+ audio_input = gr.Audio(
863
+ sources=["upload", 'microphone'],
864
+ type="filepath",
865
+ label="Audio",
866
+ elem_classes="media-upload")
867
+ image_input = gr.Image(
868
+ sources=["upload", 'webcam'],
869
+ type="filepath",
870
+ label="Image",
871
+ elem_classes="media-upload")
872
+ video_input = gr.Video(
873
+ sources=["upload", 'webcam'],
874
+ label="Video",
875
+ elem_classes="media-upload")
876
+
877
  with gr.Row():
878
+ text_input = gr.Textbox(
879
+ show_label=False,
880
+ placeholder=
881
+ "Enter text or upload files and press Submit... (输入文本或者上传文件并点击提交)",
882
+ scale=7)
883
+ submit_btn_offline = gr.Button("Submit (提交)",
884
+ variant="primary",
885
+ scale=1)
886
+ stop_btn_offline = gr.Button("Stop (停止)",
887
+ visible=False,
888
+ scale=1)
889
+ clear_btn_offline = gr.Button("Clear (清空) ",
890
+ scale=1)
891
+
892
  def clear_history_offline():
893
  return [], None, None, None, None
894
+
895
  submit_event_offline = gr.on(
896
+ triggers=[
897
+ submit_btn_offline.click, text_input.submit
898
+ ],
899
  fn=chat_predict,
900
+ inputs=[
901
+ text_input, audio_input, image_input,
902
+ video_input, chatbot, system_prompt_textbox,
903
+ voice_choice, temperature, top_p, top_k,
904
+ return_audio, enable_thinking
905
+ ],
906
+ outputs=[
907
+ text_input, audio_input, image_input,
908
+ video_input, chatbot
909
+ ])
910
+ stop_btn_offline.click(
911
+ fn=lambda: (gr.update(visible=True),
912
+ gr.update(visible=False)),
913
+ outputs=[submit_btn_offline, stop_btn_offline],
914
+ cancels=[submit_event_offline],
915
+ queue=False)
916
  clear_btn_offline.click(fn=clear_history_offline,
917
+ outputs=[
918
+ chatbot, text_input,
919
+ audio_input, image_input,
920
+ video_input
921
+ ])
922
+
923
  gr.HTML("""
924
  <style>
925
  .media-upload { min-height: 160px; border: 2px dashed #ccc; border-radius: 8px; display: flex; align-items: center; justify-content: center; }
926
  .media-upload:hover { border-color: #666; }
927
  </style>
928
  """)
929
+
930
+ demo.queue(default_concurrency_limit=100, max_size=100).launch(
931
+ max_threads=100,
932
+ ssr_mode=False,
933
+ share=args.share,
934
+ inbrowser=args.inbrowser,
935
+ # ssl_certfile="examples/offline_inference/qwen3_omni_moe/cert.pem",
936
+ # ssl_keyfile="examples/offline_inference/qwen3_omni_moe/key.pem",
937
+ # ssl_verify=False,
938
+ server_port=args.server_port,
939
+ server_name=args.server_name,
940
+ )
941
 
942
 
943
  DEFAULT_CKPT_PATH = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
 
945
 
946
  def _get_args():
947
  parser = ArgumentParser()
948
+
949
  parser.add_argument('-c',
950
  '--checkpoint-path',
951
  type=str,
952
  default=DEFAULT_CKPT_PATH,
953
  help='Checkpoint name or path, default to %(default)r')
954
+ parser.add_argument('--cpu-only',
 
 
955
  action='store_true',
956
+ help='Run demo with CPU only')
957
+
958
+ parser.add_argument(
959
+ '--flash-attn2',
960
+ action='store_true',
961
+ default=False,
962
+ help='Enable flash_attention_2 when loading the model.')
963
  parser.add_argument('--use-transformers',
964
  action='store_true',
965
  default=False,
966
  help='Use transformers for inference.')
967
+ parser.add_argument(
968
+ '--share',
969
+ action='store_true',
970
+ default=False,
971
+ help='Create a publicly shareable link for the interface.')
972
+ parser.add_argument(
973
+ '--inbrowser',
974
+ action='store_true',
975
+ default=False,
976
+ help=
977
+ 'Automatically launch the interface in a new tab on the default browser.'
978
+ )
979
+ parser.add_argument('--server-port',
980
+ type=int,
981
+ default=8905,
982
+ help='Demo server port.')
983
+ parser.add_argument('--server-name',
984
+ type=str,
985
+ default='0.0.0.0',
986
+ help='Demo server name.')
987
+ parser.add_argument('--ui-language',
988
+ type=str,
989
+ choices=['en', 'zh'],
990
+ default='zh',
991
  help='Display language for the UI.')
992
+
993
  args = parser.parse_args()
994
  return args
995
 
 
998
  args = _get_args()
999
  model, processor = _load_model_processor(args)
1000
  _launch_demo(args, model, processor)