Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import requests | |
| from datetime import datetime | |
| import os | |
| HF_KEY = os.getenv('DATASET_KEY') | |
| # Initialize session state variables | |
| if 'search_history' not in st.session_state: | |
| st.session_state['search_history'] = [] | |
| if 'search_columns' not in st.session_state: | |
| st.session_state['search_columns'] = [] | |
| if 'initial_search_done' not in st.session_state: | |
| st.session_state['initial_search_done'] = False | |
| if 'hf_token' not in st.session_state: | |
| st.session_state['hf_token'] = HF_KEY | |
| def fetch_dataset_info_auth(dataset_id, hf_token): | |
| """Fetch dataset information with authentication""" | |
| info_url = f"https://huggingface.co/api/datasets/{dataset_id}" | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| try: | |
| response = requests.get(info_url, headers=headers, timeout=30) | |
| if response.status_code == 200: | |
| return response.json() | |
| except Exception as e: | |
| st.warning(f"Error fetching dataset info: {e}") | |
| return None | |
| def fetch_dataset_splits_auth(dataset_id, hf_token): | |
| """Fetch available splits for the dataset""" | |
| splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}" | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| try: | |
| response = requests.get(splits_url, headers=headers, timeout=30) | |
| if response.status_code == 200: | |
| return response.json().get('splits', []) | |
| except Exception as e: | |
| st.warning(f"Error fetching splits: {e}") | |
| return [] | |
| def fetch_parquet_urls_auth(dataset_id, config, split, hf_token): | |
| """Fetch Parquet file URLs for a specific split""" | |
| parquet_url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/{config}/{split}" | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| try: | |
| response = requests.get(parquet_url, headers=headers, timeout=30) | |
| if response.status_code == 200: | |
| return response.json() | |
| except Exception as e: | |
| st.warning(f"Error fetching parquet URLs: {e}") | |
| return [] | |
| def fetch_rows_auth(dataset_id, config, split, offset, length, hf_token): | |
| """Fetch rows with authentication""" | |
| url = f"https://datasets-server.huggingface.co/rows?dataset={dataset_id}&config={config}&split={split}&offset={offset}&length={length}" | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| try: | |
| response = requests.get(url, headers=headers, timeout=30) | |
| if response.status_code == 200: | |
| return response.json() | |
| except Exception as e: | |
| st.warning(f"Error fetching rows: {e}") | |
| return None | |
| class ParquetVideoSearch: | |
| def __init__(self, hf_token): | |
| self.text_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.dataset_id = "tomg-group-umd/cinepile" | |
| self.config = "v2" | |
| self.hf_token = hf_token | |
| self.load_dataset() | |
| def load_dataset(self): | |
| """Load initial dataset sample""" | |
| try: | |
| rows_data = fetch_rows_auth( | |
| self.dataset_id, | |
| self.config, | |
| "train", | |
| 0, | |
| 100, | |
| self.hf_token | |
| ) | |
| if rows_data and 'rows' in rows_data: | |
| processed_rows = [] | |
| for row_data in rows_data['rows']: | |
| row = row_data.get('row', row_data) | |
| processed_rows.append(row) | |
| self.dataset = pd.DataFrame(processed_rows) | |
| st.session_state['search_columns'] = [col for col in self.dataset.columns | |
| if not any(term in col.lower() for term in ['embed', 'vector', 'encoding'])] | |
| else: | |
| self.dataset = self.load_example_data() | |
| except Exception as e: | |
| st.warning(f"Error loading dataset: {e}") | |
| self.dataset = self.load_example_data() | |
| self.prepare_features() | |
| def load_example_data(self): | |
| """Load example data as fallback""" | |
| return pd.DataFrame([{ | |
| "video_id": "example", | |
| "title": "Example Video", | |
| "description": "Example video content", | |
| "duration": 120, | |
| "start_time": 0, | |
| "end_time": 120 | |
| }]) | |
| def prepare_features(self): | |
| """Prepare text features for search""" | |
| try: | |
| # Combine relevant text fields for search | |
| text_fields = ['title', 'description'] if 'title' in self.dataset.columns else ['description'] | |
| combined_text = self.dataset[text_fields].fillna('').agg(' '.join, axis=1) | |
| self.text_embeds = self.text_model.encode(combined_text.tolist()) | |
| except Exception as e: | |
| st.warning(f"Error preparing features: {e}") | |
| self.text_embeds = np.random.randn(len(self.dataset), 384) | |
| def search(self, query, column=None, top_k=20): | |
| """Search using text embeddings and optional column filtering""" | |
| query_embedding = self.text_model.encode([query])[0] | |
| similarities = cosine_similarity([query_embedding], self.text_embeds)[0] | |
| # Column filtering | |
| if column and column in self.dataset.columns and column != "All Fields": | |
| mask = self.dataset[column].astype(str).str.contains(query, case=False) | |
| similarities[~mask] *= 0.5 | |
| top_k = min(top_k, len(similarities)) | |
| top_indices = np.argsort(similarities)[-top_k:][::-1] | |
| results = [] | |
| for idx in top_indices: | |
| result = { | |
| 'relevance_score': float(similarities[idx]), | |
| **self.dataset.iloc[idx].to_dict() | |
| } | |
| results.append(result) | |
| return results | |
| def render_video_result(result): | |
| """Render a video result with enhanced display""" | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| if 'title' in result: | |
| st.markdown(f"**Title:** {result['title']}") | |
| st.markdown("**Description:**") | |
| st.write(result.get('description', 'No description available')) | |
| # Show timing information | |
| start_time = result.get('start_time', 0) | |
| end_time = result.get('end_time', result.get('duration', 0)) | |
| st.markdown(f"**Time Range:** {start_time}s - {end_time}s") | |
| # Show additional metadata | |
| for key, value in result.items(): | |
| if key not in ['title', 'description', 'start_time', 'end_time', 'duration', | |
| 'relevance_score', 'video_id', '_config', '_split']: | |
| st.markdown(f"**{key.replace('_', ' ').title()}:** {value}") | |
| with col2: | |
| st.markdown(f"**Relevance Score:** {result['relevance_score']:.2%}") | |
| # Display video if URL is available | |
| video_url = None | |
| if 'video_url' in result: | |
| video_url = result['video_url'] | |
| elif 'youtube_id' in result: | |
| video_url = f"https://youtube.com/watch?v={result['youtube_id']}&t={start_time}" | |
| if video_url: | |
| st.video(video_url) | |
| def main(): | |
| st.title("π₯ Video Dataset Search") | |
| # Get HF token from secrets or user input | |
| if not st.session_state['hf_token']: | |
| st.session_state['hf_token'] = st.secrets.get("HF_TOKEN", None) | |
| if not st.session_state['hf_token']: | |
| hf_token = st.text_input("Enter your Hugging Face API token:", type="password") | |
| if hf_token: | |
| st.session_state['hf_token'] = hf_token | |
| if not st.session_state.get('hf_token'): | |
| st.warning("Please provide a Hugging Face API token to access the dataset.") | |
| return | |
| # Initialize search class | |
| search = ParquetVideoSearch(st.session_state['hf_token']) | |
| # Create tabs | |
| tab1, tab2 = st.tabs(["π Video Search", "π Dataset Info"]) | |
| # ---- Tab 1: Video Search ---- | |
| with tab1: | |
| st.subheader("Search Videos") | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| query = st.text_input("Enter your search query:", | |
| value="" if st.session_state['initial_search_done'] else "") | |
| with col2: | |
| search_column = st.selectbox("Search in field:", | |
| ["All Fields"] + st.session_state['search_columns']) | |
| col3, col4 = st.columns(2) | |
| with col3: | |
| num_results = st.slider("Number of results:", 1, 100, 20) | |
| with col4: | |
| search_button = st.button("π Search") | |
| if search_button and query: | |
| st.session_state['initial_search_done'] = True | |
| selected_column = None if search_column == "All Fields" else search_column | |
| with st.spinner("Searching..."): | |
| results = search.search(query, selected_column, num_results) | |
| st.session_state['search_history'].append({ | |
| 'query': query, | |
| 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| 'results': results[:5] | |
| }) | |
| for i, result in enumerate(results, 1): | |
| with st.expander( | |
| f"Result {i}: {result.get('title', result.get('description', 'No title'))[:100]}...", | |
| expanded=(i==1) | |
| ): | |
| render_video_result(result) | |
| # ---- Tab 2: Dataset Info ---- | |
| with tab2: | |
| st.subheader("Dataset Information") | |
| # Show available splits | |
| splits = fetch_dataset_splits_auth(search.dataset_id, st.session_state['hf_token']) | |
| if splits: | |
| st.write("### Available Splits") | |
| for split in splits: | |
| st.write(f"- {split['split']}: {split.get('num_rows', 'unknown')} rows") | |
| # Show dataset statistics | |
| st.write("### Dataset Statistics") | |
| st.write(f"- Loaded rows: {len(search.dataset)}") | |
| st.write(f"- Available columns: {', '.join(search.dataset.columns)}") | |
| # Show sample data | |
| st.write("### Sample Data") | |
| st.dataframe(search.dataset.head()) | |
| # Sidebar | |
| with st.sidebar: | |
| st.subheader("βοΈ Search History") | |
| if st.button("ποΈ Clear History"): | |
| st.session_state['search_history'] = [] | |
| st.experimental_rerun() | |
| st.markdown("### Recent Searches") | |
| for entry in reversed(st.session_state['search_history'][-5:]): | |
| with st.expander(f"{entry['timestamp']}: {entry['query']}"): | |
| for i, result in enumerate(entry['results'], 1): | |
| st.write(f"{i}. {result.get('title', result.get('description', 'No title'))[:100]}...") | |
| if __name__ == "__main__": | |
| main() |