Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| from pathlib import Path | |
| from typing import Iterable, List | |
| import gradio as gr | |
| import kagglehub | |
| from gradio_logsview.logsview import Log, LogsView, LogsViewRunner | |
| from huggingface_hub import HfApi | |
| KAGGLE_JSON = os.environ.get("KAGGLE_JSON") | |
| KAGGLE_JSON_PATH = Path("~/.kaggle/kaggle.json").expanduser().resolve() | |
| if KAGGLE_JSON_PATH.exists(): | |
| print(f"Found existing kaggle.json file at {KAGGLE_JSON_PATH}") | |
| elif KAGGLE_JSON is not None: | |
| print( | |
| "KAGGLE_JSON is set as secret. Will be able to be authenticated when downloading files from Kaggle." | |
| ) | |
| KAGGLE_JSON_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| KAGGLE_JSON_PATH.write_text(KAGGLE_JSON) | |
| else: | |
| print( | |
| f"No kaggle.json file found at {KAGGLE_JSON_PATH}. You will not be able to download private/gated files from Kaggle." | |
| ) | |
| MARKDOWN_DESCRIPTION = """ | |
| # Keggla-importer GUI | |
| The fastest way to import a model from KaggleHub to the Hugging Face Hub π₯ | |
| Specify a Kaggle handle and a Hugging Face Write Token to import a model from KaggleHub to the Hugging Face Hub. | |
| To find the Kaggle handle from a web UI, click on the "download dropdown" and copy the handle from the code snippet. | |
| Example: `"keras/gemma/keras/gemma_instruct_2b_en"`. | |
| """ | |
| if KAGGLE_JSON_PATH.exists(): | |
| MARKDOWN_DESCRIPTION += """ | |
| **Note**: a `kaggle.json` file exists in the home directory. This means the Space will be able to download **SOME** private/gated files from Kaggle. | |
| To access other models, please duplicate this Space to a private Space and set the `KAGGLE_JSON` environment variable with the content of the `kaggle.json` | |
| you've downloaded from your Kaggle user account. | |
| """ | |
| def import_model( | |
| kaggle_model: str, repo_name: str, token: gr.OAuthToken | None | |
| ) -> Iterable[List[Log]]: | |
| runner = LogsViewRunner() | |
| if not kaggle_model: | |
| yield runner.log("Kaggle model is required.", level="ERROR") | |
| raise gr.Error("Kaggle model is required.") | |
| if not repo_name: | |
| repo_name = kaggle_model.split("/")[-1] | |
| if not token: | |
| yield runner.log("You must sign in with HF before proceeding.", level="ERROR") | |
| raise gr.Error("Authentication is required.") | |
| api = HfApi(token=token.token) | |
| yield runner.log(f"Creating HF repo {repo_name}") | |
| repo_url = api.create_repo(repo_name, exist_ok=True) | |
| yield runner.log(f"Created HF repo: {repo_url}") | |
| repo_id = repo_url.repo_id | |
| model_id = api.model_info(repo_id) | |
| if len(model_id.siblings) > 1: | |
| yield runner.log( | |
| f"Model repo {repo_id} is not empty. Please delete it or set a different repo name.", | |
| level="ERROR", | |
| ) | |
| return | |
| yield runner.log(f"Downloading model {kaggle_model} from Kaggle.") | |
| yield from runner.run_python(kagglehub.model_download, handle=kaggle_model) | |
| if runner.exit_code != 0: | |
| yield runner.log("Failed to download model from Kaggle.", level="ERROR") | |
| api.delete_repo(repo_id=repo_id) | |
| return | |
| cache_path = kagglehub.model_download(kaggle_model) # should be instant | |
| yield runner.log(f"Model successfully downloaded from Kaggle to {cache_path}.") | |
| yield runner.log(f"Uploading model to HF repo {repo_id}.") | |
| yield from runner.run_python( | |
| api.upload_folder, repo_id=repo_id, folder_path=cache_path | |
| ) | |
| if runner.exit_code != 0: | |
| yield runner.log("Failed to upload model to HF repo.", level="ERROR") | |
| api.delete_repo(repo_id=repo_id) | |
| return | |
| yield runner.log(f"Model successfully uploaded to HF: {repo_url}.") | |
| yield runner.log(f"Deleting local cache from {cache_path}.") | |
| shutil.rmtree(cache_path) | |
| yield runner.log("Done!") | |
| with gr.Blocks() as demo: | |
| gr.Markdown(MARKDOWN_DESCRIPTION) | |
| with gr.Row(): | |
| kaggle_model = gr.Textbox( | |
| lines=1, | |
| label="Kaggle Model*", | |
| placeholder="keras/codegemma/keras/code_gemma_7b_en", | |
| ) | |
| repo_name = gr.Textbox( | |
| lines=1, | |
| label="Repo name", | |
| placeholder="Optional. Will infer from Kaggle Model if empty.", | |
| ) | |
| gr.LoginButton(min_width=250) | |
| button = gr.Button("Import", variant="primary") | |
| logs = LogsView(label="Terminal output") | |
| button.click(fn=import_model, inputs=[kaggle_model, repo_name], outputs=[logs]) | |
| demo.queue(default_concurrency_limit=1).launch() | |