diff --git a/Frontend/package-lock.json b/Frontend/package-lock.json
index 5ac2281..c6e36bd 100644
--- a/Frontend/package-lock.json
+++ b/Frontend/package-lock.json
@@ -8,7 +8,6 @@
"name": "my-react-app",
"version": "0.1.0",
"dependencies": {
- "@mui/material": "^7.1.1",
"@testing-library/dom": "^10.4.0",
"@testing-library/jest-dom": "^6.6.3",
"@testing-library/react": "^16.2.0",
@@ -2356,68 +2355,6 @@
"postcss-selector-parser": "^6.0.10"
}
},
- "node_modules/@emotion/cache": {
- "version": "11.14.0",
- "resolved": "https://registry.npmjs.org/@emotion/cache/-/cache-11.14.0.tgz",
- "integrity": "sha512-L/B1lc/TViYk4DcpGxtAVbx0ZyiKM5ktoIyafGkH6zg/tj+mA+NE//aPYKG0k8kCHSHVJrpLpcAlOBEXQ3SavA==",
- "license": "MIT",
- "dependencies": {
- "@emotion/memoize": "^0.9.0",
- "@emotion/sheet": "^1.4.0",
- "@emotion/utils": "^1.4.2",
- "@emotion/weak-memoize": "^0.4.0",
- "stylis": "4.2.0"
- }
- },
- "node_modules/@emotion/hash": {
- "version": "0.9.2",
- "resolved": "https://registry.npmjs.org/@emotion/hash/-/hash-0.9.2.tgz",
- "integrity": "sha512-MyqliTZGuOm3+5ZRSaaBGP3USLw6+EGykkwZns2EPC5g8jJ4z9OrdZY9apkl3+UP9+sdz76YYkwCKP5gh8iY3g==",
- "license": "MIT"
- },
- "node_modules/@emotion/memoize": {
- "version": "0.9.0",
- "resolved": "https://registry.npmjs.org/@emotion/memoize/-/memoize-0.9.0.tgz",
- "integrity": "sha512-30FAj7/EoJ5mwVPOWhAyCX+FPfMDrVecJAM+Iw9NRoSl4BBAQeqj4cApHHUXOVvIPgLVDsCFoz/hGD+5QQD1GQ==",
- "license": "MIT"
- },
- "node_modules/@emotion/serialize": {
- "version": "1.3.3",
- "resolved": "https://registry.npmjs.org/@emotion/serialize/-/serialize-1.3.3.tgz",
- "integrity": "sha512-EISGqt7sSNWHGI76hC7x1CksiXPahbxEOrC5RjmFRJTqLyEK9/9hZvBbiYn70dw4wuwMKiEMCUlR6ZXTSWQqxA==",
- "license": "MIT",
- "dependencies": {
- "@emotion/hash": "^0.9.2",
- "@emotion/memoize": "^0.9.0",
- "@emotion/unitless": "^0.10.0",
- "@emotion/utils": "^1.4.2",
- "csstype": "^3.0.2"
- }
- },
- "node_modules/@emotion/sheet": {
- "version": "1.4.0",
- "resolved": "https://registry.npmjs.org/@emotion/sheet/-/sheet-1.4.0.tgz",
- "integrity": "sha512-fTBW9/8r2w3dXWYM4HCB1Rdp8NLibOw2+XELH5m5+AkWiL/KqYX6dc0kKYlaYyKjrQ6ds33MCdMPEwgs2z1rqg==",
- "license": "MIT"
- },
- "node_modules/@emotion/unitless": {
- "version": "0.10.0",
- "resolved": "https://registry.npmjs.org/@emotion/unitless/-/unitless-0.10.0.tgz",
- "integrity": "sha512-dFoMUuQA20zvtVTuxZww6OHoJYgrzfKM1t52mVySDJnMSEa08ruEvdYQbhvyu6soU+NeLVd3yKfTfT0NeV6qGg==",
- "license": "MIT"
- },
- "node_modules/@emotion/utils": {
- "version": "1.4.2",
- "resolved": "https://registry.npmjs.org/@emotion/utils/-/utils-1.4.2.tgz",
- "integrity": "sha512-3vLclRofFziIa3J2wDh9jjbkUz9qk5Vi3IZ/FSTKViB0k+ef0fPV7dYrUIugbgupYDx7v9ud/SjrtEP8Y4xLoA==",
- "license": "MIT"
- },
- "node_modules/@emotion/weak-memoize": {
- "version": "0.4.0",
- "resolved": "https://registry.npmjs.org/@emotion/weak-memoize/-/weak-memoize-0.4.0.tgz",
- "integrity": "sha512-snKqtPW01tN0ui7yu9rGv69aJXr/a/Ywvl11sUjNtEcRc+ng/mQriFL0wLXMef74iHa/EkftbDzU9F8iFbH+zg==",
- "license": "MIT"
- },
"node_modules/@eslint-community/eslint-utils": {
"version": "4.7.0",
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz",
@@ -3030,213 +2967,6 @@
"integrity": "sha512-Vo+PSpZG2/fmgmiNzYK9qWRh8h/CHrwD0mo1h1DzL4yzHNSfWYujGTYsWGreD000gcgmZ7K4Ys6Tx9TxtsKdDw==",
"license": "MIT"
},
- "node_modules/@mui/core-downloads-tracker": {
- "version": "7.1.1",
- "resolved": "https://registry.npmjs.org/@mui/core-downloads-tracker/-/core-downloads-tracker-7.1.1.tgz",
- "integrity": "sha512-yBckQs4aQ8mqukLnPC6ivIRv6guhaXi8snVl00VtyojBbm+l6VbVhyTSZ68Abcx7Ah8B+GZhrB7BOli+e+9LkQ==",
- "license": "MIT",
- "funding": {
- "type": "opencollective",
- "url": "https://opencollective.com/mui-org"
- }
- },
- "node_modules/@mui/material": {
- "version": "7.1.1",
- "resolved": "https://registry.npmjs.org/@mui/material/-/material-7.1.1.tgz",
- "integrity": "sha512-mTpdmdZCaHCGOH3SrYM41+XKvNL0iQfM9KlYgpSjgadXx/fEKhhvOktxm8++Xw6FFeOHoOiV+lzOI8X1rsv71A==",
- "license": "MIT",
- "dependencies": {
- "@babel/runtime": "^7.27.1",
- "@mui/core-downloads-tracker": "^7.1.1",
- "@mui/system": "^7.1.1",
- "@mui/types": "^7.4.3",
- "@mui/utils": "^7.1.1",
- "@popperjs/core": "^2.11.8",
- "@types/react-transition-group": "^4.4.12",
- "clsx": "^2.1.1",
- "csstype": "^3.1.3",
- "prop-types": "^15.8.1",
- "react-is": "^19.1.0",
- "react-transition-group": "^4.4.5"
- },
- "engines": {
- "node": ">=14.0.0"
- },
- "funding": {
- "type": "opencollective",
- "url": "https://opencollective.com/mui-org"
- },
- "peerDependencies": {
- "@emotion/react": "^11.5.0",
- "@emotion/styled": "^11.3.0",
- "@mui/material-pigment-css": "^7.1.1",
- "@types/react": "^17.0.0 || ^18.0.0 || ^19.0.0",
- "react": "^17.0.0 || ^18.0.0 || ^19.0.0",
- "react-dom": "^17.0.0 || ^18.0.0 || ^19.0.0"
- },
- "peerDependenciesMeta": {
- "@emotion/react": {
- "optional": true
- },
- "@emotion/styled": {
- "optional": true
- },
- "@mui/material-pigment-css": {
- "optional": true
- },
- "@types/react": {
- "optional": true
- }
- }
- },
- "node_modules/@mui/private-theming": {
- "version": "7.1.1",
- "resolved": "https://registry.npmjs.org/@mui/private-theming/-/private-theming-7.1.1.tgz",
- "integrity": "sha512-M8NbLUx+armk2ZuaxBkkMk11ultnWmrPlN0Xe3jUEaBChg/mcxa5HWIWS1EE4DF36WRACaAHVAvyekWlDQf0PQ==",
- "license": "MIT",
- "dependencies": {
- "@babel/runtime": "^7.27.1",
- "@mui/utils": "^7.1.1",
- "prop-types": "^15.8.1"
- },
- "engines": {
- "node": ">=14.0.0"
- },
- "funding": {
- "type": "opencollective",
- "url": "https://opencollective.com/mui-org"
- },
- "peerDependencies": {
- "@types/react": "^17.0.0 || ^18.0.0 || ^19.0.0",
- "react": "^17.0.0 || ^18.0.0 || ^19.0.0"
- },
- "peerDependenciesMeta": {
- "@types/react": {
- "optional": true
- }
- }
- },
- "node_modules/@mui/styled-engine": {
- "version": "7.1.1",
- "resolved": "https://registry.npmjs.org/@mui/styled-engine/-/styled-engine-7.1.1.tgz",
- "integrity": "sha512-R2wpzmSN127j26HrCPYVQ53vvMcT5DaKLoWkrfwUYq3cYytL6TQrCH8JBH3z79B6g4nMZZVoaXrxO757AlShaw==",
- "license": "MIT",
- "dependencies": {
- "@babel/runtime": "^7.27.1",
- "@emotion/cache": "^11.13.5",
- "@emotion/serialize": "^1.3.3",
- "@emotion/sheet": "^1.4.0",
- "csstype": "^3.1.3",
- "prop-types": "^15.8.1"
- },
- "engines": {
- "node": ">=14.0.0"
- },
- "funding": {
- "type": "opencollective",
- "url": "https://opencollective.com/mui-org"
- },
- "peerDependencies": {
- "@emotion/react": "^11.4.1",
- "@emotion/styled": "^11.3.0",
- "react": "^17.0.0 || ^18.0.0 || ^19.0.0"
- },
- "peerDependenciesMeta": {
- "@emotion/react": {
- "optional": true
- },
- "@emotion/styled": {
- "optional": true
- }
- }
- },
- "node_modules/@mui/system": {
- "version": "7.1.1",
- "resolved": "https://registry.npmjs.org/@mui/system/-/system-7.1.1.tgz",
- "integrity": "sha512-Kj1uhiqnj4Zo7PDjAOghtXJtNABunWvhcRU0O7RQJ7WOxeynoH6wXPcilphV8QTFtkKaip8EiNJRiCD+B3eROA==",
- "license": "MIT",
- "dependencies": {
- "@babel/runtime": "^7.27.1",
- "@mui/private-theming": "^7.1.1",
- "@mui/styled-engine": "^7.1.1",
- "@mui/types": "^7.4.3",
- "@mui/utils": "^7.1.1",
- "clsx": "^2.1.1",
- "csstype": "^3.1.3",
- "prop-types": "^15.8.1"
- },
- "engines": {
- "node": ">=14.0.0"
- },
- "funding": {
- "type": "opencollective",
- "url": "https://opencollective.com/mui-org"
- },
- "peerDependencies": {
- "@emotion/react": "^11.5.0",
- "@emotion/styled": "^11.3.0",
- "@types/react": "^17.0.0 || ^18.0.0 || ^19.0.0",
- "react": "^17.0.0 || ^18.0.0 || ^19.0.0"
- },
- "peerDependenciesMeta": {
- "@emotion/react": {
- "optional": true
- },
- "@emotion/styled": {
- "optional": true
- },
- "@types/react": {
- "optional": true
- }
- }
- },
- "node_modules/@mui/types": {
- "version": "7.4.3",
- "resolved": "https://registry.npmjs.org/@mui/types/-/types-7.4.3.tgz",
- "integrity": "sha512-2UCEiK29vtiZTeLdS2d4GndBKacVyxGvReznGXGr+CzW/YhjIX+OHUdCIczZjzcRAgKBGmE9zCIgoV9FleuyRQ==",
- "license": "MIT",
- "dependencies": {
- "@babel/runtime": "^7.27.1"
- },
- "peerDependencies": {
- "@types/react": "^17.0.0 || ^18.0.0 || ^19.0.0"
- },
- "peerDependenciesMeta": {
- "@types/react": {
- "optional": true
- }
- }
- },
- "node_modules/@mui/utils": {
- "version": "7.1.1",
- "resolved": "https://registry.npmjs.org/@mui/utils/-/utils-7.1.1.tgz",
- "integrity": "sha512-BkOt2q7MBYl7pweY2JWwfrlahhp+uGLR8S+EhiyRaofeRYUWL2YKbSGQvN4hgSN1i8poN0PaUiii1kEMrchvzg==",
- "license": "MIT",
- "dependencies": {
- "@babel/runtime": "^7.27.1",
- "@mui/types": "^7.4.3",
- "@types/prop-types": "^15.7.14",
- "clsx": "^2.1.1",
- "prop-types": "^15.8.1",
- "react-is": "^19.1.0"
- },
- "engines": {
- "node": ">=14.0.0"
- },
- "funding": {
- "type": "opencollective",
- "url": "https://opencollective.com/mui-org"
- },
- "peerDependencies": {
- "@types/react": "^17.0.0 || ^18.0.0 || ^19.0.0",
- "react": "^17.0.0 || ^18.0.0 || ^19.0.0"
- },
- "peerDependenciesMeta": {
- "@types/react": {
- "optional": true
- }
- }
- },
"node_modules/@nicolo-ribaudo/eslint-scope-5-internals": {
"version": "5.1.1-v1",
"resolved": "https://registry.npmjs.org/@nicolo-ribaudo/eslint-scope-5-internals/-/eslint-scope-5-internals-5.1.1-v1.tgz",
@@ -3361,16 +3091,6 @@
}
}
},
- "node_modules/@popperjs/core": {
- "version": "2.11.8",
- "resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.8.tgz",
- "integrity": "sha512-P1st0aksCrn9sGZhp8GMYwBnQsbvAWsZAX44oXNNvLHGqAOcoVxmjZiohstwQ7SqKnbR47akdNi+uleWD8+g6A==",
- "license": "MIT",
- "funding": {
- "type": "opencollective",
- "url": "https://opencollective.com/popperjs"
- }
- },
"node_modules/@rollup/plugin-babel": {
"version": "5.3.1",
"resolved": "https://registry.npmjs.org/@rollup/plugin-babel/-/plugin-babel-5.3.1.tgz",
@@ -4087,12 +3807,6 @@
"integrity": "sha512-+68kP9yzs4LMp7VNh8gdzMSPZFL44MLGqiHWvttYJe+6qnuVr4Ek9wSBQoveqY/r+LwjCcU29kNVkidwim+kYA==",
"license": "MIT"
},
- "node_modules/@types/prop-types": {
- "version": "15.7.15",
- "resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.15.tgz",
- "integrity": "sha512-F6bEyamV9jKGAFBEmlQnesRPGOQqS2+Uwi0Em15xenOxHaf2hv6L8YCVn3rPdPJOiJfPiCnLIRyvwVaqMY3MIw==",
- "license": "MIT"
- },
"node_modules/@types/q": {
"version": "1.5.8",
"resolved": "https://registry.npmjs.org/@types/q/-/q-1.5.8.tgz",
@@ -4116,20 +3830,12 @@
"resolved": "https://registry.npmjs.org/@types/react/-/react-19.1.8.tgz",
"integrity": "sha512-AwAfQ2Wa5bCx9WP8nZL2uMZWod7J7/JSplxbTmBQ5ms6QpqNYm672H0Vu9ZVKVngQ+ii4R/byguVEUZQyeg44g==",
"license": "MIT",
+ "optional": true,
"peer": true,
"dependencies": {
"csstype": "^3.0.2"
}
},
- "node_modules/@types/react-transition-group": {
- "version": "4.4.12",
- "resolved": "https://registry.npmjs.org/@types/react-transition-group/-/react-transition-group-4.4.12.tgz",
- "integrity": "sha512-8TV6R3h2j7a91c+1DXdJi3Syo69zzIZbz7Lg5tORM5LEJG7X/E6a1V3drRyBRZq7/utz7A+c4OgYLiLcYGHG6w==",
- "license": "MIT",
- "peerDependencies": {
- "@types/react": "*"
- }
- },
"node_modules/@types/resolve": {
"version": "1.17.1",
"resolved": "https://registry.npmjs.org/@types/resolve/-/resolve-1.17.1.tgz",
@@ -5934,15 +5640,6 @@
"wrap-ansi": "^7.0.0"
}
},
- "node_modules/clsx": {
- "version": "2.1.1",
- "resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz",
- "integrity": "sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==",
- "license": "MIT",
- "engines": {
- "node": ">=6"
- }
- },
"node_modules/co": {
"version": "4.6.0",
"resolved": "https://registry.npmjs.org/co/-/co-4.6.0.tgz",
@@ -6677,7 +6374,9 @@
"version": "3.1.3",
"resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz",
"integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==",
- "license": "MIT"
+ "license": "MIT",
+ "optional": true,
+ "peer": true
},
"node_modules/damerau-levenshtein": {
"version": "1.0.8",
@@ -7005,16 +6704,6 @@
"utila": "~0.4"
}
},
- "node_modules/dom-helpers": {
- "version": "5.2.1",
- "resolved": "https://registry.npmjs.org/dom-helpers/-/dom-helpers-5.2.1.tgz",
- "integrity": "sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA==",
- "license": "MIT",
- "dependencies": {
- "@babel/runtime": "^7.8.7",
- "csstype": "^3.0.2"
- }
- },
"node_modules/dom-serializer": {
"version": "1.4.1",
"resolved": "https://registry.npmjs.org/dom-serializer/-/dom-serializer-1.4.1.tgz",
@@ -14268,12 +13957,6 @@
"integrity": "sha512-SN/U6Ytxf1QGkw/9ve5Y+NxBbZM6Ht95tuXNMKs8EJyFa/Vy/+Co3stop3KBHARfn/giv+Lj1uUnTfOJ3moFEQ==",
"license": "MIT"
},
- "node_modules/react-is": {
- "version": "19.1.0",
- "resolved": "https://registry.npmjs.org/react-is/-/react-is-19.1.0.tgz",
- "integrity": "sha512-Oe56aUPnkHyyDxxkvqtd7KkdQP5uIUfHxd5XTb3wE9d/kRnZLmKbDB0GWk919tdQ+mxxPtG6EAs6RMT6i1qtHg==",
- "license": "MIT"
- },
"node_modules/react-refresh": {
"version": "0.11.0",
"resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.11.0.tgz",
@@ -14394,22 +14077,6 @@
}
}
},
- "node_modules/react-transition-group": {
- "version": "4.4.5",
- "resolved": "https://registry.npmjs.org/react-transition-group/-/react-transition-group-4.4.5.tgz",
- "integrity": "sha512-pZcd1MCJoiKiBR2NRxeCRg13uCXbydPnmB4EOeRrY7480qNWO8IIgQG6zlDkm6uRMsURXPuKq0GWtiM59a5Q6g==",
- "license": "BSD-3-Clause",
- "dependencies": {
- "@babel/runtime": "^7.5.5",
- "dom-helpers": "^5.0.1",
- "loose-envify": "^1.4.0",
- "prop-types": "^15.6.2"
- },
- "peerDependencies": {
- "react": ">=16.6.0",
- "react-dom": ">=16.6.0"
- }
- },
"node_modules/read-cache": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz",
@@ -15999,12 +15666,6 @@
"postcss": "^8.2.15"
}
},
- "node_modules/stylis": {
- "version": "4.2.0",
- "resolved": "https://registry.npmjs.org/stylis/-/stylis-4.2.0.tgz",
- "integrity": "sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw==",
- "license": "MIT"
- },
"node_modules/sucrase": {
"version": "3.35.0",
"resolved": "https://registry.npmjs.org/sucrase/-/sucrase-3.35.0.tgz",
@@ -16776,9 +16437,9 @@
}
},
"node_modules/typescript": {
- "version": "5.8.3",
- "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.8.3.tgz",
- "integrity": "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==",
+ "version": "4.9.5",
+ "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz",
+ "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==",
"license": "Apache-2.0",
"peer": true,
"bin": {
@@ -16786,7 +16447,7 @@
"tsserver": "bin/tsserver"
},
"engines": {
- "node": ">=14.17"
+ "node": ">=4.2.0"
}
},
"node_modules/unbox-primitive": {
diff --git a/Frontend/src/App.js b/Frontend/src/App.js
index d7e6b33..b227e82 100644
--- a/Frontend/src/App.js
+++ b/Frontend/src/App.js
@@ -3,12 +3,11 @@ import React, { useState, useEffect } from 'react';
import { BrowserRouter as Router, Routes, Route, useNavigate } from 'react-router-dom';
import Navbar from './components/Navbar';
import LandingPage from './pages/AppLanding';
+import ProviderSelectionPage from './pages/ProviderSelectionPage';
import DetectHardwarePage from './pages/DetectHardwarePage';
import FinetuneSettings from './pages/FinetuningSettingsPage';
import Loading from './pages/Loading';
import TechnicalDetailsPage from './pages/TechnicalDetailsPage';
-import ListModels from './pages/ListModels';
-import './index.css';
import ListAllModels from "./pages/ListAllModels";
const RedirectToFastAPI = () => {
@@ -21,6 +20,7 @@ const RedirectToFastAPI = () => {
function App() {
const [finetuneSettings, setFinetuneSettings] = useState({
+ provider: 'huggingface', // Default provider
task: 'text-generation',
model_name: 'llama2-7b',
compute_specs: 'Standard GPU',
@@ -88,10 +88,10 @@ function App() {
} />
}
@@ -105,6 +105,15 @@ function App() {
/>
}
/>
+
+ }
+ />
{/* }
diff --git a/Frontend/src/pages/AppLanding.jsx b/Frontend/src/pages/AppLanding.jsx
index 30ed26b..f181b8d 100644
--- a/Frontend/src/pages/AppLanding.jsx
+++ b/Frontend/src/pages/AppLanding.jsx
@@ -89,7 +89,7 @@ const LandingPage = ({ appName = "ModelForge" }) => {
-
+
Start Building Your AI
+
+
+
+ {formState.provider || defaultValues.provider || 'huggingface'}
+
+
diff --git a/Frontend/src/pages/ProviderSelectionPage.jsx b/Frontend/src/pages/ProviderSelectionPage.jsx
new file mode 100644
index 0000000..d084463
--- /dev/null
+++ b/Frontend/src/pages/ProviderSelectionPage.jsx
@@ -0,0 +1,251 @@
+import React, { useState, useEffect } from 'react';
+import { useNavigate } from 'react-router-dom';
+import { config } from '../services/api';
+
+const ProviderSelectionPage = ({ currentSettings, updateSettings }) => {
+ const navigate = useNavigate();
+ const [providers, setProviders] = useState([]);
+ const [loadingProviders, setLoadingProviders] = useState(true);
+ const [selectedProvider, setSelectedProvider] = useState(currentSettings.provider || 'huggingface');
+ const [hoveredProvider, setHoveredProvider] = useState(null);
+
+ useEffect(() => {
+ const fetchProviders = async () => {
+ try {
+ const response = await fetch(`${config.baseURL}/finetune/providers`);
+ if (!response.ok) throw new Error('Failed to fetch providers');
+
+ const data = await response.json();
+ console.log("Fetched providers:", data);
+ setProviders(data.providers || []);
+ setLoadingProviders(false);
+ } catch (err) {
+ console.error("Error fetching providers:", err);
+ setLoadingProviders(false);
+ }
+ };
+
+ fetchProviders();
+ }, []);
+
+ const handleProviderSelect = (providerName) => {
+ setSelectedProvider(providerName);
+ };
+
+ const handleContinue = () => {
+ // Update settings with selected provider
+ if (updateSettings) {
+ updateSettings({ provider: selectedProvider });
+ }
+
+ // Navigate to task selection (hardware detection page)
+ navigate('/finetune/detect');
+ };
+
+ const providerDetails = {
+ huggingface: {
+ logo: 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg',
+ features: [
+ 'Standard HuggingFace transformers',
+ 'PEFT/LoRA fine-tuning',
+ '4-bit/8-bit quantization',
+ 'Maximum compatibility',
+ 'Established ecosystem'
+ ],
+ performance: 'Baseline speed & memory',
+ bestFor: 'General use, maximum compatibility'
+ },
+ unsloth: {
+ logo: 'https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20white%20text.png',
+ features: [
+ 'Optimized FastLanguageModel',
+ 'Enhanced LoRA/QLoRA',
+ 'Memory-efficient attention',
+ '~2x faster training',
+ '~30% less memory usage'
+ ],
+ performance: '2x faster, 30% less memory',
+ bestFor: 'Faster training, limited hardware',
+ installCommand: 'pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"'
+ }
+ };
+
+ return (
+
+
+ {/* Header */}
+
+
+ Choose Your Fine-tuning Provider
+
+
+ Select the backend that best suits your needs
+
+
+
+ {/* Loading State */}
+ {loadingProviders && (
+
+ )}
+
+ {/* Provider Cards */}
+ {!loadingProviders && (
+
+ {providers.map((provider) => {
+ const isSelected = selectedProvider === provider.name;
+ const isAvailable = provider.available;
+ const details = providerDetails[provider.name] || {};
+
+ return (
+
isAvailable && handleProviderSelect(provider.name)}
+ onMouseEnter={() => setHoveredProvider(provider.name)}
+ onMouseLeave={() => setHoveredProvider(null)}
+ className={`
+ relative rounded-xl p-6 border-2 transition-all duration-300 transform
+ ${isAvailable ? 'cursor-pointer hover:scale-105' : 'cursor-not-allowed opacity-60'}
+ ${isSelected && isAvailable
+ ? 'border-orange-500 bg-gradient-to-br from-orange-500/20 to-transparent shadow-xl shadow-orange-500/20'
+ : 'border-gray-700 bg-gray-800/50 hover:border-gray-600'
+ }
+ `}
+ >
+ {/* Selected Badge */}
+ {isSelected && isAvailable && (
+
+ Selected
+
+ )}
+
+ {/* Not Installed Badge */}
+ {!isAvailable && (
+
+ Not Installed
+
+ )}
+
+ {/* Provider Header */}
+
+
+ {details.logo ? (
+

{
+ e.target.style.display = 'none';
+ e.target.nextSibling.style.display = 'block';
+ }}
+ />
+ ) : null}
+
+
+
+
+
+
{provider.name}
+
{provider.description}
+
+
+
+ {/* Performance Badge */}
+ {details.performance && (
+
+ ⚡ {details.performance}
+
+ )}
+
+ {/* Features List */}
+ {details.features && (
+
+
Features:
+
+ {details.features.map((feature, idx) => (
+ -
+
+
+
+ {feature}
+
+ ))}
+
+
+ )}
+
+ {/* Best For */}
+ {details.bestFor && (
+
+
+ Best for: {details.bestFor}
+
+
+ )}
+
+ {/* Installation Instructions */}
+ {!isAvailable && details.installCommand && (
+
+
⚠️ Installation Required:
+
+ {details.installCommand}
+
+
+ )}
+
+ );
+ })}
+
+ )}
+
+ {/* Continue Button */}
+
+
+
+
+ {/* Info Footer */}
+
+
You can change the provider later in the advanced settings
+
+
+
+ );
+};
+
+export default ProviderSelectionPage;
diff --git a/ModelForge/routers/finetuning_router.py b/ModelForge/routers/finetuning_router.py
index 4161b2f..6cd6eac 100644
--- a/ModelForge/routers/finetuning_router.py
+++ b/ModelForge/routers/finetuning_router.py
@@ -13,6 +13,8 @@
from ..utilities.finetuning.CausalLLMTuner import CausalLLMFinetuner
from ..utilities.finetuning.QuestionAnsweringTuner import QuestionAnsweringTuner
from ..utilities.finetuning.Seq2SeqLMTuner import Seq2SeqFinetuner
+from ..utilities.finetuning.provider_adapter import ProviderFinetuner
+from ..utilities.finetuning.providers import ProviderRegistry
from ..utilities.hardware_detection.model_validator import ModelValidator
from ..globals.globals_instance import global_manager
@@ -25,6 +27,12 @@
VALID_TASKS = ["text-generation", "summarization", "extractive-question-answering"]
VALID_TASKS_STR = "'text-generation', 'summarization', or 'extractive-question-answering'"
+# Valid providers (dynamically loaded from registry)
+def get_valid_providers():
+ """Get list of available provider names."""
+ providers = ProviderRegistry.list_available()
+ return [p["name"] for p in providers]
+
## Pydantic Data Validator Classes
class TaskFormData(BaseModel):
task: str
@@ -53,6 +61,7 @@ def validate_repo_name(cls, repo_name):
class SettingsFormData(BaseModel):
task: str
model_name: str
+ provider: str = "huggingface" # Default to huggingface for backward compatibility
num_train_epochs: int
compute_specs: str
lora_r: int
@@ -96,6 +105,15 @@ def validate_model_name(cls, model_name):
if not model_name:
raise ValueError("Model name cannot be empty.")
return model_name
+ @field_validator("provider")
+ def validate_provider(cls, provider):
+ valid_providers = get_valid_providers()
+ if not valid_providers:
+ # If no providers available, default to huggingface
+ return "huggingface"
+ if provider not in valid_providers:
+ raise ValueError(f"Invalid provider. Must be one of {', '.join(valid_providers)}.")
+ return provider
@field_validator("num_train_epochs")
def validate_num_train_epochs(cls, num_train_epochs):
if num_train_epochs <= 0:
@@ -248,6 +266,31 @@ def validate_batch_size_with_compute_specs(self):
return self
+@router.get("/providers")
+async def list_providers(request: Request) -> JSONResponse:
+ """
+ List all available fine-tuning providers.
+
+ Returns provider information including availability status.
+ """
+ try:
+ all_providers = ProviderRegistry.list_all()
+ available_providers = ProviderRegistry.list_available()
+
+ return JSONResponse({
+ "status_code": 200,
+ "providers": all_providers,
+ "available": available_providers,
+ "default": "huggingface"
+ })
+ except Exception as e:
+ print(f"Error listing providers: {e}")
+ raise HTTPException(
+ status_code=500,
+ detail="Error retrieving provider information"
+ )
+
+
@router.get("/detect")
async def detect_hardware_page(request: Request) -> JSONResponse:
global_manager.clear_settings_cache() # Clear the cache to ensure fresh detection
@@ -451,15 +494,20 @@ def finetuning_task(llm_tuner) -> None:
# Use the path returned from finetune (should be absolute)
model_path = os.path.abspath(path) if not os.path.isabs(path) else path
+ # Get provider name (either from ProviderFinetuner or default to huggingface)
+ provider_name = getattr(llm_tuner, 'provider_name', 'huggingface')
+
model_data = {
"model_name": global_manager.settings_builder.fine_tuned_name.split('/')[-1] if global_manager.settings_builder.fine_tuned_name else os.path.basename(model_path),
"base_model": global_manager.settings_builder.model_name,
"task": global_manager.settings_builder.task,
"description": f"Fine-tuned {global_manager.settings_builder.model_name} for {global_manager.settings_builder.task}" +
- (" (Custom Model)" if global_manager.settings_builder.is_custom_model else " (Recommended Model)"),
+ (" (Custom Model)" if global_manager.settings_builder.is_custom_model else " (Recommended Model)") +
+ f" using {provider_name}",
"creation_date": datetime.now().isoformat(),
"model_path": model_path,
- "is_custom_base_model": global_manager.settings_builder.is_custom_model
+ "is_custom_base_model": global_manager.settings_builder.is_custom_model,
+ "provider": provider_name
}
global_manager.db_manager.add_model(model_data)
@@ -501,6 +549,10 @@ async def start_finetuning_page(request: Request, background_task: BackgroundTas
print(f"Starting finetuning with CUSTOM MODEL: {global_manager.settings_builder.model_name}")
else:
print(f"Starting finetuning with RECOMMENDED MODEL: {global_manager.settings_builder.model_name}")
+
+ # Log provider selection
+ provider = global_manager.settings_builder.provider
+ print(f"Using provider: {provider}")
if not global_manager.settings_cache:
raise HTTPException(
@@ -528,29 +580,43 @@ async def start_finetuning_page(request: Request, background_task: BackgroundTas
global_manager.finetuning_status["status"] = "initializing"
global_manager.finetuning_status["message"] = "Starting finetuning process..."
- if global_manager.settings_builder.task == "text-generation":
- llm_tuner = CausalLLMFinetuner(
- model_name=global_manager.settings_builder.model_name,
- compute_specs=global_manager.settings_builder.compute_profile,
- pipeline_task="text-generation"
- )
- elif global_manager.settings_builder.task == "summarization":
- llm_tuner = Seq2SeqFinetuner(
- model_name=global_manager.settings_builder.model_name,
- compute_specs=global_manager.settings_builder.compute_profile,
- pipeline_task="summarization"
- )
- elif global_manager.settings_builder.task == "extractive-question-answering":
- llm_tuner = QuestionAnsweringTuner(
+
+ # Use provider adapter if a provider is explicitly specified
+ # Otherwise fall back to legacy tuners for backward compatibility
+ if provider and provider != "huggingface":
+ # Use provider adapter for non-HuggingFace providers
+ llm_tuner = ProviderFinetuner(
model_name=global_manager.settings_builder.model_name,
- compute_specs=global_manager.settings_builder.compute_profile,
- pipeline_task="question-answering"
+ task=global_manager.settings_builder.task,
+ provider=provider,
+ compute_specs=global_manager.settings_builder.compute_profile
)
else:
- raise HTTPException(
- status_code=400,
- detail=f"Invalid task. Must be one of {VALID_TASKS_STR}."
- )
+ # Legacy path - use existing tuner classes for HuggingFace
+ # This maintains backward compatibility
+ if global_manager.settings_builder.task == "text-generation":
+ llm_tuner = CausalLLMFinetuner(
+ model_name=global_manager.settings_builder.model_name,
+ compute_specs=global_manager.settings_builder.compute_profile,
+ pipeline_task="text-generation"
+ )
+ elif global_manager.settings_builder.task == "summarization":
+ llm_tuner = Seq2SeqFinetuner(
+ model_name=global_manager.settings_builder.model_name,
+ compute_specs=global_manager.settings_builder.compute_profile,
+ pipeline_task="summarization"
+ )
+ elif global_manager.settings_builder.task == "extractive-question-answering":
+ llm_tuner = QuestionAnsweringTuner(
+ model_name=global_manager.settings_builder.model_name,
+ compute_specs=global_manager.settings_builder.compute_profile,
+ pipeline_task="question-answering"
+ )
+ else:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid task. Must be one of {VALID_TASKS_STR}."
+ )
llm_tuner.set_settings(**global_manager.settings_builder.get_settings())
diff --git a/ModelForge/utilities/finetuning/provider_adapter.py b/ModelForge/utilities/finetuning/provider_adapter.py
new file mode 100644
index 0000000..c2157c7
--- /dev/null
+++ b/ModelForge/utilities/finetuning/provider_adapter.py
@@ -0,0 +1,205 @@
+"""
+Adapter for integrating provider-based finetuning with existing infrastructure.
+
+This module provides a bridge between the new provider system and the
+existing Finetuner-based workflow to maintain backward compatibility.
+"""
+
+import json
+import os
+from typing import Optional
+from datasets import Dataset
+
+from .providers import get_provider, ProviderRegistry
+from ...globals.globals_instance import global_manager
+
+
+class ProviderFinetuner:
+ """
+ Adapter class that uses the provider system while maintaining
+ compatibility with the existing finetuning workflow.
+ """
+
+ def __init__(
+ self,
+ model_name: str,
+ task: str,
+ provider: str = "huggingface",
+ compute_specs: str = "low_end"
+ ) -> None:
+ """
+ Initialize provider-based finetuner.
+
+ Args:
+ model_name: Model identifier
+ task: Task type (text-generation, summarization, extractive-question-answering)
+ provider: Provider name (huggingface, unsloth)
+ compute_specs: Hardware profile
+ """
+ self.model_name = model_name
+ self.task = task
+ self.provider_name = provider
+ self.compute_specs = compute_specs
+
+ # Get provider instance
+ self.provider = get_provider(
+ provider_name=provider,
+ model_name=model_name,
+ task=task,
+ compute_specs=compute_specs
+ )
+
+ # Settings that will be set by set_settings()
+ self.output_dir: Optional[str] = None
+ self.fine_tuned_name: Optional[str] = None
+ self.dataset_path: Optional[str] = None
+
+ # Map task to pipeline task string
+ self.pipeline_task = self._map_task_to_pipeline(task)
+
+ def _map_task_to_pipeline(self, task: str) -> str:
+ """Map task to pipeline task string."""
+ task_map = {
+ "text-generation": "text-generation",
+ "summarization": "summarization",
+ "extractive-question-answering": "question-answering",
+ }
+ return task_map.get(task, task)
+
+ def set_settings(self, **kwargs) -> None:
+ """
+ Set training settings for the provider.
+
+ Args:
+ **kwargs: Settings dictionary
+ """
+ # Generate output paths using existing logic
+ from .Finetuner import Finetuner
+
+ uid = Finetuner.gen_uuid()
+ safe_model_name = self.model_name.replace('/', '-').replace('\\', '-')
+
+ # Use FileManager default directories
+ default_dirs = global_manager.file_manager.return_default_dirs()
+ self.fine_tuned_name = f"{default_dirs['models']}/{safe_model_name}_{uid}"
+ self.output_dir = f"{default_dirs['model_checkpoints']}/{safe_model_name}_{uid}"
+
+ # Add output paths to settings
+ kwargs["output_dir"] = self.output_dir
+ kwargs["fine_tuned_name"] = self.fine_tuned_name
+
+ # Set provider settings
+ self.provider.set_settings(**kwargs)
+
+ def load_dataset(self, dataset_path: str) -> None:
+ """
+ Load and prepare the dataset.
+
+ Args:
+ dataset_path: Path to dataset file
+ """
+ self.dataset_path = dataset_path
+ self.provider.prepare_dataset(dataset_path)
+
+ def finetune(self) -> bool | str:
+ """
+ Execute the fine-tuning process.
+
+ Returns:
+ Path to saved model if successful, False otherwise
+ """
+ try:
+ # Load model
+ self.provider.load_model(**self.provider.settings)
+
+ # Ensure dataset is loaded
+ if self.provider.dataset is None:
+ if self.dataset_path:
+ self.load_dataset(self.dataset_path)
+ else:
+ raise ValueError("Dataset must be loaded before training")
+
+ # Train
+ model_path = self.provider.train(**self.provider.settings)
+
+ # Build config file for playground compatibility
+ config_file_result = self._build_config_file(
+ model_path,
+ self.pipeline_task
+ )
+
+ if not config_file_result:
+ print("Warning: Failed to create config file. Model may not work in playground.")
+
+ # Report finish
+ self._report_finish()
+
+ return model_path
+
+ except Exception as e:
+ print(f"Fine-tuning failed: {e}")
+ self._report_finish(error=True, message=str(e))
+ return False
+
+ def _build_config_file(self, config_dir: str, pipeline_task: str) -> bool:
+ """
+ Build configuration file for the fine-tuned model.
+
+ Args:
+ config_dir: Directory to save config
+ pipeline_task: Pipeline task string
+
+ Returns:
+ True if successful
+ """
+ # Determine model class based on provider and task
+ model_class_map = {
+ "huggingface": {
+ "text-generation": "AutoPeftModelForCausalLM",
+ "summarization": "AutoPeftModelForSeq2SeqLM",
+ "question-answering": "AutoPeftModelForCausalLM",
+ },
+ "unsloth": {
+ "text-generation": "AutoPeftModelForCausalLM",
+ "summarization": "AutoPeftModelForSeq2SeqLM",
+ "question-answering": "AutoPeftModelForCausalLM",
+ }
+ }
+
+ model_class = model_class_map.get(self.provider_name, {}).get(
+ pipeline_task,
+ "AutoPeftModelForCausalLM"
+ )
+
+ try:
+ config_path = os.path.join(config_dir, "modelforge_config.json")
+ with open(config_path, "w") as f:
+ config = {
+ "model_class": model_class,
+ "pipeline_task": pipeline_task,
+ "provider": self.provider_name,
+ }
+ json.dump(config, f, indent=4)
+ print(f"Configuration file saved to {config_path}")
+ return True
+ except Exception as e:
+ print(f"Error saving configuration file: {e}")
+ return False
+
+ def _report_finish(self, error: bool = False, message: Optional[str] = None) -> None:
+ """
+ Report completion of fine-tuning.
+
+ Args:
+ error: True if an error occurred
+ message: Error message if applicable
+ """
+ print("*" * 100)
+ if not error:
+ print(f"Model fine-tuned successfully using {self.provider_name}!")
+ print(f"Model saved to {self.fine_tuned_name}")
+ print("Try out your new model in our chat playground!")
+ else:
+ print("Model fine-tuning failed!")
+ print(f"Error: {message}")
+ print("*" * 100)
diff --git a/ModelForge/utilities/finetuning/providers/__init__.py b/ModelForge/utilities/finetuning/providers/__init__.py
new file mode 100644
index 0000000..86496bb
--- /dev/null
+++ b/ModelForge/utilities/finetuning/providers/__init__.py
@@ -0,0 +1,23 @@
+"""
+Provider abstraction layer for fine-tuning backends.
+
+This module provides a unified interface for different fine-tuning providers
+(e.g., HuggingFace, Unsloth) to enable pluggable backend implementations.
+"""
+
+from .base_provider import FinetuningProvider
+from .huggingface_provider import HuggingFaceProvider
+from .unsloth_provider import UnslothProvider
+from .provider_registry import ProviderRegistry, get_provider
+
+# Register providers
+ProviderRegistry.register(HuggingFaceProvider)
+ProviderRegistry.register(UnslothProvider)
+
+__all__ = [
+ "FinetuningProvider",
+ "HuggingFaceProvider",
+ "UnslothProvider",
+ "ProviderRegistry",
+ "get_provider",
+]
diff --git a/ModelForge/utilities/finetuning/providers/base_provider.py b/ModelForge/utilities/finetuning/providers/base_provider.py
new file mode 100644
index 0000000..ec9975a
--- /dev/null
+++ b/ModelForge/utilities/finetuning/providers/base_provider.py
@@ -0,0 +1,173 @@
+"""
+Base provider interface for fine-tuning implementations.
+
+This module defines the abstract interface that all fine-tuning providers
+must implement to ensure consistent behavior across different backends.
+"""
+
+from abc import ABC, abstractmethod
+from typing import Dict, List, Optional, Any, Tuple
+from datasets import Dataset
+
+
+class FinetuningProvider(ABC):
+ """
+ Abstract base class defining the interface for fine-tuning providers.
+
+ All fine-tuning providers (HuggingFace, Unsloth, etc.) must implement
+ this interface to ensure compatibility with the ModelForge pipeline.
+ """
+
+ def __init__(
+ self,
+ model_name: str,
+ task: str,
+ compute_specs: str = "low_end"
+ ) -> None:
+ """
+ Initialize the fine-tuning provider.
+
+ Args:
+ model_name: Name or path of the model to fine-tune
+ task: Task type (text-generation, summarization, extractive-question-answering)
+ compute_specs: Hardware profile (low_end, mid_range, high_end)
+ """
+ self.model_name = model_name
+ self.task = task
+ self.compute_specs = compute_specs
+ self.model = None
+ self.tokenizer = None
+ self.dataset: Optional[Dataset] = None
+ self.settings: Dict[str, Any] = {}
+
+ @abstractmethod
+ def load_model(self, **kwargs) -> Tuple[Any, Any]:
+ """
+ Load the model and tokenizer with provider-specific configurations.
+
+ Args:
+ **kwargs: Provider-specific model loading parameters
+
+ Returns:
+ Tuple of (model, tokenizer)
+
+ Raises:
+ Exception: If model loading fails
+ """
+ pass
+
+ @abstractmethod
+ def prepare_dataset(self, dataset_path: str, **kwargs) -> Dataset:
+ """
+ Load and prepare the dataset for training.
+
+ Args:
+ dataset_path: Path to the dataset file
+ **kwargs: Provider-specific dataset preparation parameters
+
+ Returns:
+ Prepared dataset ready for training
+
+ Raises:
+ Exception: If dataset loading or preparation fails
+ """
+ pass
+
+ @abstractmethod
+ def train(self, **kwargs) -> str:
+ """
+ Execute the fine-tuning process.
+
+ Args:
+ **kwargs: Provider-specific training parameters
+
+ Returns:
+ Path to the saved fine-tuned model
+
+ Raises:
+ Exception: If training fails
+ """
+ pass
+
+ @abstractmethod
+ def export_model(self, output_path: str, **kwargs) -> bool:
+ """
+ Export the fine-tuned model to the specified path.
+
+ Args:
+ output_path: Directory to save the exported model
+ **kwargs: Provider-specific export parameters
+
+ Returns:
+ True if export succeeds, False otherwise
+ """
+ pass
+
+ @abstractmethod
+ def get_supported_hyperparameters(self) -> List[str]:
+ """
+ Return a list of hyperparameters supported by this provider.
+
+ Returns:
+ List of hyperparameter names
+ """
+ pass
+
+ @abstractmethod
+ def validate_settings(self, settings: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Validate and sanitize provider-specific settings.
+
+ Args:
+ settings: Dictionary of hyperparameters and configurations
+
+ Returns:
+ Validated settings dictionary
+
+ Raises:
+ ValueError: If settings are invalid
+ """
+ pass
+
+ @classmethod
+ @abstractmethod
+ def get_provider_name(cls) -> str:
+ """
+ Return the canonical name of this provider.
+
+ Returns:
+ Provider name (e.g., "huggingface", "unsloth")
+ """
+ pass
+
+ @classmethod
+ @abstractmethod
+ def get_provider_description(cls) -> str:
+ """
+ Return a human-readable description of this provider.
+
+ Returns:
+ Provider description
+ """
+ pass
+
+ @classmethod
+ @abstractmethod
+ def is_available(cls) -> bool:
+ """
+ Check if this provider's dependencies are installed and available.
+
+ Returns:
+ True if provider can be used, False otherwise
+ """
+ pass
+
+ def set_settings(self, **kwargs) -> None:
+ """
+ Set training settings from keyword arguments.
+
+ Args:
+ **kwargs: Settings to apply
+ """
+ validated_settings = self.validate_settings(kwargs)
+ self.settings.update(validated_settings)
diff --git a/ModelForge/utilities/finetuning/providers/huggingface_provider.py b/ModelForge/utilities/finetuning/providers/huggingface_provider.py
new file mode 100644
index 0000000..e19a0f9
--- /dev/null
+++ b/ModelForge/utilities/finetuning/providers/huggingface_provider.py
@@ -0,0 +1,386 @@
+"""
+HuggingFace provider implementation for fine-tuning.
+
+This module provides HuggingFace Transformers-based fine-tuning,
+maintaining backward compatibility with existing ModelForge workflows.
+"""
+
+import os
+from typing import Dict, List, Any, Tuple, Optional
+import torch
+from datasets import Dataset, load_dataset
+from transformers import (
+ AutoModelForCausalLM,
+ AutoModelForSeq2SeqLM,
+ AutoTokenizer,
+ BitsAndBytesConfig,
+ TrainerCallback,
+)
+from peft import LoraConfig, TaskType, get_peft_model
+from trl import SFTTrainer, SFTConfig
+
+from .base_provider import FinetuningProvider
+
+
+class HuggingFaceProgressCallback(TrainerCallback):
+ """
+ Callback to update global finetuning status during HuggingFace training.
+ """
+
+ def __init__(self):
+ super().__init__()
+ from ....globals.globals_instance import global_manager
+ self.global_manager = global_manager
+
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ """Called when logging happens during training."""
+ if state.max_steps <= 0:
+ return
+
+ progress = min(95, int((state.global_step / state.max_steps) * 100))
+ self.global_manager.finetuning_status["progress"] = progress
+ self.global_manager.finetuning_status["message"] = (
+ f"Training step {state.global_step}/{state.max_steps}"
+ )
+
+ def on_train_end(self, args, state, control, **kwargs):
+ """Called at the end of training."""
+ self.global_manager.finetuning_status["progress"] = 100
+ self.global_manager.finetuning_status["message"] = "Training completed!"
+
+
+class HuggingFaceProvider(FinetuningProvider):
+ """
+ HuggingFace Transformers provider for fine-tuning.
+
+ Implements the FinetuningProvider interface using HuggingFace's
+ transformers, peft, and trl libraries.
+ """
+
+ # Task type mappings for PEFT
+ TASK_TYPE_MAP = {
+ "text-generation": TaskType.CAUSAL_LM,
+ "summarization": TaskType.SEQ_2_SEQ_LM,
+ "extractive-question-answering": TaskType.QUESTION_ANS,
+ }
+
+ # Model class mappings
+ MODEL_CLASS_MAP = {
+ "text-generation": (AutoModelForCausalLM, "AutoPeftModelForCausalLM"),
+ "summarization": (AutoModelForSeq2SeqLM, "AutoPeftModelForSeq2SeqLM"),
+ "extractive-question-answering": (AutoModelForCausalLM, "AutoPeftModelForCausalLM"),
+ }
+
+ def __init__(
+ self,
+ model_name: str,
+ task: str,
+ compute_specs: str = "low_end"
+ ) -> None:
+ """
+ Initialize HuggingFace provider.
+
+ Args:
+ model_name: HuggingFace model identifier
+ task: Fine-tuning task type
+ compute_specs: Hardware profile
+ """
+ super().__init__(model_name, task, compute_specs)
+ self.peft_task_type = self.TASK_TYPE_MAP.get(task)
+ self.model_class, self.peft_class_name = self.MODEL_CLASS_MAP.get(task, (None, None))
+ self.output_dir: Optional[str] = None
+ self.fine_tuned_name: Optional[str] = None
+
+ def load_model(self, **kwargs) -> Tuple[Any, Any]:
+ """
+ Load HuggingFace model and tokenizer with quantization.
+
+ Args:
+ **kwargs: Settings including quantization config
+
+ Returns:
+ Tuple of (model, tokenizer)
+ """
+ if self.model_class is None:
+ raise ValueError(f"Unsupported task type: {self.task}")
+
+ # Prepare quantization config
+ bits_n_bytes_config = None
+ use_4bit = kwargs.get("use_4bit", False)
+ use_8bit = kwargs.get("use_8bit", False)
+
+ if use_4bit:
+ compute_dtype = getattr(torch, kwargs.get("bnb_4bit_compute_dtype", "float16"))
+ bits_n_bytes_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type=kwargs.get("bnb_4bit_quant_type", "nf4"),
+ bnb_4bit_compute_dtype=compute_dtype,
+ bnb_4bit_use_double_quant=kwargs.get("use_nested_quant", False),
+ )
+ elif use_8bit:
+ bits_n_bytes_config = BitsAndBytesConfig(
+ load_in_8bit=True,
+ )
+
+ # Load model
+ device_map = kwargs.get("device_map", {"": 0})
+
+ if bits_n_bytes_config:
+ model = self.model_class.from_pretrained(
+ self.model_name,
+ quantization_config=bits_n_bytes_config,
+ device_map=device_map,
+ use_cache=False,
+ )
+ else:
+ model = self.model_class.from_pretrained(
+ self.model_name,
+ device_map=device_map,
+ use_cache=False,
+ )
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ self.model_name,
+ trust_remote_code=True
+ )
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.padding_side = "right"
+
+ self.model = model
+ self.tokenizer = tokenizer
+
+ return model, tokenizer
+
+ def prepare_dataset(self, dataset_path: str, **kwargs) -> Dataset:
+ """
+ Load and format dataset for HuggingFace training.
+
+ Args:
+ dataset_path: Path to dataset file
+ **kwargs: Additional dataset preparation parameters
+
+ Returns:
+ Formatted dataset
+ """
+ dataset = load_dataset("json", data_files=dataset_path, split="train")
+
+ # Format based on task type
+ if self.task == "text-generation":
+ dataset = dataset.rename_column("input", "prompt")
+ dataset = dataset.rename_column("output", "completion")
+ dataset = dataset.map(self._format_text_generation_example)
+ elif self.task == "summarization":
+ keys = dataset.column_names
+ dataset = dataset.map(lambda x: self._format_summarization_example(x, keys))
+ dataset = dataset.remove_columns(keys)
+ elif self.task == "extractive-question-answering":
+ keys = dataset.column_names
+ dataset = dataset.map(lambda x: self._format_qa_example(x, keys))
+ dataset = dataset.remove_columns(keys)
+
+ self.dataset = dataset
+ return dataset
+
+ def _format_text_generation_example(self, example: dict) -> Dict[str, str]:
+ """Format example for text generation."""
+ return {
+ "prompt": "USER:" + example.get("prompt", ""),
+ "completion": "ASSISTANT: " + example.get("completion", "") + "<|endoftext|>"
+ }
+
+ def _format_summarization_example(self, example: dict, keys: List[str]) -> Dict[str, str]:
+ """Format example for summarization."""
+ if len(keys) < 2:
+ keys = ["article", "summary"]
+ return {
+ "text": f'''
+ ["role": "system", "content": "You are a text summarization assistant."],
+ ["role": "user", "content": {example[keys[0]]}],
+ ["role": "assistant", "content": {example[keys[1]]}]
+ '''
+ }
+
+ def _format_qa_example(self, example: dict, keys: List[str]) -> Dict[str, str]:
+ """Format example for question answering."""
+ if len(keys) < 3:
+ keys = ["context", "question", "answer"]
+ return {
+ "text": f'''
+ ["role": "system", "content": "You are a question answering assistant."],
+ ["role": "user", "content": "Context: {example[keys[0]]}\nQuestion: {example[keys[1]]}"],
+ ["role": "assistant", "content": {example[keys[2]]}]
+ '''
+ }
+
+ def train(self, **kwargs) -> str:
+ """
+ Execute HuggingFace fine-tuning with PEFT/LoRA.
+
+ Args:
+ **kwargs: Training configuration
+
+ Returns:
+ Path to saved model
+ """
+ # Ensure model and dataset are loaded
+ if self.model is None or self.tokenizer is None:
+ self.load_model(**kwargs)
+
+ if self.dataset is None:
+ raise ValueError("Dataset must be prepared before training")
+
+ # Configure LoRA
+ lora_config = LoraConfig(
+ lora_alpha=kwargs.get("lora_alpha", 32),
+ lora_dropout=kwargs.get("lora_dropout", 0.1),
+ r=kwargs.get("lora_r", 16),
+ bias="none",
+ task_type=self.peft_task_type,
+ target_modules='all-linear',
+ )
+
+ # Apply PEFT
+ model = get_peft_model(self.model, lora_config)
+
+ # Configure training
+ training_args = SFTConfig(
+ output_dir=self.output_dir or "./model_checkpoints",
+ num_train_epochs=kwargs.get("num_train_epochs", 1),
+ per_device_train_batch_size=kwargs.get("per_device_train_batch_size", 1),
+ gradient_accumulation_steps=kwargs.get("gradient_accumulation_steps", 4),
+ optim=kwargs.get("optim", "paged_adamw_32bit"),
+ save_steps=kwargs.get("save_steps", 0),
+ logging_steps=kwargs.get("logging_steps", 25),
+ learning_rate=kwargs.get("learning_rate", 2e-4),
+ warmup_ratio=kwargs.get("warmup_ratio", 0.03),
+ weight_decay=kwargs.get("weight_decay", 0.001),
+ fp16=kwargs.get("fp16", False),
+ bf16=kwargs.get("bf16", False),
+ max_grad_norm=kwargs.get("max_grad_norm", 0.3),
+ max_steps=kwargs.get("max_steps", -1),
+ group_by_length=kwargs.get("group_by_length", True),
+ lr_scheduler_type=kwargs.get("lr_scheduler_type", "cosine"),
+ report_to="tensorboard",
+ logging_dir="./training_logs",
+ max_length=None,
+ )
+
+ # Create trainer
+ trainer = SFTTrainer(
+ model=model,
+ train_dataset=self.dataset,
+ args=training_args,
+ callbacks=[HuggingFaceProgressCallback()],
+ )
+
+ # Train
+ trainer.train()
+
+ # Save model
+ save_path = self.fine_tuned_name or self.output_dir
+ trainer.model.save_pretrained(save_path)
+
+ return save_path
+
+ def export_model(self, output_path: str, **kwargs) -> bool:
+ """
+ Export the fine-tuned model (already saved during training).
+
+ Args:
+ output_path: Path to export the model
+ **kwargs: Additional export parameters
+
+ Returns:
+ True if successful
+ """
+ # HuggingFace models are already saved during training
+ # This method exists for interface compatibility
+ return True
+
+ def get_supported_hyperparameters(self) -> List[str]:
+ """
+ Return list of supported hyperparameters.
+
+ Returns:
+ List of hyperparameter names
+ """
+ return [
+ "num_train_epochs",
+ "lora_r",
+ "lora_alpha",
+ "lora_dropout",
+ "use_4bit",
+ "use_8bit",
+ "bnb_4bit_compute_dtype",
+ "bnb_4bit_quant_type",
+ "use_nested_quant",
+ "fp16",
+ "bf16",
+ "per_device_train_batch_size",
+ "per_device_eval_batch_size",
+ "gradient_accumulation_steps",
+ "gradient_checkpointing",
+ "max_grad_norm",
+ "learning_rate",
+ "weight_decay",
+ "optim",
+ "lr_scheduler_type",
+ "max_steps",
+ "warmup_ratio",
+ "group_by_length",
+ "packing",
+ "device_map",
+ "max_seq_length",
+ ]
+
+ def validate_settings(self, settings: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Validate HuggingFace-specific settings.
+
+ Args:
+ settings: Settings dictionary
+
+ Returns:
+ Validated settings
+ """
+ # Basic validation - detailed validation happens in the router
+ validated = {}
+
+ for key, value in settings.items():
+ if key in self.get_supported_hyperparameters():
+ validated[key] = value
+
+ # Extract output paths if present
+ if "output_dir" in settings:
+ self.output_dir = settings["output_dir"]
+ if "fine_tuned_name" in settings:
+ self.fine_tuned_name = settings["fine_tuned_name"]
+
+ return validated
+
+ @classmethod
+ def get_provider_name(cls) -> str:
+ """Return provider name."""
+ return "huggingface"
+
+ @classmethod
+ def get_provider_description(cls) -> str:
+ """Return provider description."""
+ return "HuggingFace Transformers with PEFT/LoRA fine-tuning"
+
+ @classmethod
+ def is_available(cls) -> bool:
+ """
+ Check if HuggingFace dependencies are available.
+
+ Returns:
+ True if available
+ """
+ try:
+ import transformers
+ import peft
+ import trl
+ return True
+ except ImportError:
+ return False
diff --git a/ModelForge/utilities/finetuning/providers/provider_registry.py b/ModelForge/utilities/finetuning/providers/provider_registry.py
new file mode 100644
index 0000000..ec56333
--- /dev/null
+++ b/ModelForge/utilities/finetuning/providers/provider_registry.py
@@ -0,0 +1,142 @@
+"""
+Provider registry for managing and selecting fine-tuning providers.
+
+This module provides a centralized registry for all available fine-tuning
+providers, enabling runtime provider selection and discovery.
+"""
+
+from typing import Dict, Type, List, Optional
+from .base_provider import FinetuningProvider
+
+
+class ProviderRegistry:
+ """
+ Central registry for managing fine-tuning provider implementations.
+ """
+
+ _providers: Dict[str, Type[FinetuningProvider]] = {}
+
+ @classmethod
+ def register(cls, provider_class: Type[FinetuningProvider]) -> None:
+ """
+ Register a fine-tuning provider.
+
+ Args:
+ provider_class: Provider class to register
+
+ Raises:
+ ValueError: If provider is already registered
+ """
+ provider_name = provider_class.get_provider_name()
+
+ if provider_name in cls._providers:
+ raise ValueError(f"Provider '{provider_name}' is already registered")
+
+ cls._providers[provider_name] = provider_class
+
+ @classmethod
+ def get(cls, provider_name: str) -> Optional[Type[FinetuningProvider]]:
+ """
+ Get a provider class by name.
+
+ Args:
+ provider_name: Name of the provider
+
+ Returns:
+ Provider class or None if not found
+ """
+ return cls._providers.get(provider_name)
+
+ @classmethod
+ def list_available(cls) -> List[Dict[str, str]]:
+ """
+ List all available (installed) providers.
+
+ Returns:
+ List of dictionaries containing provider information
+ """
+ available = []
+ for name, provider_class in cls._providers.items():
+ if provider_class.is_available():
+ available.append({
+ "name": name,
+ "description": provider_class.get_provider_description(),
+ })
+ return available
+
+ @classmethod
+ def list_all(cls) -> List[Dict[str, str]]:
+ """
+ List all registered providers (including unavailable ones).
+
+ Returns:
+ List of dictionaries containing provider information
+ """
+ all_providers = []
+ for name, provider_class in cls._providers.items():
+ all_providers.append({
+ "name": name,
+ "description": provider_class.get_provider_description(),
+ "available": provider_class.is_available(),
+ })
+ return all_providers
+
+ @classmethod
+ def is_available(cls, provider_name: str) -> bool:
+ """
+ Check if a provider is available for use.
+
+ Args:
+ provider_name: Name of the provider
+
+ Returns:
+ True if provider exists and is available, False otherwise
+ """
+ provider_class = cls.get(provider_name)
+ if provider_class is None:
+ return False
+ return provider_class.is_available()
+
+
+def get_provider(
+ provider_name: str,
+ model_name: str,
+ task: str,
+ compute_specs: str = "low_end"
+) -> FinetuningProvider:
+ """
+ Factory function to instantiate a provider.
+
+ Args:
+ provider_name: Name of the provider to instantiate
+ model_name: Model name or path
+ task: Fine-tuning task
+ compute_specs: Hardware profile
+
+ Returns:
+ Instantiated provider
+
+ Raises:
+ ValueError: If provider not found or not available
+ """
+ provider_class = ProviderRegistry.get(provider_name)
+
+ if provider_class is None:
+ available = ProviderRegistry.list_available()
+ available_names = [p["name"] for p in available]
+ raise ValueError(
+ f"Provider '{provider_name}' not found. "
+ f"Available providers: {', '.join(available_names)}"
+ )
+
+ if not provider_class.is_available():
+ raise ValueError(
+ f"Provider '{provider_name}' is registered but not available. "
+ f"Please ensure all required dependencies are installed."
+ )
+
+ return provider_class(
+ model_name=model_name,
+ task=task,
+ compute_specs=compute_specs
+ )
diff --git a/ModelForge/utilities/finetuning/providers/unsloth_provider.py b/ModelForge/utilities/finetuning/providers/unsloth_provider.py
new file mode 100644
index 0000000..2758e72
--- /dev/null
+++ b/ModelForge/utilities/finetuning/providers/unsloth_provider.py
@@ -0,0 +1,380 @@
+"""
+Unsloth AI provider implementation for fine-tuning.
+
+This module provides Unsloth-based fine-tuning with optimized
+memory usage and faster training speeds compared to standard HuggingFace.
+"""
+
+import os
+from typing import Dict, List, Any, Tuple, Optional
+from datasets import Dataset, load_dataset
+
+from .base_provider import FinetuningProvider
+
+
+class UnslothProvider(FinetuningProvider):
+ """
+ Unsloth AI provider for optimized fine-tuning.
+
+ Implements the FinetuningProvider interface using Unsloth's
+ optimized training infrastructure for faster and more memory-efficient
+ fine-tuning of large language models.
+ """
+
+ # Task type mappings for Unsloth
+ TASK_TYPE_MAP = {
+ "text-generation": "causal",
+ "summarization": "seq2seq",
+ "extractive-question-answering": "causal",
+ }
+
+ def __init__(
+ self,
+ model_name: str,
+ task: str,
+ compute_specs: str = "low_end"
+ ) -> None:
+ """
+ Initialize Unsloth provider.
+
+ Args:
+ model_name: Model identifier (HuggingFace format)
+ task: Fine-tuning task type
+ compute_specs: Hardware profile
+ """
+ super().__init__(model_name, task, compute_specs)
+ self.unsloth_task = self.TASK_TYPE_MAP.get(task)
+ self.output_dir: Optional[str] = None
+ self.fine_tuned_name: Optional[str] = None
+
+ def load_model(self, **kwargs) -> Tuple[Any, Any]:
+ """
+ Load model using Unsloth's FastLanguageModel.
+
+ Args:
+ **kwargs: Settings including max_seq_length, quantization, etc.
+
+ Returns:
+ Tuple of (model, tokenizer)
+ """
+ try:
+ from unsloth import FastLanguageModel
+ except ImportError:
+ raise ImportError(
+ "Unsloth is not installed. Install it with: "
+ "pip install unsloth"
+ )
+
+ # Determine quantization settings
+ load_in_4bit = kwargs.get("use_4bit", False) or kwargs.get("load_in_4bit", False)
+ load_in_8bit = kwargs.get("use_8bit", False) or kwargs.get("load_in_8bit", False)
+
+ # Unsloth-specific parameters
+ max_seq_length = kwargs.get("max_seq_length", 2048)
+ if max_seq_length == -1 or max_seq_length is None:
+ max_seq_length = 2048 # Unsloth default
+
+ dtype = None # Auto-detect
+ if kwargs.get("bf16", False):
+ import torch
+ dtype = torch.bfloat16
+ elif kwargs.get("fp16", False):
+ import torch
+ dtype = torch.float16
+
+ # Load model with Unsloth optimizations
+ model, tokenizer = FastLanguageModel.from_pretrained(
+ model_name=self.model_name,
+ max_seq_length=max_seq_length,
+ dtype=dtype,
+ load_in_4bit=load_in_4bit,
+ )
+
+ self.model = model
+ self.tokenizer = tokenizer
+
+ return model, tokenizer
+
+ def prepare_dataset(self, dataset_path: str, **kwargs) -> Dataset:
+ """
+ Load and format dataset for Unsloth training.
+
+ Args:
+ dataset_path: Path to dataset file
+ **kwargs: Additional dataset preparation parameters
+
+ Returns:
+ Formatted dataset
+ """
+ dataset = load_dataset("json", data_files=dataset_path, split="train")
+
+ # Format based on task type
+ if self.task == "text-generation":
+ dataset = dataset.rename_column("input", "prompt")
+ dataset = dataset.rename_column("output", "completion")
+ dataset = dataset.map(self._format_text_generation_example)
+ elif self.task == "summarization":
+ keys = dataset.column_names
+ dataset = dataset.map(lambda x: self._format_summarization_example(x, keys))
+ dataset = dataset.remove_columns(keys)
+ elif self.task == "extractive-question-answering":
+ keys = dataset.column_names
+ dataset = dataset.map(lambda x: self._format_qa_example(x, keys))
+ dataset = dataset.remove_columns(keys)
+
+ self.dataset = dataset
+ return dataset
+
+ def _format_text_generation_example(self, example: dict) -> Dict[str, str]:
+ """Format example for text generation with Unsloth."""
+ # Unsloth uses a more standard chat format
+ return {
+ "text": f"### User:\n{example.get('prompt', '')}\n\n### Assistant:\n{example.get('completion', '')}<|endoftext|>"
+ }
+
+ def _format_summarization_example(self, example: dict, keys: List[str]) -> Dict[str, str]:
+ """Format example for summarization with Unsloth."""
+ if len(keys) < 2:
+ keys = ["article", "summary"]
+ return {
+ "text": f"### Article:\n{example[keys[0]]}\n\n### Summary:\n{example[keys[1]]}<|endoftext|>"
+ }
+
+ def _format_qa_example(self, example: dict, keys: List[str]) -> Dict[str, str]:
+ """Format example for question answering with Unsloth."""
+ if len(keys) < 3:
+ keys = ["context", "question", "answer"]
+ return {
+ "text": f"### Context:\n{example[keys[0]]}\n\n### Question:\n{example[keys[1]]}\n\n### Answer:\n{example[keys[2]]}<|endoftext|>"
+ }
+
+ def train(self, **kwargs) -> str:
+ """
+ Execute Unsloth fine-tuning with optimized LoRA/QLoRA.
+
+ Args:
+ **kwargs: Training configuration
+
+ Returns:
+ Path to saved model
+ """
+ try:
+ from unsloth import FastLanguageModel
+ from trl import SFTTrainer
+ from transformers import TrainingArguments, TrainerCallback
+ except ImportError as e:
+ raise ImportError(f"Required package not available: {e}")
+
+ # Ensure model and dataset are loaded
+ if self.model is None or self.tokenizer is None:
+ self.load_model(**kwargs)
+
+ if self.dataset is None:
+ raise ValueError("Dataset must be prepared before training")
+
+ # Apply Unsloth PEFT (optimized LoRA)
+ model = FastLanguageModel.get_peft_model(
+ self.model,
+ r=kwargs.get("lora_r", 16),
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
+ "gate_proj", "up_proj", "down_proj"],
+ lora_alpha=kwargs.get("lora_alpha", 32),
+ lora_dropout=kwargs.get("lora_dropout", 0.1),
+ bias="none",
+ use_gradient_checkpointing=kwargs.get("gradient_checkpointing", True),
+ random_state=3407,
+ use_rslora=False, # Rank stabilized LoRA
+ loftq_config=None, # LoftQ quantization
+ )
+
+ # Progress callback
+ class UnslothProgressCallback(TrainerCallback):
+ """Callback to update global finetuning status during Unsloth training."""
+
+ def __init__(self):
+ super().__init__()
+ from ....globals.globals_instance import global_manager
+ self.global_manager = global_manager
+
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ """Called when logging happens during training."""
+ if state.max_steps <= 0:
+ return
+
+ progress = min(95, int((state.global_step / state.max_steps) * 100))
+ self.global_manager.finetuning_status["progress"] = progress
+ self.global_manager.finetuning_status["message"] = (
+ f"Training step {state.global_step}/{state.max_steps} (Unsloth)"
+ )
+
+ def on_train_end(self, args, state, control, **kwargs):
+ """Called at the end of training."""
+ self.global_manager.finetuning_status["progress"] = 100
+ self.global_manager.finetuning_status["message"] = "Training completed with Unsloth!"
+
+ # Configure training arguments
+ training_args = TrainingArguments(
+ output_dir=self.output_dir or "./model_checkpoints",
+ num_train_epochs=kwargs.get("num_train_epochs", 1),
+ per_device_train_batch_size=kwargs.get("per_device_train_batch_size", 2),
+ gradient_accumulation_steps=kwargs.get("gradient_accumulation_steps", 4),
+ warmup_steps=int(kwargs.get("warmup_ratio", 0.03) * 100), # Approximate
+ learning_rate=kwargs.get("learning_rate", 2e-4),
+ fp16=kwargs.get("fp16", False),
+ bf16=kwargs.get("bf16", False),
+ logging_steps=kwargs.get("logging_steps", 1),
+ optim=kwargs.get("optim", "adamw_8bit"),
+ weight_decay=kwargs.get("weight_decay", 0.01),
+ lr_scheduler_type=kwargs.get("lr_scheduler_type", "linear"),
+ seed=3407,
+ save_steps=kwargs.get("save_steps", 0),
+ max_steps=kwargs.get("max_steps", -1),
+ report_to="tensorboard",
+ logging_dir="./training_logs",
+ )
+
+ # Create SFT trainer with Unsloth optimizations
+ trainer = SFTTrainer(
+ model=model,
+ tokenizer=self.tokenizer,
+ train_dataset=self.dataset,
+ dataset_text_field="text",
+ max_seq_length=kwargs.get("max_seq_length", 2048) if kwargs.get("max_seq_length", 2048) != -1 else 2048,
+ dataset_num_proc=2,
+ packing=kwargs.get("packing", False),
+ args=training_args,
+ callbacks=[UnslothProgressCallback()],
+ )
+
+ # Train with Unsloth optimizations
+ trainer.train()
+
+ # Save model
+ save_path = self.fine_tuned_name or self.output_dir
+ model.save_pretrained(save_path)
+ self.tokenizer.save_pretrained(save_path)
+
+ return save_path
+
+ def export_model(self, output_path: str, **kwargs) -> bool:
+ """
+ Export the fine-tuned Unsloth model.
+
+ Supports multiple export formats including HuggingFace format.
+ Note: GGUF export requires additional Unsloth methods that may not be
+ available in all versions. Currently defaults to HuggingFace format.
+
+ Args:
+ output_path: Path to export the model
+ **kwargs: Export parameters (format, quantization_method, etc.)
+
+ Returns:
+ True if successful
+ """
+ export_format = kwargs.get("export_format", "huggingface")
+
+ if export_format == "huggingface":
+ # Already saved in HuggingFace format during training
+ return True
+ else:
+ # Other export formats (e.g., GGUF) may be supported in future Unsloth versions
+ # For now, return False for unsupported formats
+ print(f"Warning: Export format '{export_format}' is not yet supported. Model saved in HuggingFace format.")
+ return False
+
+ def get_supported_hyperparameters(self) -> List[str]:
+ """
+ Return list of Unsloth-supported hyperparameters.
+
+ Returns:
+ List of hyperparameter names
+ """
+ return [
+ # Standard hyperparameters
+ "num_train_epochs",
+ "lora_r",
+ "lora_alpha",
+ "lora_dropout",
+ "use_4bit",
+ "load_in_4bit",
+ "fp16",
+ "bf16",
+ "per_device_train_batch_size",
+ "gradient_accumulation_steps",
+ "gradient_checkpointing",
+ "learning_rate",
+ "weight_decay",
+ "optim",
+ "lr_scheduler_type",
+ "max_steps",
+ "warmup_ratio",
+ "packing",
+ "max_seq_length",
+ # Unsloth-specific
+ "use_rslora",
+ "use_gradient_checkpointing",
+ ]
+
+ def validate_settings(self, settings: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Validate Unsloth-specific settings.
+
+ Args:
+ settings: Settings dictionary
+
+ Returns:
+ Validated settings
+ """
+ validated = {}
+
+ for key, value in settings.items():
+ if key in self.get_supported_hyperparameters():
+ validated[key] = value
+
+ # Extract output paths if present
+ if "output_dir" in settings:
+ self.output_dir = settings["output_dir"]
+ if "fine_tuned_name" in settings:
+ self.fine_tuned_name = settings["fine_tuned_name"]
+
+ # Unsloth-specific validations
+ if "max_seq_length" in validated:
+ if validated["max_seq_length"] == -1 or validated["max_seq_length"] is None:
+ validated["max_seq_length"] = 2048
+
+ # Unsloth works best with specific optimizers
+ if "optim" in validated:
+ # Map to Unsloth-compatible optimizers
+ optimizer_map = {
+ "paged_adamw_32bit": "adamw_8bit",
+ "paged_adamw_8bit": "adamw_8bit",
+ "adamw_torch": "adamw_torch",
+ "adamw_hf": "adamw_torch",
+ }
+ validated["optim"] = optimizer_map.get(validated["optim"], "adamw_8bit")
+
+ return validated
+
+ @classmethod
+ def get_provider_name(cls) -> str:
+ """Return provider name."""
+ return "unsloth"
+
+ @classmethod
+ def get_provider_description(cls) -> str:
+ """Return provider description."""
+ return "Unsloth AI - 2x faster fine-tuning with reduced memory usage"
+
+ @classmethod
+ def is_available(cls) -> bool:
+ """
+ Check if Unsloth dependencies are available.
+
+ Returns:
+ True if Unsloth is installed
+ """
+ try:
+ import unsloth
+ return True
+ except ImportError:
+ return False
diff --git a/ModelForge/utilities/finetuning/settings_builder.py b/ModelForge/utilities/finetuning/settings_builder.py
index 9eb6e87..85eac3d 100644
--- a/ModelForge/utilities/finetuning/settings_builder.py
+++ b/ModelForge/utilities/finetuning/settings_builder.py
@@ -12,6 +12,7 @@ def __init__(self, task, model_name, compute_profile) -> None:
self.dataset = None
self.compute_profile = compute_profile
self.is_custom_model = False
+ self.provider = "huggingface" # Default provider for backward compatibility
self.lora_r = 16
self.lora_alpha = 32
self.lora_dropout = 0.1
@@ -56,6 +57,8 @@ def set_settings(self, settings_dict) -> None:
for key, value in settings_dict.items():
if key == "dataset":
self.dataset = value
+ elif key == "provider":
+ self.provider = value
elif key == "max_seq_length":
if value == -1:
self.max_seq_length = None
@@ -99,6 +102,7 @@ def get_settings(self) -> Dict[str, Union[str, float]]:
return {
"task": self.task,
"model_name": self.model_name,
+ "provider": self.provider,
"num_train_epochs": self.num_train_epochs,
"compute_specs": self.compute_profile,
"lora_r": self.lora_r,
diff --git a/README.md b/README.md
index 9b95a96..6fbe4ba 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,8 @@
## 🚀 **Features**
- **GPU-Powered Finetuning**: Optimized for NVIDIA GPUs (even 4GB VRAM).
-- **One-Click Workflow**: Upload data → Pick task → Train → Test.
+- **Multiple Providers**: Choose between HuggingFace (standard) or Unsloth AI (2x faster, reduced memory) for fine-tuning.
+- **One-Click Workflow**: Upload data → Pick task → Select provider → Train → Test.
- **Hardware-Aware**: Auto-detects your GPU/CPU and recommends models.
- **React UI**: No CLI or notebooks—just a friendly interface.
@@ -79,6 +80,52 @@ To stop the application and free up resources, press `Ctrl+C` in the terminal ru
{"input": "Enter the poem topic here...", "output": "Roses are red..."}
```
+## 🔧 **Fine-tuning Providers**
+
+ModelForge supports multiple fine-tuning providers, allowing you to choose the best backend for your needs:
+
+### **HuggingFace (Default)**
+- **Description**: Standard fine-tuning using HuggingFace Transformers with PEFT/LoRA
+- **Installation**: Included by default
+- **Best for**: General use, maximum compatibility
+- **Features**: 4-bit/8-bit quantization, gradient checkpointing, standard LoRA
+
+### **Unsloth AI (Optional)**
+- **Description**: Optimized fine-tuning with 2x faster training and reduced memory usage
+- **Installation**:
+ ```bash
+ pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
+ ```
+- **Best for**: Faster training, lower memory usage, larger models on limited hardware
+- **Features**: Optimized LoRA/QLoRA, memory-efficient attention, faster gradient computation
+- **Documentation**: [Unsloth Docs](https://docs.unsloth.ai/)
+
+### **Selecting a Provider**
+
+The provider can be selected when configuring your fine-tuning job:
+
+1. **Via UI**: Choose your provider from the dropdown in the settings page
+2. **Via API**: Include the `provider` field in your settings:
+ ```json
+ {
+ "provider": "unsloth",
+ "task": "text-generation",
+ "model_name": "meta-llama/Llama-2-7b-hf",
+ ...
+ }
+ ```
+
+**Note**: If a provider is not installed, ModelForge will fall back to HuggingFace automatically.
+
+### **Performance Comparison**
+
+| Provider | Training Speed | Memory Usage | Compatibility |
+|------------|---------------|--------------|---------------|
+| HuggingFace | 1x (baseline) | 1x (baseline)| Excellent |
+| Unsloth | ~2x faster | ~30% less | Good |
+
+*Performance metrics may vary based on model size, hardware, and configuration.*
+
## 🤝 **Contributing Model Recommendations**
ModelForge uses a modular configuration system for model recommendations. Contributors can easily add new recommended models by adding configuration files to the `model_configs/` directory. Each hardware profile (low_end, mid_range, high_end) has its own configuration file where you can specify primary and alternative models for different tasks.
diff --git a/pyproject.toml b/pyproject.toml
index 418baaf..6ae60af 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,6 +29,11 @@ dependencies = [
"python-multipart",
]
+[project.optional-dependencies]
+unsloth = [
+ "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git",
+]
+
[tool.setuptools.packages.find]
where = ["."]
include = ["ModelForge*"]