ASureevaA commited on
Commit
c14e744
·
1 Parent(s): a6352f4

Add application file

Browse files
Files changed (2) hide show
  1. app.py +887 -0
  2. requirements.txt +17 -0
app.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from typing import List, Tuple
3
+
4
+ import gradio as gr
5
+ import soundfile as soundfile_module
6
+ import torch
7
+ import torch.nn.functional as torch_functional
8
+ from gtts import gTTS
9
+ from PIL import Image, ImageDraw
10
+ from transformers import (
11
+ AutoTokenizer,
12
+ CLIPModel,
13
+ CLIPProcessor,
14
+ SamModel,
15
+ SamProcessor,
16
+ VitsModel,
17
+ pipeline,
18
+ )
19
+
20
+
21
+ MODEL_STORE = {}
22
+
23
+
24
+
25
+ def get_audio_pipeline(model_key: str):
26
+ if model_key in MODEL_STORE:
27
+ return MODEL_STORE[model_key]
28
+
29
+ if model_key == "whisper":
30
+ audio_pipeline = pipeline(
31
+ task="automatic-speech-recognition",
32
+ model="distil-whisper/distil-small.en",
33
+ )
34
+ elif model_key == "wav2vec2":
35
+ audio_pipeline = pipeline(
36
+ task="automatic-speech-recognition",
37
+ model="openai/whisper-small",
38
+ )
39
+ elif model_key == "audio_classifier":
40
+ audio_pipeline = pipeline(
41
+ task="audio-classification",
42
+ model="MIT/ast-finetuned-audioset-10-10-0.4593",
43
+ )
44
+ elif model_key == "emotion_classifier":
45
+ audio_pipeline = pipeline(
46
+ task="audio-classification",
47
+ model="superb/hubert-large-superb-er",
48
+ )
49
+ else:
50
+ raise ValueError(f"Неизвестный тип аудио модели: {model_key}")
51
+
52
+ MODEL_STORE[model_key] = audio_pipeline
53
+ return audio_pipeline
54
+
55
+
56
+ def get_zero_shot_audio_pipeline():
57
+ if "audio_zero_shot_clap" not in MODEL_STORE:
58
+ zero_shot_pipeline = pipeline(
59
+ task="zero-shot-audio-classification",
60
+ model="laion/clap-htsat-unfused",
61
+ )
62
+ MODEL_STORE["audio_zero_shot_clap"] = zero_shot_pipeline
63
+ return MODEL_STORE["audio_zero_shot_clap"]
64
+
65
+
66
+ def get_vision_pipeline(model_key: str):
67
+ if model_key in MODEL_STORE:
68
+ return MODEL_STORE[model_key]
69
+
70
+ if model_key == "object_detection_conditional_detr":
71
+ vision_pipeline = pipeline(
72
+ task="object-detection",
73
+ model="microsoft/conditional-detr-resnet-50",
74
+ )
75
+ elif model_key == "object_detection_yolos_small":
76
+ vision_pipeline = pipeline(
77
+ task="object-detection",
78
+ model="hustvl/yolos-small",
79
+ )
80
+
81
+ elif model_key == "segmentation":
82
+ vision_pipeline = pipeline(
83
+ task="image-segmentation",
84
+ model="nvidia/segformer-b0-finetuned-ade-512-512",
85
+ )
86
+
87
+ elif model_key == "depth_estimation":
88
+ vision_pipeline = pipeline(
89
+ task="depth-estimation",
90
+ model="Intel/dpt-hybrid-midas",
91
+ )
92
+
93
+ elif model_key == "captioning_blip_base":
94
+ vision_pipeline = pipeline(
95
+ task="image-to-text",
96
+ model="Salesforce/blip-image-captioning-base",
97
+ )
98
+ elif model_key == "captioning_blip_large":
99
+ vision_pipeline = pipeline(
100
+ task="image-to-text",
101
+ model="Salesforce/blip-image-captioning-large",
102
+ )
103
+
104
+ elif model_key == "vqa_blip_base":
105
+ vision_pipeline = pipeline(
106
+ task="visual-question-answering",
107
+ model="Salesforce/blip-vqa-base",
108
+ )
109
+ elif model_key == "vqa_vilt_b32":
110
+ vision_pipeline = pipeline(
111
+ task="visual-question-answering",
112
+ model="dandelin/vilt-b32-finetuned-vqa",
113
+ )
114
+
115
+ else:
116
+ raise ValueError(f"Неизвестный тип визуальной модели: {model_key}")
117
+
118
+ MODEL_STORE[model_key] = vision_pipeline
119
+ return vision_pipeline
120
+
121
+
122
+ def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]:
123
+ model_store_key_model = f"clip_model_{clip_key}"
124
+ model_store_key_processor = f"clip_processor_{clip_key}"
125
+
126
+ if model_store_key_model not in MODEL_STORE or model_store_key_processor not in MODEL_STORE:
127
+ if clip_key == "clip_large_patch14":
128
+ clip_name = "openai/clip-vit-large-patch14"
129
+ elif clip_key == "clip_base_patch32":
130
+ clip_name = "openai/clip-vit-base-patch32"
131
+ else:
132
+ raise ValueError(f"Неизвестный вариант CLIP модели: {clip_key}")
133
+
134
+ clip_model = CLIPModel.from_pretrained(clip_name)
135
+ clip_processor = CLIPProcessor.from_pretrained(clip_name)
136
+
137
+ MODEL_STORE[model_store_key_model] = clip_model
138
+ MODEL_STORE[model_store_key_processor] = clip_processor
139
+
140
+ clip_model = MODEL_STORE[model_store_key_model]
141
+ clip_processor = MODEL_STORE[model_store_key_processor]
142
+ return clip_model, clip_processor
143
+
144
+
145
+ def get_silero_tts_model():
146
+ if "silero_tts_model" not in MODEL_STORE:
147
+ silero_model, _ = torch.hub.load(
148
+ repo_or_dir="snakers4/silero-models",
149
+ model="silero_tts",
150
+ language="ru",
151
+ speaker="ru_v3",
152
+ )
153
+ MODEL_STORE["silero_tts_model"] = silero_model
154
+ return MODEL_STORE["silero_tts_model"]
155
+
156
+
157
+ def get_mms_tts_components() -> Tuple[VitsModel, AutoTokenizer]:
158
+ if "mms_tts_model" not in MODEL_STORE or "mms_tts_tokenizer" not in MODEL_STORE:
159
+ vits_model = VitsModel.from_pretrained("kakao-enterprise/vits-ljs")
160
+ vits_tokenizer = AutoTokenizer.from_pretrained("kakao-enterprise/vits-ljs")
161
+ MODEL_STORE["mms_tts_model"] = vits_model
162
+ MODEL_STORE["mms_tts_tokenizer"] = vits_tokenizer
163
+
164
+ vits_model = MODEL_STORE["mms_tts_model"]
165
+ vits_tokenizer = MODEL_STORE["mms_tts_tokenizer"]
166
+ return vits_model, vits_tokenizer
167
+
168
+
169
+ def get_sam_components() -> Tuple[SamModel, SamProcessor]:
170
+ if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE:
171
+ sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
172
+ sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
173
+ MODEL_STORE["sam_model"] = sam_model
174
+ MODEL_STORE["sam_processor"] = sam_processor
175
+
176
+ sam_model = MODEL_STORE["sam_model"]
177
+ sam_processor = MODEL_STORE["sam_processor"]
178
+ return sam_model, sam_processor
179
+
180
+
181
+
182
+ def classify_audio_file(audio_path: str, model_key: str) -> str:
183
+ audio_classifier = get_audio_pipeline(model_key)
184
+ prediction_list = audio_classifier(audio_path)
185
+
186
+ result_lines = ["Топ-5 предсказаний:"]
187
+ for prediction_index, prediction_item in enumerate(prediction_list[:5], start=1):
188
+ label_value = prediction_item["label"]
189
+ score_value = prediction_item["score"]
190
+ result_lines.append(
191
+ f"{prediction_index}. {label_value}: {score_value:.4f}"
192
+ )
193
+
194
+ return "\n".join(result_lines)
195
+
196
+
197
+ def classify_audio_zero_shot_clap(audio_path: str, label_texts: str) -> str:
198
+
199
+ clap_pipeline = get_zero_shot_audio_pipeline()
200
+
201
+ label_list = [
202
+ label_item.strip()
203
+ for label_item in label_texts.split(",")
204
+ if label_item.strip()
205
+ ]
206
+ if not label_list:
207
+ return "Не задано ни одной текстовой метки для zero-shot классификации."
208
+
209
+ prediction_list = clap_pipeline(
210
+ audio_path,
211
+ candidate_labels=label_list,
212
+ )
213
+
214
+ result_lines = ["Zero-Shot Audio Classification (CLAP):"]
215
+ for prediction_index, prediction_item in enumerate(prediction_list, start=1):
216
+ label_value = prediction_item["label"]
217
+ score_value = prediction_item["score"]
218
+ result_lines.append(
219
+ f"{prediction_index}. {label_value}: {score_value:.4f}"
220
+ )
221
+
222
+ return "\n".join(result_lines)
223
+
224
+
225
+ def recognize_speech(audio_path: str, model_key: str) -> str:
226
+ speech_pipeline = get_audio_pipeline(model_key)
227
+
228
+ prediction_result = speech_pipeline(audio_path)
229
+
230
+ return prediction_result["text"]
231
+
232
+
233
+ def synthesize_speech(text_value: str, model_key: str):
234
+ if model_key == "silero":
235
+ silero_model = get_silero_tts_model()
236
+
237
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
238
+ silero_model.save_wav(
239
+ text=text_value,
240
+ speaker="aidar",
241
+ sample_rate=48000,
242
+ audio_path=file_object.name,
243
+ )
244
+ return file_object.name
245
+
246
+ if model_key == "gtts":
247
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
248
+ text_to_speech_engine = gTTS(text=text_value, lang="ru")
249
+ text_to_speech_engine.save(file_object.name)
250
+ return file_object.name
251
+
252
+ if model_key == "mms":
253
+ vits_model, vits_tokenizer = get_mms_tts_components()
254
+ tokenized_input = vits_tokenizer(text_value, return_tensors="pt")
255
+
256
+ with torch.no_grad():
257
+ waveform_tensor = vits_model(**tokenized_input).waveform
258
+
259
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
260
+ waveform_array = waveform_tensor.numpy().squeeze()
261
+ soundfile_module.write(
262
+ file_object.name,
263
+ waveform_array,
264
+ vits_model.config.sampling_rate,
265
+ )
266
+ return file_object.name
267
+
268
+ raise ValueError(f"Неизвестная TTS модель: {model_key}")
269
+
270
+
271
+
272
+ def detect_objects_on_image(image_object, model_key: str):
273
+ detector_pipeline = get_vision_pipeline(model_key)
274
+ detection_results = detector_pipeline(image_object)
275
+
276
+ drawer_object = ImageDraw.Draw(image_object)
277
+ for detection_item in detection_results:
278
+ box_data = detection_item["box"]
279
+ label_value = detection_item["label"]
280
+ score_value = detection_item["score"]
281
+
282
+ drawer_object.rectangle(
283
+ [
284
+ box_data["xmin"],
285
+ box_data["ymin"],
286
+ box_data["xmax"],
287
+ box_data["ymax"],
288
+ ],
289
+ outline="red",
290
+ width=3,
291
+ )
292
+ drawer_object.text(
293
+ (box_data["xmin"], box_data["ymin"]),
294
+ f"{label_value}: {score_value:.2f}",
295
+ fill="red",
296
+ )
297
+
298
+ return image_object
299
+
300
+
301
+ def segment_image(image_object):
302
+ segmentation_pipeline = get_vision_pipeline("segmentation")
303
+ segmentation_results = segmentation_pipeline(image_object)
304
+ return segmentation_results[0]["mask"]
305
+
306
+
307
+ def estimate_image_depth(image_object):
308
+ depth_pipeline = get_vision_pipeline("depth_estimation")
309
+ depth_output = depth_pipeline(image_object)
310
+
311
+ predicted_depth_tensor = depth_output["predicted_depth"]
312
+
313
+ resized_depth_tensor = torch_functional.interpolate(
314
+ predicted_depth_tensor.unsqueeze(0).unsqueeze(0),
315
+ size=image_object.size[::-1], # (width, height) -> (H, W)
316
+ mode="bicubic",
317
+ align_corners=False,
318
+ )
319
+
320
+ depth_array = resized_depth_tensor.squeeze().cpu().numpy()
321
+ max_value = float(depth_array.max())
322
+
323
+ if max_value <= 0.0:
324
+ return Image.new("L", image_object.size, color=0)
325
+
326
+ normalized_depth_array = (depth_array * 255.0 / max_value).astype("uint8")
327
+ depth_image = Image.fromarray(normalized_depth_array, mode="L")
328
+ return depth_image
329
+
330
+
331
+ def generate_image_caption(image_object, model_key: str) -> str:
332
+ caption_pipeline = get_vision_pipeline(model_key)
333
+ caption_result = caption_pipeline(image_object)
334
+ return caption_result[0]["generated_text"]
335
+
336
+
337
+ def answer_visual_question(image_object, question_text: str, model_key: str) -> str:
338
+ vqa_pipeline = get_vision_pipeline(model_key)
339
+ vqa_result = vqa_pipeline(image_object, question_text)
340
+
341
+ answer_text = vqa_result[0]["answer"]
342
+ confidence_value = vqa_result[0]["score"]
343
+ return f"{answer_text} (confidence: {confidence_value:.3f})"
344
+
345
+
346
+ def perform_zero_shot_classification(
347
+ image_object,
348
+ class_texts: str,
349
+ clip_key: str,
350
+ ) -> str:
351
+ clip_model, clip_processor = get_clip_components(clip_key)
352
+
353
+ class_list = [
354
+ class_name.strip()
355
+ for class_name in class_texts.split(",")
356
+ if class_name.strip()
357
+ ]
358
+ if not class_list:
359
+ return "Не задано ни одного класса для классификации."
360
+
361
+ input_batch = clip_processor(
362
+ text=class_list,
363
+ images=image_object,
364
+ return_tensors="pt",
365
+ padding=True,
366
+ )
367
+
368
+ with torch.no_grad():
369
+ clip_outputs = clip_model(**input_batch)
370
+ logits_per_image = clip_outputs.logits_per_image
371
+ probability_tensor = logits_per_image.softmax(dim=1)
372
+
373
+ result_lines = ["Zero-Shot Classification Results:"]
374
+ for class_index, class_name in enumerate(class_list):
375
+ probability_value = probability_tensor[0][class_index].item()
376
+ result_lines.append(f"{class_name}: {probability_value:.4f}")
377
+
378
+ return "\n".join(result_lines)
379
+
380
+
381
+ def retrieve_best_image(
382
+ image_list: List,
383
+ query_text: str,
384
+ clip_key: str,
385
+ ):
386
+ if not image_list or not query_text:
387
+ return "Пожалуйста, загрузите изображения и введите запрос", None
388
+
389
+ clip_model, clip_processor = get_clip_components(clip_key)
390
+
391
+ image_inputs = clip_processor(
392
+ images=image_list,
393
+ return_tensors="pt",
394
+ padding=True,
395
+ )
396
+ with torch.no_grad():
397
+ image_features = clip_model.get_image_features(**image_inputs)
398
+ image_features = image_features / image_features.norm(
399
+ dim=-1,
400
+ keepdim=True,
401
+ )
402
+
403
+ text_inputs = clip_processor(
404
+ text=[query_text],
405
+ return_tensors="pt",
406
+ padding=True,
407
+ )
408
+ with torch.no_grad():
409
+ text_features = clip_model.get_text_features(**text_inputs)
410
+ text_features = text_features / text_features.norm(
411
+ dim=-1,
412
+ keepdim=True,
413
+ )
414
+
415
+ similarity_tensor = image_features @ text_features.T
416
+ best_index_tensor = similarity_tensor.argmax()
417
+ best_index_value = best_index_tensor.item()
418
+ best_score_value = similarity_tensor[best_index_value].item()
419
+
420
+ description_text = (
421
+ f"Лучшее изображение: #{best_index_value + 1} "
422
+ f"(схожесть: {best_score_value:.4f})"
423
+ )
424
+ return description_text, image_list[best_index_value]
425
+
426
+
427
+ def segment_image_with_sam_points(
428
+ image_object,
429
+ point_coordinates_list: List[List[int]] | None,
430
+ ) -> Image:
431
+
432
+ if not point_coordinates_list:
433
+ return Image.new("L", image_object.size, color=0)
434
+
435
+ sam_model, sam_processor = get_sam_components()
436
+
437
+ batched_points = [point_coordinates_list]
438
+ batched_labels = [[1 for _ in point_coordinates_list]]
439
+
440
+ sam_inputs = sam_processor(
441
+ image_object,
442
+ input_points=batched_points,
443
+ input_labels=batched_labels,
444
+ return_tensors="pt",
445
+ )
446
+
447
+ with torch.no_grad():
448
+ sam_outputs = sam_model(**sam_inputs)
449
+
450
+ post_processed_masks_list = sam_processor.image_processor.post_process_masks(
451
+ sam_outputs.pred_masks.cpu(),
452
+ sam_inputs["original_sizes"].cpu(),
453
+ sam_inputs["reshaped_input_sizes"].cpu(),
454
+ )
455
+
456
+ batched_masks_tensor = post_processed_masks_list[0] # shape: [num_masks, H, W]
457
+ if batched_masks_tensor.ndim != 3 or batched_masks_tensor.shape[0] == 0:
458
+ return Image.new("L", image_object.size, color=0)
459
+
460
+ first_mask_tensor = batched_masks_tensor[0] # [H, W]
461
+ mask_array = first_mask_tensor.cpu().numpy()
462
+
463
+ mask_image = Image.fromarray((mask_array * 255.0).astype("uint8"), mode="L")
464
+ return mask_image
465
+
466
+
467
+ def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]:
468
+ if not coordinates_text.strip():
469
+ return []
470
+
471
+ point_list: List[List[int]] = []
472
+ for raw_pair in coordinates_text.split(";"):
473
+ cleaned_pair = raw_pair.strip()
474
+ if not cleaned_pair:
475
+ continue
476
+ coordinate_parts = cleaned_pair.split(",")
477
+ if len(coordinate_parts) != 2:
478
+ continue
479
+ try:
480
+ x_value = int(coordinate_parts[0].strip())
481
+ y_value = int(coordinate_parts[1].strip())
482
+ except ValueError:
483
+ continue
484
+ point_list.append([x_value, y_value])
485
+
486
+ return point_list
487
+
488
+
489
+ def segment_image_with_sam_points_ui(
490
+ image_object,
491
+ coordinates_text: str,
492
+ ):
493
+ point_coordinates_list = parse_point_coordinates_text(coordinates_text)
494
+ return segment_image_with_sam_points(image_object, point_coordinates_list)
495
+
496
+
497
+
498
+ def build_interface():
499
+ with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo_block:
500
+ gr.Markdown("#Мультимодальные AI модели")
501
+ gr.Markdown(
502
+ "Демонстрация различных задач компьютерного зрения "
503
+ "и обработки звука с использованием Hugging Face Transformers",
504
+ )
505
+
506
+ with gr.Tab("Классификация аудио"):
507
+ gr.Markdown("## Audio Classification")
508
+ with gr.Row():
509
+ with gr.Column():
510
+ audio_input_component = gr.Audio(
511
+ label="Загрузите аудиофайл",
512
+ type="filepath",
513
+ )
514
+ audio_model_selector = gr.Dropdown(
515
+ choices=["audio_classifier", "emotion_classifier"],
516
+ label="Выберите модель",
517
+ value="audio_classifier",
518
+ info=(
519
+ "audio_classifier - общая классификация (AST), "
520
+ "emotion_classifier - эмоции в речи (HuBERT ER)"
521
+ ),
522
+ )
523
+ audio_classify_button = gr.Button("Классифицировать")
524
+ with gr.Column():
525
+ audio_output_component = gr.Textbox(
526
+ label="Результаты классификации",
527
+ lines=10,
528
+ )
529
+
530
+ audio_classify_button.click(
531
+ fn=classify_audio_file,
532
+ inputs=[audio_input_component, audio_model_selector],
533
+ outputs=audio_output_component,
534
+ )
535
+
536
+ with gr.Tab("Zero-Shot аудио (CLAP)"):
537
+ gr.Markdown("## Zero-Shot Audio Classification (CLAP)")
538
+ with gr.Row():
539
+ with gr.Column():
540
+ clap_audio_input_component = gr.Audio(
541
+ label="Загрузите аудиофайл",
542
+ type="filepath",
543
+ )
544
+ clap_label_texts_component = gr.Textbox(
545
+ label="Кандидатные метки (через запятую)",
546
+ placeholder="лай собаки, шум дождя, музыка, разговор",
547
+ lines=2,
548
+ )
549
+ clap_button = gr.Button("Классифицировать CLAP")
550
+ with gr.Column():
551
+ clap_output_component = gr.Textbox(
552
+ label="Результаты zero-shot классификации",
553
+ lines=10,
554
+ )
555
+
556
+ clap_button.click(
557
+ fn=classify_audio_zero_shot_clap,
558
+ inputs=[clap_audio_input_component, clap_label_texts_component],
559
+ outputs=clap_output_component,
560
+ )
561
+
562
+ with gr.Tab("Распознавание речи"):
563
+ gr.Markdown("## Automatic Speech Recognition (ASR)")
564
+ with gr.Row():
565
+ with gr.Column():
566
+ asr_audio_input_component = gr.Audio(
567
+ label="Загрузите аудио с речью",
568
+ type="filepath",
569
+ )
570
+ asr_model_selector = gr.Dropdown(
571
+ choices=["whisper", "wav2vec2"],
572
+ label="Выберите модель",
573
+ value="whisper",
574
+ info=(
575
+ "whisper - distil-whisper/distil-small.en (модель из курса, EN),\n"
576
+ "wav2vec2 - openai/whisper-small (альтернатива, мультиязычная)"
577
+ ),
578
+ )
579
+ asr_button = gr.Button("Транскрибировать")
580
+ with gr.Column():
581
+ asr_output_component = gr.Textbox(
582
+ label="Транскрипция",
583
+ lines=5,
584
+ )
585
+
586
+ asr_button.click(
587
+ fn=recognize_speech,
588
+ inputs=[asr_audio_input_component, asr_model_selector],
589
+ outputs=asr_output_component,
590
+ )
591
+
592
+ with gr.Tab("Синтез речи"):
593
+ gr.Markdown("## Text-to-Speech (TTS)")
594
+ with gr.Row():
595
+ with gr.Column():
596
+ tts_text_component = gr.Textbox(
597
+ label="Введите текст для синтеза",
598
+ placeholder="Введите текст на русском или английском языке...",
599
+ lines=3,
600
+ )
601
+ tts_model_selector = gr.Dropdown(
602
+ choices=["silero", "gtts", "mms"],
603
+ label="Выберите модель",
604
+ value="silero",
605
+ info=(
606
+ "silero - русскоязычный Silero TTS, "
607
+ "gtts - Google TTS (через gTTS), "
608
+ "mms - kakao-enterprise/vits-ljs (модель из курса, EN)"
609
+ ),
610
+ )
611
+ tts_button = gr.Button("Синтезировать речь")
612
+ with gr.Column():
613
+ tts_audio_output_component = gr.Audio(
614
+ label="Синтезированная речь",
615
+ )
616
+
617
+ tts_button.click(
618
+ fn=synthesize_speech,
619
+ inputs=tts_text_component,
620
+ outputs=tts_audio_output_component,
621
+ )
622
+
623
+ with gr.Tab("Детекция объектов"):
624
+ gr.Markdown("## Object Detection")
625
+ with gr.Row():
626
+ with gr.Column():
627
+ object_input_image = gr.Image(
628
+ label="Загрузите изображение",
629
+ type="pil",
630
+ )
631
+ object_model_selector = gr.Dropdown(
632
+ choices=[
633
+ "object_detection_conditional_detr",
634
+ "object_detection_yolos_small",
635
+ ],
636
+ label="Модель детекции",
637
+ value="object_detection_conditional_detr",
638
+ info=(
639
+ "object_detection_conditional_detr - microsoft/conditional-detr-resnet-50\n"
640
+ "object_detection_yolos_small - hustvl/yolos-small"
641
+ ),
642
+ )
643
+ object_detect_button = gr.Button("Обнаружить объекты")
644
+ with gr.Column():
645
+ object_output_image = gr.Image(
646
+ label="Результат детекции",
647
+ )
648
+
649
+ object_detect_button.click(
650
+ fn=detect_objects_on_image,
651
+ inputs=[object_input_image, object_model_selector],
652
+ outputs=object_output_image,
653
+ )
654
+
655
+ with gr.Tab("Сегментация"):
656
+ gr.Markdown("## Image Segmentation (SegFormer)")
657
+ with gr.Row():
658
+ with gr.Column():
659
+ segmentation_input_image = gr.Image(
660
+ label="Загрузите изображение",
661
+ type="pil",
662
+ )
663
+ segmentation_button = gr.Button("Сегментировать")
664
+ with gr.Column():
665
+ segmentation_output_image = gr.Image(
666
+ label="Маска сегментации",
667
+ )
668
+
669
+ segmentation_button.click(
670
+ fn=segment_image,
671
+ inputs=segmentation_input_image,
672
+ outputs=segmentation_output_image,
673
+ )
674
+
675
+ with gr.Tab("Глубина (Depth Estimation)"):
676
+ gr.Markdown("## Depth Estimation (DPT)")
677
+ with gr.Row():
678
+ with gr.Column():
679
+ depth_input_image = gr.Image(
680
+ label="Загрузите изображение",
681
+ type="pil",
682
+ )
683
+ depth_button = gr.Button("Оценить глубину")
684
+ with gr.Column():
685
+ depth_output_image = gr.Image(
686
+ label="Карта глубины",
687
+ )
688
+
689
+ depth_button.click(
690
+ fn=estimate_image_depth,
691
+ inputs=depth_input_image,
692
+ outputs=depth_output_image,
693
+ )
694
+
695
+ with gr.Tab("Интерактивная сегментация (SAM)"):
696
+ gr.Markdown("## Interactive Segmentation (SlimSAM)")
697
+ gr.Markdown(
698
+ "Укажите несколько точек в формате `x,y; x,y; ...`. "
699
+ "Каждая точка считается foreground-подсказкой."
700
+ )
701
+ with gr.Row():
702
+ with gr.Column():
703
+ sam_input_image = gr.Image(
704
+ label="Загрузите изображение",
705
+ type="pil",
706
+ )
707
+ sam_coordinates_text = gr.Textbox(
708
+ label="Координаты точек",
709
+ placeholder="100,150; 200,220",
710
+ lines=2,
711
+ )
712
+ sam_button = gr.Button("Сегментировать по точкам")
713
+ with gr.Column():
714
+ sam_output_image = gr.Image(
715
+ label="Бинарная маска (SAM)",
716
+ )
717
+
718
+ sam_button.click(
719
+ fn=segment_image_with_sam_points_ui,
720
+ inputs=[sam_input_image, sam_coordinates_text],
721
+ outputs=sam_output_image,
722
+ )
723
+
724
+ with gr.Tab("Описание изображений"):
725
+ gr.Markdown("## Image Captioning")
726
+ with gr.Row():
727
+ with gr.Column():
728
+ caption_input_image = gr.Image(
729
+ label="Загрузите изображение",
730
+ type="pil",
731
+ )
732
+ caption_model_selector = gr.Dropdown(
733
+ choices=[
734
+ "captioning_blip_base",
735
+ "captioning_blip_large",
736
+ ],
737
+ label="Модель captioning",
738
+ value="captioning_blip_base",
739
+ info=(
740
+ "captioning_blip_base - Salesforce/blip-image-captioning-base (курс)\n"
741
+ "captioning_blip_large - Salesforce/blip-image-captioning-large (альтернатива)"
742
+ ),
743
+ )
744
+ caption_button = gr.Button("Сгенерировать описание")
745
+ with gr.Column():
746
+ caption_output_text = gr.Textbox(
747
+ label="Описание изображения",
748
+ lines=3,
749
+ )
750
+
751
+ caption_button.click(
752
+ fn=generate_image_caption,
753
+ inputs=[caption_input_image, caption_model_selector],
754
+ outputs=caption_output_text,
755
+ )
756
+
757
+ with gr.Tab("Визуальные вопросы"):
758
+ gr.Markdown("## Visual Question Answering")
759
+ with gr.Row():
760
+ with gr.Column():
761
+ vqa_input_image = gr.Image(
762
+ label="Загрузите изображение",
763
+ type="pil",
764
+ )
765
+ vqa_question_text = gr.Textbox(
766
+ label="Вопрос об изображении",
767
+ placeholder="Что происходит на этом изображении?",
768
+ lines=2,
769
+ )
770
+ vqa_model_selector = gr.Dropdown(
771
+ choices=[
772
+ "vqa_blip_base",
773
+ "vqa_vilt_b32",
774
+ ],
775
+ label="Модель VQA",
776
+ value="vqa_blip_base",
777
+ info=(
778
+ "vqa_blip_base - Salesforce/blip-vqa-base (курс)\n"
779
+ "vqa_vilt_b32 - dandelin/vilt-b32-finetuned-vqa (альтернатива)"
780
+ ),
781
+ )
782
+ vqa_button = gr.Button("Ответить на вопрос")
783
+ with gr.Column():
784
+ vqa_output_text = gr.Textbox(
785
+ label="Ответ",
786
+ lines=3,
787
+ )
788
+
789
+ vqa_button.click(
790
+ fn=answer_visual_question,
791
+ inputs=[vqa_input_image, vqa_question_text, vqa_model_selector],
792
+ outputs=vqa_output_text,
793
+ )
794
+
795
+ with gr.Tab("Zero-Shot классификация"):
796
+ gr.Markdown("## Zero-Shot Image Classification")
797
+ with gr.Row():
798
+ with gr.Column():
799
+ zero_shot_input_image = gr.Image(
800
+ label="Загрузите изображение",
801
+ type="pil",
802
+ )
803
+ zero_shot_classes_text = gr.Textbox(
804
+ label="Классы для классификации (через запятую)",
805
+ placeholder="человек, машина, дерево, здание, животное",
806
+ lines=2,
807
+ )
808
+ clip_model_selector = gr.Dropdown(
809
+ choices=[
810
+ "clip_large_patch14",
811
+ "clip_base_patch32",
812
+ ],
813
+ label="CLIP модель",
814
+ value="clip_large_patch14",
815
+ info=(
816
+ "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
817
+ "clip_base_patch32 - openai/clip-vit-base-patch32 (альтернатива)"
818
+ ),
819
+ )
820
+ zero_shot_button = gr.Button("Классифицировать")
821
+ with gr.Column():
822
+ zero_shot_output_text = gr.Textbox(
823
+ label="Результаты классификации",
824
+ lines=10,
825
+ )
826
+
827
+ zero_shot_button.click(
828
+ fn=perform_zero_shot_classification,
829
+ inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector],
830
+ outputs=zero_shot_output_text,
831
+ )
832
+
833
+ with gr.Tab("Поиск изображений"):
834
+ gr.Markdown("## Image Retrieval")
835
+ with gr.Row():
836
+ with gr.Column():
837
+ retrieval_gallery = gr.Gallery(
838
+ label="Загрузите изображения для поиска",
839
+ type="pil",
840
+ )
841
+ retrieval_query_text = gr.Textbox(
842
+ label="Текстовый запрос",
843
+ placeholder="описание того, что вы ищете...",
844
+ lines=2,
845
+ )
846
+ retrieval_clip_selector = gr.Dropdown(
847
+ choices=[
848
+ "clip_large_patch14",
849
+ "clip_base_patch32",
850
+ ],
851
+ label="CLIP модель",
852
+ value="clip_large_patch14",
853
+ info=(
854
+ "clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
855
+ "clip_base_patch32 - openai/clip-vit-base-patch32 (альтернатива)"
856
+ ),
857
+ )
858
+ retrieval_button = gr.Button("Найти изображение")
859
+ with gr.Column():
860
+ retrieval_output_text = gr.Textbox(
861
+ label="Результат поиска",
862
+ )
863
+ retrieval_output_image = gr.Image(
864
+ label="Найденное изображение",
865
+ )
866
+
867
+ retrieval_button.click(
868
+ fn=retrieve_best_image,
869
+ inputs=[retrieval_gallery, retrieval_query_text, retrieval_clip_selector],
870
+ outputs=[retrieval_output_text, retrieval_output_image],
871
+ )
872
+
873
+ gr.Markdown("---")
874
+ gr.Markdown("### Задачи:")
875
+ gr.Markdown(
876
+ """
877
+ - Аудио: классификация (supervised и zero-shot через CLAP), распознавание речи, синтез речи
878
+ - Компьютерное зрение: детекция объектов, семантическая сегментация (SegFormer), оценка глубины (DPT), интерактивная сегментация по точкам (SlimSAM), генерация описаний изображений
879
+ - Мультимодальные задачи: визуальные вопросы (VQA), zero-shot классификация изображений, поиск по изображениям по текстовому запросу
880
+ """
881
+ )
882
+ return demo_block
883
+
884
+
885
+ if __name__ == "__main__":
886
+ interface_block = build_interface()
887
+ interface_block.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ torchvision>=0.16.0
3
+ torchaudio>=2.1.0
4
+ numpy>=1.24.0
5
+
6
+ transformers>=4.41.0
7
+ accelerate>=0.30.0
8
+ datasets>=2.18.0
9
+ timm>=0.9.0
10
+
11
+ soundfile>=0.12.1
12
+ librosa>=0.10.0
13
+
14
+ gradio>=4.0.0
15
+
16
+ Pillow>=9.5.0
17
+ gTTS>=2.5.1