Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .github/workflows/check_version.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Check Version

# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
on: # Trigger the workflow on push or pull request, but only for the main branch
push:
branches: [main]
pull_request:
branches: [main]
types: [opened, reopened, ready_for_review, synchronize]

defaults:
run:
shell: bash

jobs:
DeepTensor_Version:
runs-on: ubuntu-latest

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35

steps:
- uses: actions/checkout@v4
# we don't need to clone recursively.
# As we only need to check if two values are same.
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Check if python version matches in pyproject.toml and src/deeptensor/__version__.py
run: |
pip install toml
python scripts/check_version.py
65 changes: 65 additions & 0 deletions scripts/check_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

import ast
import os
import sys
from pathlib import Path

import toml


def get_pyproject_version(pyproject_path):
try:
with Path.open(pyproject_path) as file:
pyproject_data = toml.load(file)
return pyproject_data.get("project", {}).get("version")
except Exception as e:
print(f"Error reading {pyproject_path}: {e}") # noqa: T201
return None


def get_version_file_version(version_file_path):
try:
with Path.open(version_file_path, "r") as file:
file_content = file.read()
# Parse the file and extract version
tree = ast.parse(file_content, filename=version_file_path)
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if (
isinstance(target, ast.Name)
and target.id == "version"
and isinstance(node.value, ast.Constant)
): # For Python 3.8+
return node.value.value
print(f"Version not found in {version_file_path}") # noqa: T201
return None
except Exception as e:
print(f"Error reading {version_file_path}: {e}") # noqa: T201
return None


def main():
pyproject_path = "pyproject.toml"
version_file_path = os.path.join("src", "deeptensor", "__version__.py") # noqa: PTH118

pyproject_version = get_pyproject_version(pyproject_path)
version_file_version = get_version_file_version(version_file_path)

if pyproject_version is None or version_file_version is None:
print("Error: Unable to fetch version(s).") # noqa: T201
sys.exit(1)

if pyproject_version == version_file_version:
print("Version check passed!") # noqa: T201
sys.exit(0)
else:
print( # noqa: T201
f"Version mismatch: pyproject.toml ({pyproject_version}) != __version__.py ({version_file_version})"
)
sys.exit(1)


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion src/deeptensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from ._core import (
from .__version__ import version
from ._core import ( # type: ignore # noqa: PGH003
SGD,
AdaGrad,
Adam,
Expand Down Expand Up @@ -40,4 +41,5 @@
"__doc__",
"cross_entropy",
"mean_squared_error",
"version",
]
3 changes: 3 additions & 0 deletions src/deeptensor/__version__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import annotations

version = "0.3.0"
Loading